module @jit_mmm attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<4xi32> {mhlo.layout_mode = "default"}, %arg1: tensor<4xi32> {mhlo.layout_mode = "default"}) -> (tensor<4xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { %cst = stablehlo.constant dense<[[[1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00]], [[0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00], [-1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00], [0.000000e+00, 0.000000e+00, -1.000000e+00, 0.000000e+00]], [[0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, -1.000000e+00], [-1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00]], [[0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00], [0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00], [0.000000e+00, -1.000000e+00, 0.000000e+00, 0.000000e+00], [-1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00]]]> : tensor<4x4x4xf32> %0 = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<4xi32>) -> tensor<4x1x1xi32> %1 = stablehlo.convert %0 : (tensor<4x1x1xi32>) -> tensor<4x1x1xf32> %2 = stablehlo.broadcast_in_dim %1, dims = [0, 1, 2] : (tensor<4x1x1xf32>) -> tensor<4x4x4xf32> %3 = stablehlo.multiply %cst, %2 : tensor<4x4x4xf32> %4 = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<4xi32>) -> tensor<1x4x1xi32> %5 = stablehlo.convert %4 : (tensor<1x4x1xi32>) -> tensor<1x4x1xf32> %6 = stablehlo.broadcast_in_dim %5, dims = [0, 1, 2] : (tensor<1x4x1xf32>) -> tensor<4x4x4xf32> %7 = stablehlo.multiply %3, %6 : tensor<4x4x4xf32> %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor %8 = stablehlo.reduce(%7 init: %cst_0) applies stablehlo.add across dimensions = [0, 1] : (tensor<4x4x4xf32>, tensor) -> tensor<4xf32> return %8 : tensor<4xf32> } }