a = jnp.array([1, 0, 0, 0]) # identity quaternion # jaxpr = jax.make_jaxpr(mjx._src.math.quat_mul)(a, a) hlo = jax.jit(mjx._src.math.quat_mul).lower(a, a).as_text() lowered_as_text: module @jit_non_mmm attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<4xi32> {mhlo.layout_mode = "default"}, %arg1: tensor<4xi32> {mhlo.layout_mode = "default"}) -> (tensor<4xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) { %0 = stablehlo.slice %arg0 [0:1] : (tensor<4xi32>) -> tensor<1xi32> %1 = stablehlo.reshape %0 : (tensor<1xi32>) -> tensor %2 = stablehlo.slice %arg1 [0:1] : (tensor<4xi32>) -> tensor<1xi32> %3 = stablehlo.reshape %2 : (tensor<1xi32>) -> tensor %4 = stablehlo.multiply %1, %3 : tensor %5 = stablehlo.slice %arg0 [1:2] : (tensor<4xi32>) -> tensor<1xi32> %6 = stablehlo.reshape %5 : (tensor<1xi32>) -> tensor %7 = stablehlo.slice %arg1 [1:2] : (tensor<4xi32>) -> tensor<1xi32> %8 = stablehlo.reshape %7 : (tensor<1xi32>) -> tensor %9 = stablehlo.multiply %6, %8 : tensor %10 = stablehlo.subtract %4, %9 : tensor %11 = stablehlo.slice %arg0 [2:3] : (tensor<4xi32>) -> tensor<1xi32> %12 = stablehlo.reshape %11 : (tensor<1xi32>) -> tensor %13 = stablehlo.slice %arg1 [2:3] : (tensor<4xi32>) -> tensor<1xi32> %14 = stablehlo.reshape %13 : (tensor<1xi32>) -> tensor %15 = stablehlo.multiply %12, %14 : tensor %16 = stablehlo.subtract %10, %15 : tensor %17 = stablehlo.slice %arg0 [3:4] : (tensor<4xi32>) -> tensor<1xi32> %18 = stablehlo.reshape %17 : (tensor<1xi32>) -> tensor %19 = stablehlo.slice %arg1 [3:4] : (tensor<4xi32>) -> tensor<1xi32> %20 = stablehlo.reshape %19 : (tensor<1xi32>) -> tensor %21 = stablehlo.multiply %18, %20 : tensor %22 = stablehlo.subtract %16, %21 : tensor %23 = stablehlo.slice %arg0 [0:1] : (tensor<4xi32>) -> tensor<1xi32> %24 = stablehlo.reshape %23 : (tensor<1xi32>) -> tensor %25 = stablehlo.slice %arg1 [1:2] : (tensor<4xi32>) -> tensor<1xi32> %26 = stablehlo.reshape %25 : (tensor<1xi32>) -> tensor %27 = stablehlo.multiply %24, %26 : tensor %28 = stablehlo.slice %arg0 [1:2] : (tensor<4xi32>) -> tensor<1xi32> %29 = stablehlo.reshape %28 : (tensor<1xi32>) -> tensor %30 = stablehlo.slice %arg1 [0:1] : (tensor<4xi32>) -> tensor<1xi32> %31 = stablehlo.reshape %30 : (tensor<1xi32>) -> tensor %32 = stablehlo.multiply %29, %31 : tensor %33 = stablehlo.add %27, %32 : tensor %34 = stablehlo.slice %arg0 [2:3] : (tensor<4xi32>) -> tensor<1xi32> %35 = stablehlo.reshape %34 : (tensor<1xi32>) -> tensor %36 = stablehlo.slice %arg1 [3:4] : (tensor<4xi32>) -> tensor<1xi32> %37 = stablehlo.reshape %36 : (tensor<1xi32>) -> tensor %38 = stablehlo.multiply %35, %37 : tensor %39 = stablehlo.add %33, %38 : tensor %40 = stablehlo.slice %arg0 [3:4] : (tensor<4xi32>) -> tensor<1xi32> %41 = stablehlo.reshape %40 : (tensor<1xi32>) -> tensor %42 = stablehlo.slice %arg1 [2:3] : (tensor<4xi32>) -> tensor<1xi32> %43 = stablehlo.reshape %42 : (tensor<1xi32>) -> tensor %44 = stablehlo.multiply %41, %43 : tensor %45 = stablehlo.subtract %39, %44 : tensor %46 = stablehlo.slice %arg0 [0:1] : (tensor<4xi32>) -> tensor<1xi32> %47 = stablehlo.reshape %46 : (tensor<1xi32>) -> tensor %48 = stablehlo.slice %arg1 [2:3] : (tensor<4xi32>) -> tensor<1xi32> %49 = stablehlo.reshape %48 : (tensor<1xi32>) -> tensor %50 = stablehlo.multiply %47, %49 : tensor %51 = stablehlo.slice %arg0 [1:2] : (tensor<4xi32>) -> tensor<1xi32> %52 = stablehlo.reshape %51 : (tensor<1xi32>) -> tensor %53 = stablehlo.slice %arg1 [3:4] : (tensor<4xi32>) -> tensor<1xi32> %54 = stablehlo.reshape %53 : (tensor<1xi32>) -> tensor %55 = stablehlo.multiply %52, %54 : tensor %56 = stablehlo.subtract %50, %55 : tensor %57 = stablehlo.slice %arg0 [2:3] : (tensor<4xi32>) -> tensor<1xi32> %58 = stablehlo.reshape %57 : (tensor<1xi32>) -> tensor %59 = stablehlo.slice %arg1 [0:1] : (tensor<4xi32>) -> tensor<1xi32> %60 = stablehlo.reshape %59 : (tensor<1xi32>) -> tensor %61 = stablehlo.multiply %58, %60 : tensor %62 = stablehlo.add %56, %61 : tensor %63 = stablehlo.slice %arg0 [3:4] : (tensor<4xi32>) -> tensor<1xi32> %64 = stablehlo.reshape %63 : (tensor<1xi32>) -> tensor %65 = stablehlo.slice %arg1 [1:2] : (tensor<4xi32>) -> tensor<1xi32> %66 = stablehlo.reshape %65 : (tensor<1xi32>) -> tensor %67 = stablehlo.multiply %64, %66 : tensor %68 = stablehlo.add %62, %67 : tensor %69 = stablehlo.slice %arg0 [0:1] : (tensor<4xi32>) -> tensor<1xi32> %70 = stablehlo.reshape %69 : (tensor<1xi32>) -> tensor %71 = stablehlo.slice %arg1 [3:4] : (tensor<4xi32>) -> tensor<1xi32> %72 = stablehlo.reshape %71 : (tensor<1xi32>) -> tensor %73 = stablehlo.multiply %70, %72 : tensor %74 = stablehlo.slice %arg0 [1:2] : (tensor<4xi32>) -> tensor<1xi32> %75 = stablehlo.reshape %74 : (tensor<1xi32>) -> tensor %76 = stablehlo.slice %arg1 [2:3] : (tensor<4xi32>) -> tensor<1xi32> %77 = stablehlo.reshape %76 : (tensor<1xi32>) -> tensor %78 = stablehlo.multiply %75, %77 : tensor %79 = stablehlo.add %73, %78 : tensor %80 = stablehlo.slice %arg0 [2:3] : (tensor<4xi32>) -> tensor<1xi32> %81 = stablehlo.reshape %80 : (tensor<1xi32>) -> tensor %82 = stablehlo.slice %arg1 [1:2] : (tensor<4xi32>) -> tensor<1xi32> %83 = stablehlo.reshape %82 : (tensor<1xi32>) -> tensor %84 = stablehlo.multiply %81, %83 : tensor %85 = stablehlo.subtract %79, %84 : tensor %86 = stablehlo.slice %arg0 [3:4] : (tensor<4xi32>) -> tensor<1xi32> %87 = stablehlo.reshape %86 : (tensor<1xi32>) -> tensor %88 = stablehlo.slice %arg1 [0:1] : (tensor<4xi32>) -> tensor<1xi32> %89 = stablehlo.reshape %88 : (tensor<1xi32>) -> tensor %90 = stablehlo.multiply %87, %89 : tensor %91 = stablehlo.add %85, %90 : tensor %92 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<1xi32> %93 = stablehlo.broadcast_in_dim %45, dims = [] : (tensor) -> tensor<1xi32> %94 = stablehlo.broadcast_in_dim %68, dims = [] : (tensor) -> tensor<1xi32> %95 = stablehlo.broadcast_in_dim %91, dims = [] : (tensor) -> tensor<1xi32> %96 = stablehlo.concatenate %92, %93, %94, %95, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> return %96 : tensor<4xi32> } }