output of jax.make_jaxpr(mjx.kinematics)(mjx_m, mjx_d) let _take = { lambda ; a:f64[7,3] b:i64[1,0]. let _:i64[1,0] = pjit[ name=remainder jaxpr={ lambda ; c:i64[1,0] d:i64[]. let e:bool[] = eq d 0 f:i64[] = pjit[ name=_where jaxpr={ lambda ; g:bool[] h:i64[] i:i64[]. let j:i64[] = select_n g i h in (j,) } ] e 1 d k:i64[1,0] = rem c f l:bool[1,0] = ne k 0 m:bool[1,0] = lt k 0 n:bool[] = lt f 0 o:bool[1,0] = ne m n p:bool[1,0] = and o l q:i64[1,0] = add k f r:i64[1,0] = select_n p k q in (r,) } ] b 7 s:f64[1,0,3] = broadcast_in_dim[broadcast_dimensions=() shape=(1, 0, 3)] 0.0 in (s,) } in let _take1 = { lambda ; t:f64[7] u:i64[1,0]. let _:i64[1,0] = pjit[ name=remainder jaxpr={ lambda ; c:i64[1,0] d:i64[]. let e:bool[] = eq d 0 f:i64[] = pjit[ name=_where jaxpr={ lambda ; g:bool[] h:i64[] i:i64[]. let j:i64[] = select_n g i h in (j,) } ] e 1 d k:i64[1,0] = rem c f l:bool[1,0] = ne k 0 m:bool[1,0] = lt k 0 n:bool[] = lt f 0 o:bool[1,0] = ne m n p:bool[1,0] = and o l q:i64[1,0] = add k f r:i64[1,0] = select_n p k q in (r,) } ] u 7 v:f64[1,0] = broadcast_in_dim[broadcast_dimensions=() shape=(1, 0)] 0.0 in (v,) } in let _take2 = { lambda ; w:f64[1,0] x:i64[1]. let y:i64[1] = pjit[ name=remainder jaxpr={ lambda ; z:i64[1] ba:i64[]. let bb:bool[] = eq ba 0 bc:i64[] = pjit[ name=_where jaxpr={ lambda ; g:bool[] h:i64[] i:i64[]. let j:i64[] = select_n g i h in (j,) } ] bb 1 ba bd:i64[1] = rem z bc be:bool[1] = ne bd 0 bf:bool[1] = lt bd 0 bg:bool[] = lt bc 0 bh:bool[1] = ne bf bg bi:bool[1] = and bh be bj:i64[1] = add bd bc bk:i64[1] = select_n bi bd bj in (bk,) } ] x 1 bl:i64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] y bm:f64[1,0] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 0) unique_indices=False ] w bl in (bm,) } in let _take3 = { lambda ; bn:f64[1,0,3] bo:i64[1]. let bp:i64[1] = pjit[ name=remainder jaxpr={ lambda ; z:i64[1] ba:i64[]. let bb:bool[] = eq ba 0 bc:i64[] = pjit[ name=_where jaxpr={ lambda ; g:bool[] h:i64[] i:i64[]. let j:i64[] = select_n g i h in (j,) } ] bb 1 ba bd:i64[1] = rem z bc be:bool[1] = ne bd 0 bf:bool[1] = lt bd 0 bg:bool[] = lt bc 0 bh:bool[1] = ne bf bg bi:bool[1] = and bh be bj:i64[1] = add bd bc bk:i64[1] = select_n bi bd bj in (bk,) } ] bo 1 bq:i64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bp br:f64[1,0,3] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0,)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 0, 3) unique_indices=False ] bn bq in (br,) } in let _take4 = { lambda ; bs:f64[1,3] bt:i64[1]. let bu:i64[1] = pjit[ name=remainder jaxpr={ lambda ; z:i64[1] ba:i64[]. let bb:bool[] = eq ba 0 bc:i64[] = pjit[ name=_where jaxpr={ lambda ; g:bool[] h:i64[] i:i64[]. let j:i64[] = select_n g i h in (j,) } ] bb 1 ba bd:i64[1] = rem z bc be:bool[1] = ne bd 0 bf:bool[1] = lt bd 0 bg:bool[] = lt bc 0 bh:bool[1] = ne bf bg bi:bool[1] = and bh be bj:i64[1] = add bd bc bk:i64[1] = select_n bi bd bj in (bk,) } ] bt 1 bv:i64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bu bw:f64[1,3] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 3) unique_indices=False ] bs bv in (bw,) } in let _take5 = { lambda ; bx:f64[1,4] by:i64[1]. let bz:i64[1] = pjit[ name=remainder jaxpr={ lambda ; z:i64[1] ba:i64[]. let bb:bool[] = eq ba 0 bc:i64[] = pjit[ name=_where jaxpr={ lambda ; g:bool[] h:i64[] i:i64[]. let j:i64[] = select_n g i h in (j,) } ] bb 1 ba bd:i64[1] = rem z bc be:bool[1] = ne bd 0 bf:bool[1] = lt bd 0 bg:bool[] = lt bc 0 bh:bool[1] = ne bf bg bi:bool[1] = and bh be bj:i64[1] = add bd bc bk:i64[1] = select_n bi bd bj in (bk,) } ] by 1 ca:i64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bz cb:f64[1,4] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 4) unique_indices=False ] bx ca in (cb,) } in let _take6 = { lambda ; cc:f64[1,3,3] cd:i64[1]. let ce:i64[1] = pjit[ name=remainder jaxpr={ lambda ; z:i64[1] ba:i64[]. let bb:bool[] = eq ba 0 bc:i64[] = pjit[ name=_where jaxpr={ lambda ; g:bool[] h:i64[] i:i64[]. let j:i64[] = select_n g i h in (j,) } ] bb 1 ba bd:i64[1] = rem z bc be:bool[1] = ne bd 0 bf:bool[1] = lt bd 0 bg:bool[] = lt bc 0 bh:bool[1] = ne bf bg bi:bool[1] = and bh be bj:i64[1] = add bd bc bk:i64[1] = select_n bi bd bj in (bk,) } ] cd 1 cf:i64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ce cg:f64[1,3,3] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0,)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 3, 3) unique_indices=False ] cc cf in (cg,) } in let cross = { lambda ; ch:f64[1,3] ci:f64[1,3]. let cj:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0 ck:f64[1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(np.int64(1),), start_index_map=(np.int64(1),)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=True ] ch cj cl:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1 cm:f64[1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(np.int64(1),), start_index_map=(np.int64(1),)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=True ] ch cl cn:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2 co:f64[1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(np.int64(1),), start_index_map=(np.int64(1),)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=True ] ch cn cp:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0 cq:f64[1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(np.int64(1),), start_index_map=(np.int64(1),)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=True ] ci cp cr:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1 cs:f64[1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(np.int64(1),), start_index_map=(np.int64(1),)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=True ] ci cr ct:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2 cu:f64[1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(np.int64(1),), start_index_map=(np.int64(1),)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=True ] ci ct cv:f64[1] = mul cm cu cw:f64[1] = mul co cs cx:f64[1] = sub cv cw cy:f64[1] = mul co cq cz:f64[1] = mul ck cu da:f64[1] = sub cy cz db:f64[1] = mul ck cs dc:f64[1] = mul cm cq dd:f64[1] = sub db dc de:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cx df:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] da dg:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dd dh:f64[1,3] = concatenate[dimension=1] de df dg in (dh,) } in let _take7 = { lambda ; di:f64[7,3] dj:i64[1,1]. let dk:i64[1,1] = pjit[ name=remainder jaxpr={ lambda ; dl:i64[1,1] dm:i64[]. let dn:bool[] = eq dm 0 do:i64[] = pjit[ name=_where jaxpr={ lambda ; g:bool[] h:i64[] i:i64[]. let j:i64[] = select_n g i h in (j,) } ] dn 1 dm dp:i64[1,1] = rem dl do dq:bool[1,1] = ne dp 0 dr:bool[1,1] = lt dp 0 ds:bool[] = lt do 0 dt:bool[1,1] = ne dr ds du:bool[1,1] = and dt dq dv:i64[1,1] = add dp do dw:i64[1,1] = select_n du dp dv in (dw,) } ] dj 7 dx:i64[1,1,1] = broadcast_in_dim[ broadcast_dimensions=(0, 1) shape=(1, 1, 1) ] dk dy:f64[1,1,3] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(2,), collapsed_slice_dims=(0,), start_index_map=(0,)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 3) unique_indices=False ] di dx in (dy,) } in let _take8 = { lambda ; dz:f64[7] ea:i64[1,1]. let eb:i64[1,1] = pjit[ name=remainder jaxpr={ lambda ; dl:i64[1,1] dm:i64[]. let dn:bool[] = eq dm 0 do:i64[] = pjit[ name=_where jaxpr={ lambda ; g:bool[] h:i64[] i:i64[]. let j:i64[] = select_n g i h in (j,) } ] dn 1 dm dp:i64[1,1] = rem dl do dq:bool[1,1] = ne dp 0 dr:bool[1,1] = lt dp 0 ds:bool[] = lt do 0 dt:bool[1,1] = ne dr ds du:bool[1,1] = and dt dq dv:i64[1,1] = add dp do dw:i64[1,1] = select_n du dp dv in (dw,) } ] ea 7 ec:i64[1,1,1] = broadcast_in_dim[ broadcast_dimensions=(0, 1) shape=(1, 1, 1) ] eb ed:f64[1,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1,) unique_indices=False ] dz ec in (ed,) } in let _where = { lambda ; ee:bool[1] ef:i64[1] eg:i64[1]. let eh:i64[1] = select_n ee eg ef in (eh,) } in let clip = { lambda ; ei:i64[1] ej:i64[] ek:i64[]. let el:i64[1] = max ej ei em:i64[1] = min ek el in (em,) } in let argsort = { lambda ; en:i64[1]. let eo:i64[1] = iota[dimension=0 dtype=int64 shape=(1,)] _:i64[1] ep:i64[1] = sort[dimension=0 is_stable=True num_keys=1] en eo in (ep,) } in let jaxpr = { lambda ; eq:i64[] er:i64[]. let es:i64[] = add eq er in (es,) } in let _cumulative_reduction = { lambda ; et:bool[4]. let eu:i64[4] = convert_element_type[new_dtype=int64 weak_type=False] et ev:i64[4] = cumsum[axis=0 reverse=False] eu in (ev,) } in let clip1 = { lambda ; ew:i64[4] ex:i64[]. let ey:i64[] = convert_element_type[new_dtype=int64 weak_type=False] ex ez:i64[4] = max ey ew in (ez,) } in let _cumulative_reduction1 = { lambda ; fa:i64[3]. let fb:i64[3] = cumsum[axis=0 reverse=False] fa in (fb,) } in let floor_divide = { lambda ; fc:i64[3] fd:i64[]. let fe:i64[3] = div fc fd ff:i64[3] = sign fc fg:i64[] = sign fd fh:bool[3] = ne ff fg fi:i64[3] = rem fc fd fj:bool[3] = ne fi 0 fk:bool[3] = and fh fj fl:i64[3] = sub fe 1 fm:i64[3] = pjit[ name=_where jaxpr={ lambda ; fn:bool[3] fo:i64[3] fp:i64[3]. let fq:i64[3] = select_n fn fp fo in (fq,) } ] fk fl fe in (fm,) } in let remainder = { lambda ; fr:i64[3] fs:i64[]. let ft:i64[] = convert_element_type[new_dtype=int64 weak_type=False] fs fu:bool[] = eq ft 0 fv:i64[] = pjit[ name=_where jaxpr={ lambda ; g:bool[] h:i64[] i:i64[]. let j:i64[] = select_n g i h in (j,) } ] fu 1 ft fw:i64[3] = rem fr fv fx:bool[3] = ne fw 0 fy:bool[3] = lt fw 0 fz:bool[] = lt fv 0 ga:bool[3] = ne fy fz gb:bool[3] = and ga fx gc:i64[3] = add fw fv gd:i64[3] = select_n gb fw gc in (gd,) } in let _take9 = { lambda ; ge:f64[1,1] gf:i64[1]. let gg:i64[1] = pjit[ name=remainder jaxpr={ lambda ; z:i64[1] ba:i64[]. let bb:bool[] = eq ba 0 bc:i64[] = pjit[ name=_where jaxpr={ lambda ; g:bool[] h:i64[] i:i64[]. let j:i64[] = select_n g i h in (j,) } ] bb 1 ba bd:i64[1] = rem z bc be:bool[1] = ne bd 0 bf:bool[1] = lt bd 0 bg:bool[] = lt bc 0 bh:bool[1] = ne bf bg bi:bool[1] = and bh be bj:i64[1] = add bd bc bk:i64[1] = select_n bi bd bj in (bk,) } ] gf 1 gh:i64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] gg gi:f64[1,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=False ] ge gh in (gi,) } in let _take10 = { lambda ; gj:f64[1,1,3] gk:i64[1]. let gl:i64[1] = pjit[ name=remainder jaxpr={ lambda ; z:i64[1] ba:i64[]. let bb:bool[] = eq ba 0 bc:i64[] = pjit[ name=_where jaxpr={ lambda ; g:bool[] h:i64[] i:i64[]. let j:i64[] = select_n g i h in (j,) } ] bb 1 ba bd:i64[1] = rem z bc be:bool[1] = ne bd 0 bf:bool[1] = lt bd 0 bg:bool[] = lt bc 0 bh:bool[1] = ne bf bg bi:bool[1] = and bh be bj:i64[1] = add bd bc bk:i64[1] = select_n bi bd bj in (bk,) } ] gk 1 gm:i64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] gl gn:f64[1,1,3] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0,)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1, 3) unique_indices=False ] gj gm in (gn,) } in let _take11 = { lambda ; go:f64[7,3] gp:i64[7]. let gq:i64[7] = pjit[ name=remainder jaxpr={ lambda ; gr:i64[7] gs:i64[]. let gt:bool[] = eq gs 0 gu:i64[] = pjit[ name=_where jaxpr={ lambda ; g:bool[] h:i64[] i:i64[]. let j:i64[] = select_n g i h in (j,) } ] gt 1 gs gv:i64[7] = rem gr gu gw:bool[7] = ne gv 0 gx:bool[7] = lt gv 0 gy:bool[] = lt gu 0 gz:bool[7] = ne gx gy ha:bool[7] = and gz gw hb:i64[7] = add gv gu hc:i64[7] = select_n ha gv hb in (hc,) } ] gp 7 hd:i64[7,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(7, 1)] gq he:f64[7,3] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 3) unique_indices=False ] go hd in (he,) } in let remainder1 = { lambda ; c:i64[1,0] d:i64[]. let e:bool[] = eq d 0 f:i64[] = pjit[ name=_where jaxpr={ lambda ; g:bool[] h:i64[] i:i64[]. let j:i64[] = select_n g i h in (j,) } ] e 1 d k:i64[1,0] = rem c f l:bool[1,0] = ne k 0 m:bool[1,0] = lt k 0 n:bool[] = lt f 0 o:bool[1,0] = ne m n p:bool[1,0] = and o l q:i64[1,0] = add k f r:i64[1,0] = select_n p k q in (r,) } in let remainder2 = { lambda ; z:i64[1] ba:i64[]. let bb:bool[] = eq ba 0 bc:i64[] = pjit[ name=_where jaxpr={ lambda ; g:bool[] h:i64[] i:i64[]. let j:i64[] = select_n g i h in (j,) } ] bb 1 ba bd:i64[1] = rem z bc be:bool[1] = ne bd 0 bf:bool[1] = lt bd 0 bg:bool[] = lt bc 0 bh:bool[1] = ne bf bg bi:bool[1] = and bh be bj:i64[1] = add bd bc bk:i64[1] = select_n bi bd bj in (bk,) } in let remainder3 = { lambda ; dl:i64[1,1] dm:i64[]. let dn:bool[] = eq dm 0 do:i64[] = pjit[ name=_where jaxpr={ lambda ; g:bool[] h:i64[] i:i64[]. let j:i64[] = select_n g i h in (j,) } ] dn 1 dm dp:i64[1,1] = rem dl do dq:bool[1,1] = ne dp 0 dr:bool[1,1] = lt dp 0 ds:bool[] = lt do 0 dt:bool[1,1] = ne dr ds du:bool[1,1] = and dt dq dv:i64[1,1] = add dp do dw:i64[1,1] = select_n du dp dv in (dw,) } in let _where1 = { lambda ; fn:bool[3] fo:i64[3] fp:i64[3]. let fq:i64[3] = select_n fn fp fo in (fq,) } in let _where2 = { lambda ; g:bool[] h:i64[] i:i64[]. let j:i64[] = select_n g i h in (j,) } in let remainder4 = { lambda ; gr:i64[7] gs:i64[]. let gt:bool[] = eq gs 0 gu:i64[] = pjit[name=_where jaxpr=_where2] gt 1 gs gv:i64[7] = rem gr gu gw:bool[7] = ne gv 0 gx:bool[7] = lt gv 0 gy:bool[] = lt gu 0 gz:bool[7] = ne gx gy ha:bool[7] = and gz gw hb:i64[7] = add gv gu hc:i64[7] = select_n ha gv hb in (hc,) } in let remainder5 = { lambda ; hf:i64[9] hg:i64[]. let hh:bool[] = eq hg 0 hi:i64[] = pjit[name=_where jaxpr=_where2] hh 1 hg hj:i64[9] = rem hf hi hk:bool[9] = ne hj 0 hl:bool[9] = lt hj 0 hm:bool[] = lt hi 0 hn:bool[9] = ne hl hm ho:bool[9] = and hn hk hp:i64[9] = add hj hi hq:i64[9] = select_n ho hj hp in (hq,) } in { lambda hr:i64[1,0] hs:i64[1,0] ht:i64[1,0] hu:i64[1,0] hv:i64[1] hw:i64[1] hx:i64[1,0] hy:i64[1,0] hz:i64[1,0] ia:i64[1,0] ib:i64[1] ic:i64[1,1] id:i64[1,1] ie:i64[1,1] if:i64[1,1] ig:i64[1] ih:i64[1,1] ii:i64[1,1] ij:i64[1,1] ik:i64[1,1] il:i64[1] im:i64[1,1] in:i64[1,1] io:i64[1,1] ip:i64[1,1] iq:i64[1] ir:i64[1,1] is:i64[1,1] it:i64[1,1] iu:i64[1,1] iv:i64[1] iw:i64[1,1] ix:i64[1,1] iy:i64[1,1] iz:i64[1,1] ja:i64[1] jb:i64[1,1] jc:i64[1,1] jd:i64[1,1] je:i64[1,1] jf:i64[1] jg:i64[1,1] jh:i64[1,1] ji:i64[1,1] jj:i64[1,1] jk:i64[7] jl:i64[7] jm:i64[7] jn:i64[9] jo:i64[9] jp:i64[9] jq:i32[61] jr:i32[1]; js:f64[] jt:f64[] ju:f64[] jv:f64[] jw:f64[] jx:f64[] jy:f64[] jz:f64[3] ka:f64[3] kb:f64[3] kc:f64[] kd:f64[] ke:f64[] kf:f64[2] kg:f64[5] kh:f64[5] ki:f64[] kj:f64[] kk:f64[] kl:f64[] km:f64[3] kn:f64[7] ko:f64[7] kp:f64[9,3] kq:f64[9,4] kr:f64[9,3] ks:f64[9,4] kt:f64[9] ku:f64[9] kv:f64[9,3] kw:f64[9] kx:f64[9,2] ky:f64[7,2] kz:f64[7,5] la:f64[7,3] lb:f64[7,3] lc:f64[7] ld:f64[7,2] le:f64[7,2] lf:f64[7] lg:f64[7,2] lh:f64[7,5] li:f64[7] lj:f64[7] lk:f64[7] ll:f64[7] lm:f64[7] ln:f64[61] lo:f64[61,2] lp:f64[61,5] lq:f64[61,3] lr:f64[61] ls:f64[61,3] lt:f64[61,4] lu:f64[61,3] lv:f64[61] lw:f64[61] lx:f64[1,3] ly:f64[1,4] lz:f64[0,3] ma:f64[0,4] mb:f64[0,3] mc:f64[0,3] md:f64[0,3,3] me:f32[0] mf:f64[0,2] mg:f64[0,2] mh:f64[0,5] mi:f64[0] mj:f64[0] mk:f64[0,5] ml:f64[0,2] mm:f64[0,5] mn:f64[0,11] mo:f64[0,2] mp:f64[0,5] mq:f64[0,2] mr:f64[0,5] ms:f64[0,2] mt:f64[0] mu:f64[0] mv:f64[0] mw:f64[0] mx:f64[0,2] my:f64[0] mz:f64[0] na:f64[7,10] nb:f64[7,10] nc:f64[7,10] nd:f64[7,2] ne:f64[7,2] nf:f64[7,2] ng:f64[7,6] nh:i64[80] ni:i64[] nj:f64[] nk:f64[7] nl:f64[7] nm:f64[0] nn:f64[7] no:f64[7] np:f64[7] nq:f64[9,6] nr:u8[0] ns:f64[0] nt:f64[0] nu:f64[7] nv:f64[0] nw:f64[0] nx:f64[0] ny:f64[9,3] nz:f64[9,4] oa:f64[9,3,3] ob:f64[9,3] oc:f64[9,3,3] od:f64[7,3] oe:f64[7,3] of:f64[61,3] og:f64[61,3,3] oh:f64[1,3] oi:f64[1,3,3] oj:f64[0,3] ok:f64[0,3,3] ol:f64[0] om:f64[0] on:f64[9,3] oo:f64[7,6] op:f64[9,10] oq:f64[0] or:f64[0,6] os:i32[0] ot:i32[0] ou:i32[0] ov:f64[0] ow:f64[0] ox:i32[0] oy:i32[0] oz:i32[0] pa:i32[0] pb:i32[0] pc:f64[0,7] pd:f64[0] pe:i32[0] pf:f64[0] pg:f64[7] ph:f64[7,7] pi:f64[9,10] pj:f64[7,7] pk:f64[7,7] pl:f64[0] pm:f64[7] pn:f64[0] po:u8[0] pp:f64[0] pq:f64[0] pr:f64[7] ps:f64[9,6] pt:f64[7,6] pu:f64[7] pv:f64[0] pw:f64[0] px:f64[7] py:f64[7] pz:f64[7] qa:f64[9,3] qb:f64[9,3] qc:f64[0] qd:f64[0] qe:i32[0] qf:i32[0] qg:i32[0] qh:i32[0] qi:i32[0] qj:i32[0] qk:f64[0] ql:f64[0] qm:f64[7] qn:f64[7] qo:f64[7] qp:f64[7] qq:f64[7] qr:f64[7] qs:f64[9,6] qt:f64[9,6] qu:f64[9,6] qv:f64[396] qw:f64[396,3] qx:f64[396,3,3] qy:f64[396] qz:f64[396,5] ra:f64[396,2] rb:f64[396,2] rc:f64[396,5] rd:i32[396] re:i32[396] rf:i32[396,2] rg:i64[1591] rh:f64[1591,7] ri:f64[1591] rj:f64[1591] rk:f64[1591] rl:f64[1591] rm:f64[1591] rn:f64[1591] ro:f64[0] rp:f64[0] rq:f64[0]. let _:f64[1,0,3] = pjit[name=_take jaxpr=_take] la hr _:f64[1,0,3] = pjit[name=_take jaxpr=_take] lb hs rr:f64[1,0] = pjit[name=_take jaxpr=_take1] nk ht _:f64[1,0] = pjit[name=_take jaxpr=_take1] kn hu rs:f64[1,3] = pjit[ name=_take jaxpr={ lambda ; rt:f64[9,3] ru:i64[1]. let rv:i64[1] = pjit[name=remainder jaxpr=remainder2] ru 9 rw:i64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] rv rx:f64[1,3] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 3) unique_indices=False ] rt rw in (rx,) } ] kp hv ry:f64[1,4] = pjit[ name=_take jaxpr={ lambda ; rz:f64[9,4] sa:i64[1]. let sb:i64[1] = pjit[name=remainder jaxpr=remainder2] sa 9 sc:i64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] sb sd:f64[1,4] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 4) unique_indices=False ] rz sc in (sd,) } ] kq hv se:f64[0,3] = broadcast_in_dim[broadcast_dimensions=() shape=(0, 3)] 0.0 sf:f64[0,3] = broadcast_in_dim[broadcast_dimensions=() shape=(0, 3)] 0.0 sg:f64[1,4,1] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(1)) shape=(1, 4, 1) ] ry sh:f64[1,1,4] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 4) ] ry si:f64[1,4,4] = mul sg sh sj:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] si sk:f64[1] = squeeze[dimensions=(1, 2)] sj sl:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] si sm:f64[1] = squeeze[dimensions=(1, 2)] sl sn:f64[1] = add sk sm so:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] si sp:f64[1] = squeeze[dimensions=(1, 2)] so sq:f64[1] = sub sn sp sr:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] si ss:f64[1] = squeeze[dimensions=(1, 2)] sr st:f64[1] = sub sq ss su:f64[1,1,1] = slice[ limit_indices=(1, 2, 3) start_indices=(0, 1, 2) strides=None ] si sv:f64[1] = squeeze[dimensions=(1, 2)] su sw:f64[1,1,1] = slice[ limit_indices=(1, 1, 4) start_indices=(0, 0, 3) strides=None ] si sx:f64[1] = squeeze[dimensions=(1, 2)] sw sy:f64[1] = sub sv sx sz:f64[1] = mul 2.0 sy ta:f64[1,1,1] = slice[ limit_indices=(1, 2, 4) start_indices=(0, 1, 3) strides=None ] si tb:f64[1] = squeeze[dimensions=(1, 2)] ta tc:f64[1,1,1] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 2) strides=None ] si td:f64[1] = squeeze[dimensions=(1, 2)] tc te:f64[1] = add tb td tf:f64[1] = mul 2.0 te tg:f64[1,1,1] = slice[ limit_indices=(1, 2, 3) start_indices=(0, 1, 2) strides=None ] si th:f64[1] = squeeze[dimensions=(1, 2)] tg ti:f64[1,1,1] = slice[ limit_indices=(1, 1, 4) start_indices=(0, 0, 3) strides=None ] si tj:f64[1] = squeeze[dimensions=(1, 2)] ti tk:f64[1] = add th tj tl:f64[1] = mul 2.0 tk tm:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] si tn:f64[1] = squeeze[dimensions=(1, 2)] tm to:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] si tp:f64[1] = squeeze[dimensions=(1, 2)] to tq:f64[1] = sub tn tp tr:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] si ts:f64[1] = squeeze[dimensions=(1, 2)] tr tt:f64[1] = add tq ts tu:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] si tv:f64[1] = squeeze[dimensions=(1, 2)] tu tw:f64[1] = sub tt tv tx:f64[1,1,1] = slice[ limit_indices=(1, 3, 4) start_indices=(0, 2, 3) strides=None ] si ty:f64[1] = squeeze[dimensions=(1, 2)] tx tz:f64[1,1,1] = slice[ limit_indices=(1, 1, 2) start_indices=(0, 0, 1) strides=None ] si ua:f64[1] = squeeze[dimensions=(1, 2)] tz ub:f64[1] = sub ty ua uc:f64[1] = mul 2.0 ub ud:f64[1,1,1] = slice[ limit_indices=(1, 2, 4) start_indices=(0, 1, 3) strides=None ] si ue:f64[1] = squeeze[dimensions=(1, 2)] ud uf:f64[1,1,1] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 2) strides=None ] si ug:f64[1] = squeeze[dimensions=(1, 2)] uf uh:f64[1] = sub ue ug ui:f64[1] = mul 2.0 uh uj:f64[1,1,1] = slice[ limit_indices=(1, 3, 4) start_indices=(0, 2, 3) strides=None ] si uk:f64[1] = squeeze[dimensions=(1, 2)] uj ul:f64[1,1,1] = slice[ limit_indices=(1, 1, 2) start_indices=(0, 0, 1) strides=None ] si um:f64[1] = squeeze[dimensions=(1, 2)] ul un:f64[1] = add uk um uo:f64[1] = mul 2.0 un up:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] si uq:f64[1] = squeeze[dimensions=(1, 2)] up ur:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] si us:f64[1] = squeeze[dimensions=(1, 2)] ur ut:f64[1] = sub uq us uu:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] si uv:f64[1] = squeeze[dimensions=(1, 2)] uu uw:f64[1] = sub ut uv ux:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] si uy:f64[1] = squeeze[dimensions=(1, 2)] ux uz:f64[1] = add uw uy va:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] st vb:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] sz vc:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] tf vd:f64[1,3] = concatenate[dimension=1] va vb vc ve:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] tl vf:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] tw vg:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] uc vh:f64[1,3] = concatenate[dimension=1] ve vf vg vi:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ui vj:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] uo vk:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] uz vl:f64[1,3] = concatenate[dimension=1] vi vj vk vm:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] vd vn:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] vh vo:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] vl vp:f64[1,3,3] = concatenate[dimension=1] vm vn vo vq:f64[1,0,3] = broadcast_in_dim[ broadcast_dimensions=(np.int64(1), np.int64(2)) shape=(1, 0, 3) ] se vr:f64[1,0,3] = broadcast_in_dim[ broadcast_dimensions=(np.int64(1), np.int64(2)) shape=(1, 0, 3) ] sf _:f64[1,0] = pjit[name=_take jaxpr=_take2] rr hw _:f64[1,0,3] = pjit[name=_take jaxpr=_take3] vq hw _:f64[1,0,3] = pjit[name=_take jaxpr=_take3] vr hw vs:f64[1,3] = pjit[name=_take jaxpr=_take4] rs hw vt:f64[1,4] = pjit[name=_take jaxpr=_take5] ry hw _:f64[1,3,3] = pjit[name=_take jaxpr=_take6] vp hw _:f64[1,0,3] = pjit[name=_take jaxpr=_take] la hx _:f64[1,0,3] = pjit[name=_take jaxpr=_take] lb hy vu:f64[1,0] = pjit[name=_take jaxpr=_take1] nk hz _:f64[1,0] = pjit[name=_take jaxpr=_take1] kn ia vv:f64[1,3] = slice[limit_indices=(2, 3) start_indices=(1, 0) strides=None] kp vw:f64[1,4] = slice[limit_indices=(2, 4) start_indices=(1, 0) strides=None] kq vx:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] vt vy:f64[1] = squeeze[dimensions=(1,)] vx vz:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] vt wa:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] vz vv wb:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] wa wc:f64[1,3] = mul wb vz wd:f64[1,3] = mul 2.0 wc we:f64[1] = mul vy vy wf:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] vz vz wg:f64[1] = sub we wf wh:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] wg wi:f64[1,3] = mul wh vv wj:f64[1,3] = add wd wi wk:f64[1] = mul 2.0 vy wl:f64[1,3] = pjit[name=cross jaxpr=cross] vz vv wm:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] wk wn:f64[1,3] = mul wm wl wo:f64[1,3] = add wj wn wp:f64[1,3] = add vs wo wq:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] vt wr:f64[1] = squeeze[dimensions=(1,)] wq ws:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] vw wt:f64[1] = squeeze[dimensions=(1,)] ws wu:f64[1] = mul wr wt wv:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] vt ww:f64[1] = squeeze[dimensions=(1,)] wv wx:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] vw wy:f64[1] = squeeze[dimensions=(1,)] wx wz:f64[1] = mul ww wy xa:f64[1] = sub wu wz xb:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] vt xc:f64[1] = squeeze[dimensions=(1,)] xb xd:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] vw xe:f64[1] = squeeze[dimensions=(1,)] xd xf:f64[1] = mul xc xe xg:f64[1] = sub xa xf xh:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] vt xi:f64[1] = squeeze[dimensions=(1,)] xh xj:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] vw xk:f64[1] = squeeze[dimensions=(1,)] xj xl:f64[1] = mul xi xk xm:f64[1] = sub xg xl xn:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] vt xo:f64[1] = squeeze[dimensions=(1,)] xn xp:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] vw xq:f64[1] = squeeze[dimensions=(1,)] xp xr:f64[1] = mul xo xq xs:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] vt xt:f64[1] = squeeze[dimensions=(1,)] xs xu:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] vw xv:f64[1] = squeeze[dimensions=(1,)] xu xw:f64[1] = mul xt xv xx:f64[1] = add xr xw xy:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] vt xz:f64[1] = squeeze[dimensions=(1,)] xy ya:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] vw yb:f64[1] = squeeze[dimensions=(1,)] ya yc:f64[1] = mul xz yb yd:f64[1] = add xx yc ye:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] vt yf:f64[1] = squeeze[dimensions=(1,)] ye yg:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] vw yh:f64[1] = squeeze[dimensions=(1,)] yg yi:f64[1] = mul yf yh yj:f64[1] = sub yd yi yk:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] vt yl:f64[1] = squeeze[dimensions=(1,)] yk ym:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] vw yn:f64[1] = squeeze[dimensions=(1,)] ym yo:f64[1] = mul yl yn yp:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] vt yq:f64[1] = squeeze[dimensions=(1,)] yp yr:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] vw ys:f64[1] = squeeze[dimensions=(1,)] yr yt:f64[1] = mul yq ys yu:f64[1] = sub yo yt yv:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] vt yw:f64[1] = squeeze[dimensions=(1,)] yv yx:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] vw yy:f64[1] = squeeze[dimensions=(1,)] yx yz:f64[1] = mul yw yy za:f64[1] = add yu yz zb:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] vt zc:f64[1] = squeeze[dimensions=(1,)] zb zd:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] vw ze:f64[1] = squeeze[dimensions=(1,)] zd zf:f64[1] = mul zc ze zg:f64[1] = add za zf zh:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] vt zi:f64[1] = squeeze[dimensions=(1,)] zh zj:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] vw zk:f64[1] = squeeze[dimensions=(1,)] zj zl:f64[1] = mul zi zk zm:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] vt zn:f64[1] = squeeze[dimensions=(1,)] zm zo:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] vw zp:f64[1] = squeeze[dimensions=(1,)] zo zq:f64[1] = mul zn zp zr:f64[1] = add zl zq zs:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] vt zt:f64[1] = squeeze[dimensions=(1,)] zs zu:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] vw zv:f64[1] = squeeze[dimensions=(1,)] zu zw:f64[1] = mul zt zv zx:f64[1] = sub zr zw zy:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] vt zz:f64[1] = squeeze[dimensions=(1,)] zy baa:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] vw bab:f64[1] = squeeze[dimensions=(1,)] baa bac:f64[1] = mul zz bab bad:f64[1] = add zx bac bae:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] xm baf:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] yj bag:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] zg bah:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bad bai:f64[1,4] = concatenate[dimension=1] bae baf bag bah baj:f64[0,3] = broadcast_in_dim[broadcast_dimensions=() shape=(0, 3)] 0.0 bak:f64[0,3] = broadcast_in_dim[broadcast_dimensions=() shape=(0, 3)] 0.0 bal:f64[1,4,1] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(1)) shape=(1, 4, 1) ] bai bam:f64[1,1,4] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 4) ] bai ban:f64[1,4,4] = mul bal bam bao:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] ban bap:f64[1] = squeeze[dimensions=(1, 2)] bao baq:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] ban bar:f64[1] = squeeze[dimensions=(1, 2)] baq bas:f64[1] = add bap bar bat:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] ban bau:f64[1] = squeeze[dimensions=(1, 2)] bat bav:f64[1] = sub bas bau baw:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] ban bax:f64[1] = squeeze[dimensions=(1, 2)] baw bay:f64[1] = sub bav bax baz:f64[1,1,1] = slice[ limit_indices=(1, 2, 3) start_indices=(0, 1, 2) strides=None ] ban bba:f64[1] = squeeze[dimensions=(1, 2)] baz bbb:f64[1,1,1] = slice[ limit_indices=(1, 1, 4) start_indices=(0, 0, 3) strides=None ] ban bbc:f64[1] = squeeze[dimensions=(1, 2)] bbb bbd:f64[1] = sub bba bbc bbe:f64[1] = mul 2.0 bbd bbf:f64[1,1,1] = slice[ limit_indices=(1, 2, 4) start_indices=(0, 1, 3) strides=None ] ban bbg:f64[1] = squeeze[dimensions=(1, 2)] bbf bbh:f64[1,1,1] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 2) strides=None ] ban bbi:f64[1] = squeeze[dimensions=(1, 2)] bbh bbj:f64[1] = add bbg bbi bbk:f64[1] = mul 2.0 bbj bbl:f64[1,1,1] = slice[ limit_indices=(1, 2, 3) start_indices=(0, 1, 2) strides=None ] ban bbm:f64[1] = squeeze[dimensions=(1, 2)] bbl bbn:f64[1,1,1] = slice[ limit_indices=(1, 1, 4) start_indices=(0, 0, 3) strides=None ] ban bbo:f64[1] = squeeze[dimensions=(1, 2)] bbn bbp:f64[1] = add bbm bbo bbq:f64[1] = mul 2.0 bbp bbr:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] ban bbs:f64[1] = squeeze[dimensions=(1, 2)] bbr bbt:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] ban bbu:f64[1] = squeeze[dimensions=(1, 2)] bbt bbv:f64[1] = sub bbs bbu bbw:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] ban bbx:f64[1] = squeeze[dimensions=(1, 2)] bbw bby:f64[1] = add bbv bbx bbz:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] ban bca:f64[1] = squeeze[dimensions=(1, 2)] bbz bcb:f64[1] = sub bby bca bcc:f64[1,1,1] = slice[ limit_indices=(1, 3, 4) start_indices=(0, 2, 3) strides=None ] ban bcd:f64[1] = squeeze[dimensions=(1, 2)] bcc bce:f64[1,1,1] = slice[ limit_indices=(1, 1, 2) start_indices=(0, 0, 1) strides=None ] ban bcf:f64[1] = squeeze[dimensions=(1, 2)] bce bcg:f64[1] = sub bcd bcf bch:f64[1] = mul 2.0 bcg bci:f64[1,1,1] = slice[ limit_indices=(1, 2, 4) start_indices=(0, 1, 3) strides=None ] ban bcj:f64[1] = squeeze[dimensions=(1, 2)] bci bck:f64[1,1,1] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 2) strides=None ] ban bcl:f64[1] = squeeze[dimensions=(1, 2)] bck bcm:f64[1] = sub bcj bcl bcn:f64[1] = mul 2.0 bcm bco:f64[1,1,1] = slice[ limit_indices=(1, 3, 4) start_indices=(0, 2, 3) strides=None ] ban bcp:f64[1] = squeeze[dimensions=(1, 2)] bco bcq:f64[1,1,1] = slice[ limit_indices=(1, 1, 2) start_indices=(0, 0, 1) strides=None ] ban bcr:f64[1] = squeeze[dimensions=(1, 2)] bcq bcs:f64[1] = add bcp bcr bct:f64[1] = mul 2.0 bcs bcu:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] ban bcv:f64[1] = squeeze[dimensions=(1, 2)] bcu bcw:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] ban bcx:f64[1] = squeeze[dimensions=(1, 2)] bcw bcy:f64[1] = sub bcv bcx bcz:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] ban bda:f64[1] = squeeze[dimensions=(1, 2)] bcz bdb:f64[1] = sub bcy bda bdc:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] ban bdd:f64[1] = squeeze[dimensions=(1, 2)] bdc bde:f64[1] = add bdb bdd bdf:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bay bdg:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bbe bdh:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bbk bdi:f64[1,3] = concatenate[dimension=1] bdf bdg bdh bdj:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bbq bdk:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bcb bdl:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bch bdm:f64[1,3] = concatenate[dimension=1] bdj bdk bdl bdn:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bcn bdo:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bct bdp:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bde bdq:f64[1,3] = concatenate[dimension=1] bdn bdo bdp bdr:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] bdi bds:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] bdm bdt:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] bdq bdu:f64[1,3,3] = concatenate[dimension=1] bdr bds bdt bdv:f64[1,0,3] = broadcast_in_dim[ broadcast_dimensions=(np.int64(1), np.int64(2)) shape=(1, 0, 3) ] baj bdw:f64[1,0,3] = broadcast_in_dim[ broadcast_dimensions=(np.int64(1), np.int64(2)) shape=(1, 0, 3) ] bak _:f64[1,0] = pjit[name=_take jaxpr=_take2] vu ib _:f64[1,0,3] = pjit[name=_take jaxpr=_take3] bdv ib _:f64[1,0,3] = pjit[name=_take jaxpr=_take3] bdw ib bdx:f64[1,3] = pjit[name=_take jaxpr=_take4] wp ib bdy:f64[1,4] = pjit[name=_take jaxpr=_take5] bai ib _:f64[1,3,3] = pjit[name=_take jaxpr=_take6] bdu ib bdz:f64[1,1,3] = pjit[name=_take jaxpr=_take7] la ic bea:f64[1,1,3] = pjit[name=_take jaxpr=_take7] lb id beb:f64[1,1] = pjit[name=_take jaxpr=_take8] nk ie bec:f64[1,1] = pjit[name=_take jaxpr=_take8] kn if bed:f64[1,3] = slice[limit_indices=(3, 3) start_indices=(2, 0) strides=None] kp bee:f64[1,4] = slice[limit_indices=(3, 4) start_indices=(2, 0) strides=None] kq bef:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bdy beg:f64[1] = squeeze[dimensions=(1,)] bef beh:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] bdy bei:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] beh bed bej:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bei bek:f64[1,3] = mul bej beh bel:f64[1,3] = mul 2.0 bek bem:f64[1] = mul beg beg ben:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] beh beh beo:f64[1] = sub bem ben bep:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] beo beq:f64[1,3] = mul bep bed ber:f64[1,3] = add bel beq bes:f64[1] = mul 2.0 beg bet:f64[1,3] = pjit[name=cross jaxpr=cross] beh bed beu:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bes bev:f64[1,3] = mul beu bet bew:f64[1,3] = add ber bev bex:f64[1,3] = add bdx bew bey:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bdy bez:f64[1] = squeeze[dimensions=(1,)] bey bfa:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bee bfb:f64[1] = squeeze[dimensions=(1,)] bfa bfc:f64[1] = mul bez bfb bfd:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bdy bfe:f64[1] = squeeze[dimensions=(1,)] bfd bff:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bee bfg:f64[1] = squeeze[dimensions=(1,)] bff bfh:f64[1] = mul bfe bfg bfi:f64[1] = sub bfc bfh bfj:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bdy bfk:f64[1] = squeeze[dimensions=(1,)] bfj bfl:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bee bfm:f64[1] = squeeze[dimensions=(1,)] bfl bfn:f64[1] = mul bfk bfm bfo:f64[1] = sub bfi bfn bfp:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bdy bfq:f64[1] = squeeze[dimensions=(1,)] bfp bfr:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bee bfs:f64[1] = squeeze[dimensions=(1,)] bfr bft:f64[1] = mul bfq bfs bfu:f64[1] = sub bfo bft bfv:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bdy bfw:f64[1] = squeeze[dimensions=(1,)] bfv bfx:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bee bfy:f64[1] = squeeze[dimensions=(1,)] bfx bfz:f64[1] = mul bfw bfy bga:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bdy bgb:f64[1] = squeeze[dimensions=(1,)] bga bgc:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bee bgd:f64[1] = squeeze[dimensions=(1,)] bgc bge:f64[1] = mul bgb bgd bgf:f64[1] = add bfz bge bgg:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bdy bgh:f64[1] = squeeze[dimensions=(1,)] bgg bgi:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bee bgj:f64[1] = squeeze[dimensions=(1,)] bgi bgk:f64[1] = mul bgh bgj bgl:f64[1] = add bgf bgk bgm:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bdy bgn:f64[1] = squeeze[dimensions=(1,)] bgm bgo:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bee bgp:f64[1] = squeeze[dimensions=(1,)] bgo bgq:f64[1] = mul bgn bgp bgr:f64[1] = sub bgl bgq bgs:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bdy bgt:f64[1] = squeeze[dimensions=(1,)] bgs bgu:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bee bgv:f64[1] = squeeze[dimensions=(1,)] bgu bgw:f64[1] = mul bgt bgv bgx:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bdy bgy:f64[1] = squeeze[dimensions=(1,)] bgx bgz:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bee bha:f64[1] = squeeze[dimensions=(1,)] bgz bhb:f64[1] = mul bgy bha bhc:f64[1] = sub bgw bhb bhd:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bdy bhe:f64[1] = squeeze[dimensions=(1,)] bhd bhf:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bee bhg:f64[1] = squeeze[dimensions=(1,)] bhf bhh:f64[1] = mul bhe bhg bhi:f64[1] = add bhc bhh bhj:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bdy bhk:f64[1] = squeeze[dimensions=(1,)] bhj bhl:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bee bhm:f64[1] = squeeze[dimensions=(1,)] bhl bhn:f64[1] = mul bhk bhm bho:f64[1] = add bhi bhn bhp:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bdy bhq:f64[1] = squeeze[dimensions=(1,)] bhp bhr:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bee bhs:f64[1] = squeeze[dimensions=(1,)] bhr bht:f64[1] = mul bhq bhs bhu:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bdy bhv:f64[1] = squeeze[dimensions=(1,)] bhu bhw:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bee bhx:f64[1] = squeeze[dimensions=(1,)] bhw bhy:f64[1] = mul bhv bhx bhz:f64[1] = add bht bhy bia:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bdy bib:f64[1] = squeeze[dimensions=(1,)] bia bic:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bee bid:f64[1] = squeeze[dimensions=(1,)] bic bie:f64[1] = mul bib bid bif:f64[1] = sub bhz bie big:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bdy bih:f64[1] = squeeze[dimensions=(1,)] big bii:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bee bij:f64[1] = squeeze[dimensions=(1,)] bii bik:f64[1] = mul bih bij bil:f64[1] = add bif bik bim:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bfu bin:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bgr bio:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bho bip:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bil biq:f64[1,4] = concatenate[dimension=1] bim bin bio bip bir:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] bdz bis:f64[1,3] = squeeze[dimensions=(1,)] bir bit:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] biq biu:f64[1] = squeeze[dimensions=(1,)] bit biv:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] biq biw:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] biv bis bix:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] biw biy:f64[1,3] = mul bix biv biz:f64[1,3] = mul 2.0 biy bja:f64[1] = mul biu biu bjb:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] biv biv bjc:f64[1] = sub bja bjb bjd:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bjc bje:f64[1,3] = mul bjd bis bjf:f64[1,3] = add biz bje bjg:f64[1] = mul 2.0 biu bjh:f64[1,3] = pjit[name=cross jaxpr=cross] biv bis bji:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bjg bjj:f64[1,3] = mul bji bjh bjk:f64[1,3] = add bjf bjj bjl:f64[1,3] = add bjk bex bjm:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] bea bjn:f64[1,3] = squeeze[dimensions=(1,)] bjm bjo:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] biq bjp:f64[1] = squeeze[dimensions=(1,)] bjo bjq:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] biq bjr:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] bjq bjn bjs:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bjr bjt:f64[1,3] = mul bjs bjq bju:f64[1,3] = mul 2.0 bjt bjv:f64[1] = mul bjp bjp bjw:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] bjq bjq bjx:f64[1] = sub bjv bjw bjy:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bjx bjz:f64[1,3] = mul bjy bjn bka:f64[1,3] = add bju bjz bkb:f64[1] = mul 2.0 bjp bkc:f64[1,3] = pjit[name=cross jaxpr=cross] bjq bjn bkd:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bkb bke:f64[1,3] = mul bkd bkc bkf:f64[1,3] = add bka bke bkg:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] beb bkh:f64[1] = squeeze[dimensions=(1,)] bkg bki:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bec bkj:f64[1] = squeeze[dimensions=(1,)] bki bkk:f64[1] = sub bkh bkj bkl:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] bea bkm:f64[1,3] = squeeze[dimensions=(1,)] bkl bkn:f64[1] = mul bkk 0.5 bko:f64[1] = sin bkn bkp:f64[1] = mul bkk 0.5 bkq:f64[1] = cos bkp bkr:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bko bks:f64[1,3] = mul bkm bkr bkt:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bkq bku:i64[1] = reshape[dimensions=None new_sizes=(1,)] 0 bkv:i64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] bku bkw:i64[] = squeeze[dimensions=(0,)] bkv bkx:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] bkw bky:f64[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] 0.0 bkz:bool[1] = lt bkx 0 bla:i64[1] = add bkx 3 blb:i64[1] = pjit[name=_where jaxpr=_where] bkz bla bkx blc:i64[1] = pjit[name=clip jaxpr=clip] blb 0 3 bld:i64[1] = pjit[name=argsort jaxpr=argsort] blc ble:i64[1] = iota[dimension=0 dtype=int64 shape=(1,)] blf:bool[1] = lt bld 0 blg:i64[1] = add bld 1 blh:i64[1] = select_n blf bld blg bli:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] blh blj:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bli blk:i64[1] = convert_element_type[new_dtype=int64 weak_type=False] blc bll:i64[1] = scatter-add[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=jaxpr ] blk blj ble blm:i64[1] = convert_element_type[new_dtype=int64 weak_type=True] bll bln:bool[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] True blo:bool[1] = lt blm 0 blp:i64[1] = add blm 4 blq:i64[1] = select_n blo blm blp blr:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] blq bls:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] blr blt:bool[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] False blu:bool[4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] bln bls blt blv:i64[4] = pjit[name=_cumulative_reduction jaxpr=_cumulative_reduction] blu blw:i64[3] = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] 0 blx:i64[4] = pjit[name=clip jaxpr=clip1] blv 0 bly:i64[] = device_put[devices=[None] srcs=[None]] 1 blz:bool[4] = lt blx 0 bma:i64[4] = add blx 3 bmb:i64[4] = select_n blz blx bma bmc:i32[4] = convert_element_type[new_dtype=int32 weak_type=False] bmb bmd:i32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4, 1)] bmc bme:i64[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] bly bmf:i64[3] = scatter-add[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=jaxpr ] blw bmd bme bmg:i64[3] = pjit[name=_cumulative_reduction jaxpr=_cumulative_reduction1] bmf bmh:i64[3] = pjit[name=floor_divide jaxpr=floor_divide] bmg 1 bmi:i64[3] = pjit[name=remainder jaxpr=remainder] bmh 4 bmj:bool[1] = lt blm 0 bmk:i64[1] = add blm 4 bml:i64[1] = select_n bmj blm bmk bmm:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] bml bmn:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bmm bmo:f64[1,4] = broadcast_in_dim[ broadcast_dimensions=(np.int64(1),) shape=(1, 4) ] bky bmp:f64[1,4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(np.int64(1),), scatter_dims_to_operand_dims=(np.int64(1),)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] bmo bmn bkt bmq:bool[3] = lt bmi 0 bmr:i64[3] = add bmi 4 bms:i64[3] = select_n bmq bmi bmr bmt:i32[3] = convert_element_type[new_dtype=int32 weak_type=False] bms bmu:i32[3,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(3, 1)] bmt bmv:f64[1,4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(np.int64(1),), scatter_dims_to_operand_dims=(np.int64(1),)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] bmp bmu bks bmw:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] biq bmx:f64[1] = squeeze[dimensions=(1,)] bmw bmy:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bmv bmz:f64[1] = squeeze[dimensions=(1,)] bmy bna:f64[1] = mul bmx bmz bnb:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] biq bnc:f64[1] = squeeze[dimensions=(1,)] bnb bnd:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bmv bne:f64[1] = squeeze[dimensions=(1,)] bnd bnf:f64[1] = mul bnc bne bng:f64[1] = sub bna bnf bnh:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] biq bni:f64[1] = squeeze[dimensions=(1,)] bnh bnj:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bmv bnk:f64[1] = squeeze[dimensions=(1,)] bnj bnl:f64[1] = mul bni bnk bnm:f64[1] = sub bng bnl bnn:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] biq bno:f64[1] = squeeze[dimensions=(1,)] bnn bnp:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bmv bnq:f64[1] = squeeze[dimensions=(1,)] bnp bnr:f64[1] = mul bno bnq bns:f64[1] = sub bnm bnr bnt:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] biq bnu:f64[1] = squeeze[dimensions=(1,)] bnt bnv:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bmv bnw:f64[1] = squeeze[dimensions=(1,)] bnv bnx:f64[1] = mul bnu bnw bny:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] biq bnz:f64[1] = squeeze[dimensions=(1,)] bny boa:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bmv bob:f64[1] = squeeze[dimensions=(1,)] boa boc:f64[1] = mul bnz bob bod:f64[1] = add bnx boc boe:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] biq bof:f64[1] = squeeze[dimensions=(1,)] boe bog:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bmv boh:f64[1] = squeeze[dimensions=(1,)] bog boi:f64[1] = mul bof boh boj:f64[1] = add bod boi bok:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] biq bol:f64[1] = squeeze[dimensions=(1,)] bok bom:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bmv bon:f64[1] = squeeze[dimensions=(1,)] bom boo:f64[1] = mul bol bon bop:f64[1] = sub boj boo boq:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] biq bor:f64[1] = squeeze[dimensions=(1,)] boq bos:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bmv bot:f64[1] = squeeze[dimensions=(1,)] bos bou:f64[1] = mul bor bot bov:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] biq bow:f64[1] = squeeze[dimensions=(1,)] bov box:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bmv boy:f64[1] = squeeze[dimensions=(1,)] box boz:f64[1] = mul bow boy bpa:f64[1] = sub bou boz bpb:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] biq bpc:f64[1] = squeeze[dimensions=(1,)] bpb bpd:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bmv bpe:f64[1] = squeeze[dimensions=(1,)] bpd bpf:f64[1] = mul bpc bpe bpg:f64[1] = add bpa bpf bph:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] biq bpi:f64[1] = squeeze[dimensions=(1,)] bph bpj:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bmv bpk:f64[1] = squeeze[dimensions=(1,)] bpj bpl:f64[1] = mul bpi bpk bpm:f64[1] = add bpg bpl bpn:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] biq bpo:f64[1] = squeeze[dimensions=(1,)] bpn bpp:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bmv bpq:f64[1] = squeeze[dimensions=(1,)] bpp bpr:f64[1] = mul bpo bpq bps:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] biq bpt:f64[1] = squeeze[dimensions=(1,)] bps bpu:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bmv bpv:f64[1] = squeeze[dimensions=(1,)] bpu bpw:f64[1] = mul bpt bpv bpx:f64[1] = add bpr bpw bpy:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] biq bpz:f64[1] = squeeze[dimensions=(1,)] bpy bqa:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bmv bqb:f64[1] = squeeze[dimensions=(1,)] bqa bqc:f64[1] = mul bpz bqb bqd:f64[1] = sub bpx bqc bqe:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] biq bqf:f64[1] = squeeze[dimensions=(1,)] bqe bqg:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bmv bqh:f64[1] = squeeze[dimensions=(1,)] bqg bqi:f64[1] = mul bqf bqh bqj:f64[1] = add bqd bqi bqk:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bns bql:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bop bqm:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bpm bqn:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bqj bqo:f64[1,4] = concatenate[dimension=1] bqk bql bqm bqn bqp:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] bdz bqq:f64[1,3] = squeeze[dimensions=(1,)] bqp bqr:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bqo bqs:f64[1] = squeeze[dimensions=(1,)] bqr bqt:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] bqo bqu:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] bqt bqq bqv:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bqu bqw:f64[1,3] = mul bqv bqt bqx:f64[1,3] = mul 2.0 bqw bqy:f64[1] = mul bqs bqs bqz:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] bqt bqt bra:f64[1] = sub bqy bqz brb:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bra brc:f64[1,3] = mul brb bqq brd:f64[1,3] = add bqx brc bre:f64[1] = mul 2.0 bqs brf:f64[1,3] = pjit[name=cross jaxpr=cross] bqt bqq brg:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bre brh:f64[1,3] = mul brg brf bri:f64[1,3] = add brd brh brj:f64[1,3] = sub bjl bri brk:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] bjl brl:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] bkf brm:f64[1,4,1] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(1)) shape=(1, 4, 1) ] bqo brn:f64[1,1,4] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 4) ] bqo bro:f64[1,4,4] = mul brm brn brp:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] bro brq:f64[1] = squeeze[dimensions=(1, 2)] brp brr:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] bro brs:f64[1] = squeeze[dimensions=(1, 2)] brr brt:f64[1] = add brq brs bru:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] bro brv:f64[1] = squeeze[dimensions=(1, 2)] bru brw:f64[1] = sub brt brv brx:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] bro bry:f64[1] = squeeze[dimensions=(1, 2)] brx brz:f64[1] = sub brw bry bsa:f64[1,1,1] = slice[ limit_indices=(1, 2, 3) start_indices=(0, 1, 2) strides=None ] bro bsb:f64[1] = squeeze[dimensions=(1, 2)] bsa bsc:f64[1,1,1] = slice[ limit_indices=(1, 1, 4) start_indices=(0, 0, 3) strides=None ] bro bsd:f64[1] = squeeze[dimensions=(1, 2)] bsc bse:f64[1] = sub bsb bsd bsf:f64[1] = mul 2.0 bse bsg:f64[1,1,1] = slice[ limit_indices=(1, 2, 4) start_indices=(0, 1, 3) strides=None ] bro bsh:f64[1] = squeeze[dimensions=(1, 2)] bsg bsi:f64[1,1,1] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 2) strides=None ] bro bsj:f64[1] = squeeze[dimensions=(1, 2)] bsi bsk:f64[1] = add bsh bsj bsl:f64[1] = mul 2.0 bsk bsm:f64[1,1,1] = slice[ limit_indices=(1, 2, 3) start_indices=(0, 1, 2) strides=None ] bro bsn:f64[1] = squeeze[dimensions=(1, 2)] bsm bso:f64[1,1,1] = slice[ limit_indices=(1, 1, 4) start_indices=(0, 0, 3) strides=None ] bro bsp:f64[1] = squeeze[dimensions=(1, 2)] bso bsq:f64[1] = add bsn bsp bsr:f64[1] = mul 2.0 bsq bss:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] bro bst:f64[1] = squeeze[dimensions=(1, 2)] bss bsu:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] bro bsv:f64[1] = squeeze[dimensions=(1, 2)] bsu bsw:f64[1] = sub bst bsv bsx:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] bro bsy:f64[1] = squeeze[dimensions=(1, 2)] bsx bsz:f64[1] = add bsw bsy bta:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] bro btb:f64[1] = squeeze[dimensions=(1, 2)] bta btc:f64[1] = sub bsz btb btd:f64[1,1,1] = slice[ limit_indices=(1, 3, 4) start_indices=(0, 2, 3) strides=None ] bro bte:f64[1] = squeeze[dimensions=(1, 2)] btd btf:f64[1,1,1] = slice[ limit_indices=(1, 1, 2) start_indices=(0, 0, 1) strides=None ] bro btg:f64[1] = squeeze[dimensions=(1, 2)] btf bth:f64[1] = sub bte btg bti:f64[1] = mul 2.0 bth btj:f64[1,1,1] = slice[ limit_indices=(1, 2, 4) start_indices=(0, 1, 3) strides=None ] bro btk:f64[1] = squeeze[dimensions=(1, 2)] btj btl:f64[1,1,1] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 2) strides=None ] bro btm:f64[1] = squeeze[dimensions=(1, 2)] btl btn:f64[1] = sub btk btm bto:f64[1] = mul 2.0 btn btp:f64[1,1,1] = slice[ limit_indices=(1, 3, 4) start_indices=(0, 2, 3) strides=None ] bro btq:f64[1] = squeeze[dimensions=(1, 2)] btp btr:f64[1,1,1] = slice[ limit_indices=(1, 1, 2) start_indices=(0, 0, 1) strides=None ] bro bts:f64[1] = squeeze[dimensions=(1, 2)] btr btt:f64[1] = add btq bts btu:f64[1] = mul 2.0 btt btv:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] bro btw:f64[1] = squeeze[dimensions=(1, 2)] btv btx:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] bro bty:f64[1] = squeeze[dimensions=(1, 2)] btx btz:f64[1] = sub btw bty bua:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] bro bub:f64[1] = squeeze[dimensions=(1, 2)] bua buc:f64[1] = sub btz bub bud:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] bro bue:f64[1] = squeeze[dimensions=(1, 2)] bud buf:f64[1] = add buc bue bug:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] brz buh:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bsf bui:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bsl buj:f64[1,3] = concatenate[dimension=1] bug buh bui buk:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bsr bul:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] btc bum:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bti bun:f64[1,3] = concatenate[dimension=1] buk bul bum buo:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bto bup:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] btu buq:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] buf bur:f64[1,3] = concatenate[dimension=1] buo bup buq bus:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] buj but:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] bun buu:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] bur buv:f64[1,3,3] = concatenate[dimension=1] bus but buu _:f64[1,1] = pjit[name=_take jaxpr=_take9] beb ig _:f64[1,1,3] = pjit[name=_take jaxpr=_take10] brk ig _:f64[1,1,3] = pjit[name=_take jaxpr=_take10] brl ig buw:f64[1,3] = pjit[name=_take jaxpr=_take4] brj ig bux:f64[1,4] = pjit[name=_take jaxpr=_take5] bqo ig _:f64[1,3,3] = pjit[name=_take jaxpr=_take6] buv ig buy:f64[1,1,3] = pjit[name=_take jaxpr=_take7] la ih buz:f64[1,1,3] = pjit[name=_take jaxpr=_take7] lb ii bva:f64[1,1] = pjit[name=_take jaxpr=_take8] nk ij bvb:f64[1,1] = pjit[name=_take jaxpr=_take8] kn ik bvc:f64[1,3] = slice[limit_indices=(4, 3) start_indices=(3, 0) strides=None] kp bvd:f64[1,4] = slice[limit_indices=(4, 4) start_indices=(3, 0) strides=None] kq bve:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bux bvf:f64[1] = squeeze[dimensions=(1,)] bve bvg:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] bux bvh:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] bvg bvc bvi:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bvh bvj:f64[1,3] = mul bvi bvg bvk:f64[1,3] = mul 2.0 bvj bvl:f64[1] = mul bvf bvf bvm:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] bvg bvg bvn:f64[1] = sub bvl bvm bvo:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bvn bvp:f64[1,3] = mul bvo bvc bvq:f64[1,3] = add bvk bvp bvr:f64[1] = mul 2.0 bvf bvs:f64[1,3] = pjit[name=cross jaxpr=cross] bvg bvc bvt:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bvr bvu:f64[1,3] = mul bvt bvs bvv:f64[1,3] = add bvq bvu bvw:f64[1,3] = add buw bvv bvx:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bux bvy:f64[1] = squeeze[dimensions=(1,)] bvx bvz:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bvd bwa:f64[1] = squeeze[dimensions=(1,)] bvz bwb:f64[1] = mul bvy bwa bwc:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bux bwd:f64[1] = squeeze[dimensions=(1,)] bwc bwe:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bvd bwf:f64[1] = squeeze[dimensions=(1,)] bwe bwg:f64[1] = mul bwd bwf bwh:f64[1] = sub bwb bwg bwi:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bux bwj:f64[1] = squeeze[dimensions=(1,)] bwi bwk:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bvd bwl:f64[1] = squeeze[dimensions=(1,)] bwk bwm:f64[1] = mul bwj bwl bwn:f64[1] = sub bwh bwm bwo:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bux bwp:f64[1] = squeeze[dimensions=(1,)] bwo bwq:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bvd bwr:f64[1] = squeeze[dimensions=(1,)] bwq bws:f64[1] = mul bwp bwr bwt:f64[1] = sub bwn bws bwu:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bux bwv:f64[1] = squeeze[dimensions=(1,)] bwu bww:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bvd bwx:f64[1] = squeeze[dimensions=(1,)] bww bwy:f64[1] = mul bwv bwx bwz:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bux bxa:f64[1] = squeeze[dimensions=(1,)] bwz bxb:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bvd bxc:f64[1] = squeeze[dimensions=(1,)] bxb bxd:f64[1] = mul bxa bxc bxe:f64[1] = add bwy bxd bxf:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bux bxg:f64[1] = squeeze[dimensions=(1,)] bxf bxh:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bvd bxi:f64[1] = squeeze[dimensions=(1,)] bxh bxj:f64[1] = mul bxg bxi bxk:f64[1] = add bxe bxj bxl:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bux bxm:f64[1] = squeeze[dimensions=(1,)] bxl bxn:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bvd bxo:f64[1] = squeeze[dimensions=(1,)] bxn bxp:f64[1] = mul bxm bxo bxq:f64[1] = sub bxk bxp bxr:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bux bxs:f64[1] = squeeze[dimensions=(1,)] bxr bxt:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bvd bxu:f64[1] = squeeze[dimensions=(1,)] bxt bxv:f64[1] = mul bxs bxu bxw:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bux bxx:f64[1] = squeeze[dimensions=(1,)] bxw bxy:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bvd bxz:f64[1] = squeeze[dimensions=(1,)] bxy bya:f64[1] = mul bxx bxz byb:f64[1] = sub bxv bya byc:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bux byd:f64[1] = squeeze[dimensions=(1,)] byc bye:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bvd byf:f64[1] = squeeze[dimensions=(1,)] bye byg:f64[1] = mul byd byf byh:f64[1] = add byb byg byi:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bux byj:f64[1] = squeeze[dimensions=(1,)] byi byk:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bvd byl:f64[1] = squeeze[dimensions=(1,)] byk bym:f64[1] = mul byj byl byn:f64[1] = add byh bym byo:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bux byp:f64[1] = squeeze[dimensions=(1,)] byo byq:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bvd byr:f64[1] = squeeze[dimensions=(1,)] byq bys:f64[1] = mul byp byr byt:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bux byu:f64[1] = squeeze[dimensions=(1,)] byt byv:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bvd byw:f64[1] = squeeze[dimensions=(1,)] byv byx:f64[1] = mul byu byw byy:f64[1] = add bys byx byz:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bux bza:f64[1] = squeeze[dimensions=(1,)] byz bzb:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bvd bzc:f64[1] = squeeze[dimensions=(1,)] bzb bzd:f64[1] = mul bza bzc bze:f64[1] = sub byy bzd bzf:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bux bzg:f64[1] = squeeze[dimensions=(1,)] bzf bzh:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bvd bzi:f64[1] = squeeze[dimensions=(1,)] bzh bzj:f64[1] = mul bzg bzi bzk:f64[1] = add bze bzj bzl:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bwt bzm:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bxq bzn:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] byn bzo:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bzk bzp:f64[1,4] = concatenate[dimension=1] bzl bzm bzn bzo bzq:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] buy bzr:f64[1,3] = squeeze[dimensions=(1,)] bzq bzs:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bzp bzt:f64[1] = squeeze[dimensions=(1,)] bzs bzu:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] bzp bzv:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] bzu bzr bzw:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] bzv bzx:f64[1,3] = mul bzw bzu bzy:f64[1,3] = mul 2.0 bzx bzz:f64[1] = mul bzt bzt caa:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] bzu bzu cab:f64[1] = sub bzz caa cac:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cab cad:f64[1,3] = mul cac bzr cae:f64[1,3] = add bzy cad caf:f64[1] = mul 2.0 bzt cag:f64[1,3] = pjit[name=cross jaxpr=cross] bzu bzr cah:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] caf cai:f64[1,3] = mul cah cag caj:f64[1,3] = add cae cai cak:f64[1,3] = add caj bvw cal:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] buz cam:f64[1,3] = squeeze[dimensions=(1,)] cal can:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bzp cao:f64[1] = squeeze[dimensions=(1,)] can cap:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] bzp caq:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] cap cam car:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] caq cas:f64[1,3] = mul car cap cat:f64[1,3] = mul 2.0 cas cau:f64[1] = mul cao cao cav:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] cap cap caw:f64[1] = sub cau cav cax:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] caw cay:f64[1,3] = mul cax cam caz:f64[1,3] = add cat cay cba:f64[1] = mul 2.0 cao cbb:f64[1,3] = pjit[name=cross jaxpr=cross] cap cam cbc:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cba cbd:f64[1,3] = mul cbc cbb cbe:f64[1,3] = add caz cbd cbf:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bva cbg:f64[1] = squeeze[dimensions=(1,)] cbf cbh:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bvb cbi:f64[1] = squeeze[dimensions=(1,)] cbh cbj:f64[1] = sub cbg cbi cbk:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] buz cbl:f64[1,3] = squeeze[dimensions=(1,)] cbk cbm:f64[1] = mul cbj 0.5 cbn:f64[1] = sin cbm cbo:f64[1] = mul cbj 0.5 cbp:f64[1] = cos cbo cbq:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cbn cbr:f64[1,3] = mul cbl cbq cbs:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cbp cbt:i64[1] = reshape[dimensions=None new_sizes=(1,)] 0 cbu:i64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] cbt cbv:i64[] = squeeze[dimensions=(0,)] cbu cbw:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] cbv cbx:f64[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] 0.0 cby:bool[1] = lt cbw 0 cbz:i64[1] = add cbw 3 cca:i64[1] = pjit[name=_where jaxpr=_where] cby cbz cbw ccb:i64[1] = pjit[name=clip jaxpr=clip] cca 0 3 ccc:i64[1] = pjit[name=argsort jaxpr=argsort] ccb ccd:i64[1] = iota[dimension=0 dtype=int64 shape=(1,)] cce:bool[1] = lt ccc 0 ccf:i64[1] = add ccc 1 ccg:i64[1] = select_n cce ccc ccf cch:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] ccg cci:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cch ccj:i64[1] = convert_element_type[new_dtype=int64 weak_type=False] ccb cck:i64[1] = scatter-add[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=jaxpr ] ccj cci ccd ccl:i64[1] = convert_element_type[new_dtype=int64 weak_type=True] cck ccm:bool[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] True ccn:bool[1] = lt ccl 0 cco:i64[1] = add ccl 4 ccp:i64[1] = select_n ccn ccl cco ccq:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] ccp ccr:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ccq ccs:bool[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] False cct:bool[4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] ccm ccr ccs ccu:i64[4] = pjit[name=_cumulative_reduction jaxpr=_cumulative_reduction] cct ccv:i64[3] = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] 0 ccw:i64[4] = pjit[name=clip jaxpr=clip1] ccu 0 ccx:i64[] = device_put[devices=[None] srcs=[None]] 1 ccy:bool[4] = lt ccw 0 ccz:i64[4] = add ccw 3 cda:i64[4] = select_n ccy ccw ccz cdb:i32[4] = convert_element_type[new_dtype=int32 weak_type=False] cda cdc:i32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4, 1)] cdb cdd:i64[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] ccx cde:i64[3] = scatter-add[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=jaxpr ] ccv cdc cdd cdf:i64[3] = pjit[name=_cumulative_reduction jaxpr=_cumulative_reduction1] cde cdg:i64[3] = pjit[name=floor_divide jaxpr=floor_divide] cdf 1 cdh:i64[3] = pjit[name=remainder jaxpr=remainder] cdg 4 cdi:bool[1] = lt ccl 0 cdj:i64[1] = add ccl 4 cdk:i64[1] = select_n cdi ccl cdj cdl:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] cdk cdm:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cdl cdn:f64[1,4] = broadcast_in_dim[ broadcast_dimensions=(np.int64(1),) shape=(1, 4) ] cbx cdo:f64[1,4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(np.int64(1),), scatter_dims_to_operand_dims=(np.int64(1),)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] cdn cdm cbs cdp:bool[3] = lt cdh 0 cdq:i64[3] = add cdh 4 cdr:i64[3] = select_n cdp cdh cdq cds:i32[3] = convert_element_type[new_dtype=int32 weak_type=False] cdr cdt:i32[3,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(3, 1)] cds cdu:f64[1,4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(np.int64(1),), scatter_dims_to_operand_dims=(np.int64(1),)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] cdo cdt cbr cdv:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bzp cdw:f64[1] = squeeze[dimensions=(1,)] cdv cdx:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] cdu cdy:f64[1] = squeeze[dimensions=(1,)] cdx cdz:f64[1] = mul cdw cdy cea:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bzp ceb:f64[1] = squeeze[dimensions=(1,)] cea cec:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] cdu ced:f64[1] = squeeze[dimensions=(1,)] cec cee:f64[1] = mul ceb ced cef:f64[1] = sub cdz cee ceg:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bzp ceh:f64[1] = squeeze[dimensions=(1,)] ceg cei:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] cdu cej:f64[1] = squeeze[dimensions=(1,)] cei cek:f64[1] = mul ceh cej cel:f64[1] = sub cef cek cem:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bzp cen:f64[1] = squeeze[dimensions=(1,)] cem ceo:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] cdu cep:f64[1] = squeeze[dimensions=(1,)] ceo ceq:f64[1] = mul cen cep cer:f64[1] = sub cel ceq ces:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bzp cet:f64[1] = squeeze[dimensions=(1,)] ces ceu:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] cdu cev:f64[1] = squeeze[dimensions=(1,)] ceu cew:f64[1] = mul cet cev cex:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bzp cey:f64[1] = squeeze[dimensions=(1,)] cex cez:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] cdu cfa:f64[1] = squeeze[dimensions=(1,)] cez cfb:f64[1] = mul cey cfa cfc:f64[1] = add cew cfb cfd:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bzp cfe:f64[1] = squeeze[dimensions=(1,)] cfd cff:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] cdu cfg:f64[1] = squeeze[dimensions=(1,)] cff cfh:f64[1] = mul cfe cfg cfi:f64[1] = add cfc cfh cfj:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bzp cfk:f64[1] = squeeze[dimensions=(1,)] cfj cfl:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] cdu cfm:f64[1] = squeeze[dimensions=(1,)] cfl cfn:f64[1] = mul cfk cfm cfo:f64[1] = sub cfi cfn cfp:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bzp cfq:f64[1] = squeeze[dimensions=(1,)] cfp cfr:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] cdu cfs:f64[1] = squeeze[dimensions=(1,)] cfr cft:f64[1] = mul cfq cfs cfu:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bzp cfv:f64[1] = squeeze[dimensions=(1,)] cfu cfw:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] cdu cfx:f64[1] = squeeze[dimensions=(1,)] cfw cfy:f64[1] = mul cfv cfx cfz:f64[1] = sub cft cfy cga:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bzp cgb:f64[1] = squeeze[dimensions=(1,)] cga cgc:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] cdu cgd:f64[1] = squeeze[dimensions=(1,)] cgc cge:f64[1] = mul cgb cgd cgf:f64[1] = add cfz cge cgg:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bzp cgh:f64[1] = squeeze[dimensions=(1,)] cgg cgi:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] cdu cgj:f64[1] = squeeze[dimensions=(1,)] cgi cgk:f64[1] = mul cgh cgj cgl:f64[1] = add cgf cgk cgm:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] bzp cgn:f64[1] = squeeze[dimensions=(1,)] cgm cgo:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] cdu cgp:f64[1] = squeeze[dimensions=(1,)] cgo cgq:f64[1] = mul cgn cgp cgr:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] bzp cgs:f64[1] = squeeze[dimensions=(1,)] cgr cgt:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] cdu cgu:f64[1] = squeeze[dimensions=(1,)] cgt cgv:f64[1] = mul cgs cgu cgw:f64[1] = add cgq cgv cgx:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] bzp cgy:f64[1] = squeeze[dimensions=(1,)] cgx cgz:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] cdu cha:f64[1] = squeeze[dimensions=(1,)] cgz chb:f64[1] = mul cgy cha chc:f64[1] = sub cgw chb chd:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] bzp che:f64[1] = squeeze[dimensions=(1,)] chd chf:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] cdu chg:f64[1] = squeeze[dimensions=(1,)] chf chh:f64[1] = mul che chg chi:f64[1] = add chc chh chj:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cer chk:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cfo chl:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cgl chm:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] chi chn:f64[1,4] = concatenate[dimension=1] chj chk chl chm cho:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] buy chp:f64[1,3] = squeeze[dimensions=(1,)] cho chq:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] chn chr:f64[1] = squeeze[dimensions=(1,)] chq chs:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] chn cht:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] chs chp chu:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cht chv:f64[1,3] = mul chu chs chw:f64[1,3] = mul 2.0 chv chx:f64[1] = mul chr chr chy:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] chs chs chz:f64[1] = sub chx chy cia:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] chz cib:f64[1,3] = mul cia chp cic:f64[1,3] = add chw cib cid:f64[1] = mul 2.0 chr cie:f64[1,3] = pjit[name=cross jaxpr=cross] chs chp cif:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cid cig:f64[1,3] = mul cif cie cih:f64[1,3] = add cic cig cii:f64[1,3] = sub cak cih cij:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] cak cik:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] cbe cil:f64[1,4,1] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(1)) shape=(1, 4, 1) ] chn cim:f64[1,1,4] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 4) ] chn cin:f64[1,4,4] = mul cil cim cio:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] cin cip:f64[1] = squeeze[dimensions=(1, 2)] cio ciq:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] cin cir:f64[1] = squeeze[dimensions=(1, 2)] ciq cis:f64[1] = add cip cir cit:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] cin ciu:f64[1] = squeeze[dimensions=(1, 2)] cit civ:f64[1] = sub cis ciu ciw:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] cin cix:f64[1] = squeeze[dimensions=(1, 2)] ciw ciy:f64[1] = sub civ cix ciz:f64[1,1,1] = slice[ limit_indices=(1, 2, 3) start_indices=(0, 1, 2) strides=None ] cin cja:f64[1] = squeeze[dimensions=(1, 2)] ciz cjb:f64[1,1,1] = slice[ limit_indices=(1, 1, 4) start_indices=(0, 0, 3) strides=None ] cin cjc:f64[1] = squeeze[dimensions=(1, 2)] cjb cjd:f64[1] = sub cja cjc cje:f64[1] = mul 2.0 cjd cjf:f64[1,1,1] = slice[ limit_indices=(1, 2, 4) start_indices=(0, 1, 3) strides=None ] cin cjg:f64[1] = squeeze[dimensions=(1, 2)] cjf cjh:f64[1,1,1] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 2) strides=None ] cin cji:f64[1] = squeeze[dimensions=(1, 2)] cjh cjj:f64[1] = add cjg cji cjk:f64[1] = mul 2.0 cjj cjl:f64[1,1,1] = slice[ limit_indices=(1, 2, 3) start_indices=(0, 1, 2) strides=None ] cin cjm:f64[1] = squeeze[dimensions=(1, 2)] cjl cjn:f64[1,1,1] = slice[ limit_indices=(1, 1, 4) start_indices=(0, 0, 3) strides=None ] cin cjo:f64[1] = squeeze[dimensions=(1, 2)] cjn cjp:f64[1] = add cjm cjo cjq:f64[1] = mul 2.0 cjp cjr:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] cin cjs:f64[1] = squeeze[dimensions=(1, 2)] cjr cjt:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] cin cju:f64[1] = squeeze[dimensions=(1, 2)] cjt cjv:f64[1] = sub cjs cju cjw:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] cin cjx:f64[1] = squeeze[dimensions=(1, 2)] cjw cjy:f64[1] = add cjv cjx cjz:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] cin cka:f64[1] = squeeze[dimensions=(1, 2)] cjz ckb:f64[1] = sub cjy cka ckc:f64[1,1,1] = slice[ limit_indices=(1, 3, 4) start_indices=(0, 2, 3) strides=None ] cin ckd:f64[1] = squeeze[dimensions=(1, 2)] ckc cke:f64[1,1,1] = slice[ limit_indices=(1, 1, 2) start_indices=(0, 0, 1) strides=None ] cin ckf:f64[1] = squeeze[dimensions=(1, 2)] cke ckg:f64[1] = sub ckd ckf ckh:f64[1] = mul 2.0 ckg cki:f64[1,1,1] = slice[ limit_indices=(1, 2, 4) start_indices=(0, 1, 3) strides=None ] cin ckj:f64[1] = squeeze[dimensions=(1, 2)] cki ckk:f64[1,1,1] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 2) strides=None ] cin ckl:f64[1] = squeeze[dimensions=(1, 2)] ckk ckm:f64[1] = sub ckj ckl ckn:f64[1] = mul 2.0 ckm cko:f64[1,1,1] = slice[ limit_indices=(1, 3, 4) start_indices=(0, 2, 3) strides=None ] cin ckp:f64[1] = squeeze[dimensions=(1, 2)] cko ckq:f64[1,1,1] = slice[ limit_indices=(1, 1, 2) start_indices=(0, 0, 1) strides=None ] cin ckr:f64[1] = squeeze[dimensions=(1, 2)] ckq cks:f64[1] = add ckp ckr ckt:f64[1] = mul 2.0 cks cku:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] cin ckv:f64[1] = squeeze[dimensions=(1, 2)] cku ckw:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] cin ckx:f64[1] = squeeze[dimensions=(1, 2)] ckw cky:f64[1] = sub ckv ckx ckz:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] cin cla:f64[1] = squeeze[dimensions=(1, 2)] ckz clb:f64[1] = sub cky cla clc:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] cin cld:f64[1] = squeeze[dimensions=(1, 2)] clc cle:f64[1] = add clb cld clf:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ciy clg:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cje clh:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cjk cli:f64[1,3] = concatenate[dimension=1] clf clg clh clj:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cjq clk:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ckb cll:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ckh clm:f64[1,3] = concatenate[dimension=1] clj clk cll cln:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ckn clo:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ckt clp:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cle clq:f64[1,3] = concatenate[dimension=1] cln clo clp clr:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] cli cls:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] clm clt:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] clq clu:f64[1,3,3] = concatenate[dimension=1] clr cls clt _:f64[1,1] = pjit[name=_take jaxpr=_take9] bva il _:f64[1,1,3] = pjit[name=_take jaxpr=_take10] cij il _:f64[1,1,3] = pjit[name=_take jaxpr=_take10] cik il clv:f64[1,3] = pjit[name=_take jaxpr=_take4] cii il clw:f64[1,4] = pjit[name=_take jaxpr=_take5] chn il _:f64[1,3,3] = pjit[name=_take jaxpr=_take6] clu il clx:f64[1,1,3] = pjit[name=_take jaxpr=_take7] la im cly:f64[1,1,3] = pjit[name=_take jaxpr=_take7] lb in clz:f64[1,1] = pjit[name=_take jaxpr=_take8] nk io cma:f64[1,1] = pjit[name=_take jaxpr=_take8] kn ip cmb:f64[1,3] = slice[limit_indices=(5, 3) start_indices=(4, 0) strides=None] kp cmc:f64[1,4] = slice[limit_indices=(5, 4) start_indices=(4, 0) strides=None] kq cmd:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] clw cme:f64[1] = squeeze[dimensions=(1,)] cmd cmf:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] clw cmg:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] cmf cmb cmh:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cmg cmi:f64[1,3] = mul cmh cmf cmj:f64[1,3] = mul 2.0 cmi cmk:f64[1] = mul cme cme cml:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] cmf cmf cmm:f64[1] = sub cmk cml cmn:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cmm cmo:f64[1,3] = mul cmn cmb cmp:f64[1,3] = add cmj cmo cmq:f64[1] = mul 2.0 cme cmr:f64[1,3] = pjit[name=cross jaxpr=cross] cmf cmb cms:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cmq cmt:f64[1,3] = mul cms cmr cmu:f64[1,3] = add cmp cmt cmv:f64[1,3] = add clv cmu cmw:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] clw cmx:f64[1] = squeeze[dimensions=(1,)] cmw cmy:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] cmc cmz:f64[1] = squeeze[dimensions=(1,)] cmy cna:f64[1] = mul cmx cmz cnb:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] clw cnc:f64[1] = squeeze[dimensions=(1,)] cnb cnd:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] cmc cne:f64[1] = squeeze[dimensions=(1,)] cnd cnf:f64[1] = mul cnc cne cng:f64[1] = sub cna cnf cnh:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] clw cni:f64[1] = squeeze[dimensions=(1,)] cnh cnj:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] cmc cnk:f64[1] = squeeze[dimensions=(1,)] cnj cnl:f64[1] = mul cni cnk cnm:f64[1] = sub cng cnl cnn:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] clw cno:f64[1] = squeeze[dimensions=(1,)] cnn cnp:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] cmc cnq:f64[1] = squeeze[dimensions=(1,)] cnp cnr:f64[1] = mul cno cnq cns:f64[1] = sub cnm cnr cnt:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] clw cnu:f64[1] = squeeze[dimensions=(1,)] cnt cnv:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] cmc cnw:f64[1] = squeeze[dimensions=(1,)] cnv cnx:f64[1] = mul cnu cnw cny:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] clw cnz:f64[1] = squeeze[dimensions=(1,)] cny coa:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] cmc cob:f64[1] = squeeze[dimensions=(1,)] coa coc:f64[1] = mul cnz cob cod:f64[1] = add cnx coc coe:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] clw cof:f64[1] = squeeze[dimensions=(1,)] coe cog:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] cmc coh:f64[1] = squeeze[dimensions=(1,)] cog coi:f64[1] = mul cof coh coj:f64[1] = add cod coi cok:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] clw col:f64[1] = squeeze[dimensions=(1,)] cok com:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] cmc con:f64[1] = squeeze[dimensions=(1,)] com coo:f64[1] = mul col con cop:f64[1] = sub coj coo coq:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] clw cor:f64[1] = squeeze[dimensions=(1,)] coq cos:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] cmc cot:f64[1] = squeeze[dimensions=(1,)] cos cou:f64[1] = mul cor cot cov:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] clw cow:f64[1] = squeeze[dimensions=(1,)] cov cox:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] cmc coy:f64[1] = squeeze[dimensions=(1,)] cox coz:f64[1] = mul cow coy cpa:f64[1] = sub cou coz cpb:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] clw cpc:f64[1] = squeeze[dimensions=(1,)] cpb cpd:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] cmc cpe:f64[1] = squeeze[dimensions=(1,)] cpd cpf:f64[1] = mul cpc cpe cpg:f64[1] = add cpa cpf cph:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] clw cpi:f64[1] = squeeze[dimensions=(1,)] cph cpj:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] cmc cpk:f64[1] = squeeze[dimensions=(1,)] cpj cpl:f64[1] = mul cpi cpk cpm:f64[1] = add cpg cpl cpn:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] clw cpo:f64[1] = squeeze[dimensions=(1,)] cpn cpp:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] cmc cpq:f64[1] = squeeze[dimensions=(1,)] cpp cpr:f64[1] = mul cpo cpq cps:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] clw cpt:f64[1] = squeeze[dimensions=(1,)] cps cpu:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] cmc cpv:f64[1] = squeeze[dimensions=(1,)] cpu cpw:f64[1] = mul cpt cpv cpx:f64[1] = add cpr cpw cpy:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] clw cpz:f64[1] = squeeze[dimensions=(1,)] cpy cqa:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] cmc cqb:f64[1] = squeeze[dimensions=(1,)] cqa cqc:f64[1] = mul cpz cqb cqd:f64[1] = sub cpx cqc cqe:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] clw cqf:f64[1] = squeeze[dimensions=(1,)] cqe cqg:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] cmc cqh:f64[1] = squeeze[dimensions=(1,)] cqg cqi:f64[1] = mul cqf cqh cqj:f64[1] = add cqd cqi cqk:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cns cql:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cop cqm:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cpm cqn:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cqj cqo:f64[1,4] = concatenate[dimension=1] cqk cql cqm cqn cqp:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] clx cqq:f64[1,3] = squeeze[dimensions=(1,)] cqp cqr:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] cqo cqs:f64[1] = squeeze[dimensions=(1,)] cqr cqt:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] cqo cqu:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] cqt cqq cqv:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cqu cqw:f64[1,3] = mul cqv cqt cqx:f64[1,3] = mul 2.0 cqw cqy:f64[1] = mul cqs cqs cqz:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] cqt cqt cra:f64[1] = sub cqy cqz crb:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cra crc:f64[1,3] = mul crb cqq crd:f64[1,3] = add cqx crc cre:f64[1] = mul 2.0 cqs crf:f64[1,3] = pjit[name=cross jaxpr=cross] cqt cqq crg:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cre crh:f64[1,3] = mul crg crf cri:f64[1,3] = add crd crh crj:f64[1,3] = add cri cmv crk:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] cly crl:f64[1,3] = squeeze[dimensions=(1,)] crk crm:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] cqo crn:f64[1] = squeeze[dimensions=(1,)] crm cro:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] cqo crp:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] cro crl crq:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] crp crr:f64[1,3] = mul crq cro crs:f64[1,3] = mul 2.0 crr crt:f64[1] = mul crn crn cru:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] cro cro crv:f64[1] = sub crt cru crw:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] crv crx:f64[1,3] = mul crw crl cry:f64[1,3] = add crs crx crz:f64[1] = mul 2.0 crn csa:f64[1,3] = pjit[name=cross jaxpr=cross] cro crl csb:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] crz csc:f64[1,3] = mul csb csa csd:f64[1,3] = add cry csc cse:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] clz csf:f64[1] = squeeze[dimensions=(1,)] cse csg:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] cma csh:f64[1] = squeeze[dimensions=(1,)] csg csi:f64[1] = sub csf csh csj:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] cly csk:f64[1,3] = squeeze[dimensions=(1,)] csj csl:f64[1] = mul csi 0.5 csm:f64[1] = sin csl csn:f64[1] = mul csi 0.5 cso:f64[1] = cos csn csp:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] csm csq:f64[1,3] = mul csk csp csr:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cso css:i64[1] = reshape[dimensions=None new_sizes=(1,)] 0 cst:i64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] css csu:i64[] = squeeze[dimensions=(0,)] cst csv:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] csu csw:f64[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] 0.0 csx:bool[1] = lt csv 0 csy:i64[1] = add csv 3 csz:i64[1] = pjit[name=_where jaxpr=_where] csx csy csv cta:i64[1] = pjit[name=clip jaxpr=clip] csz 0 3 ctb:i64[1] = pjit[name=argsort jaxpr=argsort] cta ctc:i64[1] = iota[dimension=0 dtype=int64 shape=(1,)] ctd:bool[1] = lt ctb 0 cte:i64[1] = add ctb 1 ctf:i64[1] = select_n ctd ctb cte ctg:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] ctf cth:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ctg cti:i64[1] = convert_element_type[new_dtype=int64 weak_type=False] cta ctj:i64[1] = scatter-add[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=jaxpr ] cti cth ctc ctk:i64[1] = convert_element_type[new_dtype=int64 weak_type=True] ctj ctl:bool[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] True ctm:bool[1] = lt ctk 0 ctn:i64[1] = add ctk 4 cto:i64[1] = select_n ctm ctk ctn ctp:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] cto ctq:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ctp ctr:bool[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] False cts:bool[4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] ctl ctq ctr ctt:i64[4] = pjit[name=_cumulative_reduction jaxpr=_cumulative_reduction] cts ctu:i64[3] = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] 0 ctv:i64[4] = pjit[name=clip jaxpr=clip1] ctt 0 ctw:i64[] = device_put[devices=[None] srcs=[None]] 1 ctx:bool[4] = lt ctv 0 cty:i64[4] = add ctv 3 ctz:i64[4] = select_n ctx ctv cty cua:i32[4] = convert_element_type[new_dtype=int32 weak_type=False] ctz cub:i32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4, 1)] cua cuc:i64[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] ctw cud:i64[3] = scatter-add[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=jaxpr ] ctu cub cuc cue:i64[3] = pjit[name=_cumulative_reduction jaxpr=_cumulative_reduction1] cud cuf:i64[3] = pjit[name=floor_divide jaxpr=floor_divide] cue 1 cug:i64[3] = pjit[name=remainder jaxpr=remainder] cuf 4 cuh:bool[1] = lt ctk 0 cui:i64[1] = add ctk 4 cuj:i64[1] = select_n cuh ctk cui cuk:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] cuj cul:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cuk cum:f64[1,4] = broadcast_in_dim[ broadcast_dimensions=(np.int64(1),) shape=(1, 4) ] csw cun:f64[1,4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(np.int64(1),), scatter_dims_to_operand_dims=(np.int64(1),)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] cum cul csr cuo:bool[3] = lt cug 0 cup:i64[3] = add cug 4 cuq:i64[3] = select_n cuo cug cup cur:i32[3] = convert_element_type[new_dtype=int32 weak_type=False] cuq cus:i32[3,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(3, 1)] cur cut:f64[1,4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(np.int64(1),), scatter_dims_to_operand_dims=(np.int64(1),)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] cun cus csq cuu:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] cqo cuv:f64[1] = squeeze[dimensions=(1,)] cuu cuw:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] cut cux:f64[1] = squeeze[dimensions=(1,)] cuw cuy:f64[1] = mul cuv cux cuz:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] cqo cva:f64[1] = squeeze[dimensions=(1,)] cuz cvb:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] cut cvc:f64[1] = squeeze[dimensions=(1,)] cvb cvd:f64[1] = mul cva cvc cve:f64[1] = sub cuy cvd cvf:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] cqo cvg:f64[1] = squeeze[dimensions=(1,)] cvf cvh:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] cut cvi:f64[1] = squeeze[dimensions=(1,)] cvh cvj:f64[1] = mul cvg cvi cvk:f64[1] = sub cve cvj cvl:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] cqo cvm:f64[1] = squeeze[dimensions=(1,)] cvl cvn:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] cut cvo:f64[1] = squeeze[dimensions=(1,)] cvn cvp:f64[1] = mul cvm cvo cvq:f64[1] = sub cvk cvp cvr:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] cqo cvs:f64[1] = squeeze[dimensions=(1,)] cvr cvt:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] cut cvu:f64[1] = squeeze[dimensions=(1,)] cvt cvv:f64[1] = mul cvs cvu cvw:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] cqo cvx:f64[1] = squeeze[dimensions=(1,)] cvw cvy:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] cut cvz:f64[1] = squeeze[dimensions=(1,)] cvy cwa:f64[1] = mul cvx cvz cwb:f64[1] = add cvv cwa cwc:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] cqo cwd:f64[1] = squeeze[dimensions=(1,)] cwc cwe:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] cut cwf:f64[1] = squeeze[dimensions=(1,)] cwe cwg:f64[1] = mul cwd cwf cwh:f64[1] = add cwb cwg cwi:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] cqo cwj:f64[1] = squeeze[dimensions=(1,)] cwi cwk:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] cut cwl:f64[1] = squeeze[dimensions=(1,)] cwk cwm:f64[1] = mul cwj cwl cwn:f64[1] = sub cwh cwm cwo:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] cqo cwp:f64[1] = squeeze[dimensions=(1,)] cwo cwq:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] cut cwr:f64[1] = squeeze[dimensions=(1,)] cwq cws:f64[1] = mul cwp cwr cwt:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] cqo cwu:f64[1] = squeeze[dimensions=(1,)] cwt cwv:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] cut cww:f64[1] = squeeze[dimensions=(1,)] cwv cwx:f64[1] = mul cwu cww cwy:f64[1] = sub cws cwx cwz:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] cqo cxa:f64[1] = squeeze[dimensions=(1,)] cwz cxb:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] cut cxc:f64[1] = squeeze[dimensions=(1,)] cxb cxd:f64[1] = mul cxa cxc cxe:f64[1] = add cwy cxd cxf:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] cqo cxg:f64[1] = squeeze[dimensions=(1,)] cxf cxh:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] cut cxi:f64[1] = squeeze[dimensions=(1,)] cxh cxj:f64[1] = mul cxg cxi cxk:f64[1] = add cxe cxj cxl:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] cqo cxm:f64[1] = squeeze[dimensions=(1,)] cxl cxn:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] cut cxo:f64[1] = squeeze[dimensions=(1,)] cxn cxp:f64[1] = mul cxm cxo cxq:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] cqo cxr:f64[1] = squeeze[dimensions=(1,)] cxq cxs:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] cut cxt:f64[1] = squeeze[dimensions=(1,)] cxs cxu:f64[1] = mul cxr cxt cxv:f64[1] = add cxp cxu cxw:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] cqo cxx:f64[1] = squeeze[dimensions=(1,)] cxw cxy:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] cut cxz:f64[1] = squeeze[dimensions=(1,)] cxy cya:f64[1] = mul cxx cxz cyb:f64[1] = sub cxv cya cyc:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] cqo cyd:f64[1] = squeeze[dimensions=(1,)] cyc cye:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] cut cyf:f64[1] = squeeze[dimensions=(1,)] cye cyg:f64[1] = mul cyd cyf cyh:f64[1] = add cyb cyg cyi:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cvq cyj:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cwn cyk:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cxk cyl:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cyh cym:f64[1,4] = concatenate[dimension=1] cyi cyj cyk cyl cyn:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] clx cyo:f64[1,3] = squeeze[dimensions=(1,)] cyn cyp:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] cym cyq:f64[1] = squeeze[dimensions=(1,)] cyp cyr:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] cym cys:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] cyr cyo cyt:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cys cyu:f64[1,3] = mul cyt cyr cyv:f64[1,3] = mul 2.0 cyu cyw:f64[1] = mul cyq cyq cyx:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] cyr cyr cyy:f64[1] = sub cyw cyx cyz:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] cyy cza:f64[1,3] = mul cyz cyo czb:f64[1,3] = add cyv cza czc:f64[1] = mul 2.0 cyq czd:f64[1,3] = pjit[name=cross jaxpr=cross] cyr cyo cze:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] czc czf:f64[1,3] = mul cze czd czg:f64[1,3] = add czb czf czh:f64[1,3] = sub crj czg czi:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] crj czj:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] csd czk:f64[1,4,1] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(1)) shape=(1, 4, 1) ] cym czl:f64[1,1,4] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 4) ] cym czm:f64[1,4,4] = mul czk czl czn:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] czm czo:f64[1] = squeeze[dimensions=(1, 2)] czn czp:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] czm czq:f64[1] = squeeze[dimensions=(1, 2)] czp czr:f64[1] = add czo czq czs:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] czm czt:f64[1] = squeeze[dimensions=(1, 2)] czs czu:f64[1] = sub czr czt czv:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] czm czw:f64[1] = squeeze[dimensions=(1, 2)] czv czx:f64[1] = sub czu czw czy:f64[1,1,1] = slice[ limit_indices=(1, 2, 3) start_indices=(0, 1, 2) strides=None ] czm czz:f64[1] = squeeze[dimensions=(1, 2)] czy daa:f64[1,1,1] = slice[ limit_indices=(1, 1, 4) start_indices=(0, 0, 3) strides=None ] czm dab:f64[1] = squeeze[dimensions=(1, 2)] daa dac:f64[1] = sub czz dab dad:f64[1] = mul 2.0 dac dae:f64[1,1,1] = slice[ limit_indices=(1, 2, 4) start_indices=(0, 1, 3) strides=None ] czm daf:f64[1] = squeeze[dimensions=(1, 2)] dae dag:f64[1,1,1] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 2) strides=None ] czm dah:f64[1] = squeeze[dimensions=(1, 2)] dag dai:f64[1] = add daf dah daj:f64[1] = mul 2.0 dai dak:f64[1,1,1] = slice[ limit_indices=(1, 2, 3) start_indices=(0, 1, 2) strides=None ] czm dal:f64[1] = squeeze[dimensions=(1, 2)] dak dam:f64[1,1,1] = slice[ limit_indices=(1, 1, 4) start_indices=(0, 0, 3) strides=None ] czm dan:f64[1] = squeeze[dimensions=(1, 2)] dam dao:f64[1] = add dal dan dap:f64[1] = mul 2.0 dao daq:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] czm dar:f64[1] = squeeze[dimensions=(1, 2)] daq das:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] czm dat:f64[1] = squeeze[dimensions=(1, 2)] das dau:f64[1] = sub dar dat dav:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] czm daw:f64[1] = squeeze[dimensions=(1, 2)] dav dax:f64[1] = add dau daw day:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] czm daz:f64[1] = squeeze[dimensions=(1, 2)] day dba:f64[1] = sub dax daz dbb:f64[1,1,1] = slice[ limit_indices=(1, 3, 4) start_indices=(0, 2, 3) strides=None ] czm dbc:f64[1] = squeeze[dimensions=(1, 2)] dbb dbd:f64[1,1,1] = slice[ limit_indices=(1, 1, 2) start_indices=(0, 0, 1) strides=None ] czm dbe:f64[1] = squeeze[dimensions=(1, 2)] dbd dbf:f64[1] = sub dbc dbe dbg:f64[1] = mul 2.0 dbf dbh:f64[1,1,1] = slice[ limit_indices=(1, 2, 4) start_indices=(0, 1, 3) strides=None ] czm dbi:f64[1] = squeeze[dimensions=(1, 2)] dbh dbj:f64[1,1,1] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 2) strides=None ] czm dbk:f64[1] = squeeze[dimensions=(1, 2)] dbj dbl:f64[1] = sub dbi dbk dbm:f64[1] = mul 2.0 dbl dbn:f64[1,1,1] = slice[ limit_indices=(1, 3, 4) start_indices=(0, 2, 3) strides=None ] czm dbo:f64[1] = squeeze[dimensions=(1, 2)] dbn dbp:f64[1,1,1] = slice[ limit_indices=(1, 1, 2) start_indices=(0, 0, 1) strides=None ] czm dbq:f64[1] = squeeze[dimensions=(1, 2)] dbp dbr:f64[1] = add dbo dbq dbs:f64[1] = mul 2.0 dbr dbt:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] czm dbu:f64[1] = squeeze[dimensions=(1, 2)] dbt dbv:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] czm dbw:f64[1] = squeeze[dimensions=(1, 2)] dbv dbx:f64[1] = sub dbu dbw dby:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] czm dbz:f64[1] = squeeze[dimensions=(1, 2)] dby dca:f64[1] = sub dbx dbz dcb:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] czm dcc:f64[1] = squeeze[dimensions=(1, 2)] dcb dcd:f64[1] = add dca dcc dce:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] czx dcf:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dad dcg:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] daj dch:f64[1,3] = concatenate[dimension=1] dce dcf dcg dci:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dap dcj:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dba dck:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dbg dcl:f64[1,3] = concatenate[dimension=1] dci dcj dck dcm:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dbm dcn:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dbs dco:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dcd dcp:f64[1,3] = concatenate[dimension=1] dcm dcn dco dcq:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] dch dcr:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] dcl dcs:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] dcp dct:f64[1,3,3] = concatenate[dimension=1] dcq dcr dcs _:f64[1,1] = pjit[name=_take jaxpr=_take9] clz iq _:f64[1,1,3] = pjit[name=_take jaxpr=_take10] czi iq _:f64[1,1,3] = pjit[name=_take jaxpr=_take10] czj iq dcu:f64[1,3] = pjit[name=_take jaxpr=_take4] czh iq dcv:f64[1,4] = pjit[name=_take jaxpr=_take5] cym iq _:f64[1,3,3] = pjit[name=_take jaxpr=_take6] dct iq dcw:f64[1,1,3] = pjit[name=_take jaxpr=_take7] la ir dcx:f64[1,1,3] = pjit[name=_take jaxpr=_take7] lb is dcy:f64[1,1] = pjit[name=_take jaxpr=_take8] nk it dcz:f64[1,1] = pjit[name=_take jaxpr=_take8] kn iu dda:f64[1,3] = slice[limit_indices=(6, 3) start_indices=(5, 0) strides=None] kp ddb:f64[1,4] = slice[limit_indices=(6, 4) start_indices=(5, 0) strides=None] kq ddc:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dcv ddd:f64[1] = squeeze[dimensions=(1,)] ddc dde:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] dcv ddf:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] dde dda ddg:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ddf ddh:f64[1,3] = mul ddg dde ddi:f64[1,3] = mul 2.0 ddh ddj:f64[1] = mul ddd ddd ddk:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] dde dde ddl:f64[1] = sub ddj ddk ddm:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ddl ddn:f64[1,3] = mul ddm dda ddo:f64[1,3] = add ddi ddn ddp:f64[1] = mul 2.0 ddd ddq:f64[1,3] = pjit[name=cross jaxpr=cross] dde dda ddr:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ddp dds:f64[1,3] = mul ddr ddq ddt:f64[1,3] = add ddo dds ddu:f64[1,3] = add dcu ddt ddv:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dcv ddw:f64[1] = squeeze[dimensions=(1,)] ddv ddx:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ddb ddy:f64[1] = squeeze[dimensions=(1,)] ddx ddz:f64[1] = mul ddw ddy dea:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dcv deb:f64[1] = squeeze[dimensions=(1,)] dea dec:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] ddb ded:f64[1] = squeeze[dimensions=(1,)] dec dee:f64[1] = mul deb ded def:f64[1] = sub ddz dee deg:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dcv deh:f64[1] = squeeze[dimensions=(1,)] deg dei:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] ddb dej:f64[1] = squeeze[dimensions=(1,)] dei dek:f64[1] = mul deh dej del:f64[1] = sub def dek dem:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dcv den:f64[1] = squeeze[dimensions=(1,)] dem deo:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] ddb dep:f64[1] = squeeze[dimensions=(1,)] deo deq:f64[1] = mul den dep der:f64[1] = sub del deq des:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dcv det:f64[1] = squeeze[dimensions=(1,)] des deu:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] ddb dev:f64[1] = squeeze[dimensions=(1,)] deu dew:f64[1] = mul det dev dex:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dcv dey:f64[1] = squeeze[dimensions=(1,)] dex dez:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ddb dfa:f64[1] = squeeze[dimensions=(1,)] dez dfb:f64[1] = mul dey dfa dfc:f64[1] = add dew dfb dfd:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dcv dfe:f64[1] = squeeze[dimensions=(1,)] dfd dff:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] ddb dfg:f64[1] = squeeze[dimensions=(1,)] dff dfh:f64[1] = mul dfe dfg dfi:f64[1] = add dfc dfh dfj:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dcv dfk:f64[1] = squeeze[dimensions=(1,)] dfj dfl:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] ddb dfm:f64[1] = squeeze[dimensions=(1,)] dfl dfn:f64[1] = mul dfk dfm dfo:f64[1] = sub dfi dfn dfp:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dcv dfq:f64[1] = squeeze[dimensions=(1,)] dfp dfr:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] ddb dfs:f64[1] = squeeze[dimensions=(1,)] dfr dft:f64[1] = mul dfq dfs dfu:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dcv dfv:f64[1] = squeeze[dimensions=(1,)] dfu dfw:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] ddb dfx:f64[1] = squeeze[dimensions=(1,)] dfw dfy:f64[1] = mul dfv dfx dfz:f64[1] = sub dft dfy dga:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dcv dgb:f64[1] = squeeze[dimensions=(1,)] dga dgc:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ddb dgd:f64[1] = squeeze[dimensions=(1,)] dgc dge:f64[1] = mul dgb dgd dgf:f64[1] = add dfz dge dgg:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dcv dgh:f64[1] = squeeze[dimensions=(1,)] dgg dgi:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] ddb dgj:f64[1] = squeeze[dimensions=(1,)] dgi dgk:f64[1] = mul dgh dgj dgl:f64[1] = add dgf dgk dgm:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dcv dgn:f64[1] = squeeze[dimensions=(1,)] dgm dgo:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] ddb dgp:f64[1] = squeeze[dimensions=(1,)] dgo dgq:f64[1] = mul dgn dgp dgr:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dcv dgs:f64[1] = squeeze[dimensions=(1,)] dgr dgt:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] ddb dgu:f64[1] = squeeze[dimensions=(1,)] dgt dgv:f64[1] = mul dgs dgu dgw:f64[1] = add dgq dgv dgx:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dcv dgy:f64[1] = squeeze[dimensions=(1,)] dgx dgz:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] ddb dha:f64[1] = squeeze[dimensions=(1,)] dgz dhb:f64[1] = mul dgy dha dhc:f64[1] = sub dgw dhb dhd:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dcv dhe:f64[1] = squeeze[dimensions=(1,)] dhd dhf:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ddb dhg:f64[1] = squeeze[dimensions=(1,)] dhf dhh:f64[1] = mul dhe dhg dhi:f64[1] = add dhc dhh dhj:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] der dhk:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dfo dhl:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dgl dhm:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dhi dhn:f64[1,4] = concatenate[dimension=1] dhj dhk dhl dhm dho:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] dcw dhp:f64[1,3] = squeeze[dimensions=(1,)] dho dhq:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dhn dhr:f64[1] = squeeze[dimensions=(1,)] dhq dhs:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] dhn dht:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] dhs dhp dhu:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dht dhv:f64[1,3] = mul dhu dhs dhw:f64[1,3] = mul 2.0 dhv dhx:f64[1] = mul dhr dhr dhy:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] dhs dhs dhz:f64[1] = sub dhx dhy dia:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dhz dib:f64[1,3] = mul dia dhp dic:f64[1,3] = add dhw dib did:f64[1] = mul 2.0 dhr die:f64[1,3] = pjit[name=cross jaxpr=cross] dhs dhp dif:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] did dig:f64[1,3] = mul dif die dih:f64[1,3] = add dic dig dii:f64[1,3] = add dih ddu dij:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] dcx dik:f64[1,3] = squeeze[dimensions=(1,)] dij dil:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dhn dim:f64[1] = squeeze[dimensions=(1,)] dil din:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] dhn dio:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] din dik dip:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dio diq:f64[1,3] = mul dip din dir:f64[1,3] = mul 2.0 diq dis:f64[1] = mul dim dim dit:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] din din diu:f64[1] = sub dis dit div:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] diu diw:f64[1,3] = mul div dik dix:f64[1,3] = add dir diw diy:f64[1] = mul 2.0 dim diz:f64[1,3] = pjit[name=cross jaxpr=cross] din dik dja:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] diy djb:f64[1,3] = mul dja diz djc:f64[1,3] = add dix djb djd:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dcy dje:f64[1] = squeeze[dimensions=(1,)] djd djf:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dcz djg:f64[1] = squeeze[dimensions=(1,)] djf djh:f64[1] = sub dje djg dji:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] dcx djj:f64[1,3] = squeeze[dimensions=(1,)] dji djk:f64[1] = mul djh 0.5 djl:f64[1] = sin djk djm:f64[1] = mul djh 0.5 djn:f64[1] = cos djm djo:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] djl djp:f64[1,3] = mul djj djo djq:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] djn djr:i64[1] = reshape[dimensions=None new_sizes=(1,)] 0 djs:i64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] djr djt:i64[] = squeeze[dimensions=(0,)] djs dju:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] djt djv:f64[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] 0.0 djw:bool[1] = lt dju 0 djx:i64[1] = add dju 3 djy:i64[1] = pjit[name=_where jaxpr=_where] djw djx dju djz:i64[1] = pjit[name=clip jaxpr=clip] djy 0 3 dka:i64[1] = pjit[name=argsort jaxpr=argsort] djz dkb:i64[1] = iota[dimension=0 dtype=int64 shape=(1,)] dkc:bool[1] = lt dka 0 dkd:i64[1] = add dka 1 dke:i64[1] = select_n dkc dka dkd dkf:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] dke dkg:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dkf dkh:i64[1] = convert_element_type[new_dtype=int64 weak_type=False] djz dki:i64[1] = scatter-add[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=jaxpr ] dkh dkg dkb dkj:i64[1] = convert_element_type[new_dtype=int64 weak_type=True] dki dkk:bool[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] True dkl:bool[1] = lt dkj 0 dkm:i64[1] = add dkj 4 dkn:i64[1] = select_n dkl dkj dkm dko:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] dkn dkp:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dko dkq:bool[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] False dkr:bool[4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] dkk dkp dkq dks:i64[4] = pjit[name=_cumulative_reduction jaxpr=_cumulative_reduction] dkr dkt:i64[3] = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] 0 dku:i64[4] = pjit[name=clip jaxpr=clip1] dks 0 dkv:i64[] = device_put[devices=[None] srcs=[None]] 1 dkw:bool[4] = lt dku 0 dkx:i64[4] = add dku 3 dky:i64[4] = select_n dkw dku dkx dkz:i32[4] = convert_element_type[new_dtype=int32 weak_type=False] dky dla:i32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4, 1)] dkz dlb:i64[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] dkv dlc:i64[3] = scatter-add[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=jaxpr ] dkt dla dlb dld:i64[3] = pjit[name=_cumulative_reduction jaxpr=_cumulative_reduction1] dlc dle:i64[3] = pjit[name=floor_divide jaxpr=floor_divide] dld 1 dlf:i64[3] = pjit[name=remainder jaxpr=remainder] dle 4 dlg:bool[1] = lt dkj 0 dlh:i64[1] = add dkj 4 dli:i64[1] = select_n dlg dkj dlh dlj:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] dli dlk:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dlj dll:f64[1,4] = broadcast_in_dim[ broadcast_dimensions=(np.int64(1),) shape=(1, 4) ] djv dlm:f64[1,4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(np.int64(1),), scatter_dims_to_operand_dims=(np.int64(1),)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] dll dlk djq dln:bool[3] = lt dlf 0 dlo:i64[3] = add dlf 4 dlp:i64[3] = select_n dln dlf dlo dlq:i32[3] = convert_element_type[new_dtype=int32 weak_type=False] dlp dlr:i32[3,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(3, 1)] dlq dls:f64[1,4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(np.int64(1),), scatter_dims_to_operand_dims=(np.int64(1),)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] dlm dlr djp dlt:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dhn dlu:f64[1] = squeeze[dimensions=(1,)] dlt dlv:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dls dlw:f64[1] = squeeze[dimensions=(1,)] dlv dlx:f64[1] = mul dlu dlw dly:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dhn dlz:f64[1] = squeeze[dimensions=(1,)] dly dma:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dls dmb:f64[1] = squeeze[dimensions=(1,)] dma dmc:f64[1] = mul dlz dmb dmd:f64[1] = sub dlx dmc dme:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dhn dmf:f64[1] = squeeze[dimensions=(1,)] dme dmg:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dls dmh:f64[1] = squeeze[dimensions=(1,)] dmg dmi:f64[1] = mul dmf dmh dmj:f64[1] = sub dmd dmi dmk:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dhn dml:f64[1] = squeeze[dimensions=(1,)] dmk dmm:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dls dmn:f64[1] = squeeze[dimensions=(1,)] dmm dmo:f64[1] = mul dml dmn dmp:f64[1] = sub dmj dmo dmq:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dhn dmr:f64[1] = squeeze[dimensions=(1,)] dmq dms:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dls dmt:f64[1] = squeeze[dimensions=(1,)] dms dmu:f64[1] = mul dmr dmt dmv:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dhn dmw:f64[1] = squeeze[dimensions=(1,)] dmv dmx:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dls dmy:f64[1] = squeeze[dimensions=(1,)] dmx dmz:f64[1] = mul dmw dmy dna:f64[1] = add dmu dmz dnb:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dhn dnc:f64[1] = squeeze[dimensions=(1,)] dnb dnd:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dls dne:f64[1] = squeeze[dimensions=(1,)] dnd dnf:f64[1] = mul dnc dne dng:f64[1] = add dna dnf dnh:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dhn dni:f64[1] = squeeze[dimensions=(1,)] dnh dnj:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dls dnk:f64[1] = squeeze[dimensions=(1,)] dnj dnl:f64[1] = mul dni dnk dnm:f64[1] = sub dng dnl dnn:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dhn dno:f64[1] = squeeze[dimensions=(1,)] dnn dnp:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dls dnq:f64[1] = squeeze[dimensions=(1,)] dnp dnr:f64[1] = mul dno dnq dns:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dhn dnt:f64[1] = squeeze[dimensions=(1,)] dns dnu:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dls dnv:f64[1] = squeeze[dimensions=(1,)] dnu dnw:f64[1] = mul dnt dnv dnx:f64[1] = sub dnr dnw dny:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dhn dnz:f64[1] = squeeze[dimensions=(1,)] dny doa:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dls dob:f64[1] = squeeze[dimensions=(1,)] doa doc:f64[1] = mul dnz dob dod:f64[1] = add dnx doc doe:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dhn dof:f64[1] = squeeze[dimensions=(1,)] doe dog:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dls doh:f64[1] = squeeze[dimensions=(1,)] dog doi:f64[1] = mul dof doh doj:f64[1] = add dod doi dok:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dhn dol:f64[1] = squeeze[dimensions=(1,)] dok dom:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dls don:f64[1] = squeeze[dimensions=(1,)] dom doo:f64[1] = mul dol don dop:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dhn doq:f64[1] = squeeze[dimensions=(1,)] dop dor:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dls dos:f64[1] = squeeze[dimensions=(1,)] dor dot:f64[1] = mul doq dos dou:f64[1] = add doo dot dov:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dhn dow:f64[1] = squeeze[dimensions=(1,)] dov dox:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dls doy:f64[1] = squeeze[dimensions=(1,)] dox doz:f64[1] = mul dow doy dpa:f64[1] = sub dou doz dpb:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dhn dpc:f64[1] = squeeze[dimensions=(1,)] dpb dpd:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dls dpe:f64[1] = squeeze[dimensions=(1,)] dpd dpf:f64[1] = mul dpc dpe dpg:f64[1] = add dpa dpf dph:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dmp dpi:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dnm dpj:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] doj dpk:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dpg dpl:f64[1,4] = concatenate[dimension=1] dph dpi dpj dpk dpm:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] dcw dpn:f64[1,3] = squeeze[dimensions=(1,)] dpm dpo:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dpl dpp:f64[1] = squeeze[dimensions=(1,)] dpo dpq:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] dpl dpr:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] dpq dpn dps:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dpr dpt:f64[1,3] = mul dps dpq dpu:f64[1,3] = mul 2.0 dpt dpv:f64[1] = mul dpp dpp dpw:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] dpq dpq dpx:f64[1] = sub dpv dpw dpy:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dpx dpz:f64[1,3] = mul dpy dpn dqa:f64[1,3] = add dpu dpz dqb:f64[1] = mul 2.0 dpp dqc:f64[1,3] = pjit[name=cross jaxpr=cross] dpq dpn dqd:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dqb dqe:f64[1,3] = mul dqd dqc dqf:f64[1,3] = add dqa dqe dqg:f64[1,3] = sub dii dqf dqh:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] dii dqi:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] djc dqj:f64[1,4,1] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(1)) shape=(1, 4, 1) ] dpl dqk:f64[1,1,4] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 4) ] dpl dql:f64[1,4,4] = mul dqj dqk dqm:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] dql dqn:f64[1] = squeeze[dimensions=(1, 2)] dqm dqo:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] dql dqp:f64[1] = squeeze[dimensions=(1, 2)] dqo dqq:f64[1] = add dqn dqp dqr:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] dql dqs:f64[1] = squeeze[dimensions=(1, 2)] dqr dqt:f64[1] = sub dqq dqs dqu:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] dql dqv:f64[1] = squeeze[dimensions=(1, 2)] dqu dqw:f64[1] = sub dqt dqv dqx:f64[1,1,1] = slice[ limit_indices=(1, 2, 3) start_indices=(0, 1, 2) strides=None ] dql dqy:f64[1] = squeeze[dimensions=(1, 2)] dqx dqz:f64[1,1,1] = slice[ limit_indices=(1, 1, 4) start_indices=(0, 0, 3) strides=None ] dql dra:f64[1] = squeeze[dimensions=(1, 2)] dqz drb:f64[1] = sub dqy dra drc:f64[1] = mul 2.0 drb drd:f64[1,1,1] = slice[ limit_indices=(1, 2, 4) start_indices=(0, 1, 3) strides=None ] dql dre:f64[1] = squeeze[dimensions=(1, 2)] drd drf:f64[1,1,1] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 2) strides=None ] dql drg:f64[1] = squeeze[dimensions=(1, 2)] drf drh:f64[1] = add dre drg dri:f64[1] = mul 2.0 drh drj:f64[1,1,1] = slice[ limit_indices=(1, 2, 3) start_indices=(0, 1, 2) strides=None ] dql drk:f64[1] = squeeze[dimensions=(1, 2)] drj drl:f64[1,1,1] = slice[ limit_indices=(1, 1, 4) start_indices=(0, 0, 3) strides=None ] dql drm:f64[1] = squeeze[dimensions=(1, 2)] drl drn:f64[1] = add drk drm dro:f64[1] = mul 2.0 drn drp:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] dql drq:f64[1] = squeeze[dimensions=(1, 2)] drp drr:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] dql drs:f64[1] = squeeze[dimensions=(1, 2)] drr drt:f64[1] = sub drq drs dru:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] dql drv:f64[1] = squeeze[dimensions=(1, 2)] dru drw:f64[1] = add drt drv drx:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] dql dry:f64[1] = squeeze[dimensions=(1, 2)] drx drz:f64[1] = sub drw dry dsa:f64[1,1,1] = slice[ limit_indices=(1, 3, 4) start_indices=(0, 2, 3) strides=None ] dql dsb:f64[1] = squeeze[dimensions=(1, 2)] dsa dsc:f64[1,1,1] = slice[ limit_indices=(1, 1, 2) start_indices=(0, 0, 1) strides=None ] dql dsd:f64[1] = squeeze[dimensions=(1, 2)] dsc dse:f64[1] = sub dsb dsd dsf:f64[1] = mul 2.0 dse dsg:f64[1,1,1] = slice[ limit_indices=(1, 2, 4) start_indices=(0, 1, 3) strides=None ] dql dsh:f64[1] = squeeze[dimensions=(1, 2)] dsg dsi:f64[1,1,1] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 2) strides=None ] dql dsj:f64[1] = squeeze[dimensions=(1, 2)] dsi dsk:f64[1] = sub dsh dsj dsl:f64[1] = mul 2.0 dsk dsm:f64[1,1,1] = slice[ limit_indices=(1, 3, 4) start_indices=(0, 2, 3) strides=None ] dql dsn:f64[1] = squeeze[dimensions=(1, 2)] dsm dso:f64[1,1,1] = slice[ limit_indices=(1, 1, 2) start_indices=(0, 0, 1) strides=None ] dql dsp:f64[1] = squeeze[dimensions=(1, 2)] dso dsq:f64[1] = add dsn dsp dsr:f64[1] = mul 2.0 dsq dss:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] dql dst:f64[1] = squeeze[dimensions=(1, 2)] dss dsu:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] dql dsv:f64[1] = squeeze[dimensions=(1, 2)] dsu dsw:f64[1] = sub dst dsv dsx:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] dql dsy:f64[1] = squeeze[dimensions=(1, 2)] dsx dsz:f64[1] = sub dsw dsy dta:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] dql dtb:f64[1] = squeeze[dimensions=(1, 2)] dta dtc:f64[1] = add dsz dtb dtd:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dqw dte:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] drc dtf:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dri dtg:f64[1,3] = concatenate[dimension=1] dtd dte dtf dth:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dro dti:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] drz dtj:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dsf dtk:f64[1,3] = concatenate[dimension=1] dth dti dtj dtl:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dsl dtm:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dsr dtn:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dtc dto:f64[1,3] = concatenate[dimension=1] dtl dtm dtn dtp:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] dtg dtq:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] dtk dtr:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] dto dts:f64[1,3,3] = concatenate[dimension=1] dtp dtq dtr _:f64[1,1] = pjit[name=_take jaxpr=_take9] dcy iv _:f64[1,1,3] = pjit[name=_take jaxpr=_take10] dqh iv _:f64[1,1,3] = pjit[name=_take jaxpr=_take10] dqi iv dtt:f64[1,3] = pjit[name=_take jaxpr=_take4] dqg iv dtu:f64[1,4] = pjit[name=_take jaxpr=_take5] dpl iv _:f64[1,3,3] = pjit[name=_take jaxpr=_take6] dts iv dtv:f64[1,1,3] = pjit[name=_take jaxpr=_take7] la iw dtw:f64[1,1,3] = pjit[name=_take jaxpr=_take7] lb ix dtx:f64[1,1] = pjit[name=_take jaxpr=_take8] nk iy dty:f64[1,1] = pjit[name=_take jaxpr=_take8] kn iz dtz:f64[1,3] = slice[limit_indices=(7, 3) start_indices=(6, 0) strides=None] kp dua:f64[1,4] = slice[limit_indices=(7, 4) start_indices=(6, 0) strides=None] kq dub:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dtu duc:f64[1] = squeeze[dimensions=(1,)] dub dud:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] dtu due:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] dud dtz duf:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] due dug:f64[1,3] = mul duf dud duh:f64[1,3] = mul 2.0 dug dui:f64[1] = mul duc duc duj:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] dud dud duk:f64[1] = sub dui duj dul:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] duk dum:f64[1,3] = mul dul dtz dun:f64[1,3] = add duh dum duo:f64[1] = mul 2.0 duc dup:f64[1,3] = pjit[name=cross jaxpr=cross] dud dtz duq:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] duo dur:f64[1,3] = mul duq dup dus:f64[1,3] = add dun dur dut:f64[1,3] = add dtt dus duu:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dtu duv:f64[1] = squeeze[dimensions=(1,)] duu duw:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dua dux:f64[1] = squeeze[dimensions=(1,)] duw duy:f64[1] = mul duv dux duz:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dtu dva:f64[1] = squeeze[dimensions=(1,)] duz dvb:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dua dvc:f64[1] = squeeze[dimensions=(1,)] dvb dvd:f64[1] = mul dva dvc dve:f64[1] = sub duy dvd dvf:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dtu dvg:f64[1] = squeeze[dimensions=(1,)] dvf dvh:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dua dvi:f64[1] = squeeze[dimensions=(1,)] dvh dvj:f64[1] = mul dvg dvi dvk:f64[1] = sub dve dvj dvl:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dtu dvm:f64[1] = squeeze[dimensions=(1,)] dvl dvn:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dua dvo:f64[1] = squeeze[dimensions=(1,)] dvn dvp:f64[1] = mul dvm dvo dvq:f64[1] = sub dvk dvp dvr:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dtu dvs:f64[1] = squeeze[dimensions=(1,)] dvr dvt:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dua dvu:f64[1] = squeeze[dimensions=(1,)] dvt dvv:f64[1] = mul dvs dvu dvw:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dtu dvx:f64[1] = squeeze[dimensions=(1,)] dvw dvy:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dua dvz:f64[1] = squeeze[dimensions=(1,)] dvy dwa:f64[1] = mul dvx dvz dwb:f64[1] = add dvv dwa dwc:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dtu dwd:f64[1] = squeeze[dimensions=(1,)] dwc dwe:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dua dwf:f64[1] = squeeze[dimensions=(1,)] dwe dwg:f64[1] = mul dwd dwf dwh:f64[1] = add dwb dwg dwi:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dtu dwj:f64[1] = squeeze[dimensions=(1,)] dwi dwk:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dua dwl:f64[1] = squeeze[dimensions=(1,)] dwk dwm:f64[1] = mul dwj dwl dwn:f64[1] = sub dwh dwm dwo:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dtu dwp:f64[1] = squeeze[dimensions=(1,)] dwo dwq:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dua dwr:f64[1] = squeeze[dimensions=(1,)] dwq dws:f64[1] = mul dwp dwr dwt:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dtu dwu:f64[1] = squeeze[dimensions=(1,)] dwt dwv:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dua dww:f64[1] = squeeze[dimensions=(1,)] dwv dwx:f64[1] = mul dwu dww dwy:f64[1] = sub dws dwx dwz:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dtu dxa:f64[1] = squeeze[dimensions=(1,)] dwz dxb:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dua dxc:f64[1] = squeeze[dimensions=(1,)] dxb dxd:f64[1] = mul dxa dxc dxe:f64[1] = add dwy dxd dxf:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dtu dxg:f64[1] = squeeze[dimensions=(1,)] dxf dxh:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dua dxi:f64[1] = squeeze[dimensions=(1,)] dxh dxj:f64[1] = mul dxg dxi dxk:f64[1] = add dxe dxj dxl:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dtu dxm:f64[1] = squeeze[dimensions=(1,)] dxl dxn:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dua dxo:f64[1] = squeeze[dimensions=(1,)] dxn dxp:f64[1] = mul dxm dxo dxq:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dtu dxr:f64[1] = squeeze[dimensions=(1,)] dxq dxs:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dua dxt:f64[1] = squeeze[dimensions=(1,)] dxs dxu:f64[1] = mul dxr dxt dxv:f64[1] = add dxp dxu dxw:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dtu dxx:f64[1] = squeeze[dimensions=(1,)] dxw dxy:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dua dxz:f64[1] = squeeze[dimensions=(1,)] dxy dya:f64[1] = mul dxx dxz dyb:f64[1] = sub dxv dya dyc:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dtu dyd:f64[1] = squeeze[dimensions=(1,)] dyc dye:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dua dyf:f64[1] = squeeze[dimensions=(1,)] dye dyg:f64[1] = mul dyd dyf dyh:f64[1] = add dyb dyg dyi:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dvq dyj:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dwn dyk:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dxk dyl:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dyh dym:f64[1,4] = concatenate[dimension=1] dyi dyj dyk dyl dyn:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] dtv dyo:f64[1,3] = squeeze[dimensions=(1,)] dyn dyp:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dym dyq:f64[1] = squeeze[dimensions=(1,)] dyp dyr:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] dym dys:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] dyr dyo dyt:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dys dyu:f64[1,3] = mul dyt dyr dyv:f64[1,3] = mul 2.0 dyu dyw:f64[1] = mul dyq dyq dyx:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] dyr dyr dyy:f64[1] = sub dyw dyx dyz:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dyy dza:f64[1,3] = mul dyz dyo dzb:f64[1,3] = add dyv dza dzc:f64[1] = mul 2.0 dyq dzd:f64[1,3] = pjit[name=cross jaxpr=cross] dyr dyo dze:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dzc dzf:f64[1,3] = mul dze dzd dzg:f64[1,3] = add dzb dzf dzh:f64[1,3] = add dzg dut dzi:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] dtw dzj:f64[1,3] = squeeze[dimensions=(1,)] dzi dzk:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dym dzl:f64[1] = squeeze[dimensions=(1,)] dzk dzm:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] dym dzn:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] dzm dzj dzo:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dzn dzp:f64[1,3] = mul dzo dzm dzq:f64[1,3] = mul 2.0 dzp dzr:f64[1] = mul dzl dzl dzs:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] dzm dzm dzt:f64[1] = sub dzr dzs dzu:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dzt dzv:f64[1,3] = mul dzu dzj dzw:f64[1,3] = add dzq dzv dzx:f64[1] = mul 2.0 dzl dzy:f64[1,3] = pjit[name=cross jaxpr=cross] dzm dzj dzz:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] dzx eaa:f64[1,3] = mul dzz dzy eab:f64[1,3] = add dzw eaa eac:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dtx ead:f64[1] = squeeze[dimensions=(1,)] eac eae:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dty eaf:f64[1] = squeeze[dimensions=(1,)] eae eag:f64[1] = sub ead eaf eah:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] dtw eai:f64[1,3] = squeeze[dimensions=(1,)] eah eaj:f64[1] = mul eag 0.5 eak:f64[1] = sin eaj eal:f64[1] = mul eag 0.5 eam:f64[1] = cos eal ean:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] eak eao:f64[1,3] = mul eai ean eap:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] eam eaq:i64[1] = reshape[dimensions=None new_sizes=(1,)] 0 ear:i64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] eaq eas:i64[] = squeeze[dimensions=(0,)] ear eat:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] eas eau:f64[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] 0.0 eav:bool[1] = lt eat 0 eaw:i64[1] = add eat 3 eax:i64[1] = pjit[name=_where jaxpr=_where] eav eaw eat eay:i64[1] = pjit[name=clip jaxpr=clip] eax 0 3 eaz:i64[1] = pjit[name=argsort jaxpr=argsort] eay eba:i64[1] = iota[dimension=0 dtype=int64 shape=(1,)] ebb:bool[1] = lt eaz 0 ebc:i64[1] = add eaz 1 ebd:i64[1] = select_n ebb eaz ebc ebe:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] ebd ebf:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ebe ebg:i64[1] = convert_element_type[new_dtype=int64 weak_type=False] eay ebh:i64[1] = scatter-add[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=jaxpr ] ebg ebf eba ebi:i64[1] = convert_element_type[new_dtype=int64 weak_type=True] ebh ebj:bool[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] True ebk:bool[1] = lt ebi 0 ebl:i64[1] = add ebi 4 ebm:i64[1] = select_n ebk ebi ebl ebn:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] ebm ebo:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ebn ebp:bool[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] False ebq:bool[4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] ebj ebo ebp ebr:i64[4] = pjit[name=_cumulative_reduction jaxpr=_cumulative_reduction] ebq ebs:i64[3] = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] 0 ebt:i64[4] = pjit[name=clip jaxpr=clip1] ebr 0 ebu:i64[] = device_put[devices=[None] srcs=[None]] 1 ebv:bool[4] = lt ebt 0 ebw:i64[4] = add ebt 3 ebx:i64[4] = select_n ebv ebt ebw eby:i32[4] = convert_element_type[new_dtype=int32 weak_type=False] ebx ebz:i32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4, 1)] eby eca:i64[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] ebu ecb:i64[3] = scatter-add[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=jaxpr ] ebs ebz eca ecc:i64[3] = pjit[name=_cumulative_reduction jaxpr=_cumulative_reduction1] ecb ecd:i64[3] = pjit[name=floor_divide jaxpr=floor_divide] ecc 1 ece:i64[3] = pjit[name=remainder jaxpr=remainder] ecd 4 ecf:bool[1] = lt ebi 0 ecg:i64[1] = add ebi 4 ech:i64[1] = select_n ecf ebi ecg eci:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] ech ecj:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] eci eck:f64[1,4] = broadcast_in_dim[ broadcast_dimensions=(np.int64(1),) shape=(1, 4) ] eau ecl:f64[1,4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(np.int64(1),), scatter_dims_to_operand_dims=(np.int64(1),)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] eck ecj eap ecm:bool[3] = lt ece 0 ecn:i64[3] = add ece 4 eco:i64[3] = select_n ecm ece ecn ecp:i32[3] = convert_element_type[new_dtype=int32 weak_type=False] eco ecq:i32[3,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(3, 1)] ecp ecr:f64[1,4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(np.int64(1),), scatter_dims_to_operand_dims=(np.int64(1),)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] ecl ecq eao ecs:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dym ect:f64[1] = squeeze[dimensions=(1,)] ecs ecu:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ecr ecv:f64[1] = squeeze[dimensions=(1,)] ecu ecw:f64[1] = mul ect ecv ecx:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dym ecy:f64[1] = squeeze[dimensions=(1,)] ecx ecz:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] ecr eda:f64[1] = squeeze[dimensions=(1,)] ecz edb:f64[1] = mul ecy eda edc:f64[1] = sub ecw edb edd:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dym ede:f64[1] = squeeze[dimensions=(1,)] edd edf:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] ecr edg:f64[1] = squeeze[dimensions=(1,)] edf edh:f64[1] = mul ede edg edi:f64[1] = sub edc edh edj:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dym edk:f64[1] = squeeze[dimensions=(1,)] edj edl:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] ecr edm:f64[1] = squeeze[dimensions=(1,)] edl edn:f64[1] = mul edk edm edo:f64[1] = sub edi edn edp:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dym edq:f64[1] = squeeze[dimensions=(1,)] edp edr:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] ecr eds:f64[1] = squeeze[dimensions=(1,)] edr edt:f64[1] = mul edq eds edu:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dym edv:f64[1] = squeeze[dimensions=(1,)] edu edw:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ecr edx:f64[1] = squeeze[dimensions=(1,)] edw edy:f64[1] = mul edv edx edz:f64[1] = add edt edy eea:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dym eeb:f64[1] = squeeze[dimensions=(1,)] eea eec:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] ecr eed:f64[1] = squeeze[dimensions=(1,)] eec eee:f64[1] = mul eeb eed eef:f64[1] = add edz eee eeg:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dym eeh:f64[1] = squeeze[dimensions=(1,)] eeg eei:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] ecr eej:f64[1] = squeeze[dimensions=(1,)] eei eek:f64[1] = mul eeh eej eel:f64[1] = sub eef eek eem:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dym een:f64[1] = squeeze[dimensions=(1,)] eem eeo:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] ecr eep:f64[1] = squeeze[dimensions=(1,)] eeo eeq:f64[1] = mul een eep eer:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dym ees:f64[1] = squeeze[dimensions=(1,)] eer eet:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] ecr eeu:f64[1] = squeeze[dimensions=(1,)] eet eev:f64[1] = mul ees eeu eew:f64[1] = sub eeq eev eex:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dym eey:f64[1] = squeeze[dimensions=(1,)] eex eez:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ecr efa:f64[1] = squeeze[dimensions=(1,)] eez efb:f64[1] = mul eey efa efc:f64[1] = add eew efb efd:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dym efe:f64[1] = squeeze[dimensions=(1,)] efd eff:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] ecr efg:f64[1] = squeeze[dimensions=(1,)] eff efh:f64[1] = mul efe efg efi:f64[1] = add efc efh efj:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] dym efk:f64[1] = squeeze[dimensions=(1,)] efj efl:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] ecr efm:f64[1] = squeeze[dimensions=(1,)] efl efn:f64[1] = mul efk efm efo:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] dym efp:f64[1] = squeeze[dimensions=(1,)] efo efq:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] ecr efr:f64[1] = squeeze[dimensions=(1,)] efq efs:f64[1] = mul efp efr eft:f64[1] = add efn efs efu:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] dym efv:f64[1] = squeeze[dimensions=(1,)] efu efw:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] ecr efx:f64[1] = squeeze[dimensions=(1,)] efw efy:f64[1] = mul efv efx efz:f64[1] = sub eft efy ega:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] dym egb:f64[1] = squeeze[dimensions=(1,)] ega egc:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ecr egd:f64[1] = squeeze[dimensions=(1,)] egc ege:f64[1] = mul egb egd egf:f64[1] = add efz ege egg:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] edo egh:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] eel egi:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] efi egj:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] egf egk:f64[1,4] = concatenate[dimension=1] egg egh egi egj egl:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] dtv egm:f64[1,3] = squeeze[dimensions=(1,)] egl egn:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] egk ego:f64[1] = squeeze[dimensions=(1,)] egn egp:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] egk egq:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] egp egm egr:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] egq egs:f64[1,3] = mul egr egp egt:f64[1,3] = mul 2.0 egs egu:f64[1] = mul ego ego egv:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] egp egp egw:f64[1] = sub egu egv egx:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] egw egy:f64[1,3] = mul egx egm egz:f64[1,3] = add egt egy eha:f64[1] = mul 2.0 ego ehb:f64[1,3] = pjit[name=cross jaxpr=cross] egp egm ehc:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] eha ehd:f64[1,3] = mul ehc ehb ehe:f64[1,3] = add egz ehd ehf:f64[1,3] = sub dzh ehe ehg:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] dzh ehh:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] eab ehi:f64[1,4,1] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(1)) shape=(1, 4, 1) ] egk ehj:f64[1,1,4] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 4) ] egk ehk:f64[1,4,4] = mul ehi ehj ehl:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] ehk ehm:f64[1] = squeeze[dimensions=(1, 2)] ehl ehn:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] ehk eho:f64[1] = squeeze[dimensions=(1, 2)] ehn ehp:f64[1] = add ehm eho ehq:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] ehk ehr:f64[1] = squeeze[dimensions=(1, 2)] ehq ehs:f64[1] = sub ehp ehr eht:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] ehk ehu:f64[1] = squeeze[dimensions=(1, 2)] eht ehv:f64[1] = sub ehs ehu ehw:f64[1,1,1] = slice[ limit_indices=(1, 2, 3) start_indices=(0, 1, 2) strides=None ] ehk ehx:f64[1] = squeeze[dimensions=(1, 2)] ehw ehy:f64[1,1,1] = slice[ limit_indices=(1, 1, 4) start_indices=(0, 0, 3) strides=None ] ehk ehz:f64[1] = squeeze[dimensions=(1, 2)] ehy eia:f64[1] = sub ehx ehz eib:f64[1] = mul 2.0 eia eic:f64[1,1,1] = slice[ limit_indices=(1, 2, 4) start_indices=(0, 1, 3) strides=None ] ehk eid:f64[1] = squeeze[dimensions=(1, 2)] eic eie:f64[1,1,1] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 2) strides=None ] ehk eif:f64[1] = squeeze[dimensions=(1, 2)] eie eig:f64[1] = add eid eif eih:f64[1] = mul 2.0 eig eii:f64[1,1,1] = slice[ limit_indices=(1, 2, 3) start_indices=(0, 1, 2) strides=None ] ehk eij:f64[1] = squeeze[dimensions=(1, 2)] eii eik:f64[1,1,1] = slice[ limit_indices=(1, 1, 4) start_indices=(0, 0, 3) strides=None ] ehk eil:f64[1] = squeeze[dimensions=(1, 2)] eik eim:f64[1] = add eij eil ein:f64[1] = mul 2.0 eim eio:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] ehk eip:f64[1] = squeeze[dimensions=(1, 2)] eio eiq:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] ehk eir:f64[1] = squeeze[dimensions=(1, 2)] eiq eis:f64[1] = sub eip eir eit:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] ehk eiu:f64[1] = squeeze[dimensions=(1, 2)] eit eiv:f64[1] = add eis eiu eiw:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] ehk eix:f64[1] = squeeze[dimensions=(1, 2)] eiw eiy:f64[1] = sub eiv eix eiz:f64[1,1,1] = slice[ limit_indices=(1, 3, 4) start_indices=(0, 2, 3) strides=None ] ehk eja:f64[1] = squeeze[dimensions=(1, 2)] eiz ejb:f64[1,1,1] = slice[ limit_indices=(1, 1, 2) start_indices=(0, 0, 1) strides=None ] ehk ejc:f64[1] = squeeze[dimensions=(1, 2)] ejb ejd:f64[1] = sub eja ejc eje:f64[1] = mul 2.0 ejd ejf:f64[1,1,1] = slice[ limit_indices=(1, 2, 4) start_indices=(0, 1, 3) strides=None ] ehk ejg:f64[1] = squeeze[dimensions=(1, 2)] ejf ejh:f64[1,1,1] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 2) strides=None ] ehk eji:f64[1] = squeeze[dimensions=(1, 2)] ejh ejj:f64[1] = sub ejg eji ejk:f64[1] = mul 2.0 ejj ejl:f64[1,1,1] = slice[ limit_indices=(1, 3, 4) start_indices=(0, 2, 3) strides=None ] ehk ejm:f64[1] = squeeze[dimensions=(1, 2)] ejl ejn:f64[1,1,1] = slice[ limit_indices=(1, 1, 2) start_indices=(0, 0, 1) strides=None ] ehk ejo:f64[1] = squeeze[dimensions=(1, 2)] ejn ejp:f64[1] = add ejm ejo ejq:f64[1] = mul 2.0 ejp ejr:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] ehk ejs:f64[1] = squeeze[dimensions=(1, 2)] ejr ejt:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] ehk eju:f64[1] = squeeze[dimensions=(1, 2)] ejt ejv:f64[1] = sub ejs eju ejw:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] ehk ejx:f64[1] = squeeze[dimensions=(1, 2)] ejw ejy:f64[1] = sub ejv ejx ejz:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] ehk eka:f64[1] = squeeze[dimensions=(1, 2)] ejz ekb:f64[1] = add ejy eka ekc:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ehv ekd:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] eib eke:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] eih ekf:f64[1,3] = concatenate[dimension=1] ekc ekd eke ekg:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ein ekh:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] eiy eki:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] eje ekj:f64[1,3] = concatenate[dimension=1] ekg ekh eki ekk:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ejk ekl:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ejq ekm:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ekb ekn:f64[1,3] = concatenate[dimension=1] ekk ekl ekm eko:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] ekf ekp:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] ekj ekq:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] ekn ekr:f64[1,3,3] = concatenate[dimension=1] eko ekp ekq _:f64[1,1] = pjit[name=_take jaxpr=_take9] dtx ja _:f64[1,1,3] = pjit[name=_take jaxpr=_take10] ehg ja _:f64[1,1,3] = pjit[name=_take jaxpr=_take10] ehh ja eks:f64[1,3] = pjit[name=_take jaxpr=_take4] ehf ja ekt:f64[1,4] = pjit[name=_take jaxpr=_take5] egk ja _:f64[1,3,3] = pjit[name=_take jaxpr=_take6] ekr ja eku:f64[1,1,3] = pjit[name=_take jaxpr=_take7] la jb ekv:f64[1,1,3] = pjit[name=_take jaxpr=_take7] lb jc ekw:f64[1,1] = pjit[name=_take jaxpr=_take8] nk jd ekx:f64[1,1] = pjit[name=_take jaxpr=_take8] kn je eky:f64[1,3] = slice[limit_indices=(8, 3) start_indices=(7, 0) strides=None] kp ekz:f64[1,4] = slice[limit_indices=(8, 4) start_indices=(7, 0) strides=None] kq ela:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ekt elb:f64[1] = squeeze[dimensions=(1,)] ela elc:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] ekt eld:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] elc eky ele:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] eld elf:f64[1,3] = mul ele elc elg:f64[1,3] = mul 2.0 elf elh:f64[1] = mul elb elb eli:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] elc elc elj:f64[1] = sub elh eli elk:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] elj ell:f64[1,3] = mul elk eky elm:f64[1,3] = add elg ell eln:f64[1] = mul 2.0 elb elo:f64[1,3] = pjit[name=cross jaxpr=cross] elc eky elp:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] eln elq:f64[1,3] = mul elp elo elr:f64[1,3] = add elm elq els:f64[1,3] = add eks elr elt:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ekt elu:f64[1] = squeeze[dimensions=(1,)] elt elv:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ekz elw:f64[1] = squeeze[dimensions=(1,)] elv elx:f64[1] = mul elu elw ely:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] ekt elz:f64[1] = squeeze[dimensions=(1,)] ely ema:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] ekz emb:f64[1] = squeeze[dimensions=(1,)] ema emc:f64[1] = mul elz emb emd:f64[1] = sub elx emc eme:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] ekt emf:f64[1] = squeeze[dimensions=(1,)] eme emg:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] ekz emh:f64[1] = squeeze[dimensions=(1,)] emg emi:f64[1] = mul emf emh emj:f64[1] = sub emd emi emk:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] ekt eml:f64[1] = squeeze[dimensions=(1,)] emk emm:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] ekz emn:f64[1] = squeeze[dimensions=(1,)] emm emo:f64[1] = mul eml emn emp:f64[1] = sub emj emo emq:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ekt emr:f64[1] = squeeze[dimensions=(1,)] emq ems:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] ekz emt:f64[1] = squeeze[dimensions=(1,)] ems emu:f64[1] = mul emr emt emv:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] ekt emw:f64[1] = squeeze[dimensions=(1,)] emv emx:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ekz emy:f64[1] = squeeze[dimensions=(1,)] emx emz:f64[1] = mul emw emy ena:f64[1] = add emu emz enb:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] ekt enc:f64[1] = squeeze[dimensions=(1,)] enb end:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] ekz ene:f64[1] = squeeze[dimensions=(1,)] end enf:f64[1] = mul enc ene eng:f64[1] = add ena enf enh:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] ekt eni:f64[1] = squeeze[dimensions=(1,)] enh enj:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] ekz enk:f64[1] = squeeze[dimensions=(1,)] enj enl:f64[1] = mul eni enk enm:f64[1] = sub eng enl enn:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ekt eno:f64[1] = squeeze[dimensions=(1,)] enn enp:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] ekz enq:f64[1] = squeeze[dimensions=(1,)] enp enr:f64[1] = mul eno enq ens:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] ekt ent:f64[1] = squeeze[dimensions=(1,)] ens enu:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] ekz env:f64[1] = squeeze[dimensions=(1,)] enu enw:f64[1] = mul ent env enx:f64[1] = sub enr enw eny:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] ekt enz:f64[1] = squeeze[dimensions=(1,)] eny eoa:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ekz eob:f64[1] = squeeze[dimensions=(1,)] eoa eoc:f64[1] = mul enz eob eod:f64[1] = add enx eoc eoe:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] ekt eof:f64[1] = squeeze[dimensions=(1,)] eoe eog:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] ekz eoh:f64[1] = squeeze[dimensions=(1,)] eog eoi:f64[1] = mul eof eoh eoj:f64[1] = add eod eoi eok:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ekt eol:f64[1] = squeeze[dimensions=(1,)] eok eom:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] ekz eon:f64[1] = squeeze[dimensions=(1,)] eom eoo:f64[1] = mul eol eon eop:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] ekt eoq:f64[1] = squeeze[dimensions=(1,)] eop eor:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] ekz eos:f64[1] = squeeze[dimensions=(1,)] eor eot:f64[1] = mul eoq eos eou:f64[1] = add eoo eot eov:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] ekt eow:f64[1] = squeeze[dimensions=(1,)] eov eox:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] ekz eoy:f64[1] = squeeze[dimensions=(1,)] eox eoz:f64[1] = mul eow eoy epa:f64[1] = sub eou eoz epb:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] ekt epc:f64[1] = squeeze[dimensions=(1,)] epb epd:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ekz epe:f64[1] = squeeze[dimensions=(1,)] epd epf:f64[1] = mul epc epe epg:f64[1] = add epa epf eph:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] emp epi:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] enm epj:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] eoj epk:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] epg epl:f64[1,4] = concatenate[dimension=1] eph epi epj epk epm:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] eku epn:f64[1,3] = squeeze[dimensions=(1,)] epm epo:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] epl epp:f64[1] = squeeze[dimensions=(1,)] epo epq:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] epl epr:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] epq epn eps:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] epr ept:f64[1,3] = mul eps epq epu:f64[1,3] = mul 2.0 ept epv:f64[1] = mul epp epp epw:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] epq epq epx:f64[1] = sub epv epw epy:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] epx epz:f64[1,3] = mul epy epn eqa:f64[1,3] = add epu epz eqb:f64[1] = mul 2.0 epp eqc:f64[1,3] = pjit[name=cross jaxpr=cross] epq epn eqd:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] eqb eqe:f64[1,3] = mul eqd eqc eqf:f64[1,3] = add eqa eqe eqg:f64[1,3] = add eqf els eqh:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] ekv eqi:f64[1,3] = squeeze[dimensions=(1,)] eqh eqj:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] epl eqk:f64[1] = squeeze[dimensions=(1,)] eqj eql:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] epl eqm:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] eql eqi eqn:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] eqm eqo:f64[1,3] = mul eqn eql eqp:f64[1,3] = mul 2.0 eqo eqq:f64[1] = mul eqk eqk eqr:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] eql eql eqs:f64[1] = sub eqq eqr eqt:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] eqs equ:f64[1,3] = mul eqt eqi eqv:f64[1,3] = add eqp equ eqw:f64[1] = mul 2.0 eqk eqx:f64[1,3] = pjit[name=cross jaxpr=cross] eql eqi eqy:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] eqw eqz:f64[1,3] = mul eqy eqx era:f64[1,3] = add eqv eqz erb:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ekw erc:f64[1] = squeeze[dimensions=(1,)] erb erd:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ekx ere:f64[1] = squeeze[dimensions=(1,)] erd erf:f64[1] = sub erc ere erg:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] ekv erh:f64[1,3] = squeeze[dimensions=(1,)] erg eri:f64[1] = mul erf 0.5 erj:f64[1] = sin eri erk:f64[1] = mul erf 0.5 erl:f64[1] = cos erk erm:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] erj ern:f64[1,3] = mul erh erm ero:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] erl erp:i64[1] = reshape[dimensions=None new_sizes=(1,)] 0 erq:i64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] erp err:i64[] = squeeze[dimensions=(0,)] erq ers:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] err ert:f64[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] 0.0 eru:bool[1] = lt ers 0 erv:i64[1] = add ers 3 erw:i64[1] = pjit[name=_where jaxpr=_where] eru erv ers erx:i64[1] = pjit[name=clip jaxpr=clip] erw 0 3 ery:i64[1] = pjit[name=argsort jaxpr=argsort] erx erz:i64[1] = iota[dimension=0 dtype=int64 shape=(1,)] esa:bool[1] = lt ery 0 esb:i64[1] = add ery 1 esc:i64[1] = select_n esa ery esb esd:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] esc ese:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] esd esf:i64[1] = convert_element_type[new_dtype=int64 weak_type=False] erx esg:i64[1] = scatter-add[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=jaxpr ] esf ese erz esh:i64[1] = convert_element_type[new_dtype=int64 weak_type=True] esg esi:bool[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] True esj:bool[1] = lt esh 0 esk:i64[1] = add esh 4 esl:i64[1] = select_n esj esh esk esm:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] esl esn:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] esm eso:bool[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] False esp:bool[4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] esi esn eso esq:i64[4] = pjit[name=_cumulative_reduction jaxpr=_cumulative_reduction] esp esr:i64[3] = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] 0 ess:i64[4] = pjit[name=clip jaxpr=clip1] esq 0 est:i64[] = device_put[devices=[None] srcs=[None]] 1 esu:bool[4] = lt ess 0 esv:i64[4] = add ess 3 esw:i64[4] = select_n esu ess esv esx:i32[4] = convert_element_type[new_dtype=int32 weak_type=False] esw esy:i32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4, 1)] esx esz:i64[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] est eta:i64[3] = scatter-add[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=jaxpr ] esr esy esz etb:i64[3] = pjit[name=_cumulative_reduction jaxpr=_cumulative_reduction1] eta etc:i64[3] = pjit[name=floor_divide jaxpr=floor_divide] etb 1 etd:i64[3] = pjit[name=remainder jaxpr=remainder] etc 4 ete:bool[1] = lt esh 0 etf:i64[1] = add esh 4 etg:i64[1] = select_n ete esh etf eth:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] etg eti:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] eth etj:f64[1,4] = broadcast_in_dim[ broadcast_dimensions=(np.int64(1),) shape=(1, 4) ] ert etk:f64[1,4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(np.int64(1),), scatter_dims_to_operand_dims=(np.int64(1),)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] etj eti ero etl:bool[3] = lt etd 0 etm:i64[3] = add etd 4 etn:i64[3] = select_n etl etd etm eto:i32[3] = convert_element_type[new_dtype=int32 weak_type=False] etn etp:i32[3,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(3, 1)] eto etq:f64[1,4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(np.int64(1),), scatter_dims_to_operand_dims=(np.int64(1),)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] etk etp ern etr:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] epl ets:f64[1] = squeeze[dimensions=(1,)] etr ett:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] etq etu:f64[1] = squeeze[dimensions=(1,)] ett etv:f64[1] = mul ets etu etw:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] epl etx:f64[1] = squeeze[dimensions=(1,)] etw ety:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] etq etz:f64[1] = squeeze[dimensions=(1,)] ety eua:f64[1] = mul etx etz eub:f64[1] = sub etv eua euc:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] epl eud:f64[1] = squeeze[dimensions=(1,)] euc eue:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] etq euf:f64[1] = squeeze[dimensions=(1,)] eue eug:f64[1] = mul eud euf euh:f64[1] = sub eub eug eui:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] epl euj:f64[1] = squeeze[dimensions=(1,)] eui euk:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] etq eul:f64[1] = squeeze[dimensions=(1,)] euk eum:f64[1] = mul euj eul eun:f64[1] = sub euh eum euo:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] epl eup:f64[1] = squeeze[dimensions=(1,)] euo euq:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] etq eur:f64[1] = squeeze[dimensions=(1,)] euq eus:f64[1] = mul eup eur eut:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] epl euu:f64[1] = squeeze[dimensions=(1,)] eut euv:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] etq euw:f64[1] = squeeze[dimensions=(1,)] euv eux:f64[1] = mul euu euw euy:f64[1] = add eus eux euz:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] epl eva:f64[1] = squeeze[dimensions=(1,)] euz evb:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] etq evc:f64[1] = squeeze[dimensions=(1,)] evb evd:f64[1] = mul eva evc eve:f64[1] = add euy evd evf:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] epl evg:f64[1] = squeeze[dimensions=(1,)] evf evh:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] etq evi:f64[1] = squeeze[dimensions=(1,)] evh evj:f64[1] = mul evg evi evk:f64[1] = sub eve evj evl:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] epl evm:f64[1] = squeeze[dimensions=(1,)] evl evn:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] etq evo:f64[1] = squeeze[dimensions=(1,)] evn evp:f64[1] = mul evm evo evq:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] epl evr:f64[1] = squeeze[dimensions=(1,)] evq evs:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] etq evt:f64[1] = squeeze[dimensions=(1,)] evs evu:f64[1] = mul evr evt evv:f64[1] = sub evp evu evw:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] epl evx:f64[1] = squeeze[dimensions=(1,)] evw evy:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] etq evz:f64[1] = squeeze[dimensions=(1,)] evy ewa:f64[1] = mul evx evz ewb:f64[1] = add evv ewa ewc:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] epl ewd:f64[1] = squeeze[dimensions=(1,)] ewc ewe:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] etq ewf:f64[1] = squeeze[dimensions=(1,)] ewe ewg:f64[1] = mul ewd ewf ewh:f64[1] = add ewb ewg ewi:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] epl ewj:f64[1] = squeeze[dimensions=(1,)] ewi ewk:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] etq ewl:f64[1] = squeeze[dimensions=(1,)] ewk ewm:f64[1] = mul ewj ewl ewn:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] epl ewo:f64[1] = squeeze[dimensions=(1,)] ewn ewp:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] etq ewq:f64[1] = squeeze[dimensions=(1,)] ewp ewr:f64[1] = mul ewo ewq ews:f64[1] = add ewm ewr ewt:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] epl ewu:f64[1] = squeeze[dimensions=(1,)] ewt ewv:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] etq eww:f64[1] = squeeze[dimensions=(1,)] ewv ewx:f64[1] = mul ewu eww ewy:f64[1] = sub ews ewx ewz:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] epl exa:f64[1] = squeeze[dimensions=(1,)] ewz exb:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] etq exc:f64[1] = squeeze[dimensions=(1,)] exb exd:f64[1] = mul exa exc exe:f64[1] = add ewy exd exf:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] eun exg:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] evk exh:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ewh exi:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] exe exj:f64[1,4] = concatenate[dimension=1] exf exg exh exi exk:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] eku exl:f64[1,3] = squeeze[dimensions=(1,)] exk exm:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] exj exn:f64[1] = squeeze[dimensions=(1,)] exm exo:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] exj exp:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] exo exl exq:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] exp exr:f64[1,3] = mul exq exo exs:f64[1,3] = mul 2.0 exr ext:f64[1] = mul exn exn exu:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] exo exo exv:f64[1] = sub ext exu exw:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] exv exx:f64[1,3] = mul exw exl exy:f64[1,3] = add exs exx exz:f64[1] = mul 2.0 exn eya:f64[1,3] = pjit[name=cross jaxpr=cross] exo exl eyb:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] exz eyc:f64[1,3] = mul eyb eya eyd:f64[1,3] = add exy eyc eye:f64[1,3] = sub eqg eyd eyf:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] eqg eyg:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] era eyh:f64[1,4,1] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(1)) shape=(1, 4, 1) ] exj eyi:f64[1,1,4] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 4) ] exj eyj:f64[1,4,4] = mul eyh eyi eyk:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] eyj eyl:f64[1] = squeeze[dimensions=(1, 2)] eyk eym:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] eyj eyn:f64[1] = squeeze[dimensions=(1, 2)] eym eyo:f64[1] = add eyl eyn eyp:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] eyj eyq:f64[1] = squeeze[dimensions=(1, 2)] eyp eyr:f64[1] = sub eyo eyq eys:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] eyj eyt:f64[1] = squeeze[dimensions=(1, 2)] eys eyu:f64[1] = sub eyr eyt eyv:f64[1,1,1] = slice[ limit_indices=(1, 2, 3) start_indices=(0, 1, 2) strides=None ] eyj eyw:f64[1] = squeeze[dimensions=(1, 2)] eyv eyx:f64[1,1,1] = slice[ limit_indices=(1, 1, 4) start_indices=(0, 0, 3) strides=None ] eyj eyy:f64[1] = squeeze[dimensions=(1, 2)] eyx eyz:f64[1] = sub eyw eyy eza:f64[1] = mul 2.0 eyz ezb:f64[1,1,1] = slice[ limit_indices=(1, 2, 4) start_indices=(0, 1, 3) strides=None ] eyj ezc:f64[1] = squeeze[dimensions=(1, 2)] ezb ezd:f64[1,1,1] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 2) strides=None ] eyj eze:f64[1] = squeeze[dimensions=(1, 2)] ezd ezf:f64[1] = add ezc eze ezg:f64[1] = mul 2.0 ezf ezh:f64[1,1,1] = slice[ limit_indices=(1, 2, 3) start_indices=(0, 1, 2) strides=None ] eyj ezi:f64[1] = squeeze[dimensions=(1, 2)] ezh ezj:f64[1,1,1] = slice[ limit_indices=(1, 1, 4) start_indices=(0, 0, 3) strides=None ] eyj ezk:f64[1] = squeeze[dimensions=(1, 2)] ezj ezl:f64[1] = add ezi ezk ezm:f64[1] = mul 2.0 ezl ezn:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] eyj ezo:f64[1] = squeeze[dimensions=(1, 2)] ezn ezp:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] eyj ezq:f64[1] = squeeze[dimensions=(1, 2)] ezp ezr:f64[1] = sub ezo ezq ezs:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] eyj ezt:f64[1] = squeeze[dimensions=(1, 2)] ezs ezu:f64[1] = add ezr ezt ezv:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] eyj ezw:f64[1] = squeeze[dimensions=(1, 2)] ezv ezx:f64[1] = sub ezu ezw ezy:f64[1,1,1] = slice[ limit_indices=(1, 3, 4) start_indices=(0, 2, 3) strides=None ] eyj ezz:f64[1] = squeeze[dimensions=(1, 2)] ezy faa:f64[1,1,1] = slice[ limit_indices=(1, 1, 2) start_indices=(0, 0, 1) strides=None ] eyj fab:f64[1] = squeeze[dimensions=(1, 2)] faa fac:f64[1] = sub ezz fab fad:f64[1] = mul 2.0 fac fae:f64[1,1,1] = slice[ limit_indices=(1, 2, 4) start_indices=(0, 1, 3) strides=None ] eyj faf:f64[1] = squeeze[dimensions=(1, 2)] fae fag:f64[1,1,1] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 2) strides=None ] eyj fah:f64[1] = squeeze[dimensions=(1, 2)] fag fai:f64[1] = sub faf fah faj:f64[1] = mul 2.0 fai fak:f64[1,1,1] = slice[ limit_indices=(1, 3, 4) start_indices=(0, 2, 3) strides=None ] eyj fal:f64[1] = squeeze[dimensions=(1, 2)] fak fam:f64[1,1,1] = slice[ limit_indices=(1, 1, 2) start_indices=(0, 0, 1) strides=None ] eyj fan:f64[1] = squeeze[dimensions=(1, 2)] fam fao:f64[1] = add fal fan fap:f64[1] = mul 2.0 fao faq:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] eyj far:f64[1] = squeeze[dimensions=(1, 2)] faq fas:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] eyj fat:f64[1] = squeeze[dimensions=(1, 2)] fas fau:f64[1] = sub far fat fav:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] eyj faw:f64[1] = squeeze[dimensions=(1, 2)] fav fax:f64[1] = sub fau faw fay:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] eyj faz:f64[1] = squeeze[dimensions=(1, 2)] fay fba:f64[1] = add fax faz fbb:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] eyu fbc:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] eza fbd:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ezg fbe:f64[1,3] = concatenate[dimension=1] fbb fbc fbd fbf:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ezm fbg:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ezx fbh:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fad fbi:f64[1,3] = concatenate[dimension=1] fbf fbg fbh fbj:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] faj fbk:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fap fbl:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fba fbm:f64[1,3] = concatenate[dimension=1] fbj fbk fbl fbn:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] fbe fbo:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] fbi fbp:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] fbm fbq:f64[1,3,3] = concatenate[dimension=1] fbn fbo fbp _:f64[1,1] = pjit[name=_take jaxpr=_take9] ekw jf _:f64[1,1,3] = pjit[name=_take jaxpr=_take10] eyf jf _:f64[1,1,3] = pjit[name=_take jaxpr=_take10] eyg jf fbr:f64[1,3] = pjit[name=_take jaxpr=_take4] eye jf fbs:f64[1,4] = pjit[name=_take jaxpr=_take5] exj jf _:f64[1,3,3] = pjit[name=_take jaxpr=_take6] fbq jf fbt:f64[1,1,3] = pjit[name=_take jaxpr=_take7] la jg fbu:f64[1,1,3] = pjit[name=_take jaxpr=_take7] lb jh fbv:f64[1,1] = pjit[name=_take jaxpr=_take8] nk ji fbw:f64[1,1] = pjit[name=_take jaxpr=_take8] kn jj fbx:f64[1,3] = slice[limit_indices=(9, 3) start_indices=(8, 0) strides=None] kp fby:f64[1,4] = slice[limit_indices=(9, 4) start_indices=(8, 0) strides=None] kq fbz:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fbs fca:f64[1] = squeeze[dimensions=(1,)] fbz fcb:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] fbs fcc:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] fcb fbx fcd:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fcc fce:f64[1,3] = mul fcd fcb fcf:f64[1,3] = mul 2.0 fce fcg:f64[1] = mul fca fca fch:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] fcb fcb fci:f64[1] = sub fcg fch fcj:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fci fck:f64[1,3] = mul fcj fbx fcl:f64[1,3] = add fcf fck fcm:f64[1] = mul 2.0 fca fcn:f64[1,3] = pjit[name=cross jaxpr=cross] fcb fbx fco:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fcm fcp:f64[1,3] = mul fco fcn fcq:f64[1,3] = add fcl fcp fcr:f64[1,3] = add fbr fcq fcs:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fbs fct:f64[1] = squeeze[dimensions=(1,)] fcs fcu:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fby fcv:f64[1] = squeeze[dimensions=(1,)] fcu fcw:f64[1] = mul fct fcv fcx:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] fbs fcy:f64[1] = squeeze[dimensions=(1,)] fcx fcz:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] fby fda:f64[1] = squeeze[dimensions=(1,)] fcz fdb:f64[1] = mul fcy fda fdc:f64[1] = sub fcw fdb fdd:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] fbs fde:f64[1] = squeeze[dimensions=(1,)] fdd fdf:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] fby fdg:f64[1] = squeeze[dimensions=(1,)] fdf fdh:f64[1] = mul fde fdg fdi:f64[1] = sub fdc fdh fdj:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] fbs fdk:f64[1] = squeeze[dimensions=(1,)] fdj fdl:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] fby fdm:f64[1] = squeeze[dimensions=(1,)] fdl fdn:f64[1] = mul fdk fdm fdo:f64[1] = sub fdi fdn fdp:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fbs fdq:f64[1] = squeeze[dimensions=(1,)] fdp fdr:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] fby fds:f64[1] = squeeze[dimensions=(1,)] fdr fdt:f64[1] = mul fdq fds fdu:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] fbs fdv:f64[1] = squeeze[dimensions=(1,)] fdu fdw:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fby fdx:f64[1] = squeeze[dimensions=(1,)] fdw fdy:f64[1] = mul fdv fdx fdz:f64[1] = add fdt fdy fea:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] fbs feb:f64[1] = squeeze[dimensions=(1,)] fea fec:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] fby fed:f64[1] = squeeze[dimensions=(1,)] fec fee:f64[1] = mul feb fed fef:f64[1] = add fdz fee feg:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] fbs feh:f64[1] = squeeze[dimensions=(1,)] feg fei:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] fby fej:f64[1] = squeeze[dimensions=(1,)] fei fek:f64[1] = mul feh fej fel:f64[1] = sub fef fek fem:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fbs fen:f64[1] = squeeze[dimensions=(1,)] fem feo:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] fby fep:f64[1] = squeeze[dimensions=(1,)] feo feq:f64[1] = mul fen fep fer:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] fbs fes:f64[1] = squeeze[dimensions=(1,)] fer fet:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] fby feu:f64[1] = squeeze[dimensions=(1,)] fet fev:f64[1] = mul fes feu few:f64[1] = sub feq fev fex:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] fbs fey:f64[1] = squeeze[dimensions=(1,)] fex fez:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fby ffa:f64[1] = squeeze[dimensions=(1,)] fez ffb:f64[1] = mul fey ffa ffc:f64[1] = add few ffb ffd:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] fbs ffe:f64[1] = squeeze[dimensions=(1,)] ffd fff:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] fby ffg:f64[1] = squeeze[dimensions=(1,)] fff ffh:f64[1] = mul ffe ffg ffi:f64[1] = add ffc ffh ffj:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fbs ffk:f64[1] = squeeze[dimensions=(1,)] ffj ffl:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] fby ffm:f64[1] = squeeze[dimensions=(1,)] ffl ffn:f64[1] = mul ffk ffm ffo:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] fbs ffp:f64[1] = squeeze[dimensions=(1,)] ffo ffq:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] fby ffr:f64[1] = squeeze[dimensions=(1,)] ffq ffs:f64[1] = mul ffp ffr fft:f64[1] = add ffn ffs ffu:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] fbs ffv:f64[1] = squeeze[dimensions=(1,)] ffu ffw:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] fby ffx:f64[1] = squeeze[dimensions=(1,)] ffw ffy:f64[1] = mul ffv ffx ffz:f64[1] = sub fft ffy fga:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] fbs fgb:f64[1] = squeeze[dimensions=(1,)] fga fgc:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fby fgd:f64[1] = squeeze[dimensions=(1,)] fgc fge:f64[1] = mul fgb fgd fgf:f64[1] = add ffz fge fgg:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fdo fgh:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fel fgi:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] ffi fgj:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fgf fgk:f64[1,4] = concatenate[dimension=1] fgg fgh fgi fgj fgl:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] fbt fgm:f64[1,3] = squeeze[dimensions=(1,)] fgl fgn:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fgk fgo:f64[1] = squeeze[dimensions=(1,)] fgn fgp:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] fgk fgq:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] fgp fgm fgr:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fgq fgs:f64[1,3] = mul fgr fgp fgt:f64[1,3] = mul 2.0 fgs fgu:f64[1] = mul fgo fgo fgv:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] fgp fgp fgw:f64[1] = sub fgu fgv fgx:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fgw fgy:f64[1,3] = mul fgx fgm fgz:f64[1,3] = add fgt fgy fha:f64[1] = mul 2.0 fgo fhb:f64[1,3] = pjit[name=cross jaxpr=cross] fgp fgm fhc:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fha fhd:f64[1,3] = mul fhc fhb fhe:f64[1,3] = add fgz fhd fhf:f64[1,3] = add fhe fcr fhg:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] fbu fhh:f64[1,3] = squeeze[dimensions=(1,)] fhg fhi:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fgk fhj:f64[1] = squeeze[dimensions=(1,)] fhi fhk:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] fgk fhl:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] fhk fhh fhm:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fhl fhn:f64[1,3] = mul fhm fhk fho:f64[1,3] = mul 2.0 fhn fhp:f64[1] = mul fhj fhj fhq:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] fhk fhk fhr:f64[1] = sub fhp fhq fhs:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fhr fht:f64[1,3] = mul fhs fhh fhu:f64[1,3] = add fho fht fhv:f64[1] = mul 2.0 fhj fhw:f64[1,3] = pjit[name=cross jaxpr=cross] fhk fhh fhx:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fhv fhy:f64[1,3] = mul fhx fhw fhz:f64[1,3] = add fhu fhy fia:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fbv fib:f64[1] = squeeze[dimensions=(1,)] fia fic:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fbw fid:f64[1] = squeeze[dimensions=(1,)] fic fie:f64[1] = sub fib fid fif:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] fbu fig:f64[1,3] = squeeze[dimensions=(1,)] fif fih:f64[1] = mul fie 0.5 fii:f64[1] = sin fih fij:f64[1] = mul fie 0.5 fik:f64[1] = cos fij fil:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fii fim:f64[1,3] = mul fig fil fin:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fik fio:i64[1] = reshape[dimensions=None new_sizes=(1,)] 0 fip:i64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] fio fiq:i64[] = squeeze[dimensions=(0,)] fip fir:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] fiq fis:f64[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] 0.0 fit:bool[1] = lt fir 0 fiu:i64[1] = add fir 3 fiv:i64[1] = pjit[name=_where jaxpr=_where] fit fiu fir fiw:i64[1] = pjit[name=clip jaxpr=clip] fiv 0 3 fix:i64[1] = pjit[name=argsort jaxpr=argsort] fiw fiy:i64[1] = iota[dimension=0 dtype=int64 shape=(1,)] fiz:bool[1] = lt fix 0 fja:i64[1] = add fix 1 fjb:i64[1] = select_n fiz fix fja fjc:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] fjb fjd:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fjc fje:i64[1] = convert_element_type[new_dtype=int64 weak_type=False] fiw fjf:i64[1] = scatter-add[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=jaxpr ] fje fjd fiy fjg:i64[1] = convert_element_type[new_dtype=int64 weak_type=True] fjf fjh:bool[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] True fji:bool[1] = lt fjg 0 fjj:i64[1] = add fjg 4 fjk:i64[1] = select_n fji fjg fjj fjl:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] fjk fjm:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fjl fjn:bool[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] False fjo:bool[4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] fjh fjm fjn fjp:i64[4] = pjit[name=_cumulative_reduction jaxpr=_cumulative_reduction] fjo fjq:i64[3] = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] 0 fjr:i64[4] = pjit[name=clip jaxpr=clip1] fjp 0 fjs:i64[] = device_put[devices=[None] srcs=[None]] 1 fjt:bool[4] = lt fjr 0 fju:i64[4] = add fjr 3 fjv:i64[4] = select_n fjt fjr fju fjw:i32[4] = convert_element_type[new_dtype=int32 weak_type=False] fjv fjx:i32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4, 1)] fjw fjy:i64[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] fjs fjz:i64[3] = scatter-add[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=jaxpr ] fjq fjx fjy fka:i64[3] = pjit[name=_cumulative_reduction jaxpr=_cumulative_reduction1] fjz fkb:i64[3] = pjit[name=floor_divide jaxpr=floor_divide] fka 1 fkc:i64[3] = pjit[name=remainder jaxpr=remainder] fkb 4 fkd:bool[1] = lt fjg 0 fke:i64[1] = add fjg 4 fkf:i64[1] = select_n fkd fjg fke fkg:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] fkf fkh:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fkg fki:f64[1,4] = broadcast_in_dim[ broadcast_dimensions=(np.int64(1),) shape=(1, 4) ] fis fkj:f64[1,4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(np.int64(1),), scatter_dims_to_operand_dims=(np.int64(1),)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] fki fkh fin fkk:bool[3] = lt fkc 0 fkl:i64[3] = add fkc 4 fkm:i64[3] = select_n fkk fkc fkl fkn:i32[3] = convert_element_type[new_dtype=int32 weak_type=False] fkm fko:i32[3,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(3, 1)] fkn fkp:f64[1,4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(np.int64(1),), scatter_dims_to_operand_dims=(np.int64(1),)) indices_are_sorted=False mode=GatherScatterMode.FILL_OR_DROP unique_indices=False update_consts=() update_jaxpr=None ] fkj fko fim fkq:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fgk fkr:f64[1] = squeeze[dimensions=(1,)] fkq fks:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fkp fkt:f64[1] = squeeze[dimensions=(1,)] fks fku:f64[1] = mul fkr fkt fkv:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] fgk fkw:f64[1] = squeeze[dimensions=(1,)] fkv fkx:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] fkp fky:f64[1] = squeeze[dimensions=(1,)] fkx fkz:f64[1] = mul fkw fky fla:f64[1] = sub fku fkz flb:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] fgk flc:f64[1] = squeeze[dimensions=(1,)] flb fld:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] fkp fle:f64[1] = squeeze[dimensions=(1,)] fld flf:f64[1] = mul flc fle flg:f64[1] = sub fla flf flh:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] fgk fli:f64[1] = squeeze[dimensions=(1,)] flh flj:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] fkp flk:f64[1] = squeeze[dimensions=(1,)] flj fll:f64[1] = mul fli flk flm:f64[1] = sub flg fll fln:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fgk flo:f64[1] = squeeze[dimensions=(1,)] fln flp:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] fkp flq:f64[1] = squeeze[dimensions=(1,)] flp flr:f64[1] = mul flo flq fls:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] fgk flt:f64[1] = squeeze[dimensions=(1,)] fls flu:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fkp flv:f64[1] = squeeze[dimensions=(1,)] flu flw:f64[1] = mul flt flv flx:f64[1] = add flr flw fly:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] fgk flz:f64[1] = squeeze[dimensions=(1,)] fly fma:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] fkp fmb:f64[1] = squeeze[dimensions=(1,)] fma fmc:f64[1] = mul flz fmb fmd:f64[1] = add flx fmc fme:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] fgk fmf:f64[1] = squeeze[dimensions=(1,)] fme fmg:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] fkp fmh:f64[1] = squeeze[dimensions=(1,)] fmg fmi:f64[1] = mul fmf fmh fmj:f64[1] = sub fmd fmi fmk:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fgk fml:f64[1] = squeeze[dimensions=(1,)] fmk fmm:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] fkp fmn:f64[1] = squeeze[dimensions=(1,)] fmm fmo:f64[1] = mul fml fmn fmp:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] fgk fmq:f64[1] = squeeze[dimensions=(1,)] fmp fmr:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] fkp fms:f64[1] = squeeze[dimensions=(1,)] fmr fmt:f64[1] = mul fmq fms fmu:f64[1] = sub fmo fmt fmv:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] fgk fmw:f64[1] = squeeze[dimensions=(1,)] fmv fmx:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fkp fmy:f64[1] = squeeze[dimensions=(1,)] fmx fmz:f64[1] = mul fmw fmy fna:f64[1] = add fmu fmz fnb:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] fgk fnc:f64[1] = squeeze[dimensions=(1,)] fnb fnd:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] fkp fne:f64[1] = squeeze[dimensions=(1,)] fnd fnf:f64[1] = mul fnc fne fng:f64[1] = add fna fnf fnh:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fgk fni:f64[1] = squeeze[dimensions=(1,)] fnh fnj:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] fkp fnk:f64[1] = squeeze[dimensions=(1,)] fnj fnl:f64[1] = mul fni fnk fnm:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] fgk fnn:f64[1] = squeeze[dimensions=(1,)] fnm fno:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] fkp fnp:f64[1] = squeeze[dimensions=(1,)] fno fnq:f64[1] = mul fnn fnp fnr:f64[1] = add fnl fnq fns:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] fgk fnt:f64[1] = squeeze[dimensions=(1,)] fns fnu:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] fkp fnv:f64[1] = squeeze[dimensions=(1,)] fnu fnw:f64[1] = mul fnt fnv fnx:f64[1] = sub fnr fnw fny:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] fgk fnz:f64[1] = squeeze[dimensions=(1,)] fny foa:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] fkp fob:f64[1] = squeeze[dimensions=(1,)] foa foc:f64[1] = mul fnz fob fod:f64[1] = add fnx foc foe:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] flm fof:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fmj fog:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fng foh:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fod foi:f64[1,4] = concatenate[dimension=1] foe fof fog foh foj:f64[1,1,3] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 0) strides=None ] fbt fok:f64[1,3] = squeeze[dimensions=(1,)] foj fol:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] foi fom:f64[1] = squeeze[dimensions=(1,)] fol fon:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] foi foo:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] fon fok fop:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] foo foq:f64[1,3] = mul fop fon for:f64[1,3] = mul 2.0 foq fos:f64[1] = mul fom fom fot:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] fon fon fou:f64[1] = sub fos fot fov:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fou fow:f64[1,3] = mul fov fok fox:f64[1,3] = add for fow foy:f64[1] = mul 2.0 fom foz:f64[1,3] = pjit[name=cross jaxpr=cross] fon fok fpa:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] foy fpb:f64[1,3] = mul fpa foz fpc:f64[1,3] = add fox fpb fpd:f64[1,3] = sub fhf fpc fpe:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] fhf fpf:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] fhz fpg:f64[1,4,1] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(1)) shape=(1, 4, 1) ] foi fph:f64[1,1,4] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 4) ] foi fpi:f64[1,4,4] = mul fpg fph fpj:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] fpi fpk:f64[1] = squeeze[dimensions=(1, 2)] fpj fpl:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] fpi fpm:f64[1] = squeeze[dimensions=(1, 2)] fpl fpn:f64[1] = add fpk fpm fpo:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] fpi fpp:f64[1] = squeeze[dimensions=(1, 2)] fpo fpq:f64[1] = sub fpn fpp fpr:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] fpi fps:f64[1] = squeeze[dimensions=(1, 2)] fpr fpt:f64[1] = sub fpq fps fpu:f64[1,1,1] = slice[ limit_indices=(1, 2, 3) start_indices=(0, 1, 2) strides=None ] fpi fpv:f64[1] = squeeze[dimensions=(1, 2)] fpu fpw:f64[1,1,1] = slice[ limit_indices=(1, 1, 4) start_indices=(0, 0, 3) strides=None ] fpi fpx:f64[1] = squeeze[dimensions=(1, 2)] fpw fpy:f64[1] = sub fpv fpx fpz:f64[1] = mul 2.0 fpy fqa:f64[1,1,1] = slice[ limit_indices=(1, 2, 4) start_indices=(0, 1, 3) strides=None ] fpi fqb:f64[1] = squeeze[dimensions=(1, 2)] fqa fqc:f64[1,1,1] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 2) strides=None ] fpi fqd:f64[1] = squeeze[dimensions=(1, 2)] fqc fqe:f64[1] = add fqb fqd fqf:f64[1] = mul 2.0 fqe fqg:f64[1,1,1] = slice[ limit_indices=(1, 2, 3) start_indices=(0, 1, 2) strides=None ] fpi fqh:f64[1] = squeeze[dimensions=(1, 2)] fqg fqi:f64[1,1,1] = slice[ limit_indices=(1, 1, 4) start_indices=(0, 0, 3) strides=None ] fpi fqj:f64[1] = squeeze[dimensions=(1, 2)] fqi fqk:f64[1] = add fqh fqj fql:f64[1] = mul 2.0 fqk fqm:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] fpi fqn:f64[1] = squeeze[dimensions=(1, 2)] fqm fqo:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] fpi fqp:f64[1] = squeeze[dimensions=(1, 2)] fqo fqq:f64[1] = sub fqn fqp fqr:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] fpi fqs:f64[1] = squeeze[dimensions=(1, 2)] fqr fqt:f64[1] = add fqq fqs fqu:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] fpi fqv:f64[1] = squeeze[dimensions=(1, 2)] fqu fqw:f64[1] = sub fqt fqv fqx:f64[1,1,1] = slice[ limit_indices=(1, 3, 4) start_indices=(0, 2, 3) strides=None ] fpi fqy:f64[1] = squeeze[dimensions=(1, 2)] fqx fqz:f64[1,1,1] = slice[ limit_indices=(1, 1, 2) start_indices=(0, 0, 1) strides=None ] fpi fra:f64[1] = squeeze[dimensions=(1, 2)] fqz frb:f64[1] = sub fqy fra frc:f64[1] = mul 2.0 frb frd:f64[1,1,1] = slice[ limit_indices=(1, 2, 4) start_indices=(0, 1, 3) strides=None ] fpi fre:f64[1] = squeeze[dimensions=(1, 2)] frd frf:f64[1,1,1] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 2) strides=None ] fpi frg:f64[1] = squeeze[dimensions=(1, 2)] frf frh:f64[1] = sub fre frg fri:f64[1] = mul 2.0 frh frj:f64[1,1,1] = slice[ limit_indices=(1, 3, 4) start_indices=(0, 2, 3) strides=None ] fpi frk:f64[1] = squeeze[dimensions=(1, 2)] frj frl:f64[1,1,1] = slice[ limit_indices=(1, 1, 2) start_indices=(0, 0, 1) strides=None ] fpi frm:f64[1] = squeeze[dimensions=(1, 2)] frl frn:f64[1] = add frk frm fro:f64[1] = mul 2.0 frn frp:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] fpi frq:f64[1] = squeeze[dimensions=(1, 2)] frp frr:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] fpi frs:f64[1] = squeeze[dimensions=(1, 2)] frr frt:f64[1] = sub frq frs fru:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] fpi frv:f64[1] = squeeze[dimensions=(1, 2)] fru frw:f64[1] = sub frt frv frx:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] fpi fry:f64[1] = squeeze[dimensions=(1, 2)] frx frz:f64[1] = add frw fry fsa:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fpt fsb:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fpz fsc:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fqf fsd:f64[1,3] = concatenate[dimension=1] fsa fsb fsc fse:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fql fsf:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fqw fsg:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] frc fsh:f64[1,3] = concatenate[dimension=1] fse fsf fsg fsi:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fri fsj:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] fro fsk:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] frz fsl:f64[1,3] = concatenate[dimension=1] fsi fsj fsk fsm:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] fsd fsn:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] fsh fso:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] fsl fsp:f64[1,3,3] = concatenate[dimension=1] fsm fsn fso fsq:f64[0] = reshape[dimensions=None new_sizes=(0,)] rr fsr:f64[0] = reshape[dimensions=None new_sizes=(0,)] vu fss:f64[1] = reshape[dimensions=None new_sizes=(1,)] beb fst:f64[1] = reshape[dimensions=None new_sizes=(1,)] bva fsu:f64[1] = reshape[dimensions=None new_sizes=(1,)] clz fsv:f64[1] = reshape[dimensions=None new_sizes=(1,)] dcy fsw:f64[1] = reshape[dimensions=None new_sizes=(1,)] dtx fsx:f64[1] = reshape[dimensions=None new_sizes=(1,)] ekw fsy:f64[1] = reshape[dimensions=None new_sizes=(1,)] fbv fsz:f64[7] = concatenate[dimension=0] fsq fsr fss fst fsu fsv fsw fsx fsy fta:f64[7] = pjit[ name=_take jaxpr={ lambda ; ftb:f64[7] ftc:i64[7]. let ftd:i64[7] = pjit[name=remainder jaxpr=remainder4] ftc 7 fte:i64[7,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(7, 1) ] ftd ftf:f64[7] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1,) unique_indices=False ] ftb fte in (ftf,) } ] fsz jk ftg:f64[0,3] = reshape[dimensions=None new_sizes=(0, 3)] vq fth:f64[0,3] = reshape[dimensions=None new_sizes=(0, 3)] bdv fti:f64[1,3] = reshape[dimensions=None new_sizes=(1, 3)] brk ftj:f64[1,3] = reshape[dimensions=None new_sizes=(1, 3)] cij ftk:f64[1,3] = reshape[dimensions=None new_sizes=(1, 3)] czi ftl:f64[1,3] = reshape[dimensions=None new_sizes=(1, 3)] dqh ftm:f64[1,3] = reshape[dimensions=None new_sizes=(1, 3)] ehg ftn:f64[1,3] = reshape[dimensions=None new_sizes=(1, 3)] eyf fto:f64[1,3] = reshape[dimensions=None new_sizes=(1, 3)] fpe ftp:f64[7,3] = concatenate[dimension=0] ftg fth fti ftj ftk ftl ftm ftn fto ftq:f64[7,3] = pjit[name=_take jaxpr=_take11] ftp jl ftr:f64[0,3] = reshape[dimensions=None new_sizes=(0, 3)] vr fts:f64[0,3] = reshape[dimensions=None new_sizes=(0, 3)] bdw ftt:f64[1,3] = reshape[dimensions=None new_sizes=(1, 3)] brl ftu:f64[1,3] = reshape[dimensions=None new_sizes=(1, 3)] cik ftv:f64[1,3] = reshape[dimensions=None new_sizes=(1, 3)] czj ftw:f64[1,3] = reshape[dimensions=None new_sizes=(1, 3)] dqi ftx:f64[1,3] = reshape[dimensions=None new_sizes=(1, 3)] ehh fty:f64[1,3] = reshape[dimensions=None new_sizes=(1, 3)] eyg ftz:f64[1,3] = reshape[dimensions=None new_sizes=(1, 3)] fpf fua:f64[7,3] = concatenate[dimension=0] ftr fts ftt ftu ftv ftw ftx fty ftz fub:f64[7,3] = pjit[name=_take jaxpr=_take11] fua jm fuc:f64[9,3] = concatenate[dimension=0] rs wp brj cii czh dqg ehf eye fpd fud:f64[9,3] = pjit[ name=_take jaxpr={ lambda ; fue:f64[9,3] fuf:i64[9]. let fug:i64[9] = pjit[name=remainder jaxpr=remainder5] fuf 9 fuh:i64[9,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(9, 1) ] fug fui:f64[9,3] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 3) unique_indices=False ] fue fuh in (fui,) } ] fuc jn fuj:f64[9,4] = concatenate[dimension=0] ry bai bqo chn cym dpl egk exj foi fuk:f64[9,4] = pjit[ name=_take jaxpr={ lambda ; ful:f64[9,4] fum:i64[9]. let fun:i64[9] = pjit[name=remainder jaxpr=remainder5] fum 9 fuo:i64[9,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(9, 1) ] fun fup:f64[9,4] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 4) unique_indices=False ] ful fuo in (fup,) } ] fuj jo fuq:f64[9,3,3] = concatenate[dimension=0] vp bdu buv clu dct dts ekr fbq fsp fur:f64[9,3,3] = pjit[ name=_take jaxpr={ lambda ; fus:f64[9,3,3] fut:i64[9]. let fuu:i64[9] = pjit[name=remainder jaxpr=remainder5] fut 9 fuv:i64[9,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(9, 1) ] fuu fuw:f64[9,3,3] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0,)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 3, 3) unique_indices=False ] fus fuv in (fuw,) } ] fuq jp fux:f64[9,1] = slice[limit_indices=(9, 1) start_indices=(0, 0) strides=None] fuk fuy:f64[9] = squeeze[dimensions=(1,)] fux fuz:f64[9,3] = slice[limit_indices=(9, 4) start_indices=(0, 1) strides=None] fuk fva:f64[9] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] fuz kr fvb:f64[9,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(9, 1)] fva fvc:f64[9,3] = mul fvb fuz fvd:f64[9,3] = mul 2.0 fvc fve:f64[9] = mul fuy fuy fvf:f64[9] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] fuz fuz fvg:f64[9] = sub fve fvf fvh:f64[9,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(9, 1)] fvg fvi:f64[9,3] = mul fvh kr fvj:f64[9,3] = add fvd fvi fvk:f64[9] = mul 2.0 fuy fvl:f64[9,3] = pjit[ name=cross jaxpr={ lambda ; fvm:f64[9,3] fvn:f64[9,3]. let fvo:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0 fvp:f64[9] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(np.int64(1),), start_index_map=(np.int64(1),)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(9, 1) unique_indices=True ] fvm fvo fvq:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1 fvr:f64[9] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(np.int64(1),), start_index_map=(np.int64(1),)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(9, 1) unique_indices=True ] fvm fvq fvs:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2 fvt:f64[9] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(np.int64(1),), start_index_map=(np.int64(1),)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(9, 1) unique_indices=True ] fvm fvs fvu:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0 fvv:f64[9] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(np.int64(1),), start_index_map=(np.int64(1),)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(9, 1) unique_indices=True ] fvn fvu fvw:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1 fvx:f64[9] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(np.int64(1),), start_index_map=(np.int64(1),)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(9, 1) unique_indices=True ] fvn fvw fvy:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2 fvz:f64[9] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(np.int64(1),), start_index_map=(np.int64(1),)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(9, 1) unique_indices=True ] fvn fvy fwa:f64[9] = mul fvr fvz fwb:f64[9] = mul fvt fvx fwc:f64[9] = sub fwa fwb fwd:f64[9] = mul fvt fvv fwe:f64[9] = mul fvp fvz fwf:f64[9] = sub fwd fwe fwg:f64[9] = mul fvp fvx fwh:f64[9] = mul fvr fvv fwi:f64[9] = sub fwg fwh fwj:f64[9,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(9, 1) ] fwc fwk:f64[9,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(9, 1) ] fwf fwl:f64[9,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(9, 1) ] fwi fwm:f64[9,3] = concatenate[dimension=1] fwj fwk fwl in (fwm,) } ] fuz kr fwn:f64[9,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(9, 1)] fvk fwo:f64[9,3] = mul fwn fvl fwp:f64[9,3] = add fvj fwo fwq:f64[9,3] = add fud fwp fwr:f64[9,1] = slice[limit_indices=(9, 1) start_indices=(0, 0) strides=None] fuk fws:f64[9] = squeeze[dimensions=(1,)] fwr fwt:f64[9,1] = slice[limit_indices=(9, 1) start_indices=(0, 0) strides=None] ks fwu:f64[9] = squeeze[dimensions=(1,)] fwt fwv:f64[9] = mul fws fwu fww:f64[9,1] = slice[limit_indices=(9, 2) start_indices=(0, 1) strides=None] fuk fwx:f64[9] = squeeze[dimensions=(1,)] fww fwy:f64[9,1] = slice[limit_indices=(9, 2) start_indices=(0, 1) strides=None] ks fwz:f64[9] = squeeze[dimensions=(1,)] fwy fxa:f64[9] = mul fwx fwz fxb:f64[9] = sub fwv fxa fxc:f64[9,1] = slice[limit_indices=(9, 3) start_indices=(0, 2) strides=None] fuk fxd:f64[9] = squeeze[dimensions=(1,)] fxc fxe:f64[9,1] = slice[limit_indices=(9, 3) start_indices=(0, 2) strides=None] ks fxf:f64[9] = squeeze[dimensions=(1,)] fxe fxg:f64[9] = mul fxd fxf fxh:f64[9] = sub fxb fxg fxi:f64[9,1] = slice[limit_indices=(9, 4) start_indices=(0, 3) strides=None] fuk fxj:f64[9] = squeeze[dimensions=(1,)] fxi fxk:f64[9,1] = slice[limit_indices=(9, 4) start_indices=(0, 3) strides=None] ks fxl:f64[9] = squeeze[dimensions=(1,)] fxk fxm:f64[9] = mul fxj fxl fxn:f64[9] = sub fxh fxm fxo:f64[9,1] = slice[limit_indices=(9, 1) start_indices=(0, 0) strides=None] fuk fxp:f64[9] = squeeze[dimensions=(1,)] fxo fxq:f64[9,1] = slice[limit_indices=(9, 2) start_indices=(0, 1) strides=None] ks fxr:f64[9] = squeeze[dimensions=(1,)] fxq fxs:f64[9] = mul fxp fxr fxt:f64[9,1] = slice[limit_indices=(9, 2) start_indices=(0, 1) strides=None] fuk fxu:f64[9] = squeeze[dimensions=(1,)] fxt fxv:f64[9,1] = slice[limit_indices=(9, 1) start_indices=(0, 0) strides=None] ks fxw:f64[9] = squeeze[dimensions=(1,)] fxv fxx:f64[9] = mul fxu fxw fxy:f64[9] = add fxs fxx fxz:f64[9,1] = slice[limit_indices=(9, 3) start_indices=(0, 2) strides=None] fuk fya:f64[9] = squeeze[dimensions=(1,)] fxz fyb:f64[9,1] = slice[limit_indices=(9, 4) start_indices=(0, 3) strides=None] ks fyc:f64[9] = squeeze[dimensions=(1,)] fyb fyd:f64[9] = mul fya fyc fye:f64[9] = add fxy fyd fyf:f64[9,1] = slice[limit_indices=(9, 4) start_indices=(0, 3) strides=None] fuk fyg:f64[9] = squeeze[dimensions=(1,)] fyf fyh:f64[9,1] = slice[limit_indices=(9, 3) start_indices=(0, 2) strides=None] ks fyi:f64[9] = squeeze[dimensions=(1,)] fyh fyj:f64[9] = mul fyg fyi fyk:f64[9] = sub fye fyj fyl:f64[9,1] = slice[limit_indices=(9, 1) start_indices=(0, 0) strides=None] fuk fym:f64[9] = squeeze[dimensions=(1,)] fyl fyn:f64[9,1] = slice[limit_indices=(9, 3) start_indices=(0, 2) strides=None] ks fyo:f64[9] = squeeze[dimensions=(1,)] fyn fyp:f64[9] = mul fym fyo fyq:f64[9,1] = slice[limit_indices=(9, 2) start_indices=(0, 1) strides=None] fuk fyr:f64[9] = squeeze[dimensions=(1,)] fyq fys:f64[9,1] = slice[limit_indices=(9, 4) start_indices=(0, 3) strides=None] ks fyt:f64[9] = squeeze[dimensions=(1,)] fys fyu:f64[9] = mul fyr fyt fyv:f64[9] = sub fyp fyu fyw:f64[9,1] = slice[limit_indices=(9, 3) start_indices=(0, 2) strides=None] fuk fyx:f64[9] = squeeze[dimensions=(1,)] fyw fyy:f64[9,1] = slice[limit_indices=(9, 1) start_indices=(0, 0) strides=None] ks fyz:f64[9] = squeeze[dimensions=(1,)] fyy fza:f64[9] = mul fyx fyz fzb:f64[9] = add fyv fza fzc:f64[9,1] = slice[limit_indices=(9, 4) start_indices=(0, 3) strides=None] fuk fzd:f64[9] = squeeze[dimensions=(1,)] fzc fze:f64[9,1] = slice[limit_indices=(9, 2) start_indices=(0, 1) strides=None] ks fzf:f64[9] = squeeze[dimensions=(1,)] fze fzg:f64[9] = mul fzd fzf fzh:f64[9] = add fzb fzg fzi:f64[9,1] = slice[limit_indices=(9, 1) start_indices=(0, 0) strides=None] fuk fzj:f64[9] = squeeze[dimensions=(1,)] fzi fzk:f64[9,1] = slice[limit_indices=(9, 4) start_indices=(0, 3) strides=None] ks fzl:f64[9] = squeeze[dimensions=(1,)] fzk fzm:f64[9] = mul fzj fzl fzn:f64[9,1] = slice[limit_indices=(9, 2) start_indices=(0, 1) strides=None] fuk fzo:f64[9] = squeeze[dimensions=(1,)] fzn fzp:f64[9,1] = slice[limit_indices=(9, 3) start_indices=(0, 2) strides=None] ks fzq:f64[9] = squeeze[dimensions=(1,)] fzp fzr:f64[9] = mul fzo fzq fzs:f64[9] = add fzm fzr fzt:f64[9,1] = slice[limit_indices=(9, 3) start_indices=(0, 2) strides=None] fuk fzu:f64[9] = squeeze[dimensions=(1,)] fzt fzv:f64[9,1] = slice[limit_indices=(9, 2) start_indices=(0, 1) strides=None] ks fzw:f64[9] = squeeze[dimensions=(1,)] fzv fzx:f64[9] = mul fzu fzw fzy:f64[9] = sub fzs fzx fzz:f64[9,1] = slice[limit_indices=(9, 4) start_indices=(0, 3) strides=None] fuk gaa:f64[9] = squeeze[dimensions=(1,)] fzz gab:f64[9,1] = slice[limit_indices=(9, 1) start_indices=(0, 0) strides=None] ks gac:f64[9] = squeeze[dimensions=(1,)] gab gad:f64[9] = mul gaa gac gae:f64[9] = add fzy gad gaf:f64[9,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(9, 1)] fxn gag:f64[9,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(9, 1)] fyk gah:f64[9,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(9, 1)] fzh gai:f64[9,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(9, 1)] gae gaj:f64[9,4] = concatenate[dimension=1] gaf gag gah gai gak:f64[9,4,1] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(1)) shape=(9, 4, 1) ] gaj gal:f64[9,1,4] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(9, 1, 4) ] gaj gam:f64[9,4,4] = mul gak gal gan:f64[9,1,1] = slice[ limit_indices=(9, 1, 1) start_indices=(0, 0, 0) strides=None ] gam gao:f64[9] = squeeze[dimensions=(1, 2)] gan gap:f64[9,1,1] = slice[ limit_indices=(9, 2, 2) start_indices=(0, 1, 1) strides=None ] gam gaq:f64[9] = squeeze[dimensions=(1, 2)] gap gar:f64[9] = add gao gaq gas:f64[9,1,1] = slice[ limit_indices=(9, 3, 3) start_indices=(0, 2, 2) strides=None ] gam gat:f64[9] = squeeze[dimensions=(1, 2)] gas gau:f64[9] = sub gar gat gav:f64[9,1,1] = slice[ limit_indices=(9, 4, 4) start_indices=(0, 3, 3) strides=None ] gam gaw:f64[9] = squeeze[dimensions=(1, 2)] gav gax:f64[9] = sub gau gaw gay:f64[9,1,1] = slice[ limit_indices=(9, 2, 3) start_indices=(0, 1, 2) strides=None ] gam gaz:f64[9] = squeeze[dimensions=(1, 2)] gay gba:f64[9,1,1] = slice[ limit_indices=(9, 1, 4) start_indices=(0, 0, 3) strides=None ] gam gbb:f64[9] = squeeze[dimensions=(1, 2)] gba gbc:f64[9] = sub gaz gbb gbd:f64[9] = mul 2.0 gbc gbe:f64[9,1,1] = slice[ limit_indices=(9, 2, 4) start_indices=(0, 1, 3) strides=None ] gam gbf:f64[9] = squeeze[dimensions=(1, 2)] gbe gbg:f64[9,1,1] = slice[ limit_indices=(9, 1, 3) start_indices=(0, 0, 2) strides=None ] gam gbh:f64[9] = squeeze[dimensions=(1, 2)] gbg gbi:f64[9] = add gbf gbh gbj:f64[9] = mul 2.0 gbi gbk:f64[9,1,1] = slice[ limit_indices=(9, 2, 3) start_indices=(0, 1, 2) strides=None ] gam gbl:f64[9] = squeeze[dimensions=(1, 2)] gbk gbm:f64[9,1,1] = slice[ limit_indices=(9, 1, 4) start_indices=(0, 0, 3) strides=None ] gam gbn:f64[9] = squeeze[dimensions=(1, 2)] gbm gbo:f64[9] = add gbl gbn gbp:f64[9] = mul 2.0 gbo gbq:f64[9,1,1] = slice[ limit_indices=(9, 1, 1) start_indices=(0, 0, 0) strides=None ] gam gbr:f64[9] = squeeze[dimensions=(1, 2)] gbq gbs:f64[9,1,1] = slice[ limit_indices=(9, 2, 2) start_indices=(0, 1, 1) strides=None ] gam gbt:f64[9] = squeeze[dimensions=(1, 2)] gbs gbu:f64[9] = sub gbr gbt gbv:f64[9,1,1] = slice[ limit_indices=(9, 3, 3) start_indices=(0, 2, 2) strides=None ] gam gbw:f64[9] = squeeze[dimensions=(1, 2)] gbv gbx:f64[9] = add gbu gbw gby:f64[9,1,1] = slice[ limit_indices=(9, 4, 4) start_indices=(0, 3, 3) strides=None ] gam gbz:f64[9] = squeeze[dimensions=(1, 2)] gby gca:f64[9] = sub gbx gbz gcb:f64[9,1,1] = slice[ limit_indices=(9, 3, 4) start_indices=(0, 2, 3) strides=None ] gam gcc:f64[9] = squeeze[dimensions=(1, 2)] gcb gcd:f64[9,1,1] = slice[ limit_indices=(9, 1, 2) start_indices=(0, 0, 1) strides=None ] gam gce:f64[9] = squeeze[dimensions=(1, 2)] gcd gcf:f64[9] = sub gcc gce gcg:f64[9] = mul 2.0 gcf gch:f64[9,1,1] = slice[ limit_indices=(9, 2, 4) start_indices=(0, 1, 3) strides=None ] gam gci:f64[9] = squeeze[dimensions=(1, 2)] gch gcj:f64[9,1,1] = slice[ limit_indices=(9, 1, 3) start_indices=(0, 0, 2) strides=None ] gam gck:f64[9] = squeeze[dimensions=(1, 2)] gcj gcl:f64[9] = sub gci gck gcm:f64[9] = mul 2.0 gcl gcn:f64[9,1,1] = slice[ limit_indices=(9, 3, 4) start_indices=(0, 2, 3) strides=None ] gam gco:f64[9] = squeeze[dimensions=(1, 2)] gcn gcp:f64[9,1,1] = slice[ limit_indices=(9, 1, 2) start_indices=(0, 0, 1) strides=None ] gam gcq:f64[9] = squeeze[dimensions=(1, 2)] gcp gcr:f64[9] = add gco gcq gcs:f64[9] = mul 2.0 gcr gct:f64[9,1,1] = slice[ limit_indices=(9, 1, 1) start_indices=(0, 0, 0) strides=None ] gam gcu:f64[9] = squeeze[dimensions=(1, 2)] gct gcv:f64[9,1,1] = slice[ limit_indices=(9, 2, 2) start_indices=(0, 1, 1) strides=None ] gam gcw:f64[9] = squeeze[dimensions=(1, 2)] gcv gcx:f64[9] = sub gcu gcw gcy:f64[9,1,1] = slice[ limit_indices=(9, 3, 3) start_indices=(0, 2, 2) strides=None ] gam gcz:f64[9] = squeeze[dimensions=(1, 2)] gcy gda:f64[9] = sub gcx gcz gdb:f64[9,1,1] = slice[ limit_indices=(9, 4, 4) start_indices=(0, 3, 3) strides=None ] gam gdc:f64[9] = squeeze[dimensions=(1, 2)] gdb gdd:f64[9] = add gda gdc gde:f64[9,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(9, 1)] gax gdf:f64[9,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(9, 1)] gbd gdg:f64[9,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(9, 1)] gbj gdh:f64[9,3] = concatenate[dimension=1] gde gdf gdg gdi:f64[9,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(9, 1)] gbp gdj:f64[9,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(9, 1)] gca gdk:f64[9,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(9, 1)] gcg gdl:f64[9,3] = concatenate[dimension=1] gdi gdj gdk gdm:f64[9,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(9, 1)] gcm gdn:f64[9,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(9, 1)] gcs gdo:f64[9,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(9, 1)] gdd gdp:f64[9,3] = concatenate[dimension=1] gdm gdn gdo gdq:f64[9,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(9, 1, 3) ] gdh gdr:f64[9,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(9, 1, 3) ] gdl gds:f64[9,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(9, 1, 3) ] gdp gdt:f64[9,3,3] = concatenate[dimension=1] gdq gdr gds gdu:i32[61] = device_put[devices=[None] srcs=[None]] jq gdv:bool[61] = lt gdu 0 gdw:i32[61] = add gdu 9 gdx:i32[61] = select_n gdv gdu gdw gdy:i32[61,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(61, 1)] gdx gdz:f64[61,3] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 3) unique_indices=False ] fud gdy gea:i32[61] = device_put[devices=[None] srcs=[None]] jq geb:bool[61] = lt gea 0 gec:i32[61] = add gea 9 ged:i32[61] = select_n geb gea gec gee:i32[61,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(61, 1)] ged gef:f64[61,4] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 4) unique_indices=False ] fuk gee geg:f64[61,1] = slice[ limit_indices=(61, 1) start_indices=(0, 0) strides=None ] gef geh:f64[61] = squeeze[dimensions=(1,)] geg gei:f64[61,3] = slice[ limit_indices=(61, 4) start_indices=(0, 1) strides=None ] gef gej:f64[61] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] gei ls gek:f64[61,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(61, 1)] gej gel:f64[61,3] = mul gek gei gem:f64[61,3] = mul 2.0 gel gen:f64[61] = mul geh geh geo:f64[61] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] gei gei gep:f64[61] = sub gen geo geq:f64[61,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(61, 1)] gep ger:f64[61,3] = mul geq ls ges:f64[61,3] = add gem ger get:f64[61] = mul 2.0 geh geu:f64[61,3] = pjit[ name=cross jaxpr={ lambda ; gev:f64[61,3] gew:f64[61,3]. let gex:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0 gey:f64[61] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(np.int64(1),), start_index_map=(np.int64(1),)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(61, 1) unique_indices=True ] gev gex gez:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1 gfa:f64[61] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(np.int64(1),), start_index_map=(np.int64(1),)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(61, 1) unique_indices=True ] gev gez gfb:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2 gfc:f64[61] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(np.int64(1),), start_index_map=(np.int64(1),)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(61, 1) unique_indices=True ] gev gfb gfd:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0 gfe:f64[61] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(np.int64(1),), start_index_map=(np.int64(1),)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(61, 1) unique_indices=True ] gew gfd gff:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1 gfg:f64[61] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(np.int64(1),), start_index_map=(np.int64(1),)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(61, 1) unique_indices=True ] gew gff gfh:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2 gfi:f64[61] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(np.int64(1),), start_index_map=(np.int64(1),)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(61, 1) unique_indices=True ] gew gfh gfj:f64[61] = mul gfa gfi gfk:f64[61] = mul gfc gfg gfl:f64[61] = sub gfj gfk gfm:f64[61] = mul gfc gfe gfn:f64[61] = mul gey gfi gfo:f64[61] = sub gfm gfn gfp:f64[61] = mul gey gfg gfq:f64[61] = mul gfa gfe gfr:f64[61] = sub gfp gfq gfs:f64[61,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(61, 1) ] gfl gft:f64[61,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(61, 1) ] gfo gfu:f64[61,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(61, 1) ] gfr gfv:f64[61,3] = concatenate[dimension=1] gfs gft gfu in (gfv,) } ] gei ls gfw:f64[61,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(61, 1)] get gfx:f64[61,3] = mul gfw geu gfy:f64[61,3] = add ges gfx gfz:f64[61,3] = add gdz gfy gga:f64[61,1] = slice[ limit_indices=(61, 1) start_indices=(0, 0) strides=None ] gef ggb:f64[61] = squeeze[dimensions=(1,)] gga ggc:f64[61,1] = slice[ limit_indices=(61, 1) start_indices=(0, 0) strides=None ] lt ggd:f64[61] = squeeze[dimensions=(1,)] ggc gge:f64[61] = mul ggb ggd ggf:f64[61,1] = slice[ limit_indices=(61, 2) start_indices=(0, 1) strides=None ] gef ggg:f64[61] = squeeze[dimensions=(1,)] ggf ggh:f64[61,1] = slice[ limit_indices=(61, 2) start_indices=(0, 1) strides=None ] lt ggi:f64[61] = squeeze[dimensions=(1,)] ggh ggj:f64[61] = mul ggg ggi ggk:f64[61] = sub gge ggj ggl:f64[61,1] = slice[ limit_indices=(61, 3) start_indices=(0, 2) strides=None ] gef ggm:f64[61] = squeeze[dimensions=(1,)] ggl ggn:f64[61,1] = slice[ limit_indices=(61, 3) start_indices=(0, 2) strides=None ] lt ggo:f64[61] = squeeze[dimensions=(1,)] ggn ggp:f64[61] = mul ggm ggo ggq:f64[61] = sub ggk ggp ggr:f64[61,1] = slice[ limit_indices=(61, 4) start_indices=(0, 3) strides=None ] gef ggs:f64[61] = squeeze[dimensions=(1,)] ggr ggt:f64[61,1] = slice[ limit_indices=(61, 4) start_indices=(0, 3) strides=None ] lt ggu:f64[61] = squeeze[dimensions=(1,)] ggt ggv:f64[61] = mul ggs ggu ggw:f64[61] = sub ggq ggv ggx:f64[61,1] = slice[ limit_indices=(61, 1) start_indices=(0, 0) strides=None ] gef ggy:f64[61] = squeeze[dimensions=(1,)] ggx ggz:f64[61,1] = slice[ limit_indices=(61, 2) start_indices=(0, 1) strides=None ] lt gha:f64[61] = squeeze[dimensions=(1,)] ggz ghb:f64[61] = mul ggy gha ghc:f64[61,1] = slice[ limit_indices=(61, 2) start_indices=(0, 1) strides=None ] gef ghd:f64[61] = squeeze[dimensions=(1,)] ghc ghe:f64[61,1] = slice[ limit_indices=(61, 1) start_indices=(0, 0) strides=None ] lt ghf:f64[61] = squeeze[dimensions=(1,)] ghe ghg:f64[61] = mul ghd ghf ghh:f64[61] = add ghb ghg ghi:f64[61,1] = slice[ limit_indices=(61, 3) start_indices=(0, 2) strides=None ] gef ghj:f64[61] = squeeze[dimensions=(1,)] ghi ghk:f64[61,1] = slice[ limit_indices=(61, 4) start_indices=(0, 3) strides=None ] lt ghl:f64[61] = squeeze[dimensions=(1,)] ghk ghm:f64[61] = mul ghj ghl ghn:f64[61] = add ghh ghm gho:f64[61,1] = slice[ limit_indices=(61, 4) start_indices=(0, 3) strides=None ] gef ghp:f64[61] = squeeze[dimensions=(1,)] gho ghq:f64[61,1] = slice[ limit_indices=(61, 3) start_indices=(0, 2) strides=None ] lt ghr:f64[61] = squeeze[dimensions=(1,)] ghq ghs:f64[61] = mul ghp ghr ght:f64[61] = sub ghn ghs ghu:f64[61,1] = slice[ limit_indices=(61, 1) start_indices=(0, 0) strides=None ] gef ghv:f64[61] = squeeze[dimensions=(1,)] ghu ghw:f64[61,1] = slice[ limit_indices=(61, 3) start_indices=(0, 2) strides=None ] lt ghx:f64[61] = squeeze[dimensions=(1,)] ghw ghy:f64[61] = mul ghv ghx ghz:f64[61,1] = slice[ limit_indices=(61, 2) start_indices=(0, 1) strides=None ] gef gia:f64[61] = squeeze[dimensions=(1,)] ghz gib:f64[61,1] = slice[ limit_indices=(61, 4) start_indices=(0, 3) strides=None ] lt gic:f64[61] = squeeze[dimensions=(1,)] gib gid:f64[61] = mul gia gic gie:f64[61] = sub ghy gid gif:f64[61,1] = slice[ limit_indices=(61, 3) start_indices=(0, 2) strides=None ] gef gig:f64[61] = squeeze[dimensions=(1,)] gif gih:f64[61,1] = slice[ limit_indices=(61, 1) start_indices=(0, 0) strides=None ] lt gii:f64[61] = squeeze[dimensions=(1,)] gih gij:f64[61] = mul gig gii gik:f64[61] = add gie gij gil:f64[61,1] = slice[ limit_indices=(61, 4) start_indices=(0, 3) strides=None ] gef gim:f64[61] = squeeze[dimensions=(1,)] gil gin:f64[61,1] = slice[ limit_indices=(61, 2) start_indices=(0, 1) strides=None ] lt gio:f64[61] = squeeze[dimensions=(1,)] gin gip:f64[61] = mul gim gio giq:f64[61] = add gik gip gir:f64[61,1] = slice[ limit_indices=(61, 1) start_indices=(0, 0) strides=None ] gef gis:f64[61] = squeeze[dimensions=(1,)] gir git:f64[61,1] = slice[ limit_indices=(61, 4) start_indices=(0, 3) strides=None ] lt giu:f64[61] = squeeze[dimensions=(1,)] git giv:f64[61] = mul gis giu giw:f64[61,1] = slice[ limit_indices=(61, 2) start_indices=(0, 1) strides=None ] gef gix:f64[61] = squeeze[dimensions=(1,)] giw giy:f64[61,1] = slice[ limit_indices=(61, 3) start_indices=(0, 2) strides=None ] lt giz:f64[61] = squeeze[dimensions=(1,)] giy gja:f64[61] = mul gix giz gjb:f64[61] = add giv gja gjc:f64[61,1] = slice[ limit_indices=(61, 3) start_indices=(0, 2) strides=None ] gef gjd:f64[61] = squeeze[dimensions=(1,)] gjc gje:f64[61,1] = slice[ limit_indices=(61, 2) start_indices=(0, 1) strides=None ] lt gjf:f64[61] = squeeze[dimensions=(1,)] gje gjg:f64[61] = mul gjd gjf gjh:f64[61] = sub gjb gjg gji:f64[61,1] = slice[ limit_indices=(61, 4) start_indices=(0, 3) strides=None ] gef gjj:f64[61] = squeeze[dimensions=(1,)] gji gjk:f64[61,1] = slice[ limit_indices=(61, 1) start_indices=(0, 0) strides=None ] lt gjl:f64[61] = squeeze[dimensions=(1,)] gjk gjm:f64[61] = mul gjj gjl gjn:f64[61] = add gjh gjm gjo:f64[61,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(61, 1)] ggw gjp:f64[61,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(61, 1)] ght gjq:f64[61,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(61, 1)] giq gjr:f64[61,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(61, 1)] gjn gjs:f64[61,4] = concatenate[dimension=1] gjo gjp gjq gjr gjt:f64[61,4,1] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(1)) shape=(61, 4, 1) ] gjs gju:f64[61,1,4] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(61, 1, 4) ] gjs gjv:f64[61,4,4] = mul gjt gju gjw:f64[61,1,1] = slice[ limit_indices=(61, 1, 1) start_indices=(0, 0, 0) strides=None ] gjv gjx:f64[61] = squeeze[dimensions=(1, 2)] gjw gjy:f64[61,1,1] = slice[ limit_indices=(61, 2, 2) start_indices=(0, 1, 1) strides=None ] gjv gjz:f64[61] = squeeze[dimensions=(1, 2)] gjy gka:f64[61] = add gjx gjz gkb:f64[61,1,1] = slice[ limit_indices=(61, 3, 3) start_indices=(0, 2, 2) strides=None ] gjv gkc:f64[61] = squeeze[dimensions=(1, 2)] gkb gkd:f64[61] = sub gka gkc gke:f64[61,1,1] = slice[ limit_indices=(61, 4, 4) start_indices=(0, 3, 3) strides=None ] gjv gkf:f64[61] = squeeze[dimensions=(1, 2)] gke gkg:f64[61] = sub gkd gkf gkh:f64[61,1,1] = slice[ limit_indices=(61, 2, 3) start_indices=(0, 1, 2) strides=None ] gjv gki:f64[61] = squeeze[dimensions=(1, 2)] gkh gkj:f64[61,1,1] = slice[ limit_indices=(61, 1, 4) start_indices=(0, 0, 3) strides=None ] gjv gkk:f64[61] = squeeze[dimensions=(1, 2)] gkj gkl:f64[61] = sub gki gkk gkm:f64[61] = mul 2.0 gkl gkn:f64[61,1,1] = slice[ limit_indices=(61, 2, 4) start_indices=(0, 1, 3) strides=None ] gjv gko:f64[61] = squeeze[dimensions=(1, 2)] gkn gkp:f64[61,1,1] = slice[ limit_indices=(61, 1, 3) start_indices=(0, 0, 2) strides=None ] gjv gkq:f64[61] = squeeze[dimensions=(1, 2)] gkp gkr:f64[61] = add gko gkq gks:f64[61] = mul 2.0 gkr gkt:f64[61,1,1] = slice[ limit_indices=(61, 2, 3) start_indices=(0, 1, 2) strides=None ] gjv gku:f64[61] = squeeze[dimensions=(1, 2)] gkt gkv:f64[61,1,1] = slice[ limit_indices=(61, 1, 4) start_indices=(0, 0, 3) strides=None ] gjv gkw:f64[61] = squeeze[dimensions=(1, 2)] gkv gkx:f64[61] = add gku gkw gky:f64[61] = mul 2.0 gkx gkz:f64[61,1,1] = slice[ limit_indices=(61, 1, 1) start_indices=(0, 0, 0) strides=None ] gjv gla:f64[61] = squeeze[dimensions=(1, 2)] gkz glb:f64[61,1,1] = slice[ limit_indices=(61, 2, 2) start_indices=(0, 1, 1) strides=None ] gjv glc:f64[61] = squeeze[dimensions=(1, 2)] glb gld:f64[61] = sub gla glc gle:f64[61,1,1] = slice[ limit_indices=(61, 3, 3) start_indices=(0, 2, 2) strides=None ] gjv glf:f64[61] = squeeze[dimensions=(1, 2)] gle glg:f64[61] = add gld glf glh:f64[61,1,1] = slice[ limit_indices=(61, 4, 4) start_indices=(0, 3, 3) strides=None ] gjv gli:f64[61] = squeeze[dimensions=(1, 2)] glh glj:f64[61] = sub glg gli glk:f64[61,1,1] = slice[ limit_indices=(61, 3, 4) start_indices=(0, 2, 3) strides=None ] gjv gll:f64[61] = squeeze[dimensions=(1, 2)] glk glm:f64[61,1,1] = slice[ limit_indices=(61, 1, 2) start_indices=(0, 0, 1) strides=None ] gjv gln:f64[61] = squeeze[dimensions=(1, 2)] glm glo:f64[61] = sub gll gln glp:f64[61] = mul 2.0 glo glq:f64[61,1,1] = slice[ limit_indices=(61, 2, 4) start_indices=(0, 1, 3) strides=None ] gjv glr:f64[61] = squeeze[dimensions=(1, 2)] glq gls:f64[61,1,1] = slice[ limit_indices=(61, 1, 3) start_indices=(0, 0, 2) strides=None ] gjv glt:f64[61] = squeeze[dimensions=(1, 2)] gls glu:f64[61] = sub glr glt glv:f64[61] = mul 2.0 glu glw:f64[61,1,1] = slice[ limit_indices=(61, 3, 4) start_indices=(0, 2, 3) strides=None ] gjv glx:f64[61] = squeeze[dimensions=(1, 2)] glw gly:f64[61,1,1] = slice[ limit_indices=(61, 1, 2) start_indices=(0, 0, 1) strides=None ] gjv glz:f64[61] = squeeze[dimensions=(1, 2)] gly gma:f64[61] = add glx glz gmb:f64[61] = mul 2.0 gma gmc:f64[61,1,1] = slice[ limit_indices=(61, 1, 1) start_indices=(0, 0, 0) strides=None ] gjv gmd:f64[61] = squeeze[dimensions=(1, 2)] gmc gme:f64[61,1,1] = slice[ limit_indices=(61, 2, 2) start_indices=(0, 1, 1) strides=None ] gjv gmf:f64[61] = squeeze[dimensions=(1, 2)] gme gmg:f64[61] = sub gmd gmf gmh:f64[61,1,1] = slice[ limit_indices=(61, 3, 3) start_indices=(0, 2, 2) strides=None ] gjv gmi:f64[61] = squeeze[dimensions=(1, 2)] gmh gmj:f64[61] = sub gmg gmi gmk:f64[61,1,1] = slice[ limit_indices=(61, 4, 4) start_indices=(0, 3, 3) strides=None ] gjv gml:f64[61] = squeeze[dimensions=(1, 2)] gmk gmm:f64[61] = add gmj gml gmn:f64[61,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(61, 1)] gkg gmo:f64[61,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(61, 1)] gkm gmp:f64[61,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(61, 1)] gks gmq:f64[61,3] = concatenate[dimension=1] gmn gmo gmp gmr:f64[61,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(61, 1)] gky gms:f64[61,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(61, 1)] glj gmt:f64[61,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(61, 1)] glp gmu:f64[61,3] = concatenate[dimension=1] gmr gms gmt gmv:f64[61,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(61, 1)] glv gmw:f64[61,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(61, 1)] gmb gmx:f64[61,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(61, 1)] gmm gmy:f64[61,3] = concatenate[dimension=1] gmv gmw gmx gmz:f64[61,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(61, 1, 3) ] gmq gna:f64[61,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(61, 1, 3) ] gmu gnb:f64[61,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(61, 1, 3) ] gmy gnc:f64[61,3,3] = concatenate[dimension=1] gmz gna gnb gnd:i32[1] = device_put[devices=[None] srcs=[None]] jr gne:bool[1] = lt gnd 0 gnf:i32[1] = add gnd 9 gng:i32[1] = select_n gne gnd gnf gnh:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] gng gni:f64[1,3] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 3) unique_indices=False ] fud gnh gnj:i32[1] = device_put[devices=[None] srcs=[None]] jr gnk:bool[1] = lt gnj 0 gnl:i32[1] = add gnj 9 gnm:i32[1] = select_n gnk gnj gnl gnn:i32[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] gnm gno:f64[1,4] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 4) unique_indices=False ] fuk gnn gnp:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] gno gnq:f64[1] = squeeze[dimensions=(1,)] gnp gnr:f64[1,3] = slice[limit_indices=(1, 4) start_indices=(0, 1) strides=None] gno gns:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] gnr lx gnt:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] gns gnu:f64[1,3] = mul gnt gnr gnv:f64[1,3] = mul 2.0 gnu gnw:f64[1] = mul gnq gnq gnx:f64[1] = dot_general[ dimension_numbers=(([1], [1]), ([0], [0])) preferred_element_type=float64 ] gnr gnr gny:f64[1] = sub gnw gnx gnz:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] gny goa:f64[1,3] = mul gnz lx gob:f64[1,3] = add gnv goa goc:f64[1] = mul 2.0 gnq god:f64[1,3] = pjit[name=cross jaxpr=cross] gnr lx goe:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] goc gof:f64[1,3] = mul goe god gog:f64[1,3] = add gob gof goh:f64[1,3] = add gni gog goi:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] gno goj:f64[1] = squeeze[dimensions=(1,)] goi gok:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ly gol:f64[1] = squeeze[dimensions=(1,)] gok gom:f64[1] = mul goj gol gon:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] gno goo:f64[1] = squeeze[dimensions=(1,)] gon gop:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] ly goq:f64[1] = squeeze[dimensions=(1,)] gop gor:f64[1] = mul goo goq gos:f64[1] = sub gom gor got:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] gno gou:f64[1] = squeeze[dimensions=(1,)] got gov:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] ly gow:f64[1] = squeeze[dimensions=(1,)] gov gox:f64[1] = mul gou gow goy:f64[1] = sub gos gox goz:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] gno gpa:f64[1] = squeeze[dimensions=(1,)] goz gpb:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] ly gpc:f64[1] = squeeze[dimensions=(1,)] gpb gpd:f64[1] = mul gpa gpc gpe:f64[1] = sub goy gpd gpf:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] gno gpg:f64[1] = squeeze[dimensions=(1,)] gpf gph:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] ly gpi:f64[1] = squeeze[dimensions=(1,)] gph gpj:f64[1] = mul gpg gpi gpk:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] gno gpl:f64[1] = squeeze[dimensions=(1,)] gpk gpm:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ly gpn:f64[1] = squeeze[dimensions=(1,)] gpm gpo:f64[1] = mul gpl gpn gpp:f64[1] = add gpj gpo gpq:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] gno gpr:f64[1] = squeeze[dimensions=(1,)] gpq gps:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] ly gpt:f64[1] = squeeze[dimensions=(1,)] gps gpu:f64[1] = mul gpr gpt gpv:f64[1] = add gpp gpu gpw:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] gno gpx:f64[1] = squeeze[dimensions=(1,)] gpw gpy:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] ly gpz:f64[1] = squeeze[dimensions=(1,)] gpy gqa:f64[1] = mul gpx gpz gqb:f64[1] = sub gpv gqa gqc:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] gno gqd:f64[1] = squeeze[dimensions=(1,)] gqc gqe:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] ly gqf:f64[1] = squeeze[dimensions=(1,)] gqe gqg:f64[1] = mul gqd gqf gqh:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] gno gqi:f64[1] = squeeze[dimensions=(1,)] gqh gqj:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] ly gqk:f64[1] = squeeze[dimensions=(1,)] gqj gql:f64[1] = mul gqi gqk gqm:f64[1] = sub gqg gql gqn:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] gno gqo:f64[1] = squeeze[dimensions=(1,)] gqn gqp:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ly gqq:f64[1] = squeeze[dimensions=(1,)] gqp gqr:f64[1] = mul gqo gqq gqs:f64[1] = add gqm gqr gqt:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] gno gqu:f64[1] = squeeze[dimensions=(1,)] gqt gqv:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] ly gqw:f64[1] = squeeze[dimensions=(1,)] gqv gqx:f64[1] = mul gqu gqw gqy:f64[1] = add gqs gqx gqz:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] gno gra:f64[1] = squeeze[dimensions=(1,)] gqz grb:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] ly grc:f64[1] = squeeze[dimensions=(1,)] grb grd:f64[1] = mul gra grc gre:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] gno grf:f64[1] = squeeze[dimensions=(1,)] gre grg:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] ly grh:f64[1] = squeeze[dimensions=(1,)] grg gri:f64[1] = mul grf grh grj:f64[1] = add grd gri grk:f64[1,1] = slice[limit_indices=(1, 3) start_indices=(0, 2) strides=None] gno grl:f64[1] = squeeze[dimensions=(1,)] grk grm:f64[1,1] = slice[limit_indices=(1, 2) start_indices=(0, 1) strides=None] ly grn:f64[1] = squeeze[dimensions=(1,)] grm gro:f64[1] = mul grl grn grp:f64[1] = sub grj gro grq:f64[1,1] = slice[limit_indices=(1, 4) start_indices=(0, 3) strides=None] gno grr:f64[1] = squeeze[dimensions=(1,)] grq grs:f64[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] ly grt:f64[1] = squeeze[dimensions=(1,)] grs gru:f64[1] = mul grr grt grv:f64[1] = add grp gru grw:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] gpe grx:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] gqb gry:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] gqy grz:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] grv gsa:f64[1,4] = concatenate[dimension=1] grw grx gry grz gsb:f64[1,4,1] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(1)) shape=(1, 4, 1) ] gsa gsc:f64[1,1,4] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 4) ] gsa gsd:f64[1,4,4] = mul gsb gsc gse:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] gsd gsf:f64[1] = squeeze[dimensions=(1, 2)] gse gsg:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] gsd gsh:f64[1] = squeeze[dimensions=(1, 2)] gsg gsi:f64[1] = add gsf gsh gsj:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] gsd gsk:f64[1] = squeeze[dimensions=(1, 2)] gsj gsl:f64[1] = sub gsi gsk gsm:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] gsd gsn:f64[1] = squeeze[dimensions=(1, 2)] gsm gso:f64[1] = sub gsl gsn gsp:f64[1,1,1] = slice[ limit_indices=(1, 2, 3) start_indices=(0, 1, 2) strides=None ] gsd gsq:f64[1] = squeeze[dimensions=(1, 2)] gsp gsr:f64[1,1,1] = slice[ limit_indices=(1, 1, 4) start_indices=(0, 0, 3) strides=None ] gsd gss:f64[1] = squeeze[dimensions=(1, 2)] gsr gst:f64[1] = sub gsq gss gsu:f64[1] = mul 2.0 gst gsv:f64[1,1,1] = slice[ limit_indices=(1, 2, 4) start_indices=(0, 1, 3) strides=None ] gsd gsw:f64[1] = squeeze[dimensions=(1, 2)] gsv gsx:f64[1,1,1] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 2) strides=None ] gsd gsy:f64[1] = squeeze[dimensions=(1, 2)] gsx gsz:f64[1] = add gsw gsy gta:f64[1] = mul 2.0 gsz gtb:f64[1,1,1] = slice[ limit_indices=(1, 2, 3) start_indices=(0, 1, 2) strides=None ] gsd gtc:f64[1] = squeeze[dimensions=(1, 2)] gtb gtd:f64[1,1,1] = slice[ limit_indices=(1, 1, 4) start_indices=(0, 0, 3) strides=None ] gsd gte:f64[1] = squeeze[dimensions=(1, 2)] gtd gtf:f64[1] = add gtc gte gtg:f64[1] = mul 2.0 gtf gth:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] gsd gti:f64[1] = squeeze[dimensions=(1, 2)] gth gtj:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] gsd gtk:f64[1] = squeeze[dimensions=(1, 2)] gtj gtl:f64[1] = sub gti gtk gtm:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] gsd gtn:f64[1] = squeeze[dimensions=(1, 2)] gtm gto:f64[1] = add gtl gtn gtp:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] gsd gtq:f64[1] = squeeze[dimensions=(1, 2)] gtp gtr:f64[1] = sub gto gtq gts:f64[1,1,1] = slice[ limit_indices=(1, 3, 4) start_indices=(0, 2, 3) strides=None ] gsd gtt:f64[1] = squeeze[dimensions=(1, 2)] gts gtu:f64[1,1,1] = slice[ limit_indices=(1, 1, 2) start_indices=(0, 0, 1) strides=None ] gsd gtv:f64[1] = squeeze[dimensions=(1, 2)] gtu gtw:f64[1] = sub gtt gtv gtx:f64[1] = mul 2.0 gtw gty:f64[1,1,1] = slice[ limit_indices=(1, 2, 4) start_indices=(0, 1, 3) strides=None ] gsd gtz:f64[1] = squeeze[dimensions=(1, 2)] gty gua:f64[1,1,1] = slice[ limit_indices=(1, 1, 3) start_indices=(0, 0, 2) strides=None ] gsd gub:f64[1] = squeeze[dimensions=(1, 2)] gua guc:f64[1] = sub gtz gub gud:f64[1] = mul 2.0 guc gue:f64[1,1,1] = slice[ limit_indices=(1, 3, 4) start_indices=(0, 2, 3) strides=None ] gsd guf:f64[1] = squeeze[dimensions=(1, 2)] gue gug:f64[1,1,1] = slice[ limit_indices=(1, 1, 2) start_indices=(0, 0, 1) strides=None ] gsd guh:f64[1] = squeeze[dimensions=(1, 2)] gug gui:f64[1] = add guf guh guj:f64[1] = mul 2.0 gui guk:f64[1,1,1] = slice[ limit_indices=(1, 1, 1) start_indices=(0, 0, 0) strides=None ] gsd gul:f64[1] = squeeze[dimensions=(1, 2)] guk gum:f64[1,1,1] = slice[ limit_indices=(1, 2, 2) start_indices=(0, 1, 1) strides=None ] gsd gun:f64[1] = squeeze[dimensions=(1, 2)] gum guo:f64[1] = sub gul gun gup:f64[1,1,1] = slice[ limit_indices=(1, 3, 3) start_indices=(0, 2, 2) strides=None ] gsd guq:f64[1] = squeeze[dimensions=(1, 2)] gup gur:f64[1] = sub guo guq gus:f64[1,1,1] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 3, 3) strides=None ] gsd gut:f64[1] = squeeze[dimensions=(1, 2)] gus guu:f64[1] = add gur gut guv:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] gso guw:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] gsu gux:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] gta guy:f64[1,3] = concatenate[dimension=1] guv guw gux guz:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] gtg gva:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] gtr gvb:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] gtx gvc:f64[1,3] = concatenate[dimension=1] guz gva gvb gvd:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] gud gve:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] guj gvf:f64[1,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(1, 1)] guu gvg:f64[1,3] = concatenate[dimension=1] gvd gve gvf gvh:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] guy gvi:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] gvc gvj:f64[1,1,3] = broadcast_in_dim[ broadcast_dimensions=(0, np.int64(2)) shape=(1, 1, 3) ] gvg gvk:f64[1,3,3] = concatenate[dimension=1] gvh gvi gvj in (ni, nj, fta, nl, nm, nn, no, np, nq, nr, ns, nt, nu, nv, nw, nx, fud, fuk, fur, fwq, gdt, ftq, fub, gfz, gnc, goh, gvk, oj, ok, ol, om, on, oo, op, oq, or, os, ot, ou, ov, ow, ox, oy, oz, pa, pb, pc, pd, pe, pf, pg, ph, pi, pj, pk, pl, pm, pn, po, pp, pq, pr, ps, pt, pu, pv, pw, px, py, pz, qa, qb, qc, qd, qe, qf, qg, qh, qi, qj, qk, ql, qm, qn, qo, qp, qq, qr, qs, qt, qu, qv, qw, qx, qy, qz, ra, rb, rc, rd, re, rf, rg, rh, ri, rj, rk, rl, rm, rn, ro, rp, rq) }