Matt Graham
UCL Centre for Advanced Research Computing
JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.
The wide user familiarity with NumPy’s API and large amount of existing code using NumPy has led some libraries to providing NumPy ‘like’ APIs
However NumPy API not designed for this purpose and has some shortcomings:
Array API standard defined by Consortium for Python Data API Standards (https://data-apis.org/).
Consider following implementation of
\[\text{LogSumExp}(x_{1:N}) = x^* + \log \sum_{n=1}^N \exp(x_n - x^*), \\ x^* = \max x_{1:N}\]
def log_sum_exp(x):
xp = x.__array_namespace__()
max_x = xp.max(x)
return max_x + xp.log(xp.sum(xp.exp(x - max_x)))log_sum_exp can then be called with array objects from any library supporting array API.
array_api_compat acts a small wrapper around common array libraries to provide compatibility with array API standard.
Useful for libraries with only partial support for array API.
NumPy v2.1+ and JAX v0.4.32+ fully support array API so using array_api_compat not strictly necessary.
We will now demonstrate some of JAX and the Array API’s key features in an applied example.
jax_enable_x64
configuration
option
can
be
used
to
use
double-precision
types
by
default,
more
closely
matching
NumPy’s
behaviour.
As an example of a scientific computing task, we will consider simulating motion of an incompressible inviscid fluid on a two-dimensional domain with periodic boundary conditions.
Mathematically such a flow can be described by
\[ \partial_t \omega = -(v \cdot \nabla) \omega, \quad \omega = \nabla \times v, \]
where \(\omega\) is the vorticity of a two-dimensional velocity field \(v\).
We will additionally consider the motion of a passive tracer field \(\tau\) advected by the velocity field \(v\)
\[ \partial_t \tau = -(v \cdot \nabla) \tau. \]
Preview of example simulation output on a 64×46 mesh:
On each step we
def step(
vorticity: Array, tracer: Array, kernels: Kernels, mesh: Mesh, time_step: float
) -> Array:
"""Perform single time step update of vorticity and tracers."""
velocity = velocity_from_vorticity(vorticity, kernels)
new_vorticity = bfecc_advect(vorticity, velocity, mesh, time_step)
new_tracer = bfecc_advect(tracer, velocity, mesh, time_step)
return new_vorticity, new_tracerTo allow generating smooth vorticity fields we
def generate_vorticity(u: Array, mesh: Mesh, kernels: Kernels) -> Array:
"""Generate vorticity field from standard normal random vector `u`."""
xp = u.__array_namespace__()
fft_coefficients = real_array_to_rfft2_coeff(u, mesh.coordinates.shape[1:])
fft_vorticity = fft_coefficients * kernels.initialization
return xp.fft.irfft2(fft_vorticity, norm="ortho")We can generate initial vorticity and velocity fields with generate_vorticity and velocity_from_vorticity using either Numpy or JAX depending on array argument types:
mesh = generate_mesh(shape=(64, 64))
kernels = generate_kernels(mesh)
rng = np.random.default_rng(1234)
u = rng.standard_normal(size=mesh.shape[0] * mesh.shape[1])
vorticity_numpy = generate_vorticity(np.asarray(u), mesh, kernels)
velocity_numpy = velocity_from_vorticity(vorticity_numpy, kernels)
type(vorticity_numpy), type(velocity_numpy)(numpy.ndarray, numpy.ndarray)
The output arrays are numerically equivalent:
We can also visualize the equivalence of the function outputs

How does calling functions using NumPy and JAX arrays compare in terms of compute time?
We time calling a subset of the model functions for both NumPy and JAX arrays and visualize the results:

We can see that when evaluating on JAX arrays the functions are consistently and significantly slower!
This may at first seem surprising as JAX describes itself as a performance-oriented library. However,
Text adapted from: https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy
This has led to different priorities in the packages.
NumPy has put significant effort into decreasing the per-call dispatch overhead for individual array operations, because in NumPy’s computational model that overhead cannot be avoided.
JAX, on the other hand, has several ways to avoid dispatch overhead and so reducing per-call overhead has been less of a priority.
Text adapted from: https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy
One particular way of reducing dispatch overhead offered by JAX is just-in-time (JIT) compilation.
As mentioned earlier JAX builds on top of the XLA compiler toolchain, which allows compiling code for a variety device backends (CPU, GPU, TPU).
Before we try out JIT compilation we will first take a look at how JAX’s function transformations work.
At a high level, JAX’s function transformations, operate by first tracing the Python function to be transformed to an intermediate representation (IR), with this IR then interpreted with a transformation specific interpreter.
Jaxprs are JAX’s internal IR of traced programs.
Adapted from: https://jax.readthedocs.io/en/latest/jaxpr.html
make_jaxprConsider the advect_points function defined in the model:
def advect_points(points: Array, velocity: Array, time_step: float) -> Array:
"""Advect points along velocity with Euler method."""
return points + time_step * velocity
JAX provides a special jax.make_jaxpr transformation that allows us to inspect the jaxpr representation of a function.
Internally JAX traces functions with abstract values. For most transformations the abstraction is at the level of the shape and datatype of arguments, but not their values.
We can use the jax.ShapedDtypeStruct to create an abstract array with just shape and dtype attributes (plus some additional attributes for JAX bookkeeping such as device layout and sharding).
If we now call the jax.make_jaxpr transformed function with these abstract arguments we get exactly the same result as before
{ lambda ; a:f64[2,64,64] b:f64[2,64,64] c:f64[]. let
d:f64[] = convert_element_type[new_dtype=float64 weak_type=False] c
e:f64[2,64,64] = mul d b
f:f64[2,64,64] = add a e
in (f,) }
semi_lagrangian_advect
function
which
calls
advect_points
:
def semi_lagrangian_advect(
field: Array, velocity: Array, mesh: Mesh, time_step: float
) -> Array:
"""Use semi-Lagrangian method to advect a given field a single step."""
origin_points = advect_points(mesh.coordinates, velocity, -time_step)
return bilinear_interpolate(field, origin_points / mesh.cell_size[:, None, None])
The jaxpr for this function corresponds to tracing through the full computation.
let remainder = { lambda ; a:i64[2,64,64] b:i64[2,1,1]. let
c:bool[2,1,1] = eq b 0:i64[]
d:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(2, 1, 1)
sharding=None
] 1:i64[]
e:i64[2,1,1] = jit[
name=_where
jaxpr={ lambda ; c:bool[2,1,1] d:i64[2,1,1] b:i64[2,1,1]. let
e:i64[2,1,1] = select_n c b d
in (e,) }
] c d b
f:i64[2,64,64] = rem a e
g:bool[2,64,64] = ne f 0:i64[]
h:bool[2,64,64] = lt f 0:i64[]
i:bool[2,1,1] = lt e 0:i64[]
j:bool[2,64,64] = ne h i
k:bool[2,64,64] = and j g
l:i64[2,64,64] = add f e
m:i64[2,64,64] = select_n k f l
in (m,) } in
let _where = { lambda ; c:bool[2,1,1] d:i64[2,1,1] b:i64[2,1,1]. let
e:i64[2,1,1] = select_n c b d
in (e,) } in
{ lambda n:i64[2]; o:f64[64,64] p:f64[2,64,64] q:i64[] r:i64[] s:f64[2] t:f64[2,64,64]
u:f64[]. let
v:f64[] = neg u
w:f64[] = convert_element_type[new_dtype=float64 weak_type=False] v
x:f64[2,64,64] = mul w p
y:f64[2,64,64] = add t x
z:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] s
ba:f64[2,64,64] = div y z
bb:f64[2,64,64] = floor ba
bc:f64[2,64,64] = sub ba bb
bd:f64[2,64,64] = sub 1.0:f64[] bc
be:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] n
bf:f64[2,64,64] = sub ba bc
bg:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] bf
bh:i64[2,64,64] = jit[name=remainder jaxpr=remainder] bg be
bi:i64[2,64,64] = add bh 1:i64[]
bj:i64[2,64,64] = jit[name=remainder jaxpr=remainder] bi be
bk:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bd
bl:f64[64,64] = squeeze[dimensions=(0,)] bk
bm:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bd
bn:f64[64,64] = squeeze[dimensions=(0,)] bm
bo:f64[64,64] = mul bl bn
bp:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bh
bq:i64[64,64] = squeeze[dimensions=(0,)] bp
br:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bh
bs:i64[64,64] = squeeze[dimensions=(0,)] br
bt:bool[64,64] = lt bq 0:i64[]
bu:i64[64,64] = add bq 64:i64[]
bv:i64[64,64] = select_n bt bq bu
bw:bool[64,64] = lt bs 0:i64[]
bx:i64[64,64] = add bs 64:i64[]
by:i64[64,64] = select_n bw bs bx
bz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bv
ca:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] by
cb:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bz
cc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ca
cd:i32[64,64,2] = concatenate[dimension=2] cb cc
ce:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] o cd
cf:f64[64,64] = mul bo ce
cg:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bd
ch:f64[64,64] = squeeze[dimensions=(0,)] cg
ci:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bc
cj:f64[64,64] = squeeze[dimensions=(0,)] ci
ck:f64[64,64] = mul ch cj
cl:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bh
cm:i64[64,64] = squeeze[dimensions=(0,)] cl
cn:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bj
co:i64[64,64] = squeeze[dimensions=(0,)] cn
cp:bool[64,64] = lt cm 0:i64[]
cq:i64[64,64] = add cm 64:i64[]
cr:i64[64,64] = select_n cp cm cq
cs:bool[64,64] = lt co 0:i64[]
ct:i64[64,64] = add co 64:i64[]
cu:i64[64,64] = select_n cs co ct
cv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cr
cw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cu
cx:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cv
cy:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cw
cz:i32[64,64,2] = concatenate[dimension=2] cx cy
da:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] o cz
db:f64[64,64] = mul ck da
dc:f64[64,64] = add cf db
dd:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bc
de:f64[64,64] = squeeze[dimensions=(0,)] dd
df:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bd
dg:f64[64,64] = squeeze[dimensions=(0,)] df
dh:f64[64,64] = mul de dg
di:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bj
dj:i64[64,64] = squeeze[dimensions=(0,)] di
dk:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bh
dl:i64[64,64] = squeeze[dimensions=(0,)] dk
dm:bool[64,64] = lt dj 0:i64[]
dn:i64[64,64] = add dj 64:i64[]
do:i64[64,64] = select_n dm dj dn
dp:bool[64,64] = lt dl 0:i64[]
dq:i64[64,64] = add dl 64:i64[]
dr:i64[64,64] = select_n dp dl dq
ds:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] do
dt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dr
du:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ds
dv:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dt
dw:i32[64,64,2] = concatenate[dimension=2] du dv
dx:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] o dw
dy:f64[64,64] = mul dh dx
dz:f64[64,64] = add dc dy
ea:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bc
eb:f64[64,64] = squeeze[dimensions=(0,)] ea
ec:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bc
ed:f64[64,64] = squeeze[dimensions=(0,)] ec
ee:f64[64,64] = mul eb ed
ef:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bj
eg:i64[64,64] = squeeze[dimensions=(0,)] ef
eh:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bj
ei:i64[64,64] = squeeze[dimensions=(0,)] eh
ej:bool[64,64] = lt eg 0:i64[]
ek:i64[64,64] = add eg 64:i64[]
el:i64[64,64] = select_n ej eg ek
em:bool[64,64] = lt ei 0:i64[]
en:i64[64,64] = add ei 64:i64[]
eo:i64[64,64] = select_n em ei en
ep:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] el
eq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eo
er:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ep
es:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] eq
et:i32[64,64,2] = concatenate[dimension=2] er es
eu:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] o et
ev:f64[64,64] = mul ee eu
ew:f64[64,64] = add dz ev
in (ew,) }
jax.jitOne of the key JAX transformations is the jax.jit function which allows functions to be JIT compiled with XLA for execution on a particular device (CPU, GPU or TPU).
{ lambda ; a:f64[2,64,64] b:f64[2,64,64] c:f64[]. let
d:f64[2,64,64] = jit[
name=advect_points
jaxpr={ lambda ; a:f64[2,64,64] b:f64[2,64,64] c:f64[]. let
e:f64[] = convert_element_type[new_dtype=float64 weak_type=False] c
f:f64[2,64,64] = mul e b
d:f64[2,64,64] = add a f
in (d,) }
] a b c
in (d,) }
The jitted and original functions give equivalent outputs to within floating point error
The jitted function however is faster
print("Without jit")
%timeit advect_points(mesh.coordinates, velocity_jax, time_step).block_until_ready()
print("With jit")
%timeit jitted_advect_points(mesh.coordinates, velocity_jax, time_step).block_until_ready()Without jit
59.9 μs ± 3.96 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
With jit
77.6 μs ± 2.44 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
jax.jit: under the hoodWhen calling a jitted function JAX does the following in order:
Text adapted from: https://jax.readthedocs.io/en/latest/aot.html
jit’s effect on performanceWe now time jitted versions of each of the model functions:

JAX typically gives greater speed-ups when JIT compiling larger function traces, due to the decreased dispatch overhead.
In our case the model functions we have been looking at so far are only substeps of the overall simulation of interest.
We can simulate the fields forward an abitrary amount of time by iteratively applying the step function.
We use a helper function to generate a vertically banded initial tracer field.

Using the integrate function we can then simulate forward:

integrate
as
the
number
of
iterations
n_step
is
traced
with
an
abstract
value.
--------------------------------------------------------------------------- TracerIntegerConversionError Traceback (most recent call last) Cell In[36], line 1 ----> 1 jax.make_jaxpr(integrate)( 2 vorticity_jax, tracer, kernels, mesh, time_step, n_step 3 ) [... skipping hidden 14 frame] Cell In[32], line 3, in integrate(vorticity, tracer, kernels, mesh, time_step, n_step) 1 def integrate(vorticity, tracer, kernels, mesh, time_step, n_step): 2 """Integrate fields forward in time a specified number of steps.""" ----> 3 for s in range(n_step): 4 vorticity, tracer = step(vorticity, tracer, kernels, mesh, time_step) 5 return vorticity, tracer [... skipping hidden 1 frame] File ~/projects/jax-sci-comp-demo/.venv/lib/python3.13t/site-packages/jax/_src/core.py:1834, in concretization_function_error.<locals>.error(self, arg) 1833 def error(self, arg): -> 1834 raise TracerIntegerConversionError(arg) TracerIntegerConversionError: The __index__() method was called on traced array with shape int64[] The error occurred while tracing the function integrate at /tmp/ipykernel_21229/42931077.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument n_step. See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerIntegerConversionError
To side-step this issue we can mark arguments as static.
JAX will then re-trace the relevant function each time its static arguments are changed.
For example applying to the n_step argument in integrate:
let fft = { lambda ; a:f64[64,64]. let
b:c128[64,33] = fft[fft_lengths=(64, 64) fft_type=2] a
in (b,) } in
let fft1 = { lambda ; c:c128[2,64,33]. let
d:f64[2,64,64] = fft[fft_lengths=(64, 64) fft_type=3] c
in (d,) } in
let remainder = { lambda ; e:i64[2,64,64] f:i64[2,1,1]. let
g:bool[2,1,1] = eq f 0:i64[]
h:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(2, 1, 1)
sharding=None
] 1:i64[]
i:i64[2,1,1] = jit[
name=_where
jaxpr={ lambda ; g:bool[2,1,1] h:i64[2,1,1] f:i64[2,1,1]. let
i:i64[2,1,1] = select_n g f h
in (i,) }
] g h f
j:i64[2,64,64] = rem e i
k:bool[2,64,64] = ne j 0:i64[]
l:bool[2,64,64] = lt j 0:i64[]
m:bool[2,1,1] = lt i 0:i64[]
n:bool[2,64,64] = ne l m
o:bool[2,64,64] = and n k
p:i64[2,64,64] = add j i
q:i64[2,64,64] = select_n o j p
in (q,) } in
let _where = { lambda ; g:bool[2,1,1] h:i64[2,1,1] f:i64[2,1,1]. let
i:i64[2,1,1] = select_n g f h
in (i,) } in
{ lambda r:c128[2] s:f64[2] t:i64[2] u:i64[2] v:i64[2] w:i64[2] x:i64[2] y:i64[2]
z:c128[2] ba:f64[2] bb:i64[2] bc:i64[2] bd:i64[2] be:i64[2] bf:i64[2] bg:i64[2]
bh:c128[2] bi:f64[2] bj:i64[2] bk:i64[2] bl:i64[2] bm:i64[2] bn:i64[2] bo:i64[2]
bp:c128[2] bq:f64[2] br:i64[2] bs:i64[2] bt:i64[2] bu:i64[2] bv:i64[2] bw:i64[2]
bx:c128[2] by:f64[2] bz:i64[2] ca:i64[2] cb:i64[2] cc:i64[2] cd:i64[2] ce:i64[2]
cf:c128[2] cg:f64[2] ch:i64[2] ci:i64[2] cj:i64[2] ck:i64[2] cl:i64[2] cm:i64[2]
cn:c128[2] co:f64[2] cp:i64[2] cq:i64[2] cr:i64[2] cs:i64[2] ct:i64[2] cu:i64[2]
cv:c128[2] cw:f64[2] cx:i64[2] cy:i64[2] cz:i64[2] da:i64[2] db:i64[2] dc:i64[2]
dd:c128[2] de:f64[2] df:i64[2] dg:i64[2] dh:i64[2] di:i64[2] dj:i64[2] dk:i64[2]
dl:c128[2] dm:f64[2] dn:i64[2] do:i64[2] dp:i64[2] dq:i64[2] dr:i64[2] ds:i64[2]; dt:f64[64,64]
du:f64[64,64] dv:f64[64,33] dw:c128[2,64,33] dx:i64[] dy:i64[] dz:f64[2] ea:f64[2,64,64]
eb:f64[]. let
ec:c128[64,33] = jit[name=fft jaxpr=fft] dt
ed:c128[] = reduce_prod[axes=(0,)] r
ee:c128[] = sqrt ed
ef:c128[] = div (1+0j):c128[] ee
eg:c128[64,33] = mul ec ef
eh:c128[1,64,33] = broadcast_in_dim[
broadcast_dimensions=(1, 2)
shape=(1, 64, 33)
sharding=None
] eg
ei:c128[2,64,33] = mul dw eh
ej:f64[2,64,64] = jit[name=fft jaxpr=fft1] ei
ek:f64[] = reduce_prod[axes=(0,)] s
el:f64[] = sqrt ek
em:f64[2,64,64] = mul ej el
en:f64[] = neg eb
eo:f64[] = convert_element_type[new_dtype=float64 weak_type=False] en
ep:f64[2,64,64] = mul eo em
eq:f64[2,64,64] = add ea ep
er:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
es:f64[2,64,64] = div eq er
et:f64[2,64,64] = floor es
eu:f64[2,64,64] = sub es et
ev:f64[2,64,64] = sub 1.0:f64[] eu
ew:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] t
ex:f64[2,64,64] = sub es eu
ey:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ex
ez:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ey ew
fa:i64[2,64,64] = add ez 1:i64[]
fb:i64[2,64,64] = jit[name=remainder jaxpr=remainder] fa ew
fc:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ev
fd:f64[64,64] = squeeze[dimensions=(0,)] fc
fe:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ev
ff:f64[64,64] = squeeze[dimensions=(0,)] fe
fg:f64[64,64] = mul fd ff
fh:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ez
fi:i64[64,64] = squeeze[dimensions=(0,)] fh
fj:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ez
fk:i64[64,64] = squeeze[dimensions=(0,)] fj
fl:bool[64,64] = lt fi 0:i64[]
fm:i64[64,64] = add fi 64:i64[]
fn:i64[64,64] = select_n fl fi fm
fo:bool[64,64] = lt fk 0:i64[]
fp:i64[64,64] = add fk 64:i64[]
fq:i64[64,64] = select_n fo fk fp
fr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fn
fs:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fq
ft:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fr
fu:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fs
fv:i32[64,64,2] = concatenate[dimension=2] ft fu
fw:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dt fv
fx:f64[64,64] = mul fg fw
fy:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ev
fz:f64[64,64] = squeeze[dimensions=(0,)] fy
ga:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] eu
gb:f64[64,64] = squeeze[dimensions=(0,)] ga
gc:f64[64,64] = mul fz gb
gd:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ez
ge:i64[64,64] = squeeze[dimensions=(0,)] gd
gf:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fb
gg:i64[64,64] = squeeze[dimensions=(0,)] gf
gh:bool[64,64] = lt ge 0:i64[]
gi:i64[64,64] = add ge 64:i64[]
gj:i64[64,64] = select_n gh ge gi
gk:bool[64,64] = lt gg 0:i64[]
gl:i64[64,64] = add gg 64:i64[]
gm:i64[64,64] = select_n gk gg gl
gn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gj
go:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gm
gp:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gn
gq:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] go
gr:i32[64,64,2] = concatenate[dimension=2] gp gq
gs:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dt gr
gt:f64[64,64] = mul gc gs
gu:f64[64,64] = add fx gt
gv:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] eu
gw:f64[64,64] = squeeze[dimensions=(0,)] gv
gx:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ev
gy:f64[64,64] = squeeze[dimensions=(0,)] gx
gz:f64[64,64] = mul gw gy
ha:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fb
hb:i64[64,64] = squeeze[dimensions=(0,)] ha
hc:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ez
hd:i64[64,64] = squeeze[dimensions=(0,)] hc
he:bool[64,64] = lt hb 0:i64[]
hf:i64[64,64] = add hb 64:i64[]
hg:i64[64,64] = select_n he hb hf
hh:bool[64,64] = lt hd 0:i64[]
hi:i64[64,64] = add hd 64:i64[]
hj:i64[64,64] = select_n hh hd hi
hk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hg
hl:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hj
hm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hk
hn:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hl
ho:i32[64,64,2] = concatenate[dimension=2] hm hn
hp:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dt ho
hq:f64[64,64] = mul gz hp
hr:f64[64,64] = add gu hq
hs:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] eu
ht:f64[64,64] = squeeze[dimensions=(0,)] hs
hu:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] eu
hv:f64[64,64] = squeeze[dimensions=(0,)] hu
hw:f64[64,64] = mul ht hv
hx:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fb
hy:i64[64,64] = squeeze[dimensions=(0,)] hx
hz:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fb
ia:i64[64,64] = squeeze[dimensions=(0,)] hz
ib:bool[64,64] = lt hy 0:i64[]
ic:i64[64,64] = add hy 64:i64[]
id:i64[64,64] = select_n ib hy ic
ie:bool[64,64] = lt ia 0:i64[]
if:i64[64,64] = add ia 64:i64[]
ig:i64[64,64] = select_n ie ia if
ih:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] id
ii:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ig
ij:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ih
ik:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ii
il:i32[64,64,2] = concatenate[dimension=2] ij ik
im:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dt il
in:f64[64,64] = mul hw im
io:f64[64,64] = add hr in
ip:f64[2,64,64] = neg em
iq:f64[] = neg eb
ir:f64[] = convert_element_type[new_dtype=float64 weak_type=False] iq
is:f64[2,64,64] = mul ir ip
it:f64[2,64,64] = add ea is
iu:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
iv:f64[2,64,64] = div it iu
iw:f64[2,64,64] = floor iv
ix:f64[2,64,64] = sub iv iw
iy:f64[2,64,64] = sub 1.0:f64[] ix
iz:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] u
ja:f64[2,64,64] = sub iv ix
jb:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ja
jc:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jb iz
jd:i64[2,64,64] = add jc 1:i64[]
je:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jd iz
jf:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] iy
jg:f64[64,64] = squeeze[dimensions=(0,)] jf
jh:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] iy
ji:f64[64,64] = squeeze[dimensions=(0,)] jh
jj:f64[64,64] = mul jg ji
jk:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jc
jl:i64[64,64] = squeeze[dimensions=(0,)] jk
jm:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jc
jn:i64[64,64] = squeeze[dimensions=(0,)] jm
jo:bool[64,64] = lt jl 0:i64[]
jp:i64[64,64] = add jl 64:i64[]
jq:i64[64,64] = select_n jo jl jp
jr:bool[64,64] = lt jn 0:i64[]
js:i64[64,64] = add jn 64:i64[]
jt:i64[64,64] = select_n jr jn js
ju:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jq
jv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jt
jw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ju
jx:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jv
jy:i32[64,64,2] = concatenate[dimension=2] jw jx
jz:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] io jy
ka:f64[64,64] = mul jj jz
kb:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] iy
kc:f64[64,64] = squeeze[dimensions=(0,)] kb
kd:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ix
ke:f64[64,64] = squeeze[dimensions=(0,)] kd
kf:f64[64,64] = mul kc ke
kg:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jc
kh:i64[64,64] = squeeze[dimensions=(0,)] kg
ki:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] je
kj:i64[64,64] = squeeze[dimensions=(0,)] ki
kk:bool[64,64] = lt kh 0:i64[]
kl:i64[64,64] = add kh 64:i64[]
km:i64[64,64] = select_n kk kh kl
kn:bool[64,64] = lt kj 0:i64[]
ko:i64[64,64] = add kj 64:i64[]
kp:i64[64,64] = select_n kn kj ko
kq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] km
kr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] kp
ks:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] kq
kt:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] kr
ku:i32[64,64,2] = concatenate[dimension=2] ks kt
kv:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] io ku
kw:f64[64,64] = mul kf kv
kx:f64[64,64] = add ka kw
ky:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ix
kz:f64[64,64] = squeeze[dimensions=(0,)] ky
la:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] iy
lb:f64[64,64] = squeeze[dimensions=(0,)] la
lc:f64[64,64] = mul kz lb
ld:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] je
le:i64[64,64] = squeeze[dimensions=(0,)] ld
lf:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jc
lg:i64[64,64] = squeeze[dimensions=(0,)] lf
lh:bool[64,64] = lt le 0:i64[]
li:i64[64,64] = add le 64:i64[]
lj:i64[64,64] = select_n lh le li
lk:bool[64,64] = lt lg 0:i64[]
ll:i64[64,64] = add lg 64:i64[]
lm:i64[64,64] = select_n lk lg ll
ln:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] lj
lo:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] lm
lp:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ln
lq:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] lo
lr:i32[64,64,2] = concatenate[dimension=2] lp lq
ls:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] io lr
lt:f64[64,64] = mul lc ls
lu:f64[64,64] = add kx lt
lv:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ix
lw:f64[64,64] = squeeze[dimensions=(0,)] lv
lx:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ix
ly:f64[64,64] = squeeze[dimensions=(0,)] lx
lz:f64[64,64] = mul lw ly
ma:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] je
mb:i64[64,64] = squeeze[dimensions=(0,)] ma
mc:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] je
md:i64[64,64] = squeeze[dimensions=(0,)] mc
me:bool[64,64] = lt mb 0:i64[]
mf:i64[64,64] = add mb 64:i64[]
mg:i64[64,64] = select_n me mb mf
mh:bool[64,64] = lt md 0:i64[]
mi:i64[64,64] = add md 64:i64[]
mj:i64[64,64] = select_n mh md mi
mk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] mg
ml:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] mj
mm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] mk
mn:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ml
mo:i32[64,64,2] = concatenate[dimension=2] mm mn
mp:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] io mo
mq:f64[64,64] = mul lz mp
mr:f64[64,64] = add lu mq
ms:f64[64,64] = sub dt mr
mt:f64[64,64] = div ms 2.0:f64[]
mu:f64[64,64] = add dt mt
mv:f64[] = neg eb
mw:f64[] = convert_element_type[new_dtype=float64 weak_type=False] mv
mx:f64[2,64,64] = mul mw em
my:f64[2,64,64] = add ea mx
mz:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
na:f64[2,64,64] = div my mz
nb:f64[2,64,64] = floor na
nc:f64[2,64,64] = sub na nb
nd:f64[2,64,64] = sub 1.0:f64[] nc
ne:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] v
nf:f64[2,64,64] = sub na nc
ng:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] nf
nh:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ng ne
ni:i64[2,64,64] = add nh 1:i64[]
nj:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ni ne
nk:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] nd
nl:f64[64,64] = squeeze[dimensions=(0,)] nk
nm:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] nd
nn:f64[64,64] = squeeze[dimensions=(0,)] nm
no:f64[64,64] = mul nl nn
np:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] nh
nq:i64[64,64] = squeeze[dimensions=(0,)] np
nr:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] nh
ns:i64[64,64] = squeeze[dimensions=(0,)] nr
nt:bool[64,64] = lt nq 0:i64[]
nu:i64[64,64] = add nq 64:i64[]
nv:i64[64,64] = select_n nt nq nu
nw:bool[64,64] = lt ns 0:i64[]
nx:i64[64,64] = add ns 64:i64[]
ny:i64[64,64] = select_n nw ns nx
nz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] nv
oa:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ny
ob:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] nz
oc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] oa
od:i32[64,64,2] = concatenate[dimension=2] ob oc
oe:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] mu od
of:f64[64,64] = mul no oe
og:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] nd
oh:f64[64,64] = squeeze[dimensions=(0,)] og
oi:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] nc
oj:f64[64,64] = squeeze[dimensions=(0,)] oi
ok:f64[64,64] = mul oh oj
ol:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] nh
om:i64[64,64] = squeeze[dimensions=(0,)] ol
on:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] nj
oo:i64[64,64] = squeeze[dimensions=(0,)] on
op:bool[64,64] = lt om 0:i64[]
oq:i64[64,64] = add om 64:i64[]
or:i64[64,64] = select_n op om oq
os:bool[64,64] = lt oo 0:i64[]
ot:i64[64,64] = add oo 64:i64[]
ou:i64[64,64] = select_n os oo ot
ov:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] or
ow:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ou
ox:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ov
oy:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ow
oz:i32[64,64,2] = concatenate[dimension=2] ox oy
pa:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] mu oz
pb:f64[64,64] = mul ok pa
pc:f64[64,64] = add of pb
pd:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] nc
pe:f64[64,64] = squeeze[dimensions=(0,)] pd
pf:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] nd
pg:f64[64,64] = squeeze[dimensions=(0,)] pf
ph:f64[64,64] = mul pe pg
pi:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] nj
pj:i64[64,64] = squeeze[dimensions=(0,)] pi
pk:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] nh
pl:i64[64,64] = squeeze[dimensions=(0,)] pk
pm:bool[64,64] = lt pj 0:i64[]
pn:i64[64,64] = add pj 64:i64[]
po:i64[64,64] = select_n pm pj pn
pp:bool[64,64] = lt pl 0:i64[]
pq:i64[64,64] = add pl 64:i64[]
pr:i64[64,64] = select_n pp pl pq
ps:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] po
pt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] pr
pu:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ps
pv:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] pt
pw:i32[64,64,2] = concatenate[dimension=2] pu pv
px:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] mu pw
py:f64[64,64] = mul ph px
pz:f64[64,64] = add pc py
qa:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] nc
qb:f64[64,64] = squeeze[dimensions=(0,)] qa
qc:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] nc
qd:f64[64,64] = squeeze[dimensions=(0,)] qc
qe:f64[64,64] = mul qb qd
qf:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] nj
qg:i64[64,64] = squeeze[dimensions=(0,)] qf
qh:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] nj
qi:i64[64,64] = squeeze[dimensions=(0,)] qh
qj:bool[64,64] = lt qg 0:i64[]
qk:i64[64,64] = add qg 64:i64[]
ql:i64[64,64] = select_n qj qg qk
qm:bool[64,64] = lt qi 0:i64[]
qn:i64[64,64] = add qi 64:i64[]
qo:i64[64,64] = select_n qm qi qn
qp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ql
qq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] qo
qr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] qp
qs:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] qq
qt:i32[64,64,2] = concatenate[dimension=2] qr qs
qu:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] mu qt
qv:f64[64,64] = mul qe qu
qw:f64[64,64] = add pz qv
qx:f64[] = neg eb
qy:f64[] = convert_element_type[new_dtype=float64 weak_type=False] qx
qz:f64[2,64,64] = mul qy em
ra:f64[2,64,64] = add ea qz
rb:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
rc:f64[2,64,64] = div ra rb
rd:f64[2,64,64] = floor rc
re:f64[2,64,64] = sub rc rd
rf:f64[2,64,64] = sub 1.0:f64[] re
rg:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] w
rh:f64[2,64,64] = sub rc re
ri:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] rh
rj:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ri rg
rk:i64[2,64,64] = add rj 1:i64[]
rl:i64[2,64,64] = jit[name=remainder jaxpr=remainder] rk rg
rm:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] rf
rn:f64[64,64] = squeeze[dimensions=(0,)] rm
ro:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] rf
rp:f64[64,64] = squeeze[dimensions=(0,)] ro
rq:f64[64,64] = mul rn rp
rr:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] rj
rs:i64[64,64] = squeeze[dimensions=(0,)] rr
rt:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] rj
ru:i64[64,64] = squeeze[dimensions=(0,)] rt
rv:bool[64,64] = lt rs 0:i64[]
rw:i64[64,64] = add rs 64:i64[]
rx:i64[64,64] = select_n rv rs rw
ry:bool[64,64] = lt ru 0:i64[]
rz:i64[64,64] = add ru 64:i64[]
sa:i64[64,64] = select_n ry ru rz
sb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] rx
sc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] sa
sd:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] sb
se:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] sc
sf:i32[64,64,2] = concatenate[dimension=2] sd se
sg:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] du sf
sh:f64[64,64] = mul rq sg
si:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] rf
sj:f64[64,64] = squeeze[dimensions=(0,)] si
sk:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] re
sl:f64[64,64] = squeeze[dimensions=(0,)] sk
sm:f64[64,64] = mul sj sl
sn:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] rj
so:i64[64,64] = squeeze[dimensions=(0,)] sn
sp:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] rl
sq:i64[64,64] = squeeze[dimensions=(0,)] sp
sr:bool[64,64] = lt so 0:i64[]
ss:i64[64,64] = add so 64:i64[]
st:i64[64,64] = select_n sr so ss
su:bool[64,64] = lt sq 0:i64[]
sv:i64[64,64] = add sq 64:i64[]
sw:i64[64,64] = select_n su sq sv
sx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] st
sy:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] sw
sz:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] sx
ta:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] sy
tb:i32[64,64,2] = concatenate[dimension=2] sz ta
tc:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] du tb
td:f64[64,64] = mul sm tc
te:f64[64,64] = add sh td
tf:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] re
tg:f64[64,64] = squeeze[dimensions=(0,)] tf
th:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] rf
ti:f64[64,64] = squeeze[dimensions=(0,)] th
tj:f64[64,64] = mul tg ti
tk:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] rl
tl:i64[64,64] = squeeze[dimensions=(0,)] tk
tm:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] rj
tn:i64[64,64] = squeeze[dimensions=(0,)] tm
to:bool[64,64] = lt tl 0:i64[]
tp:i64[64,64] = add tl 64:i64[]
tq:i64[64,64] = select_n to tl tp
tr:bool[64,64] = lt tn 0:i64[]
ts:i64[64,64] = add tn 64:i64[]
tt:i64[64,64] = select_n tr tn ts
tu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] tq
tv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] tt
tw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] tu
tx:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] tv
ty:i32[64,64,2] = concatenate[dimension=2] tw tx
tz:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] du ty
ua:f64[64,64] = mul tj tz
ub:f64[64,64] = add te ua
uc:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] re
ud:f64[64,64] = squeeze[dimensions=(0,)] uc
ue:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] re
uf:f64[64,64] = squeeze[dimensions=(0,)] ue
ug:f64[64,64] = mul ud uf
uh:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] rl
ui:i64[64,64] = squeeze[dimensions=(0,)] uh
uj:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] rl
uk:i64[64,64] = squeeze[dimensions=(0,)] uj
ul:bool[64,64] = lt ui 0:i64[]
um:i64[64,64] = add ui 64:i64[]
un:i64[64,64] = select_n ul ui um
uo:bool[64,64] = lt uk 0:i64[]
up:i64[64,64] = add uk 64:i64[]
uq:i64[64,64] = select_n uo uk up
ur:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] un
us:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] uq
ut:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ur
uu:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] us
uv:i32[64,64,2] = concatenate[dimension=2] ut uu
uw:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] du uv
ux:f64[64,64] = mul ug uw
uy:f64[64,64] = add ub ux
uz:f64[2,64,64] = neg em
va:f64[] = neg eb
vb:f64[] = convert_element_type[new_dtype=float64 weak_type=False] va
vc:f64[2,64,64] = mul vb uz
vd:f64[2,64,64] = add ea vc
ve:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
vf:f64[2,64,64] = div vd ve
vg:f64[2,64,64] = floor vf
vh:f64[2,64,64] = sub vf vg
vi:f64[2,64,64] = sub 1.0:f64[] vh
vj:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] x
vk:f64[2,64,64] = sub vf vh
vl:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] vk
vm:i64[2,64,64] = jit[name=remainder jaxpr=remainder] vl vj
vn:i64[2,64,64] = add vm 1:i64[]
vo:i64[2,64,64] = jit[name=remainder jaxpr=remainder] vn vj
vp:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] vi
vq:f64[64,64] = squeeze[dimensions=(0,)] vp
vr:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] vi
vs:f64[64,64] = squeeze[dimensions=(0,)] vr
vt:f64[64,64] = mul vq vs
vu:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] vm
vv:i64[64,64] = squeeze[dimensions=(0,)] vu
vw:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] vm
vx:i64[64,64] = squeeze[dimensions=(0,)] vw
vy:bool[64,64] = lt vv 0:i64[]
vz:i64[64,64] = add vv 64:i64[]
wa:i64[64,64] = select_n vy vv vz
wb:bool[64,64] = lt vx 0:i64[]
wc:i64[64,64] = add vx 64:i64[]
wd:i64[64,64] = select_n wb vx wc
we:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] wa
wf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] wd
wg:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] we
wh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] wf
wi:i32[64,64,2] = concatenate[dimension=2] wg wh
wj:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] uy wi
wk:f64[64,64] = mul vt wj
wl:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] vi
wm:f64[64,64] = squeeze[dimensions=(0,)] wl
wn:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] vh
wo:f64[64,64] = squeeze[dimensions=(0,)] wn
wp:f64[64,64] = mul wm wo
wq:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] vm
wr:i64[64,64] = squeeze[dimensions=(0,)] wq
ws:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] vo
wt:i64[64,64] = squeeze[dimensions=(0,)] ws
wu:bool[64,64] = lt wr 0:i64[]
wv:i64[64,64] = add wr 64:i64[]
ww:i64[64,64] = select_n wu wr wv
wx:bool[64,64] = lt wt 0:i64[]
wy:i64[64,64] = add wt 64:i64[]
wz:i64[64,64] = select_n wx wt wy
xa:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ww
xb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] wz
xc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] xa
xd:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] xb
xe:i32[64,64,2] = concatenate[dimension=2] xc xd
xf:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] uy xe
xg:f64[64,64] = mul wp xf
xh:f64[64,64] = add wk xg
xi:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] vh
xj:f64[64,64] = squeeze[dimensions=(0,)] xi
xk:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] vi
xl:f64[64,64] = squeeze[dimensions=(0,)] xk
xm:f64[64,64] = mul xj xl
xn:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] vo
xo:i64[64,64] = squeeze[dimensions=(0,)] xn
xp:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] vm
xq:i64[64,64] = squeeze[dimensions=(0,)] xp
xr:bool[64,64] = lt xo 0:i64[]
xs:i64[64,64] = add xo 64:i64[]
xt:i64[64,64] = select_n xr xo xs
xu:bool[64,64] = lt xq 0:i64[]
xv:i64[64,64] = add xq 64:i64[]
xw:i64[64,64] = select_n xu xq xv
xx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] xt
xy:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] xw
xz:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] xx
ya:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] xy
yb:i32[64,64,2] = concatenate[dimension=2] xz ya
yc:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] uy yb
yd:f64[64,64] = mul xm yc
ye:f64[64,64] = add xh yd
yf:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] vh
yg:f64[64,64] = squeeze[dimensions=(0,)] yf
yh:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] vh
yi:f64[64,64] = squeeze[dimensions=(0,)] yh
yj:f64[64,64] = mul yg yi
yk:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] vo
yl:i64[64,64] = squeeze[dimensions=(0,)] yk
ym:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] vo
yn:i64[64,64] = squeeze[dimensions=(0,)] ym
yo:bool[64,64] = lt yl 0:i64[]
yp:i64[64,64] = add yl 64:i64[]
yq:i64[64,64] = select_n yo yl yp
yr:bool[64,64] = lt yn 0:i64[]
ys:i64[64,64] = add yn 64:i64[]
yt:i64[64,64] = select_n yr yn ys
yu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] yq
yv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] yt
yw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] yu
yx:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] yv
yy:i32[64,64,2] = concatenate[dimension=2] yw yx
yz:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] uy yy
za:f64[64,64] = mul yj yz
zb:f64[64,64] = add ye za
zc:f64[64,64] = sub du zb
zd:f64[64,64] = div zc 2.0:f64[]
ze:f64[64,64] = add du zd
zf:f64[] = neg eb
zg:f64[] = convert_element_type[new_dtype=float64 weak_type=False] zf
zh:f64[2,64,64] = mul zg em
zi:f64[2,64,64] = add ea zh
zj:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
zk:f64[2,64,64] = div zi zj
zl:f64[2,64,64] = floor zk
zm:f64[2,64,64] = sub zk zl
zn:f64[2,64,64] = sub 1.0:f64[] zm
zo:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] y
zp:f64[2,64,64] = sub zk zm
zq:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] zp
zr:i64[2,64,64] = jit[name=remainder jaxpr=remainder] zq zo
zs:i64[2,64,64] = add zr 1:i64[]
zt:i64[2,64,64] = jit[name=remainder jaxpr=remainder] zs zo
zu:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] zn
zv:f64[64,64] = squeeze[dimensions=(0,)] zu
zw:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] zn
zx:f64[64,64] = squeeze[dimensions=(0,)] zw
zy:f64[64,64] = mul zv zx
zz:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] zr
baa:i64[64,64] = squeeze[dimensions=(0,)] zz
bab:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] zr
bac:i64[64,64] = squeeze[dimensions=(0,)] bab
bad:bool[64,64] = lt baa 0:i64[]
bae:i64[64,64] = add baa 64:i64[]
baf:i64[64,64] = select_n bad baa bae
bag:bool[64,64] = lt bac 0:i64[]
bah:i64[64,64] = add bac 64:i64[]
bai:i64[64,64] = select_n bag bac bah
baj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] baf
bak:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bai
bal:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] baj
bam:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bak
ban:i32[64,64,2] = concatenate[dimension=2] bal bam
bao:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ze ban
bap:f64[64,64] = mul zy bao
baq:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] zn
bar:f64[64,64] = squeeze[dimensions=(0,)] baq
bas:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] zm
bat:f64[64,64] = squeeze[dimensions=(0,)] bas
bau:f64[64,64] = mul bar bat
bav:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] zr
baw:i64[64,64] = squeeze[dimensions=(0,)] bav
bax:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] zt
bay:i64[64,64] = squeeze[dimensions=(0,)] bax
baz:bool[64,64] = lt baw 0:i64[]
bba:i64[64,64] = add baw 64:i64[]
bbb:i64[64,64] = select_n baz baw bba
bbc:bool[64,64] = lt bay 0:i64[]
bbd:i64[64,64] = add bay 64:i64[]
bbe:i64[64,64] = select_n bbc bay bbd
bbf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bbb
bbg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bbe
bbh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bbf
bbi:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bbg
bbj:i32[64,64,2] = concatenate[dimension=2] bbh bbi
bbk:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ze bbj
bbl:f64[64,64] = mul bau bbk
bbm:f64[64,64] = add bap bbl
bbn:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] zm
bbo:f64[64,64] = squeeze[dimensions=(0,)] bbn
bbp:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] zn
bbq:f64[64,64] = squeeze[dimensions=(0,)] bbp
bbr:f64[64,64] = mul bbo bbq
bbs:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] zt
bbt:i64[64,64] = squeeze[dimensions=(0,)] bbs
bbu:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] zr
bbv:i64[64,64] = squeeze[dimensions=(0,)] bbu
bbw:bool[64,64] = lt bbt 0:i64[]
bbx:i64[64,64] = add bbt 64:i64[]
bby:i64[64,64] = select_n bbw bbt bbx
bbz:bool[64,64] = lt bbv 0:i64[]
bca:i64[64,64] = add bbv 64:i64[]
bcb:i64[64,64] = select_n bbz bbv bca
bcc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bby
bcd:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bcb
bce:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bcc
bcf:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bcd
bcg:i32[64,64,2] = concatenate[dimension=2] bce bcf
bch:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ze bcg
bci:f64[64,64] = mul bbr bch
bcj:f64[64,64] = add bbm bci
bck:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] zm
bcl:f64[64,64] = squeeze[dimensions=(0,)] bck
bcm:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] zm
bcn:f64[64,64] = squeeze[dimensions=(0,)] bcm
bco:f64[64,64] = mul bcl bcn
bcp:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] zt
bcq:i64[64,64] = squeeze[dimensions=(0,)] bcp
bcr:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] zt
bcs:i64[64,64] = squeeze[dimensions=(0,)] bcr
bct:bool[64,64] = lt bcq 0:i64[]
bcu:i64[64,64] = add bcq 64:i64[]
bcv:i64[64,64] = select_n bct bcq bcu
bcw:bool[64,64] = lt bcs 0:i64[]
bcx:i64[64,64] = add bcs 64:i64[]
bcy:i64[64,64] = select_n bcw bcs bcx
bcz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bcv
bda:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bcy
bdb:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bcz
bdc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bda
bdd:i32[64,64,2] = concatenate[dimension=2] bdb bdc
bde:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ze bdd
bdf:f64[64,64] = mul bco bde
bdg:f64[64,64] = add bcj bdf
bdh:c128[64,33] = jit[name=fft jaxpr=fft] qw
bdi:c128[] = reduce_prod[axes=(0,)] z
bdj:c128[] = sqrt bdi
bdk:c128[] = div (1+0j):c128[] bdj
bdl:c128[64,33] = mul bdh bdk
bdm:c128[1,64,33] = broadcast_in_dim[
broadcast_dimensions=(1, 2)
shape=(1, 64, 33)
sharding=None
] bdl
bdn:c128[2,64,33] = mul dw bdm
bdo:f64[2,64,64] = jit[name=fft jaxpr=fft1] bdn
bdp:f64[] = reduce_prod[axes=(0,)] ba
bdq:f64[] = sqrt bdp
bdr:f64[2,64,64] = mul bdo bdq
bds:f64[] = neg eb
bdt:f64[] = convert_element_type[new_dtype=float64 weak_type=False] bds
bdu:f64[2,64,64] = mul bdt bdr
bdv:f64[2,64,64] = add ea bdu
bdw:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
bdx:f64[2,64,64] = div bdv bdw
bdy:f64[2,64,64] = floor bdx
bdz:f64[2,64,64] = sub bdx bdy
bea:f64[2,64,64] = sub 1.0:f64[] bdz
beb:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] bb
bec:f64[2,64,64] = sub bdx bdz
bed:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] bec
bee:i64[2,64,64] = jit[name=remainder jaxpr=remainder] bed beb
bef:i64[2,64,64] = add bee 1:i64[]
beg:i64[2,64,64] = jit[name=remainder jaxpr=remainder] bef beb
beh:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bea
bei:f64[64,64] = squeeze[dimensions=(0,)] beh
bej:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bea
bek:f64[64,64] = squeeze[dimensions=(0,)] bej
bel:f64[64,64] = mul bei bek
bem:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bee
ben:i64[64,64] = squeeze[dimensions=(0,)] bem
beo:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bee
bep:i64[64,64] = squeeze[dimensions=(0,)] beo
beq:bool[64,64] = lt ben 0:i64[]
ber:i64[64,64] = add ben 64:i64[]
bes:i64[64,64] = select_n beq ben ber
bet:bool[64,64] = lt bep 0:i64[]
beu:i64[64,64] = add bep 64:i64[]
bev:i64[64,64] = select_n bet bep beu
bew:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bes
bex:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bev
bey:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bew
bez:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bex
bfa:i32[64,64,2] = concatenate[dimension=2] bey bez
bfb:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] qw bfa
bfc:f64[64,64] = mul bel bfb
bfd:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bea
bfe:f64[64,64] = squeeze[dimensions=(0,)] bfd
bff:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bdz
bfg:f64[64,64] = squeeze[dimensions=(0,)] bff
bfh:f64[64,64] = mul bfe bfg
bfi:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bee
bfj:i64[64,64] = squeeze[dimensions=(0,)] bfi
bfk:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] beg
bfl:i64[64,64] = squeeze[dimensions=(0,)] bfk
bfm:bool[64,64] = lt bfj 0:i64[]
bfn:i64[64,64] = add bfj 64:i64[]
bfo:i64[64,64] = select_n bfm bfj bfn
bfp:bool[64,64] = lt bfl 0:i64[]
bfq:i64[64,64] = add bfl 64:i64[]
bfr:i64[64,64] = select_n bfp bfl bfq
bfs:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bfo
bft:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bfr
bfu:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bfs
bfv:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bft
bfw:i32[64,64,2] = concatenate[dimension=2] bfu bfv
bfx:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] qw bfw
bfy:f64[64,64] = mul bfh bfx
bfz:f64[64,64] = add bfc bfy
bga:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bdz
bgb:f64[64,64] = squeeze[dimensions=(0,)] bga
bgc:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bea
bgd:f64[64,64] = squeeze[dimensions=(0,)] bgc
bge:f64[64,64] = mul bgb bgd
bgf:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] beg
bgg:i64[64,64] = squeeze[dimensions=(0,)] bgf
bgh:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bee
bgi:i64[64,64] = squeeze[dimensions=(0,)] bgh
bgj:bool[64,64] = lt bgg 0:i64[]
bgk:i64[64,64] = add bgg 64:i64[]
bgl:i64[64,64] = select_n bgj bgg bgk
bgm:bool[64,64] = lt bgi 0:i64[]
bgn:i64[64,64] = add bgi 64:i64[]
bgo:i64[64,64] = select_n bgm bgi bgn
bgp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bgl
bgq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bgo
bgr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bgp
bgs:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bgq
bgt:i32[64,64,2] = concatenate[dimension=2] bgr bgs
bgu:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] qw bgt
bgv:f64[64,64] = mul bge bgu
bgw:f64[64,64] = add bfz bgv
bgx:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bdz
bgy:f64[64,64] = squeeze[dimensions=(0,)] bgx
bgz:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bdz
bha:f64[64,64] = squeeze[dimensions=(0,)] bgz
bhb:f64[64,64] = mul bgy bha
bhc:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] beg
bhd:i64[64,64] = squeeze[dimensions=(0,)] bhc
bhe:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] beg
bhf:i64[64,64] = squeeze[dimensions=(0,)] bhe
bhg:bool[64,64] = lt bhd 0:i64[]
bhh:i64[64,64] = add bhd 64:i64[]
bhi:i64[64,64] = select_n bhg bhd bhh
bhj:bool[64,64] = lt bhf 0:i64[]
bhk:i64[64,64] = add bhf 64:i64[]
bhl:i64[64,64] = select_n bhj bhf bhk
bhm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bhi
bhn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bhl
bho:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bhm
bhp:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bhn
bhq:i32[64,64,2] = concatenate[dimension=2] bho bhp
bhr:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] qw bhq
bhs:f64[64,64] = mul bhb bhr
bht:f64[64,64] = add bgw bhs
bhu:f64[2,64,64] = neg bdr
bhv:f64[] = neg eb
bhw:f64[] = convert_element_type[new_dtype=float64 weak_type=False] bhv
bhx:f64[2,64,64] = mul bhw bhu
bhy:f64[2,64,64] = add ea bhx
bhz:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
bia:f64[2,64,64] = div bhy bhz
bib:f64[2,64,64] = floor bia
bic:f64[2,64,64] = sub bia bib
bid:f64[2,64,64] = sub 1.0:f64[] bic
bie:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] bc
bif:f64[2,64,64] = sub bia bic
big:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] bif
bih:i64[2,64,64] = jit[name=remainder jaxpr=remainder] big bie
bii:i64[2,64,64] = add bih 1:i64[]
bij:i64[2,64,64] = jit[name=remainder jaxpr=remainder] bii bie
bik:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bid
bil:f64[64,64] = squeeze[dimensions=(0,)] bik
bim:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bid
bin:f64[64,64] = squeeze[dimensions=(0,)] bim
bio:f64[64,64] = mul bil bin
bip:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bih
biq:i64[64,64] = squeeze[dimensions=(0,)] bip
bir:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bih
bis:i64[64,64] = squeeze[dimensions=(0,)] bir
bit:bool[64,64] = lt biq 0:i64[]
biu:i64[64,64] = add biq 64:i64[]
biv:i64[64,64] = select_n bit biq biu
biw:bool[64,64] = lt bis 0:i64[]
bix:i64[64,64] = add bis 64:i64[]
biy:i64[64,64] = select_n biw bis bix
biz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] biv
bja:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] biy
bjb:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] biz
bjc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bja
bjd:i32[64,64,2] = concatenate[dimension=2] bjb bjc
bje:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] bht bjd
bjf:f64[64,64] = mul bio bje
bjg:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bid
bjh:f64[64,64] = squeeze[dimensions=(0,)] bjg
bji:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bic
bjj:f64[64,64] = squeeze[dimensions=(0,)] bji
bjk:f64[64,64] = mul bjh bjj
bjl:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bih
bjm:i64[64,64] = squeeze[dimensions=(0,)] bjl
bjn:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bij
bjo:i64[64,64] = squeeze[dimensions=(0,)] bjn
bjp:bool[64,64] = lt bjm 0:i64[]
bjq:i64[64,64] = add bjm 64:i64[]
bjr:i64[64,64] = select_n bjp bjm bjq
bjs:bool[64,64] = lt bjo 0:i64[]
bjt:i64[64,64] = add bjo 64:i64[]
bju:i64[64,64] = select_n bjs bjo bjt
bjv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bjr
bjw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bju
bjx:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bjv
bjy:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bjw
bjz:i32[64,64,2] = concatenate[dimension=2] bjx bjy
bka:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] bht bjz
bkb:f64[64,64] = mul bjk bka
bkc:f64[64,64] = add bjf bkb
bkd:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bic
bke:f64[64,64] = squeeze[dimensions=(0,)] bkd
bkf:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bid
bkg:f64[64,64] = squeeze[dimensions=(0,)] bkf
bkh:f64[64,64] = mul bke bkg
bki:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bij
bkj:i64[64,64] = squeeze[dimensions=(0,)] bki
bkk:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bih
bkl:i64[64,64] = squeeze[dimensions=(0,)] bkk
bkm:bool[64,64] = lt bkj 0:i64[]
bkn:i64[64,64] = add bkj 64:i64[]
bko:i64[64,64] = select_n bkm bkj bkn
bkp:bool[64,64] = lt bkl 0:i64[]
bkq:i64[64,64] = add bkl 64:i64[]
bkr:i64[64,64] = select_n bkp bkl bkq
bks:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bko
bkt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bkr
bku:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bks
bkv:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bkt
bkw:i32[64,64,2] = concatenate[dimension=2] bku bkv
bkx:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] bht bkw
bky:f64[64,64] = mul bkh bkx
bkz:f64[64,64] = add bkc bky
bla:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bic
blb:f64[64,64] = squeeze[dimensions=(0,)] bla
blc:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bic
bld:f64[64,64] = squeeze[dimensions=(0,)] blc
ble:f64[64,64] = mul blb bld
blf:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bij
blg:i64[64,64] = squeeze[dimensions=(0,)] blf
blh:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bij
bli:i64[64,64] = squeeze[dimensions=(0,)] blh
blj:bool[64,64] = lt blg 0:i64[]
blk:i64[64,64] = add blg 64:i64[]
bll:i64[64,64] = select_n blj blg blk
blm:bool[64,64] = lt bli 0:i64[]
bln:i64[64,64] = add bli 64:i64[]
blo:i64[64,64] = select_n blm bli bln
blp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bll
blq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] blo
blr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] blp
bls:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] blq
blt:i32[64,64,2] = concatenate[dimension=2] blr bls
blu:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] bht blt
blv:f64[64,64] = mul ble blu
blw:f64[64,64] = add bkz blv
blx:f64[64,64] = sub qw blw
bly:f64[64,64] = div blx 2.0:f64[]
blz:f64[64,64] = add qw bly
bma:f64[] = neg eb
bmb:f64[] = convert_element_type[new_dtype=float64 weak_type=False] bma
bmc:f64[2,64,64] = mul bmb bdr
bmd:f64[2,64,64] = add ea bmc
bme:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
bmf:f64[2,64,64] = div bmd bme
bmg:f64[2,64,64] = floor bmf
bmh:f64[2,64,64] = sub bmf bmg
bmi:f64[2,64,64] = sub 1.0:f64[] bmh
bmj:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] bd
bmk:f64[2,64,64] = sub bmf bmh
bml:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] bmk
bmm:i64[2,64,64] = jit[name=remainder jaxpr=remainder] bml bmj
bmn:i64[2,64,64] = add bmm 1:i64[]
bmo:i64[2,64,64] = jit[name=remainder jaxpr=remainder] bmn bmj
bmp:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bmi
bmq:f64[64,64] = squeeze[dimensions=(0,)] bmp
bmr:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bmi
bms:f64[64,64] = squeeze[dimensions=(0,)] bmr
bmt:f64[64,64] = mul bmq bms
bmu:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bmm
bmv:i64[64,64] = squeeze[dimensions=(0,)] bmu
bmw:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bmm
bmx:i64[64,64] = squeeze[dimensions=(0,)] bmw
bmy:bool[64,64] = lt bmv 0:i64[]
bmz:i64[64,64] = add bmv 64:i64[]
bna:i64[64,64] = select_n bmy bmv bmz
bnb:bool[64,64] = lt bmx 0:i64[]
bnc:i64[64,64] = add bmx 64:i64[]
bnd:i64[64,64] = select_n bnb bmx bnc
bne:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bna
bnf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bnd
bng:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bne
bnh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bnf
bni:i32[64,64,2] = concatenate[dimension=2] bng bnh
bnj:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] blz bni
bnk:f64[64,64] = mul bmt bnj
bnl:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bmi
bnm:f64[64,64] = squeeze[dimensions=(0,)] bnl
bnn:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bmh
bno:f64[64,64] = squeeze[dimensions=(0,)] bnn
bnp:f64[64,64] = mul bnm bno
bnq:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bmm
bnr:i64[64,64] = squeeze[dimensions=(0,)] bnq
bns:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bmo
bnt:i64[64,64] = squeeze[dimensions=(0,)] bns
bnu:bool[64,64] = lt bnr 0:i64[]
bnv:i64[64,64] = add bnr 64:i64[]
bnw:i64[64,64] = select_n bnu bnr bnv
bnx:bool[64,64] = lt bnt 0:i64[]
bny:i64[64,64] = add bnt 64:i64[]
bnz:i64[64,64] = select_n bnx bnt bny
boa:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bnw
bob:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bnz
boc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] boa
bod:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bob
boe:i32[64,64,2] = concatenate[dimension=2] boc bod
bof:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] blz boe
bog:f64[64,64] = mul bnp bof
boh:f64[64,64] = add bnk bog
boi:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bmh
boj:f64[64,64] = squeeze[dimensions=(0,)] boi
bok:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bmi
bol:f64[64,64] = squeeze[dimensions=(0,)] bok
bom:f64[64,64] = mul boj bol
bon:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bmo
boo:i64[64,64] = squeeze[dimensions=(0,)] bon
bop:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bmm
boq:i64[64,64] = squeeze[dimensions=(0,)] bop
bor:bool[64,64] = lt boo 0:i64[]
bos:i64[64,64] = add boo 64:i64[]
bot:i64[64,64] = select_n bor boo bos
bou:bool[64,64] = lt boq 0:i64[]
bov:i64[64,64] = add boq 64:i64[]
bow:i64[64,64] = select_n bou boq bov
box:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bot
boy:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bow
boz:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] box
bpa:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] boy
bpb:i32[64,64,2] = concatenate[dimension=2] boz bpa
bpc:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] blz bpb
bpd:f64[64,64] = mul bom bpc
bpe:f64[64,64] = add boh bpd
bpf:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bmh
bpg:f64[64,64] = squeeze[dimensions=(0,)] bpf
bph:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bmh
bpi:f64[64,64] = squeeze[dimensions=(0,)] bph
bpj:f64[64,64] = mul bpg bpi
bpk:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bmo
bpl:i64[64,64] = squeeze[dimensions=(0,)] bpk
bpm:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bmo
bpn:i64[64,64] = squeeze[dimensions=(0,)] bpm
bpo:bool[64,64] = lt bpl 0:i64[]
bpp:i64[64,64] = add bpl 64:i64[]
bpq:i64[64,64] = select_n bpo bpl bpp
bpr:bool[64,64] = lt bpn 0:i64[]
bps:i64[64,64] = add bpn 64:i64[]
bpt:i64[64,64] = select_n bpr bpn bps
bpu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bpq
bpv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bpt
bpw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bpu
bpx:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bpv
bpy:i32[64,64,2] = concatenate[dimension=2] bpw bpx
bpz:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] blz bpy
bqa:f64[64,64] = mul bpj bpz
bqb:f64[64,64] = add bpe bqa
bqc:f64[] = neg eb
bqd:f64[] = convert_element_type[new_dtype=float64 weak_type=False] bqc
bqe:f64[2,64,64] = mul bqd bdr
bqf:f64[2,64,64] = add ea bqe
bqg:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
bqh:f64[2,64,64] = div bqf bqg
bqi:f64[2,64,64] = floor bqh
bqj:f64[2,64,64] = sub bqh bqi
bqk:f64[2,64,64] = sub 1.0:f64[] bqj
bql:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] be
bqm:f64[2,64,64] = sub bqh bqj
bqn:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] bqm
bqo:i64[2,64,64] = jit[name=remainder jaxpr=remainder] bqn bql
bqp:i64[2,64,64] = add bqo 1:i64[]
bqq:i64[2,64,64] = jit[name=remainder jaxpr=remainder] bqp bql
bqr:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bqk
bqs:f64[64,64] = squeeze[dimensions=(0,)] bqr
bqt:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bqk
bqu:f64[64,64] = squeeze[dimensions=(0,)] bqt
bqv:f64[64,64] = mul bqs bqu
bqw:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bqo
bqx:i64[64,64] = squeeze[dimensions=(0,)] bqw
bqy:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bqo
bqz:i64[64,64] = squeeze[dimensions=(0,)] bqy
bra:bool[64,64] = lt bqx 0:i64[]
brb:i64[64,64] = add bqx 64:i64[]
brc:i64[64,64] = select_n bra bqx brb
brd:bool[64,64] = lt bqz 0:i64[]
bre:i64[64,64] = add bqz 64:i64[]
brf:i64[64,64] = select_n brd bqz bre
brg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] brc
brh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] brf
bri:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] brg
brj:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] brh
brk:i32[64,64,2] = concatenate[dimension=2] bri brj
brl:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] bdg brk
brm:f64[64,64] = mul bqv brl
brn:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bqk
bro:f64[64,64] = squeeze[dimensions=(0,)] brn
brp:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bqj
brq:f64[64,64] = squeeze[dimensions=(0,)] brp
brr:f64[64,64] = mul bro brq
brs:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bqo
brt:i64[64,64] = squeeze[dimensions=(0,)] brs
bru:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bqq
brv:i64[64,64] = squeeze[dimensions=(0,)] bru
brw:bool[64,64] = lt brt 0:i64[]
brx:i64[64,64] = add brt 64:i64[]
bry:i64[64,64] = select_n brw brt brx
brz:bool[64,64] = lt brv 0:i64[]
bsa:i64[64,64] = add brv 64:i64[]
bsb:i64[64,64] = select_n brz brv bsa
bsc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bry
bsd:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bsb
bse:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bsc
bsf:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bsd
bsg:i32[64,64,2] = concatenate[dimension=2] bse bsf
bsh:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] bdg bsg
bsi:f64[64,64] = mul brr bsh
bsj:f64[64,64] = add brm bsi
bsk:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bqj
bsl:f64[64,64] = squeeze[dimensions=(0,)] bsk
bsm:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bqk
bsn:f64[64,64] = squeeze[dimensions=(0,)] bsm
bso:f64[64,64] = mul bsl bsn
bsp:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bqq
bsq:i64[64,64] = squeeze[dimensions=(0,)] bsp
bsr:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bqo
bss:i64[64,64] = squeeze[dimensions=(0,)] bsr
bst:bool[64,64] = lt bsq 0:i64[]
bsu:i64[64,64] = add bsq 64:i64[]
bsv:i64[64,64] = select_n bst bsq bsu
bsw:bool[64,64] = lt bss 0:i64[]
bsx:i64[64,64] = add bss 64:i64[]
bsy:i64[64,64] = select_n bsw bss bsx
bsz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bsv
bta:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bsy
btb:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bsz
btc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bta
btd:i32[64,64,2] = concatenate[dimension=2] btb btc
bte:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] bdg btd
btf:f64[64,64] = mul bso bte
btg:f64[64,64] = add bsj btf
bth:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bqj
bti:f64[64,64] = squeeze[dimensions=(0,)] bth
btj:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bqj
btk:f64[64,64] = squeeze[dimensions=(0,)] btj
btl:f64[64,64] = mul bti btk
btm:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bqq
btn:i64[64,64] = squeeze[dimensions=(0,)] btm
bto:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bqq
btp:i64[64,64] = squeeze[dimensions=(0,)] bto
btq:bool[64,64] = lt btn 0:i64[]
btr:i64[64,64] = add btn 64:i64[]
bts:i64[64,64] = select_n btq btn btr
btt:bool[64,64] = lt btp 0:i64[]
btu:i64[64,64] = add btp 64:i64[]
btv:i64[64,64] = select_n btt btp btu
btw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bts
btx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] btv
bty:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] btw
btz:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] btx
bua:i32[64,64,2] = concatenate[dimension=2] bty btz
bub:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] bdg bua
buc:f64[64,64] = mul btl bub
bud:f64[64,64] = add btg buc
bue:f64[2,64,64] = neg bdr
buf:f64[] = neg eb
bug:f64[] = convert_element_type[new_dtype=float64 weak_type=False] buf
buh:f64[2,64,64] = mul bug bue
bui:f64[2,64,64] = add ea buh
buj:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
buk:f64[2,64,64] = div bui buj
bul:f64[2,64,64] = floor buk
bum:f64[2,64,64] = sub buk bul
bun:f64[2,64,64] = sub 1.0:f64[] bum
buo:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] bf
bup:f64[2,64,64] = sub buk bum
buq:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] bup
bur:i64[2,64,64] = jit[name=remainder jaxpr=remainder] buq buo
bus:i64[2,64,64] = add bur 1:i64[]
but:i64[2,64,64] = jit[name=remainder jaxpr=remainder] bus buo
buu:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bun
buv:f64[64,64] = squeeze[dimensions=(0,)] buu
buw:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bun
bux:f64[64,64] = squeeze[dimensions=(0,)] buw
buy:f64[64,64] = mul buv bux
buz:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bur
bva:i64[64,64] = squeeze[dimensions=(0,)] buz
bvb:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bur
bvc:i64[64,64] = squeeze[dimensions=(0,)] bvb
bvd:bool[64,64] = lt bva 0:i64[]
bve:i64[64,64] = add bva 64:i64[]
bvf:i64[64,64] = select_n bvd bva bve
bvg:bool[64,64] = lt bvc 0:i64[]
bvh:i64[64,64] = add bvc 64:i64[]
bvi:i64[64,64] = select_n bvg bvc bvh
bvj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bvf
bvk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bvi
bvl:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bvj
bvm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bvk
bvn:i32[64,64,2] = concatenate[dimension=2] bvl bvm
bvo:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] bud bvn
bvp:f64[64,64] = mul buy bvo
bvq:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bun
bvr:f64[64,64] = squeeze[dimensions=(0,)] bvq
bvs:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bum
bvt:f64[64,64] = squeeze[dimensions=(0,)] bvs
bvu:f64[64,64] = mul bvr bvt
bvv:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bur
bvw:i64[64,64] = squeeze[dimensions=(0,)] bvv
bvx:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] but
bvy:i64[64,64] = squeeze[dimensions=(0,)] bvx
bvz:bool[64,64] = lt bvw 0:i64[]
bwa:i64[64,64] = add bvw 64:i64[]
bwb:i64[64,64] = select_n bvz bvw bwa
bwc:bool[64,64] = lt bvy 0:i64[]
bwd:i64[64,64] = add bvy 64:i64[]
bwe:i64[64,64] = select_n bwc bvy bwd
bwf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bwb
bwg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bwe
bwh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bwf
bwi:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bwg
bwj:i32[64,64,2] = concatenate[dimension=2] bwh bwi
bwk:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] bud bwj
bwl:f64[64,64] = mul bvu bwk
bwm:f64[64,64] = add bvp bwl
bwn:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bum
bwo:f64[64,64] = squeeze[dimensions=(0,)] bwn
bwp:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bun
bwq:f64[64,64] = squeeze[dimensions=(0,)] bwp
bwr:f64[64,64] = mul bwo bwq
bws:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] but
bwt:i64[64,64] = squeeze[dimensions=(0,)] bws
bwu:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bur
bwv:i64[64,64] = squeeze[dimensions=(0,)] bwu
bww:bool[64,64] = lt bwt 0:i64[]
bwx:i64[64,64] = add bwt 64:i64[]
bwy:i64[64,64] = select_n bww bwt bwx
bwz:bool[64,64] = lt bwv 0:i64[]
bxa:i64[64,64] = add bwv 64:i64[]
bxb:i64[64,64] = select_n bwz bwv bxa
bxc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bwy
bxd:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bxb
bxe:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bxc
bxf:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bxd
bxg:i32[64,64,2] = concatenate[dimension=2] bxe bxf
bxh:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] bud bxg
bxi:f64[64,64] = mul bwr bxh
bxj:f64[64,64] = add bwm bxi
bxk:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bum
bxl:f64[64,64] = squeeze[dimensions=(0,)] bxk
bxm:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bum
bxn:f64[64,64] = squeeze[dimensions=(0,)] bxm
bxo:f64[64,64] = mul bxl bxn
bxp:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] but
bxq:i64[64,64] = squeeze[dimensions=(0,)] bxp
bxr:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] but
bxs:i64[64,64] = squeeze[dimensions=(0,)] bxr
bxt:bool[64,64] = lt bxq 0:i64[]
bxu:i64[64,64] = add bxq 64:i64[]
bxv:i64[64,64] = select_n bxt bxq bxu
bxw:bool[64,64] = lt bxs 0:i64[]
bxx:i64[64,64] = add bxs 64:i64[]
bxy:i64[64,64] = select_n bxw bxs bxx
bxz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bxv
bya:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bxy
byb:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bxz
byc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bya
byd:i32[64,64,2] = concatenate[dimension=2] byb byc
bye:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] bud byd
byf:f64[64,64] = mul bxo bye
byg:f64[64,64] = add bxj byf
byh:f64[64,64] = sub bdg byg
byi:f64[64,64] = div byh 2.0:f64[]
byj:f64[64,64] = add bdg byi
byk:f64[] = neg eb
byl:f64[] = convert_element_type[new_dtype=float64 weak_type=False] byk
bym:f64[2,64,64] = mul byl bdr
byn:f64[2,64,64] = add ea bym
byo:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
byp:f64[2,64,64] = div byn byo
byq:f64[2,64,64] = floor byp
byr:f64[2,64,64] = sub byp byq
bys:f64[2,64,64] = sub 1.0:f64[] byr
byt:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] bg
byu:f64[2,64,64] = sub byp byr
byv:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] byu
byw:i64[2,64,64] = jit[name=remainder jaxpr=remainder] byv byt
byx:i64[2,64,64] = add byw 1:i64[]
byy:i64[2,64,64] = jit[name=remainder jaxpr=remainder] byx byt
byz:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bys
bza:f64[64,64] = squeeze[dimensions=(0,)] byz
bzb:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bys
bzc:f64[64,64] = squeeze[dimensions=(0,)] bzb
bzd:f64[64,64] = mul bza bzc
bze:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] byw
bzf:i64[64,64] = squeeze[dimensions=(0,)] bze
bzg:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] byw
bzh:i64[64,64] = squeeze[dimensions=(0,)] bzg
bzi:bool[64,64] = lt bzf 0:i64[]
bzj:i64[64,64] = add bzf 64:i64[]
bzk:i64[64,64] = select_n bzi bzf bzj
bzl:bool[64,64] = lt bzh 0:i64[]
bzm:i64[64,64] = add bzh 64:i64[]
bzn:i64[64,64] = select_n bzl bzh bzm
bzo:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bzk
bzp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bzn
bzq:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bzo
bzr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] bzp
bzs:i32[64,64,2] = concatenate[dimension=2] bzq bzr
bzt:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] byj bzs
bzu:f64[64,64] = mul bzd bzt
bzv:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] bys
bzw:f64[64,64] = squeeze[dimensions=(0,)] bzv
bzx:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] byr
bzy:f64[64,64] = squeeze[dimensions=(0,)] bzx
bzz:f64[64,64] = mul bzw bzy
caa:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] byw
cab:i64[64,64] = squeeze[dimensions=(0,)] caa
cac:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] byy
cad:i64[64,64] = squeeze[dimensions=(0,)] cac
cae:bool[64,64] = lt cab 0:i64[]
caf:i64[64,64] = add cab 64:i64[]
cag:i64[64,64] = select_n cae cab caf
cah:bool[64,64] = lt cad 0:i64[]
cai:i64[64,64] = add cad 64:i64[]
caj:i64[64,64] = select_n cah cad cai
cak:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cag
cal:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] caj
cam:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cak
can:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cal
cao:i32[64,64,2] = concatenate[dimension=2] cam can
cap:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] byj cao
caq:f64[64,64] = mul bzz cap
car:f64[64,64] = add bzu caq
cas:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] byr
cat:f64[64,64] = squeeze[dimensions=(0,)] cas
cau:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] bys
cav:f64[64,64] = squeeze[dimensions=(0,)] cau
caw:f64[64,64] = mul cat cav
cax:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] byy
cay:i64[64,64] = squeeze[dimensions=(0,)] cax
caz:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] byw
cba:i64[64,64] = squeeze[dimensions=(0,)] caz
cbb:bool[64,64] = lt cay 0:i64[]
cbc:i64[64,64] = add cay 64:i64[]
cbd:i64[64,64] = select_n cbb cay cbc
cbe:bool[64,64] = lt cba 0:i64[]
cbf:i64[64,64] = add cba 64:i64[]
cbg:i64[64,64] = select_n cbe cba cbf
cbh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cbd
cbi:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cbg
cbj:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cbh
cbk:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cbi
cbl:i32[64,64,2] = concatenate[dimension=2] cbj cbk
cbm:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] byj cbl
cbn:f64[64,64] = mul caw cbm
cbo:f64[64,64] = add car cbn
cbp:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] byr
cbq:f64[64,64] = squeeze[dimensions=(0,)] cbp
cbr:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] byr
cbs:f64[64,64] = squeeze[dimensions=(0,)] cbr
cbt:f64[64,64] = mul cbq cbs
cbu:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] byy
cbv:i64[64,64] = squeeze[dimensions=(0,)] cbu
cbw:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] byy
cbx:i64[64,64] = squeeze[dimensions=(0,)] cbw
cby:bool[64,64] = lt cbv 0:i64[]
cbz:i64[64,64] = add cbv 64:i64[]
cca:i64[64,64] = select_n cby cbv cbz
ccb:bool[64,64] = lt cbx 0:i64[]
ccc:i64[64,64] = add cbx 64:i64[]
ccd:i64[64,64] = select_n ccb cbx ccc
cce:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cca
ccf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ccd
ccg:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cce
cch:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ccf
cci:i32[64,64,2] = concatenate[dimension=2] ccg cch
ccj:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] byj cci
cck:f64[64,64] = mul cbt ccj
ccl:f64[64,64] = add cbo cck
ccm:c128[64,33] = jit[name=fft jaxpr=fft] bqb
ccn:c128[] = reduce_prod[axes=(0,)] bh
cco:c128[] = sqrt ccn
ccp:c128[] = div (1+0j):c128[] cco
ccq:c128[64,33] = mul ccm ccp
ccr:c128[1,64,33] = broadcast_in_dim[
broadcast_dimensions=(1, 2)
shape=(1, 64, 33)
sharding=None
] ccq
ccs:c128[2,64,33] = mul dw ccr
cct:f64[2,64,64] = jit[name=fft jaxpr=fft1] ccs
ccu:f64[] = reduce_prod[axes=(0,)] bi
ccv:f64[] = sqrt ccu
ccw:f64[2,64,64] = mul cct ccv
ccx:f64[] = neg eb
ccy:f64[] = convert_element_type[new_dtype=float64 weak_type=False] ccx
ccz:f64[2,64,64] = mul ccy ccw
cda:f64[2,64,64] = add ea ccz
cdb:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
cdc:f64[2,64,64] = div cda cdb
cdd:f64[2,64,64] = floor cdc
cde:f64[2,64,64] = sub cdc cdd
cdf:f64[2,64,64] = sub 1.0:f64[] cde
cdg:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] bj
cdh:f64[2,64,64] = sub cdc cde
cdi:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] cdh
cdj:i64[2,64,64] = jit[name=remainder jaxpr=remainder] cdi cdg
cdk:i64[2,64,64] = add cdj 1:i64[]
cdl:i64[2,64,64] = jit[name=remainder jaxpr=remainder] cdk cdg
cdm:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cdf
cdn:f64[64,64] = squeeze[dimensions=(0,)] cdm
cdo:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cdf
cdp:f64[64,64] = squeeze[dimensions=(0,)] cdo
cdq:f64[64,64] = mul cdn cdp
cdr:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cdj
cds:i64[64,64] = squeeze[dimensions=(0,)] cdr
cdt:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cdj
cdu:i64[64,64] = squeeze[dimensions=(0,)] cdt
cdv:bool[64,64] = lt cds 0:i64[]
cdw:i64[64,64] = add cds 64:i64[]
cdx:i64[64,64] = select_n cdv cds cdw
cdy:bool[64,64] = lt cdu 0:i64[]
cdz:i64[64,64] = add cdu 64:i64[]
cea:i64[64,64] = select_n cdy cdu cdz
ceb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cdx
cec:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cea
ced:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ceb
cee:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cec
cef:i32[64,64,2] = concatenate[dimension=2] ced cee
ceg:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] bqb cef
ceh:f64[64,64] = mul cdq ceg
cei:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cdf
cej:f64[64,64] = squeeze[dimensions=(0,)] cei
cek:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cde
cel:f64[64,64] = squeeze[dimensions=(0,)] cek
cem:f64[64,64] = mul cej cel
cen:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cdj
ceo:i64[64,64] = squeeze[dimensions=(0,)] cen
cep:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cdl
ceq:i64[64,64] = squeeze[dimensions=(0,)] cep
cer:bool[64,64] = lt ceo 0:i64[]
ces:i64[64,64] = add ceo 64:i64[]
cet:i64[64,64] = select_n cer ceo ces
ceu:bool[64,64] = lt ceq 0:i64[]
cev:i64[64,64] = add ceq 64:i64[]
cew:i64[64,64] = select_n ceu ceq cev
cex:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cet
cey:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cew
cez:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cex
cfa:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cey
cfb:i32[64,64,2] = concatenate[dimension=2] cez cfa
cfc:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] bqb cfb
cfd:f64[64,64] = mul cem cfc
cfe:f64[64,64] = add ceh cfd
cff:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cde
cfg:f64[64,64] = squeeze[dimensions=(0,)] cff
cfh:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cdf
cfi:f64[64,64] = squeeze[dimensions=(0,)] cfh
cfj:f64[64,64] = mul cfg cfi
cfk:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cdl
cfl:i64[64,64] = squeeze[dimensions=(0,)] cfk
cfm:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cdj
cfn:i64[64,64] = squeeze[dimensions=(0,)] cfm
cfo:bool[64,64] = lt cfl 0:i64[]
cfp:i64[64,64] = add cfl 64:i64[]
cfq:i64[64,64] = select_n cfo cfl cfp
cfr:bool[64,64] = lt cfn 0:i64[]
cfs:i64[64,64] = add cfn 64:i64[]
cft:i64[64,64] = select_n cfr cfn cfs
cfu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cfq
cfv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cft
cfw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cfu
cfx:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cfv
cfy:i32[64,64,2] = concatenate[dimension=2] cfw cfx
cfz:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] bqb cfy
cga:f64[64,64] = mul cfj cfz
cgb:f64[64,64] = add cfe cga
cgc:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cde
cgd:f64[64,64] = squeeze[dimensions=(0,)] cgc
cge:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cde
cgf:f64[64,64] = squeeze[dimensions=(0,)] cge
cgg:f64[64,64] = mul cgd cgf
cgh:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cdl
cgi:i64[64,64] = squeeze[dimensions=(0,)] cgh
cgj:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cdl
cgk:i64[64,64] = squeeze[dimensions=(0,)] cgj
cgl:bool[64,64] = lt cgi 0:i64[]
cgm:i64[64,64] = add cgi 64:i64[]
cgn:i64[64,64] = select_n cgl cgi cgm
cgo:bool[64,64] = lt cgk 0:i64[]
cgp:i64[64,64] = add cgk 64:i64[]
cgq:i64[64,64] = select_n cgo cgk cgp
cgr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cgn
cgs:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cgq
cgt:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cgr
cgu:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cgs
cgv:i32[64,64,2] = concatenate[dimension=2] cgt cgu
cgw:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] bqb cgv
cgx:f64[64,64] = mul cgg cgw
cgy:f64[64,64] = add cgb cgx
cgz:f64[2,64,64] = neg ccw
cha:f64[] = neg eb
chb:f64[] = convert_element_type[new_dtype=float64 weak_type=False] cha
chc:f64[2,64,64] = mul chb cgz
chd:f64[2,64,64] = add ea chc
che:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
chf:f64[2,64,64] = div chd che
chg:f64[2,64,64] = floor chf
chh:f64[2,64,64] = sub chf chg
chi:f64[2,64,64] = sub 1.0:f64[] chh
chj:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] bk
chk:f64[2,64,64] = sub chf chh
chl:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] chk
chm:i64[2,64,64] = jit[name=remainder jaxpr=remainder] chl chj
chn:i64[2,64,64] = add chm 1:i64[]
cho:i64[2,64,64] = jit[name=remainder jaxpr=remainder] chn chj
chp:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] chi
chq:f64[64,64] = squeeze[dimensions=(0,)] chp
chr:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] chi
chs:f64[64,64] = squeeze[dimensions=(0,)] chr
cht:f64[64,64] = mul chq chs
chu:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] chm
chv:i64[64,64] = squeeze[dimensions=(0,)] chu
chw:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] chm
chx:i64[64,64] = squeeze[dimensions=(0,)] chw
chy:bool[64,64] = lt chv 0:i64[]
chz:i64[64,64] = add chv 64:i64[]
cia:i64[64,64] = select_n chy chv chz
cib:bool[64,64] = lt chx 0:i64[]
cic:i64[64,64] = add chx 64:i64[]
cid:i64[64,64] = select_n cib chx cic
cie:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cia
cif:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cid
cig:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cie
cih:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cif
cii:i32[64,64,2] = concatenate[dimension=2] cig cih
cij:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] cgy cii
cik:f64[64,64] = mul cht cij
cil:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] chi
cim:f64[64,64] = squeeze[dimensions=(0,)] cil
cin:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] chh
cio:f64[64,64] = squeeze[dimensions=(0,)] cin
cip:f64[64,64] = mul cim cio
ciq:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] chm
cir:i64[64,64] = squeeze[dimensions=(0,)] ciq
cis:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cho
cit:i64[64,64] = squeeze[dimensions=(0,)] cis
ciu:bool[64,64] = lt cir 0:i64[]
civ:i64[64,64] = add cir 64:i64[]
ciw:i64[64,64] = select_n ciu cir civ
cix:bool[64,64] = lt cit 0:i64[]
ciy:i64[64,64] = add cit 64:i64[]
ciz:i64[64,64] = select_n cix cit ciy
cja:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ciw
cjb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ciz
cjc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cja
cjd:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cjb
cje:i32[64,64,2] = concatenate[dimension=2] cjc cjd
cjf:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] cgy cje
cjg:f64[64,64] = mul cip cjf
cjh:f64[64,64] = add cik cjg
cji:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] chh
cjj:f64[64,64] = squeeze[dimensions=(0,)] cji
cjk:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] chi
cjl:f64[64,64] = squeeze[dimensions=(0,)] cjk
cjm:f64[64,64] = mul cjj cjl
cjn:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cho
cjo:i64[64,64] = squeeze[dimensions=(0,)] cjn
cjp:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] chm
cjq:i64[64,64] = squeeze[dimensions=(0,)] cjp
cjr:bool[64,64] = lt cjo 0:i64[]
cjs:i64[64,64] = add cjo 64:i64[]
cjt:i64[64,64] = select_n cjr cjo cjs
cju:bool[64,64] = lt cjq 0:i64[]
cjv:i64[64,64] = add cjq 64:i64[]
cjw:i64[64,64] = select_n cju cjq cjv
cjx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cjt
cjy:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cjw
cjz:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cjx
cka:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cjy
ckb:i32[64,64,2] = concatenate[dimension=2] cjz cka
ckc:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] cgy ckb
ckd:f64[64,64] = mul cjm ckc
cke:f64[64,64] = add cjh ckd
ckf:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] chh
ckg:f64[64,64] = squeeze[dimensions=(0,)] ckf
ckh:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] chh
cki:f64[64,64] = squeeze[dimensions=(0,)] ckh
ckj:f64[64,64] = mul ckg cki
ckk:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cho
ckl:i64[64,64] = squeeze[dimensions=(0,)] ckk
ckm:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cho
ckn:i64[64,64] = squeeze[dimensions=(0,)] ckm
cko:bool[64,64] = lt ckl 0:i64[]
ckp:i64[64,64] = add ckl 64:i64[]
ckq:i64[64,64] = select_n cko ckl ckp
ckr:bool[64,64] = lt ckn 0:i64[]
cks:i64[64,64] = add ckn 64:i64[]
ckt:i64[64,64] = select_n ckr ckn cks
cku:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ckq
ckv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ckt
ckw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cku
ckx:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ckv
cky:i32[64,64,2] = concatenate[dimension=2] ckw ckx
ckz:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] cgy cky
cla:f64[64,64] = mul ckj ckz
clb:f64[64,64] = add cke cla
clc:f64[64,64] = sub bqb clb
cld:f64[64,64] = div clc 2.0:f64[]
cle:f64[64,64] = add bqb cld
clf:f64[] = neg eb
clg:f64[] = convert_element_type[new_dtype=float64 weak_type=False] clf
clh:f64[2,64,64] = mul clg ccw
cli:f64[2,64,64] = add ea clh
clj:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
clk:f64[2,64,64] = div cli clj
cll:f64[2,64,64] = floor clk
clm:f64[2,64,64] = sub clk cll
cln:f64[2,64,64] = sub 1.0:f64[] clm
clo:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] bl
clp:f64[2,64,64] = sub clk clm
clq:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] clp
clr:i64[2,64,64] = jit[name=remainder jaxpr=remainder] clq clo
cls:i64[2,64,64] = add clr 1:i64[]
clt:i64[2,64,64] = jit[name=remainder jaxpr=remainder] cls clo
clu:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cln
clv:f64[64,64] = squeeze[dimensions=(0,)] clu
clw:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cln
clx:f64[64,64] = squeeze[dimensions=(0,)] clw
cly:f64[64,64] = mul clv clx
clz:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] clr
cma:i64[64,64] = squeeze[dimensions=(0,)] clz
cmb:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] clr
cmc:i64[64,64] = squeeze[dimensions=(0,)] cmb
cmd:bool[64,64] = lt cma 0:i64[]
cme:i64[64,64] = add cma 64:i64[]
cmf:i64[64,64] = select_n cmd cma cme
cmg:bool[64,64] = lt cmc 0:i64[]
cmh:i64[64,64] = add cmc 64:i64[]
cmi:i64[64,64] = select_n cmg cmc cmh
cmj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cmf
cmk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cmi
cml:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cmj
cmm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cmk
cmn:i32[64,64,2] = concatenate[dimension=2] cml cmm
cmo:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] cle cmn
cmp:f64[64,64] = mul cly cmo
cmq:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cln
cmr:f64[64,64] = squeeze[dimensions=(0,)] cmq
cms:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] clm
cmt:f64[64,64] = squeeze[dimensions=(0,)] cms
cmu:f64[64,64] = mul cmr cmt
cmv:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] clr
cmw:i64[64,64] = squeeze[dimensions=(0,)] cmv
cmx:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] clt
cmy:i64[64,64] = squeeze[dimensions=(0,)] cmx
cmz:bool[64,64] = lt cmw 0:i64[]
cna:i64[64,64] = add cmw 64:i64[]
cnb:i64[64,64] = select_n cmz cmw cna
cnc:bool[64,64] = lt cmy 0:i64[]
cnd:i64[64,64] = add cmy 64:i64[]
cne:i64[64,64] = select_n cnc cmy cnd
cnf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cnb
cng:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cne
cnh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cnf
cni:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cng
cnj:i32[64,64,2] = concatenate[dimension=2] cnh cni
cnk:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] cle cnj
cnl:f64[64,64] = mul cmu cnk
cnm:f64[64,64] = add cmp cnl
cnn:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] clm
cno:f64[64,64] = squeeze[dimensions=(0,)] cnn
cnp:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cln
cnq:f64[64,64] = squeeze[dimensions=(0,)] cnp
cnr:f64[64,64] = mul cno cnq
cns:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] clt
cnt:i64[64,64] = squeeze[dimensions=(0,)] cns
cnu:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] clr
cnv:i64[64,64] = squeeze[dimensions=(0,)] cnu
cnw:bool[64,64] = lt cnt 0:i64[]
cnx:i64[64,64] = add cnt 64:i64[]
cny:i64[64,64] = select_n cnw cnt cnx
cnz:bool[64,64] = lt cnv 0:i64[]
coa:i64[64,64] = add cnv 64:i64[]
cob:i64[64,64] = select_n cnz cnv coa
coc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cny
cod:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cob
coe:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] coc
cof:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cod
cog:i32[64,64,2] = concatenate[dimension=2] coe cof
coh:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] cle cog
coi:f64[64,64] = mul cnr coh
coj:f64[64,64] = add cnm coi
cok:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] clm
col:f64[64,64] = squeeze[dimensions=(0,)] cok
com:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] clm
con:f64[64,64] = squeeze[dimensions=(0,)] com
coo:f64[64,64] = mul col con
cop:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] clt
coq:i64[64,64] = squeeze[dimensions=(0,)] cop
cor:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] clt
cos:i64[64,64] = squeeze[dimensions=(0,)] cor
cot:bool[64,64] = lt coq 0:i64[]
cou:i64[64,64] = add coq 64:i64[]
cov:i64[64,64] = select_n cot coq cou
cow:bool[64,64] = lt cos 0:i64[]
cox:i64[64,64] = add cos 64:i64[]
coy:i64[64,64] = select_n cow cos cox
coz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cov
cpa:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] coy
cpb:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] coz
cpc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cpa
cpd:i32[64,64,2] = concatenate[dimension=2] cpb cpc
cpe:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] cle cpd
cpf:f64[64,64] = mul coo cpe
cpg:f64[64,64] = add coj cpf
cph:f64[] = neg eb
cpi:f64[] = convert_element_type[new_dtype=float64 weak_type=False] cph
cpj:f64[2,64,64] = mul cpi ccw
cpk:f64[2,64,64] = add ea cpj
cpl:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
cpm:f64[2,64,64] = div cpk cpl
cpn:f64[2,64,64] = floor cpm
cpo:f64[2,64,64] = sub cpm cpn
cpp:f64[2,64,64] = sub 1.0:f64[] cpo
cpq:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] bm
cpr:f64[2,64,64] = sub cpm cpo
cps:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] cpr
cpt:i64[2,64,64] = jit[name=remainder jaxpr=remainder] cps cpq
cpu:i64[2,64,64] = add cpt 1:i64[]
cpv:i64[2,64,64] = jit[name=remainder jaxpr=remainder] cpu cpq
cpw:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cpp
cpx:f64[64,64] = squeeze[dimensions=(0,)] cpw
cpy:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cpp
cpz:f64[64,64] = squeeze[dimensions=(0,)] cpy
cqa:f64[64,64] = mul cpx cpz
cqb:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cpt
cqc:i64[64,64] = squeeze[dimensions=(0,)] cqb
cqd:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cpt
cqe:i64[64,64] = squeeze[dimensions=(0,)] cqd
cqf:bool[64,64] = lt cqc 0:i64[]
cqg:i64[64,64] = add cqc 64:i64[]
cqh:i64[64,64] = select_n cqf cqc cqg
cqi:bool[64,64] = lt cqe 0:i64[]
cqj:i64[64,64] = add cqe 64:i64[]
cqk:i64[64,64] = select_n cqi cqe cqj
cql:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cqh
cqm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cqk
cqn:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cql
cqo:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cqm
cqp:i32[64,64,2] = concatenate[dimension=2] cqn cqo
cqq:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ccl cqp
cqr:f64[64,64] = mul cqa cqq
cqs:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cpp
cqt:f64[64,64] = squeeze[dimensions=(0,)] cqs
cqu:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cpo
cqv:f64[64,64] = squeeze[dimensions=(0,)] cqu
cqw:f64[64,64] = mul cqt cqv
cqx:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cpt
cqy:i64[64,64] = squeeze[dimensions=(0,)] cqx
cqz:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cpv
cra:i64[64,64] = squeeze[dimensions=(0,)] cqz
crb:bool[64,64] = lt cqy 0:i64[]
crc:i64[64,64] = add cqy 64:i64[]
crd:i64[64,64] = select_n crb cqy crc
cre:bool[64,64] = lt cra 0:i64[]
crf:i64[64,64] = add cra 64:i64[]
crg:i64[64,64] = select_n cre cra crf
crh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] crd
cri:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] crg
crj:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] crh
crk:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cri
crl:i32[64,64,2] = concatenate[dimension=2] crj crk
crm:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ccl crl
crn:f64[64,64] = mul cqw crm
cro:f64[64,64] = add cqr crn
crp:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cpo
crq:f64[64,64] = squeeze[dimensions=(0,)] crp
crr:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cpp
crs:f64[64,64] = squeeze[dimensions=(0,)] crr
crt:f64[64,64] = mul crq crs
cru:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cpv
crv:i64[64,64] = squeeze[dimensions=(0,)] cru
crw:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cpt
crx:i64[64,64] = squeeze[dimensions=(0,)] crw
cry:bool[64,64] = lt crv 0:i64[]
crz:i64[64,64] = add crv 64:i64[]
csa:i64[64,64] = select_n cry crv crz
csb:bool[64,64] = lt crx 0:i64[]
csc:i64[64,64] = add crx 64:i64[]
csd:i64[64,64] = select_n csb crx csc
cse:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] csa
csf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] csd
csg:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cse
csh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] csf
csi:i32[64,64,2] = concatenate[dimension=2] csg csh
csj:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ccl csi
csk:f64[64,64] = mul crt csj
csl:f64[64,64] = add cro csk
csm:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cpo
csn:f64[64,64] = squeeze[dimensions=(0,)] csm
cso:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cpo
csp:f64[64,64] = squeeze[dimensions=(0,)] cso
csq:f64[64,64] = mul csn csp
csr:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cpv
css:i64[64,64] = squeeze[dimensions=(0,)] csr
cst:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cpv
csu:i64[64,64] = squeeze[dimensions=(0,)] cst
csv:bool[64,64] = lt css 0:i64[]
csw:i64[64,64] = add css 64:i64[]
csx:i64[64,64] = select_n csv css csw
csy:bool[64,64] = lt csu 0:i64[]
csz:i64[64,64] = add csu 64:i64[]
cta:i64[64,64] = select_n csy csu csz
ctb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] csx
ctc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cta
ctd:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ctb
cte:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ctc
ctf:i32[64,64,2] = concatenate[dimension=2] ctd cte
ctg:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ccl ctf
cth:f64[64,64] = mul csq ctg
cti:f64[64,64] = add csl cth
ctj:f64[2,64,64] = neg ccw
ctk:f64[] = neg eb
ctl:f64[] = convert_element_type[new_dtype=float64 weak_type=False] ctk
ctm:f64[2,64,64] = mul ctl ctj
ctn:f64[2,64,64] = add ea ctm
cto:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
ctp:f64[2,64,64] = div ctn cto
ctq:f64[2,64,64] = floor ctp
ctr:f64[2,64,64] = sub ctp ctq
cts:f64[2,64,64] = sub 1.0:f64[] ctr
ctt:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] bn
ctu:f64[2,64,64] = sub ctp ctr
ctv:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ctu
ctw:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ctv ctt
ctx:i64[2,64,64] = add ctw 1:i64[]
cty:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ctx ctt
ctz:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cts
cua:f64[64,64] = squeeze[dimensions=(0,)] ctz
cub:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cts
cuc:f64[64,64] = squeeze[dimensions=(0,)] cub
cud:f64[64,64] = mul cua cuc
cue:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ctw
cuf:i64[64,64] = squeeze[dimensions=(0,)] cue
cug:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ctw
cuh:i64[64,64] = squeeze[dimensions=(0,)] cug
cui:bool[64,64] = lt cuf 0:i64[]
cuj:i64[64,64] = add cuf 64:i64[]
cuk:i64[64,64] = select_n cui cuf cuj
cul:bool[64,64] = lt cuh 0:i64[]
cum:i64[64,64] = add cuh 64:i64[]
cun:i64[64,64] = select_n cul cuh cum
cuo:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cuk
cup:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cun
cuq:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cuo
cur:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cup
cus:i32[64,64,2] = concatenate[dimension=2] cuq cur
cut:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] cti cus
cuu:f64[64,64] = mul cud cut
cuv:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cts
cuw:f64[64,64] = squeeze[dimensions=(0,)] cuv
cux:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ctr
cuy:f64[64,64] = squeeze[dimensions=(0,)] cux
cuz:f64[64,64] = mul cuw cuy
cva:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ctw
cvb:i64[64,64] = squeeze[dimensions=(0,)] cva
cvc:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cty
cvd:i64[64,64] = squeeze[dimensions=(0,)] cvc
cve:bool[64,64] = lt cvb 0:i64[]
cvf:i64[64,64] = add cvb 64:i64[]
cvg:i64[64,64] = select_n cve cvb cvf
cvh:bool[64,64] = lt cvd 0:i64[]
cvi:i64[64,64] = add cvd 64:i64[]
cvj:i64[64,64] = select_n cvh cvd cvi
cvk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cvg
cvl:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cvj
cvm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cvk
cvn:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cvl
cvo:i32[64,64,2] = concatenate[dimension=2] cvm cvn
cvp:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] cti cvo
cvq:f64[64,64] = mul cuz cvp
cvr:f64[64,64] = add cuu cvq
cvs:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ctr
cvt:f64[64,64] = squeeze[dimensions=(0,)] cvs
cvu:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cts
cvv:f64[64,64] = squeeze[dimensions=(0,)] cvu
cvw:f64[64,64] = mul cvt cvv
cvx:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cty
cvy:i64[64,64] = squeeze[dimensions=(0,)] cvx
cvz:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ctw
cwa:i64[64,64] = squeeze[dimensions=(0,)] cvz
cwb:bool[64,64] = lt cvy 0:i64[]
cwc:i64[64,64] = add cvy 64:i64[]
cwd:i64[64,64] = select_n cwb cvy cwc
cwe:bool[64,64] = lt cwa 0:i64[]
cwf:i64[64,64] = add cwa 64:i64[]
cwg:i64[64,64] = select_n cwe cwa cwf
cwh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cwd
cwi:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cwg
cwj:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cwh
cwk:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cwi
cwl:i32[64,64,2] = concatenate[dimension=2] cwj cwk
cwm:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] cti cwl
cwn:f64[64,64] = mul cvw cwm
cwo:f64[64,64] = add cvr cwn
cwp:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ctr
cwq:f64[64,64] = squeeze[dimensions=(0,)] cwp
cwr:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ctr
cws:f64[64,64] = squeeze[dimensions=(0,)] cwr
cwt:f64[64,64] = mul cwq cws
cwu:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cty
cwv:i64[64,64] = squeeze[dimensions=(0,)] cwu
cww:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cty
cwx:i64[64,64] = squeeze[dimensions=(0,)] cww
cwy:bool[64,64] = lt cwv 0:i64[]
cwz:i64[64,64] = add cwv 64:i64[]
cxa:i64[64,64] = select_n cwy cwv cwz
cxb:bool[64,64] = lt cwx 0:i64[]
cxc:i64[64,64] = add cwx 64:i64[]
cxd:i64[64,64] = select_n cxb cwx cxc
cxe:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cxa
cxf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cxd
cxg:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cxe
cxh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cxf
cxi:i32[64,64,2] = concatenate[dimension=2] cxg cxh
cxj:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] cti cxi
cxk:f64[64,64] = mul cwt cxj
cxl:f64[64,64] = add cwo cxk
cxm:f64[64,64] = sub ccl cxl
cxn:f64[64,64] = div cxm 2.0:f64[]
cxo:f64[64,64] = add ccl cxn
cxp:f64[] = neg eb
cxq:f64[] = convert_element_type[new_dtype=float64 weak_type=False] cxp
cxr:f64[2,64,64] = mul cxq ccw
cxs:f64[2,64,64] = add ea cxr
cxt:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
cxu:f64[2,64,64] = div cxs cxt
cxv:f64[2,64,64] = floor cxu
cxw:f64[2,64,64] = sub cxu cxv
cxx:f64[2,64,64] = sub 1.0:f64[] cxw
cxy:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] bo
cxz:f64[2,64,64] = sub cxu cxw
cya:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] cxz
cyb:i64[2,64,64] = jit[name=remainder jaxpr=remainder] cya cxy
cyc:i64[2,64,64] = add cyb 1:i64[]
cyd:i64[2,64,64] = jit[name=remainder jaxpr=remainder] cyc cxy
cye:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cxx
cyf:f64[64,64] = squeeze[dimensions=(0,)] cye
cyg:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cxx
cyh:f64[64,64] = squeeze[dimensions=(0,)] cyg
cyi:f64[64,64] = mul cyf cyh
cyj:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cyb
cyk:i64[64,64] = squeeze[dimensions=(0,)] cyj
cyl:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cyb
cym:i64[64,64] = squeeze[dimensions=(0,)] cyl
cyn:bool[64,64] = lt cyk 0:i64[]
cyo:i64[64,64] = add cyk 64:i64[]
cyp:i64[64,64] = select_n cyn cyk cyo
cyq:bool[64,64] = lt cym 0:i64[]
cyr:i64[64,64] = add cym 64:i64[]
cys:i64[64,64] = select_n cyq cym cyr
cyt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cyp
cyu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cys
cyv:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cyt
cyw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] cyu
cyx:i32[64,64,2] = concatenate[dimension=2] cyv cyw
cyy:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] cxo cyx
cyz:f64[64,64] = mul cyi cyy
cza:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cxx
czb:f64[64,64] = squeeze[dimensions=(0,)] cza
czc:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cxw
czd:f64[64,64] = squeeze[dimensions=(0,)] czc
cze:f64[64,64] = mul czb czd
czf:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cyb
czg:i64[64,64] = squeeze[dimensions=(0,)] czf
czh:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cyd
czi:i64[64,64] = squeeze[dimensions=(0,)] czh
czj:bool[64,64] = lt czg 0:i64[]
czk:i64[64,64] = add czg 64:i64[]
czl:i64[64,64] = select_n czj czg czk
czm:bool[64,64] = lt czi 0:i64[]
czn:i64[64,64] = add czi 64:i64[]
czo:i64[64,64] = select_n czm czi czn
czp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] czl
czq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] czo
czr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] czp
czs:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] czq
czt:i32[64,64,2] = concatenate[dimension=2] czr czs
czu:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] cxo czt
czv:f64[64,64] = mul cze czu
czw:f64[64,64] = add cyz czv
czx:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cxw
czy:f64[64,64] = squeeze[dimensions=(0,)] czx
czz:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cxx
daa:f64[64,64] = squeeze[dimensions=(0,)] czz
dab:f64[64,64] = mul czy daa
dac:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cyd
dad:i64[64,64] = squeeze[dimensions=(0,)] dac
dae:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cyb
daf:i64[64,64] = squeeze[dimensions=(0,)] dae
dag:bool[64,64] = lt dad 0:i64[]
dah:i64[64,64] = add dad 64:i64[]
dai:i64[64,64] = select_n dag dad dah
daj:bool[64,64] = lt daf 0:i64[]
dak:i64[64,64] = add daf 64:i64[]
dal:i64[64,64] = select_n daj daf dak
dam:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dai
dan:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dal
dao:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dam
dap:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dan
daq:i32[64,64,2] = concatenate[dimension=2] dao dap
dar:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] cxo daq
das:f64[64,64] = mul dab dar
dat:f64[64,64] = add czw das
dau:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cxw
dav:f64[64,64] = squeeze[dimensions=(0,)] dau
daw:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cxw
dax:f64[64,64] = squeeze[dimensions=(0,)] daw
day:f64[64,64] = mul dav dax
daz:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] cyd
dba:i64[64,64] = squeeze[dimensions=(0,)] daz
dbb:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] cyd
dbc:i64[64,64] = squeeze[dimensions=(0,)] dbb
dbd:bool[64,64] = lt dba 0:i64[]
dbe:i64[64,64] = add dba 64:i64[]
dbf:i64[64,64] = select_n dbd dba dbe
dbg:bool[64,64] = lt dbc 0:i64[]
dbh:i64[64,64] = add dbc 64:i64[]
dbi:i64[64,64] = select_n dbg dbc dbh
dbj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dbf
dbk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dbi
dbl:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dbj
dbm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dbk
dbn:i32[64,64,2] = concatenate[dimension=2] dbl dbm
dbo:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] cxo dbn
dbp:f64[64,64] = mul day dbo
dbq:f64[64,64] = add dat dbp
dbr:c128[64,33] = jit[name=fft jaxpr=fft] cpg
dbs:c128[] = reduce_prod[axes=(0,)] bp
dbt:c128[] = sqrt dbs
dbu:c128[] = div (1+0j):c128[] dbt
dbv:c128[64,33] = mul dbr dbu
dbw:c128[1,64,33] = broadcast_in_dim[
broadcast_dimensions=(1, 2)
shape=(1, 64, 33)
sharding=None
] dbv
dbx:c128[2,64,33] = mul dw dbw
dby:f64[2,64,64] = jit[name=fft jaxpr=fft1] dbx
dbz:f64[] = reduce_prod[axes=(0,)] bq
dca:f64[] = sqrt dbz
dcb:f64[2,64,64] = mul dby dca
dcc:f64[] = neg eb
dcd:f64[] = convert_element_type[new_dtype=float64 weak_type=False] dcc
dce:f64[2,64,64] = mul dcd dcb
dcf:f64[2,64,64] = add ea dce
dcg:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
dch:f64[2,64,64] = div dcf dcg
dci:f64[2,64,64] = floor dch
dcj:f64[2,64,64] = sub dch dci
dck:f64[2,64,64] = sub 1.0:f64[] dcj
dcl:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] br
dcm:f64[2,64,64] = sub dch dcj
dcn:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] dcm
dco:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dcn dcl
dcp:i64[2,64,64] = add dco 1:i64[]
dcq:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dcp dcl
dcr:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dck
dcs:f64[64,64] = squeeze[dimensions=(0,)] dcr
dct:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dck
dcu:f64[64,64] = squeeze[dimensions=(0,)] dct
dcv:f64[64,64] = mul dcs dcu
dcw:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dco
dcx:i64[64,64] = squeeze[dimensions=(0,)] dcw
dcy:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dco
dcz:i64[64,64] = squeeze[dimensions=(0,)] dcy
dda:bool[64,64] = lt dcx 0:i64[]
ddb:i64[64,64] = add dcx 64:i64[]
ddc:i64[64,64] = select_n dda dcx ddb
ddd:bool[64,64] = lt dcz 0:i64[]
dde:i64[64,64] = add dcz 64:i64[]
ddf:i64[64,64] = select_n ddd dcz dde
ddg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ddc
ddh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ddf
ddi:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ddg
ddj:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ddh
ddk:i32[64,64,2] = concatenate[dimension=2] ddi ddj
ddl:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] cpg ddk
ddm:f64[64,64] = mul dcv ddl
ddn:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dck
ddo:f64[64,64] = squeeze[dimensions=(0,)] ddn
ddp:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dcj
ddq:f64[64,64] = squeeze[dimensions=(0,)] ddp
ddr:f64[64,64] = mul ddo ddq
dds:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dco
ddt:i64[64,64] = squeeze[dimensions=(0,)] dds
ddu:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dcq
ddv:i64[64,64] = squeeze[dimensions=(0,)] ddu
ddw:bool[64,64] = lt ddt 0:i64[]
ddx:i64[64,64] = add ddt 64:i64[]
ddy:i64[64,64] = select_n ddw ddt ddx
ddz:bool[64,64] = lt ddv 0:i64[]
dea:i64[64,64] = add ddv 64:i64[]
deb:i64[64,64] = select_n ddz ddv dea
dec:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ddy
ded:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] deb
dee:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dec
def:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ded
deg:i32[64,64,2] = concatenate[dimension=2] dee def
deh:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] cpg deg
dei:f64[64,64] = mul ddr deh
dej:f64[64,64] = add ddm dei
dek:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dcj
del:f64[64,64] = squeeze[dimensions=(0,)] dek
dem:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dck
den:f64[64,64] = squeeze[dimensions=(0,)] dem
deo:f64[64,64] = mul del den
dep:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dcq
deq:i64[64,64] = squeeze[dimensions=(0,)] dep
der:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dco
des:i64[64,64] = squeeze[dimensions=(0,)] der
det:bool[64,64] = lt deq 0:i64[]
deu:i64[64,64] = add deq 64:i64[]
dev:i64[64,64] = select_n det deq deu
dew:bool[64,64] = lt des 0:i64[]
dex:i64[64,64] = add des 64:i64[]
dey:i64[64,64] = select_n dew des dex
dez:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dev
dfa:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dey
dfb:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dez
dfc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dfa
dfd:i32[64,64,2] = concatenate[dimension=2] dfb dfc
dfe:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] cpg dfd
dff:f64[64,64] = mul deo dfe
dfg:f64[64,64] = add dej dff
dfh:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dcj
dfi:f64[64,64] = squeeze[dimensions=(0,)] dfh
dfj:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dcj
dfk:f64[64,64] = squeeze[dimensions=(0,)] dfj
dfl:f64[64,64] = mul dfi dfk
dfm:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dcq
dfn:i64[64,64] = squeeze[dimensions=(0,)] dfm
dfo:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dcq
dfp:i64[64,64] = squeeze[dimensions=(0,)] dfo
dfq:bool[64,64] = lt dfn 0:i64[]
dfr:i64[64,64] = add dfn 64:i64[]
dfs:i64[64,64] = select_n dfq dfn dfr
dft:bool[64,64] = lt dfp 0:i64[]
dfu:i64[64,64] = add dfp 64:i64[]
dfv:i64[64,64] = select_n dft dfp dfu
dfw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dfs
dfx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dfv
dfy:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dfw
dfz:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dfx
dga:i32[64,64,2] = concatenate[dimension=2] dfy dfz
dgb:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] cpg dga
dgc:f64[64,64] = mul dfl dgb
dgd:f64[64,64] = add dfg dgc
dge:f64[2,64,64] = neg dcb
dgf:f64[] = neg eb
dgg:f64[] = convert_element_type[new_dtype=float64 weak_type=False] dgf
dgh:f64[2,64,64] = mul dgg dge
dgi:f64[2,64,64] = add ea dgh
dgj:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
dgk:f64[2,64,64] = div dgi dgj
dgl:f64[2,64,64] = floor dgk
dgm:f64[2,64,64] = sub dgk dgl
dgn:f64[2,64,64] = sub 1.0:f64[] dgm
dgo:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] bs
dgp:f64[2,64,64] = sub dgk dgm
dgq:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] dgp
dgr:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dgq dgo
dgs:i64[2,64,64] = add dgr 1:i64[]
dgt:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dgs dgo
dgu:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dgn
dgv:f64[64,64] = squeeze[dimensions=(0,)] dgu
dgw:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dgn
dgx:f64[64,64] = squeeze[dimensions=(0,)] dgw
dgy:f64[64,64] = mul dgv dgx
dgz:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dgr
dha:i64[64,64] = squeeze[dimensions=(0,)] dgz
dhb:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dgr
dhc:i64[64,64] = squeeze[dimensions=(0,)] dhb
dhd:bool[64,64] = lt dha 0:i64[]
dhe:i64[64,64] = add dha 64:i64[]
dhf:i64[64,64] = select_n dhd dha dhe
dhg:bool[64,64] = lt dhc 0:i64[]
dhh:i64[64,64] = add dhc 64:i64[]
dhi:i64[64,64] = select_n dhg dhc dhh
dhj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dhf
dhk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dhi
dhl:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dhj
dhm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dhk
dhn:i32[64,64,2] = concatenate[dimension=2] dhl dhm
dho:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dgd dhn
dhp:f64[64,64] = mul dgy dho
dhq:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dgn
dhr:f64[64,64] = squeeze[dimensions=(0,)] dhq
dhs:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dgm
dht:f64[64,64] = squeeze[dimensions=(0,)] dhs
dhu:f64[64,64] = mul dhr dht
dhv:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dgr
dhw:i64[64,64] = squeeze[dimensions=(0,)] dhv
dhx:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dgt
dhy:i64[64,64] = squeeze[dimensions=(0,)] dhx
dhz:bool[64,64] = lt dhw 0:i64[]
dia:i64[64,64] = add dhw 64:i64[]
dib:i64[64,64] = select_n dhz dhw dia
dic:bool[64,64] = lt dhy 0:i64[]
did:i64[64,64] = add dhy 64:i64[]
die:i64[64,64] = select_n dic dhy did
dif:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dib
dig:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] die
dih:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dif
dii:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dig
dij:i32[64,64,2] = concatenate[dimension=2] dih dii
dik:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dgd dij
dil:f64[64,64] = mul dhu dik
dim:f64[64,64] = add dhp dil
din:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dgm
dio:f64[64,64] = squeeze[dimensions=(0,)] din
dip:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dgn
diq:f64[64,64] = squeeze[dimensions=(0,)] dip
dir:f64[64,64] = mul dio diq
dis:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dgt
dit:i64[64,64] = squeeze[dimensions=(0,)] dis
diu:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dgr
div:i64[64,64] = squeeze[dimensions=(0,)] diu
diw:bool[64,64] = lt dit 0:i64[]
dix:i64[64,64] = add dit 64:i64[]
diy:i64[64,64] = select_n diw dit dix
diz:bool[64,64] = lt div 0:i64[]
dja:i64[64,64] = add div 64:i64[]
djb:i64[64,64] = select_n diz div dja
djc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] diy
djd:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] djb
dje:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] djc
djf:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] djd
djg:i32[64,64,2] = concatenate[dimension=2] dje djf
djh:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dgd djg
dji:f64[64,64] = mul dir djh
djj:f64[64,64] = add dim dji
djk:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dgm
djl:f64[64,64] = squeeze[dimensions=(0,)] djk
djm:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dgm
djn:f64[64,64] = squeeze[dimensions=(0,)] djm
djo:f64[64,64] = mul djl djn
djp:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dgt
djq:i64[64,64] = squeeze[dimensions=(0,)] djp
djr:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dgt
djs:i64[64,64] = squeeze[dimensions=(0,)] djr
djt:bool[64,64] = lt djq 0:i64[]
dju:i64[64,64] = add djq 64:i64[]
djv:i64[64,64] = select_n djt djq dju
djw:bool[64,64] = lt djs 0:i64[]
djx:i64[64,64] = add djs 64:i64[]
djy:i64[64,64] = select_n djw djs djx
djz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] djv
dka:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] djy
dkb:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] djz
dkc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dka
dkd:i32[64,64,2] = concatenate[dimension=2] dkb dkc
dke:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dgd dkd
dkf:f64[64,64] = mul djo dke
dkg:f64[64,64] = add djj dkf
dkh:f64[64,64] = sub cpg dkg
dki:f64[64,64] = div dkh 2.0:f64[]
dkj:f64[64,64] = add cpg dki
dkk:f64[] = neg eb
dkl:f64[] = convert_element_type[new_dtype=float64 weak_type=False] dkk
dkm:f64[2,64,64] = mul dkl dcb
dkn:f64[2,64,64] = add ea dkm
dko:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
dkp:f64[2,64,64] = div dkn dko
dkq:f64[2,64,64] = floor dkp
dkr:f64[2,64,64] = sub dkp dkq
dks:f64[2,64,64] = sub 1.0:f64[] dkr
dkt:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] bt
dku:f64[2,64,64] = sub dkp dkr
dkv:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] dku
dkw:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dkv dkt
dkx:i64[2,64,64] = add dkw 1:i64[]
dky:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dkx dkt
dkz:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dks
dla:f64[64,64] = squeeze[dimensions=(0,)] dkz
dlb:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dks
dlc:f64[64,64] = squeeze[dimensions=(0,)] dlb
dld:f64[64,64] = mul dla dlc
dle:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dkw
dlf:i64[64,64] = squeeze[dimensions=(0,)] dle
dlg:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dkw
dlh:i64[64,64] = squeeze[dimensions=(0,)] dlg
dli:bool[64,64] = lt dlf 0:i64[]
dlj:i64[64,64] = add dlf 64:i64[]
dlk:i64[64,64] = select_n dli dlf dlj
dll:bool[64,64] = lt dlh 0:i64[]
dlm:i64[64,64] = add dlh 64:i64[]
dln:i64[64,64] = select_n dll dlh dlm
dlo:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dlk
dlp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dln
dlq:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dlo
dlr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dlp
dls:i32[64,64,2] = concatenate[dimension=2] dlq dlr
dlt:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dkj dls
dlu:f64[64,64] = mul dld dlt
dlv:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dks
dlw:f64[64,64] = squeeze[dimensions=(0,)] dlv
dlx:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dkr
dly:f64[64,64] = squeeze[dimensions=(0,)] dlx
dlz:f64[64,64] = mul dlw dly
dma:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dkw
dmb:i64[64,64] = squeeze[dimensions=(0,)] dma
dmc:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dky
dmd:i64[64,64] = squeeze[dimensions=(0,)] dmc
dme:bool[64,64] = lt dmb 0:i64[]
dmf:i64[64,64] = add dmb 64:i64[]
dmg:i64[64,64] = select_n dme dmb dmf
dmh:bool[64,64] = lt dmd 0:i64[]
dmi:i64[64,64] = add dmd 64:i64[]
dmj:i64[64,64] = select_n dmh dmd dmi
dmk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dmg
dml:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dmj
dmm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dmk
dmn:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dml
dmo:i32[64,64,2] = concatenate[dimension=2] dmm dmn
dmp:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dkj dmo
dmq:f64[64,64] = mul dlz dmp
dmr:f64[64,64] = add dlu dmq
dms:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dkr
dmt:f64[64,64] = squeeze[dimensions=(0,)] dms
dmu:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dks
dmv:f64[64,64] = squeeze[dimensions=(0,)] dmu
dmw:f64[64,64] = mul dmt dmv
dmx:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dky
dmy:i64[64,64] = squeeze[dimensions=(0,)] dmx
dmz:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dkw
dna:i64[64,64] = squeeze[dimensions=(0,)] dmz
dnb:bool[64,64] = lt dmy 0:i64[]
dnc:i64[64,64] = add dmy 64:i64[]
dnd:i64[64,64] = select_n dnb dmy dnc
dne:bool[64,64] = lt dna 0:i64[]
dnf:i64[64,64] = add dna 64:i64[]
dng:i64[64,64] = select_n dne dna dnf
dnh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dnd
dni:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dng
dnj:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dnh
dnk:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dni
dnl:i32[64,64,2] = concatenate[dimension=2] dnj dnk
dnm:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dkj dnl
dnn:f64[64,64] = mul dmw dnm
dno:f64[64,64] = add dmr dnn
dnp:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dkr
dnq:f64[64,64] = squeeze[dimensions=(0,)] dnp
dnr:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dkr
dns:f64[64,64] = squeeze[dimensions=(0,)] dnr
dnt:f64[64,64] = mul dnq dns
dnu:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dky
dnv:i64[64,64] = squeeze[dimensions=(0,)] dnu
dnw:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dky
dnx:i64[64,64] = squeeze[dimensions=(0,)] dnw
dny:bool[64,64] = lt dnv 0:i64[]
dnz:i64[64,64] = add dnv 64:i64[]
doa:i64[64,64] = select_n dny dnv dnz
dob:bool[64,64] = lt dnx 0:i64[]
doc:i64[64,64] = add dnx 64:i64[]
dod:i64[64,64] = select_n dob dnx doc
doe:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] doa
dof:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dod
dog:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] doe
doh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dof
doi:i32[64,64,2] = concatenate[dimension=2] dog doh
doj:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dkj doi
dok:f64[64,64] = mul dnt doj
dol:f64[64,64] = add dno dok
dom:f64[] = neg eb
don:f64[] = convert_element_type[new_dtype=float64 weak_type=False] dom
doo:f64[2,64,64] = mul don dcb
dop:f64[2,64,64] = add ea doo
doq:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
dor:f64[2,64,64] = div dop doq
dos:f64[2,64,64] = floor dor
dot:f64[2,64,64] = sub dor dos
dou:f64[2,64,64] = sub 1.0:f64[] dot
dov:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] bu
dow:f64[2,64,64] = sub dor dot
dox:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] dow
doy:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dox dov
doz:i64[2,64,64] = add doy 1:i64[]
dpa:i64[2,64,64] = jit[name=remainder jaxpr=remainder] doz dov
dpb:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dou
dpc:f64[64,64] = squeeze[dimensions=(0,)] dpb
dpd:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dou
dpe:f64[64,64] = squeeze[dimensions=(0,)] dpd
dpf:f64[64,64] = mul dpc dpe
dpg:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] doy
dph:i64[64,64] = squeeze[dimensions=(0,)] dpg
dpi:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] doy
dpj:i64[64,64] = squeeze[dimensions=(0,)] dpi
dpk:bool[64,64] = lt dph 0:i64[]
dpl:i64[64,64] = add dph 64:i64[]
dpm:i64[64,64] = select_n dpk dph dpl
dpn:bool[64,64] = lt dpj 0:i64[]
dpo:i64[64,64] = add dpj 64:i64[]
dpp:i64[64,64] = select_n dpn dpj dpo
dpq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dpm
dpr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dpp
dps:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dpq
dpt:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dpr
dpu:i32[64,64,2] = concatenate[dimension=2] dps dpt
dpv:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dbq dpu
dpw:f64[64,64] = mul dpf dpv
dpx:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dou
dpy:f64[64,64] = squeeze[dimensions=(0,)] dpx
dpz:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dot
dqa:f64[64,64] = squeeze[dimensions=(0,)] dpz
dqb:f64[64,64] = mul dpy dqa
dqc:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] doy
dqd:i64[64,64] = squeeze[dimensions=(0,)] dqc
dqe:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dpa
dqf:i64[64,64] = squeeze[dimensions=(0,)] dqe
dqg:bool[64,64] = lt dqd 0:i64[]
dqh:i64[64,64] = add dqd 64:i64[]
dqi:i64[64,64] = select_n dqg dqd dqh
dqj:bool[64,64] = lt dqf 0:i64[]
dqk:i64[64,64] = add dqf 64:i64[]
dql:i64[64,64] = select_n dqj dqf dqk
dqm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dqi
dqn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dql
dqo:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dqm
dqp:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dqn
dqq:i32[64,64,2] = concatenate[dimension=2] dqo dqp
dqr:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dbq dqq
dqs:f64[64,64] = mul dqb dqr
dqt:f64[64,64] = add dpw dqs
dqu:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dot
dqv:f64[64,64] = squeeze[dimensions=(0,)] dqu
dqw:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dou
dqx:f64[64,64] = squeeze[dimensions=(0,)] dqw
dqy:f64[64,64] = mul dqv dqx
dqz:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dpa
dra:i64[64,64] = squeeze[dimensions=(0,)] dqz
drb:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] doy
drc:i64[64,64] = squeeze[dimensions=(0,)] drb
drd:bool[64,64] = lt dra 0:i64[]
dre:i64[64,64] = add dra 64:i64[]
drf:i64[64,64] = select_n drd dra dre
drg:bool[64,64] = lt drc 0:i64[]
drh:i64[64,64] = add drc 64:i64[]
dri:i64[64,64] = select_n drg drc drh
drj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] drf
drk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dri
drl:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] drj
drm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] drk
drn:i32[64,64,2] = concatenate[dimension=2] drl drm
dro:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dbq drn
drp:f64[64,64] = mul dqy dro
drq:f64[64,64] = add dqt drp
drr:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dot
drs:f64[64,64] = squeeze[dimensions=(0,)] drr
drt:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dot
dru:f64[64,64] = squeeze[dimensions=(0,)] drt
drv:f64[64,64] = mul drs dru
drw:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dpa
drx:i64[64,64] = squeeze[dimensions=(0,)] drw
dry:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dpa
drz:i64[64,64] = squeeze[dimensions=(0,)] dry
dsa:bool[64,64] = lt drx 0:i64[]
dsb:i64[64,64] = add drx 64:i64[]
dsc:i64[64,64] = select_n dsa drx dsb
dsd:bool[64,64] = lt drz 0:i64[]
dse:i64[64,64] = add drz 64:i64[]
dsf:i64[64,64] = select_n dsd drz dse
dsg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dsc
dsh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dsf
dsi:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dsg
dsj:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dsh
dsk:i32[64,64,2] = concatenate[dimension=2] dsi dsj
dsl:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dbq dsk
dsm:f64[64,64] = mul drv dsl
dsn:f64[64,64] = add drq dsm
dso:f64[2,64,64] = neg dcb
dsp:f64[] = neg eb
dsq:f64[] = convert_element_type[new_dtype=float64 weak_type=False] dsp
dsr:f64[2,64,64] = mul dsq dso
dss:f64[2,64,64] = add ea dsr
dst:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
dsu:f64[2,64,64] = div dss dst
dsv:f64[2,64,64] = floor dsu
dsw:f64[2,64,64] = sub dsu dsv
dsx:f64[2,64,64] = sub 1.0:f64[] dsw
dsy:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] bv
dsz:f64[2,64,64] = sub dsu dsw
dta:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] dsz
dtb:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dta dsy
dtc:i64[2,64,64] = add dtb 1:i64[]
dtd:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dtc dsy
dte:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dsx
dtf:f64[64,64] = squeeze[dimensions=(0,)] dte
dtg:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dsx
dth:f64[64,64] = squeeze[dimensions=(0,)] dtg
dti:f64[64,64] = mul dtf dth
dtj:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dtb
dtk:i64[64,64] = squeeze[dimensions=(0,)] dtj
dtl:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dtb
dtm:i64[64,64] = squeeze[dimensions=(0,)] dtl
dtn:bool[64,64] = lt dtk 0:i64[]
dto:i64[64,64] = add dtk 64:i64[]
dtp:i64[64,64] = select_n dtn dtk dto
dtq:bool[64,64] = lt dtm 0:i64[]
dtr:i64[64,64] = add dtm 64:i64[]
dts:i64[64,64] = select_n dtq dtm dtr
dtt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dtp
dtu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dts
dtv:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dtt
dtw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dtu
dtx:i32[64,64,2] = concatenate[dimension=2] dtv dtw
dty:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dsn dtx
dtz:f64[64,64] = mul dti dty
dua:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dsx
dub:f64[64,64] = squeeze[dimensions=(0,)] dua
duc:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dsw
dud:f64[64,64] = squeeze[dimensions=(0,)] duc
due:f64[64,64] = mul dub dud
duf:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dtb
dug:i64[64,64] = squeeze[dimensions=(0,)] duf
duh:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dtd
dui:i64[64,64] = squeeze[dimensions=(0,)] duh
duj:bool[64,64] = lt dug 0:i64[]
duk:i64[64,64] = add dug 64:i64[]
dul:i64[64,64] = select_n duj dug duk
dum:bool[64,64] = lt dui 0:i64[]
dun:i64[64,64] = add dui 64:i64[]
duo:i64[64,64] = select_n dum dui dun
dup:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dul
duq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] duo
dur:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dup
dus:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] duq
dut:i32[64,64,2] = concatenate[dimension=2] dur dus
duu:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dsn dut
duv:f64[64,64] = mul due duu
duw:f64[64,64] = add dtz duv
dux:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dsw
duy:f64[64,64] = squeeze[dimensions=(0,)] dux
duz:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dsx
dva:f64[64,64] = squeeze[dimensions=(0,)] duz
dvb:f64[64,64] = mul duy dva
dvc:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dtd
dvd:i64[64,64] = squeeze[dimensions=(0,)] dvc
dve:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dtb
dvf:i64[64,64] = squeeze[dimensions=(0,)] dve
dvg:bool[64,64] = lt dvd 0:i64[]
dvh:i64[64,64] = add dvd 64:i64[]
dvi:i64[64,64] = select_n dvg dvd dvh
dvj:bool[64,64] = lt dvf 0:i64[]
dvk:i64[64,64] = add dvf 64:i64[]
dvl:i64[64,64] = select_n dvj dvf dvk
dvm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dvi
dvn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dvl
dvo:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dvm
dvp:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dvn
dvq:i32[64,64,2] = concatenate[dimension=2] dvo dvp
dvr:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dsn dvq
dvs:f64[64,64] = mul dvb dvr
dvt:f64[64,64] = add duw dvs
dvu:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dsw
dvv:f64[64,64] = squeeze[dimensions=(0,)] dvu
dvw:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dsw
dvx:f64[64,64] = squeeze[dimensions=(0,)] dvw
dvy:f64[64,64] = mul dvv dvx
dvz:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dtd
dwa:i64[64,64] = squeeze[dimensions=(0,)] dvz
dwb:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dtd
dwc:i64[64,64] = squeeze[dimensions=(0,)] dwb
dwd:bool[64,64] = lt dwa 0:i64[]
dwe:i64[64,64] = add dwa 64:i64[]
dwf:i64[64,64] = select_n dwd dwa dwe
dwg:bool[64,64] = lt dwc 0:i64[]
dwh:i64[64,64] = add dwc 64:i64[]
dwi:i64[64,64] = select_n dwg dwc dwh
dwj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dwf
dwk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dwi
dwl:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dwj
dwm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dwk
dwn:i32[64,64,2] = concatenate[dimension=2] dwl dwm
dwo:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dsn dwn
dwp:f64[64,64] = mul dvy dwo
dwq:f64[64,64] = add dvt dwp
dwr:f64[64,64] = sub dbq dwq
dws:f64[64,64] = div dwr 2.0:f64[]
dwt:f64[64,64] = add dbq dws
dwu:f64[] = neg eb
dwv:f64[] = convert_element_type[new_dtype=float64 weak_type=False] dwu
dww:f64[2,64,64] = mul dwv dcb
dwx:f64[2,64,64] = add ea dww
dwy:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
dwz:f64[2,64,64] = div dwx dwy
dxa:f64[2,64,64] = floor dwz
dxb:f64[2,64,64] = sub dwz dxa
dxc:f64[2,64,64] = sub 1.0:f64[] dxb
dxd:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] bw
dxe:f64[2,64,64] = sub dwz dxb
dxf:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] dxe
dxg:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dxf dxd
dxh:i64[2,64,64] = add dxg 1:i64[]
dxi:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dxh dxd
dxj:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dxc
dxk:f64[64,64] = squeeze[dimensions=(0,)] dxj
dxl:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dxc
dxm:f64[64,64] = squeeze[dimensions=(0,)] dxl
dxn:f64[64,64] = mul dxk dxm
dxo:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dxg
dxp:i64[64,64] = squeeze[dimensions=(0,)] dxo
dxq:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dxg
dxr:i64[64,64] = squeeze[dimensions=(0,)] dxq
dxs:bool[64,64] = lt dxp 0:i64[]
dxt:i64[64,64] = add dxp 64:i64[]
dxu:i64[64,64] = select_n dxs dxp dxt
dxv:bool[64,64] = lt dxr 0:i64[]
dxw:i64[64,64] = add dxr 64:i64[]
dxx:i64[64,64] = select_n dxv dxr dxw
dxy:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dxu
dxz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dxx
dya:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dxy
dyb:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dxz
dyc:i32[64,64,2] = concatenate[dimension=2] dya dyb
dyd:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dwt dyc
dye:f64[64,64] = mul dxn dyd
dyf:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dxc
dyg:f64[64,64] = squeeze[dimensions=(0,)] dyf
dyh:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dxb
dyi:f64[64,64] = squeeze[dimensions=(0,)] dyh
dyj:f64[64,64] = mul dyg dyi
dyk:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dxg
dyl:i64[64,64] = squeeze[dimensions=(0,)] dyk
dym:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dxi
dyn:i64[64,64] = squeeze[dimensions=(0,)] dym
dyo:bool[64,64] = lt dyl 0:i64[]
dyp:i64[64,64] = add dyl 64:i64[]
dyq:i64[64,64] = select_n dyo dyl dyp
dyr:bool[64,64] = lt dyn 0:i64[]
dys:i64[64,64] = add dyn 64:i64[]
dyt:i64[64,64] = select_n dyr dyn dys
dyu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dyq
dyv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dyt
dyw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dyu
dyx:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dyv
dyy:i32[64,64,2] = concatenate[dimension=2] dyw dyx
dyz:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dwt dyy
dza:f64[64,64] = mul dyj dyz
dzb:f64[64,64] = add dye dza
dzc:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dxb
dzd:f64[64,64] = squeeze[dimensions=(0,)] dzc
dze:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dxc
dzf:f64[64,64] = squeeze[dimensions=(0,)] dze
dzg:f64[64,64] = mul dzd dzf
dzh:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dxi
dzi:i64[64,64] = squeeze[dimensions=(0,)] dzh
dzj:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dxg
dzk:i64[64,64] = squeeze[dimensions=(0,)] dzj
dzl:bool[64,64] = lt dzi 0:i64[]
dzm:i64[64,64] = add dzi 64:i64[]
dzn:i64[64,64] = select_n dzl dzi dzm
dzo:bool[64,64] = lt dzk 0:i64[]
dzp:i64[64,64] = add dzk 64:i64[]
dzq:i64[64,64] = select_n dzo dzk dzp
dzr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dzn
dzs:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dzq
dzt:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dzr
dzu:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] dzs
dzv:i32[64,64,2] = concatenate[dimension=2] dzt dzu
dzw:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dwt dzv
dzx:f64[64,64] = mul dzg dzw
dzy:f64[64,64] = add dzb dzx
dzz:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dxb
eaa:f64[64,64] = squeeze[dimensions=(0,)] dzz
eab:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dxb
eac:f64[64,64] = squeeze[dimensions=(0,)] eab
ead:f64[64,64] = mul eaa eac
eae:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] dxi
eaf:i64[64,64] = squeeze[dimensions=(0,)] eae
eag:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] dxi
eah:i64[64,64] = squeeze[dimensions=(0,)] eag
eai:bool[64,64] = lt eaf 0:i64[]
eaj:i64[64,64] = add eaf 64:i64[]
eak:i64[64,64] = select_n eai eaf eaj
eal:bool[64,64] = lt eah 0:i64[]
eam:i64[64,64] = add eah 64:i64[]
ean:i64[64,64] = select_n eal eah eam
eao:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eak
eap:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ean
eaq:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] eao
ear:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] eap
eas:i32[64,64,2] = concatenate[dimension=2] eaq ear
eat:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dwt eas
eau:f64[64,64] = mul ead eat
eav:f64[64,64] = add dzy eau
eaw:c128[64,33] = jit[name=fft jaxpr=fft] dol
eax:c128[] = reduce_prod[axes=(0,)] bx
eay:c128[] = sqrt eax
eaz:c128[] = div (1+0j):c128[] eay
eba:c128[64,33] = mul eaw eaz
ebb:c128[1,64,33] = broadcast_in_dim[
broadcast_dimensions=(1, 2)
shape=(1, 64, 33)
sharding=None
] eba
ebc:c128[2,64,33] = mul dw ebb
ebd:f64[2,64,64] = jit[name=fft jaxpr=fft1] ebc
ebe:f64[] = reduce_prod[axes=(0,)] by
ebf:f64[] = sqrt ebe
ebg:f64[2,64,64] = mul ebd ebf
ebh:f64[] = neg eb
ebi:f64[] = convert_element_type[new_dtype=float64 weak_type=False] ebh
ebj:f64[2,64,64] = mul ebi ebg
ebk:f64[2,64,64] = add ea ebj
ebl:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
ebm:f64[2,64,64] = div ebk ebl
ebn:f64[2,64,64] = floor ebm
ebo:f64[2,64,64] = sub ebm ebn
ebp:f64[2,64,64] = sub 1.0:f64[] ebo
ebq:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] bz
ebr:f64[2,64,64] = sub ebm ebo
ebs:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ebr
ebt:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ebs ebq
ebu:i64[2,64,64] = add ebt 1:i64[]
ebv:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ebu ebq
ebw:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ebp
ebx:f64[64,64] = squeeze[dimensions=(0,)] ebw
eby:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ebp
ebz:f64[64,64] = squeeze[dimensions=(0,)] eby
eca:f64[64,64] = mul ebx ebz
ecb:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ebt
ecc:i64[64,64] = squeeze[dimensions=(0,)] ecb
ecd:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ebt
ece:i64[64,64] = squeeze[dimensions=(0,)] ecd
ecf:bool[64,64] = lt ecc 0:i64[]
ecg:i64[64,64] = add ecc 64:i64[]
ech:i64[64,64] = select_n ecf ecc ecg
eci:bool[64,64] = lt ece 0:i64[]
ecj:i64[64,64] = add ece 64:i64[]
eck:i64[64,64] = select_n eci ece ecj
ecl:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ech
ecm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eck
ecn:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ecl
eco:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ecm
ecp:i32[64,64,2] = concatenate[dimension=2] ecn eco
ecq:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dol ecp
ecr:f64[64,64] = mul eca ecq
ecs:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ebp
ect:f64[64,64] = squeeze[dimensions=(0,)] ecs
ecu:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ebo
ecv:f64[64,64] = squeeze[dimensions=(0,)] ecu
ecw:f64[64,64] = mul ect ecv
ecx:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ebt
ecy:i64[64,64] = squeeze[dimensions=(0,)] ecx
ecz:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ebv
eda:i64[64,64] = squeeze[dimensions=(0,)] ecz
edb:bool[64,64] = lt ecy 0:i64[]
edc:i64[64,64] = add ecy 64:i64[]
edd:i64[64,64] = select_n edb ecy edc
ede:bool[64,64] = lt eda 0:i64[]
edf:i64[64,64] = add eda 64:i64[]
edg:i64[64,64] = select_n ede eda edf
edh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] edd
edi:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] edg
edj:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] edh
edk:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] edi
edl:i32[64,64,2] = concatenate[dimension=2] edj edk
edm:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dol edl
edn:f64[64,64] = mul ecw edm
edo:f64[64,64] = add ecr edn
edp:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ebo
edq:f64[64,64] = squeeze[dimensions=(0,)] edp
edr:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ebp
eds:f64[64,64] = squeeze[dimensions=(0,)] edr
edt:f64[64,64] = mul edq eds
edu:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ebv
edv:i64[64,64] = squeeze[dimensions=(0,)] edu
edw:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ebt
edx:i64[64,64] = squeeze[dimensions=(0,)] edw
edy:bool[64,64] = lt edv 0:i64[]
edz:i64[64,64] = add edv 64:i64[]
eea:i64[64,64] = select_n edy edv edz
eeb:bool[64,64] = lt edx 0:i64[]
eec:i64[64,64] = add edx 64:i64[]
eed:i64[64,64] = select_n eeb edx eec
eee:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eea
eef:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eed
eeg:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] eee
eeh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] eef
eei:i32[64,64,2] = concatenate[dimension=2] eeg eeh
eej:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dol eei
eek:f64[64,64] = mul edt eej
eel:f64[64,64] = add edo eek
eem:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ebo
een:f64[64,64] = squeeze[dimensions=(0,)] eem
eeo:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ebo
eep:f64[64,64] = squeeze[dimensions=(0,)] eeo
eeq:f64[64,64] = mul een eep
eer:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ebv
ees:i64[64,64] = squeeze[dimensions=(0,)] eer
eet:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ebv
eeu:i64[64,64] = squeeze[dimensions=(0,)] eet
eev:bool[64,64] = lt ees 0:i64[]
eew:i64[64,64] = add ees 64:i64[]
eex:i64[64,64] = select_n eev ees eew
eey:bool[64,64] = lt eeu 0:i64[]
eez:i64[64,64] = add eeu 64:i64[]
efa:i64[64,64] = select_n eey eeu eez
efb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eex
efc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] efa
efd:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] efb
efe:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] efc
eff:i32[64,64,2] = concatenate[dimension=2] efd efe
efg:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] dol eff
efh:f64[64,64] = mul eeq efg
efi:f64[64,64] = add eel efh
efj:f64[2,64,64] = neg ebg
efk:f64[] = neg eb
efl:f64[] = convert_element_type[new_dtype=float64 weak_type=False] efk
efm:f64[2,64,64] = mul efl efj
efn:f64[2,64,64] = add ea efm
efo:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
efp:f64[2,64,64] = div efn efo
efq:f64[2,64,64] = floor efp
efr:f64[2,64,64] = sub efp efq
efs:f64[2,64,64] = sub 1.0:f64[] efr
eft:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] ca
efu:f64[2,64,64] = sub efp efr
efv:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] efu
efw:i64[2,64,64] = jit[name=remainder jaxpr=remainder] efv eft
efx:i64[2,64,64] = add efw 1:i64[]
efy:i64[2,64,64] = jit[name=remainder jaxpr=remainder] efx eft
efz:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] efs
ega:f64[64,64] = squeeze[dimensions=(0,)] efz
egb:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] efs
egc:f64[64,64] = squeeze[dimensions=(0,)] egb
egd:f64[64,64] = mul ega egc
ege:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] efw
egf:i64[64,64] = squeeze[dimensions=(0,)] ege
egg:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] efw
egh:i64[64,64] = squeeze[dimensions=(0,)] egg
egi:bool[64,64] = lt egf 0:i64[]
egj:i64[64,64] = add egf 64:i64[]
egk:i64[64,64] = select_n egi egf egj
egl:bool[64,64] = lt egh 0:i64[]
egm:i64[64,64] = add egh 64:i64[]
egn:i64[64,64] = select_n egl egh egm
ego:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] egk
egp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] egn
egq:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ego
egr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] egp
egs:i32[64,64,2] = concatenate[dimension=2] egq egr
egt:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] efi egs
egu:f64[64,64] = mul egd egt
egv:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] efs
egw:f64[64,64] = squeeze[dimensions=(0,)] egv
egx:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] efr
egy:f64[64,64] = squeeze[dimensions=(0,)] egx
egz:f64[64,64] = mul egw egy
eha:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] efw
ehb:i64[64,64] = squeeze[dimensions=(0,)] eha
ehc:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] efy
ehd:i64[64,64] = squeeze[dimensions=(0,)] ehc
ehe:bool[64,64] = lt ehb 0:i64[]
ehf:i64[64,64] = add ehb 64:i64[]
ehg:i64[64,64] = select_n ehe ehb ehf
ehh:bool[64,64] = lt ehd 0:i64[]
ehi:i64[64,64] = add ehd 64:i64[]
ehj:i64[64,64] = select_n ehh ehd ehi
ehk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ehg
ehl:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ehj
ehm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ehk
ehn:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ehl
eho:i32[64,64,2] = concatenate[dimension=2] ehm ehn
ehp:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] efi eho
ehq:f64[64,64] = mul egz ehp
ehr:f64[64,64] = add egu ehq
ehs:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] efr
eht:f64[64,64] = squeeze[dimensions=(0,)] ehs
ehu:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] efs
ehv:f64[64,64] = squeeze[dimensions=(0,)] ehu
ehw:f64[64,64] = mul eht ehv
ehx:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] efy
ehy:i64[64,64] = squeeze[dimensions=(0,)] ehx
ehz:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] efw
eia:i64[64,64] = squeeze[dimensions=(0,)] ehz
eib:bool[64,64] = lt ehy 0:i64[]
eic:i64[64,64] = add ehy 64:i64[]
eid:i64[64,64] = select_n eib ehy eic
eie:bool[64,64] = lt eia 0:i64[]
eif:i64[64,64] = add eia 64:i64[]
eig:i64[64,64] = select_n eie eia eif
eih:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eid
eii:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eig
eij:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] eih
eik:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] eii
eil:i32[64,64,2] = concatenate[dimension=2] eij eik
eim:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] efi eil
ein:f64[64,64] = mul ehw eim
eio:f64[64,64] = add ehr ein
eip:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] efr
eiq:f64[64,64] = squeeze[dimensions=(0,)] eip
eir:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] efr
eis:f64[64,64] = squeeze[dimensions=(0,)] eir
eit:f64[64,64] = mul eiq eis
eiu:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] efy
eiv:i64[64,64] = squeeze[dimensions=(0,)] eiu
eiw:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] efy
eix:i64[64,64] = squeeze[dimensions=(0,)] eiw
eiy:bool[64,64] = lt eiv 0:i64[]
eiz:i64[64,64] = add eiv 64:i64[]
eja:i64[64,64] = select_n eiy eiv eiz
ejb:bool[64,64] = lt eix 0:i64[]
ejc:i64[64,64] = add eix 64:i64[]
ejd:i64[64,64] = select_n ejb eix ejc
eje:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eja
ejf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ejd
ejg:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] eje
ejh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ejf
eji:i32[64,64,2] = concatenate[dimension=2] ejg ejh
ejj:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] efi eji
ejk:f64[64,64] = mul eit ejj
ejl:f64[64,64] = add eio ejk
ejm:f64[64,64] = sub dol ejl
ejn:f64[64,64] = div ejm 2.0:f64[]
ejo:f64[64,64] = add dol ejn
ejp:f64[] = neg eb
ejq:f64[] = convert_element_type[new_dtype=float64 weak_type=False] ejp
ejr:f64[2,64,64] = mul ejq ebg
ejs:f64[2,64,64] = add ea ejr
ejt:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
eju:f64[2,64,64] = div ejs ejt
ejv:f64[2,64,64] = floor eju
ejw:f64[2,64,64] = sub eju ejv
ejx:f64[2,64,64] = sub 1.0:f64[] ejw
ejy:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] cb
ejz:f64[2,64,64] = sub eju ejw
eka:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ejz
ekb:i64[2,64,64] = jit[name=remainder jaxpr=remainder] eka ejy
ekc:i64[2,64,64] = add ekb 1:i64[]
ekd:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ekc ejy
eke:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ejx
ekf:f64[64,64] = squeeze[dimensions=(0,)] eke
ekg:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ejx
ekh:f64[64,64] = squeeze[dimensions=(0,)] ekg
eki:f64[64,64] = mul ekf ekh
ekj:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ekb
ekk:i64[64,64] = squeeze[dimensions=(0,)] ekj
ekl:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ekb
ekm:i64[64,64] = squeeze[dimensions=(0,)] ekl
ekn:bool[64,64] = lt ekk 0:i64[]
eko:i64[64,64] = add ekk 64:i64[]
ekp:i64[64,64] = select_n ekn ekk eko
ekq:bool[64,64] = lt ekm 0:i64[]
ekr:i64[64,64] = add ekm 64:i64[]
eks:i64[64,64] = select_n ekq ekm ekr
ekt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ekp
eku:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eks
ekv:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ekt
ekw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] eku
ekx:i32[64,64,2] = concatenate[dimension=2] ekv ekw
eky:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ejo ekx
ekz:f64[64,64] = mul eki eky
ela:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ejx
elb:f64[64,64] = squeeze[dimensions=(0,)] ela
elc:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ejw
eld:f64[64,64] = squeeze[dimensions=(0,)] elc
ele:f64[64,64] = mul elb eld
elf:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ekb
elg:i64[64,64] = squeeze[dimensions=(0,)] elf
elh:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ekd
eli:i64[64,64] = squeeze[dimensions=(0,)] elh
elj:bool[64,64] = lt elg 0:i64[]
elk:i64[64,64] = add elg 64:i64[]
ell:i64[64,64] = select_n elj elg elk
elm:bool[64,64] = lt eli 0:i64[]
eln:i64[64,64] = add eli 64:i64[]
elo:i64[64,64] = select_n elm eli eln
elp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ell
elq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] elo
elr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] elp
els:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] elq
elt:i32[64,64,2] = concatenate[dimension=2] elr els
elu:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ejo elt
elv:f64[64,64] = mul ele elu
elw:f64[64,64] = add ekz elv
elx:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ejw
ely:f64[64,64] = squeeze[dimensions=(0,)] elx
elz:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ejx
ema:f64[64,64] = squeeze[dimensions=(0,)] elz
emb:f64[64,64] = mul ely ema
emc:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ekd
emd:i64[64,64] = squeeze[dimensions=(0,)] emc
eme:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ekb
emf:i64[64,64] = squeeze[dimensions=(0,)] eme
emg:bool[64,64] = lt emd 0:i64[]
emh:i64[64,64] = add emd 64:i64[]
emi:i64[64,64] = select_n emg emd emh
emj:bool[64,64] = lt emf 0:i64[]
emk:i64[64,64] = add emf 64:i64[]
eml:i64[64,64] = select_n emj emf emk
emm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] emi
emn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eml
emo:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] emm
emp:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] emn
emq:i32[64,64,2] = concatenate[dimension=2] emo emp
emr:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ejo emq
ems:f64[64,64] = mul emb emr
emt:f64[64,64] = add elw ems
emu:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ejw
emv:f64[64,64] = squeeze[dimensions=(0,)] emu
emw:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ejw
emx:f64[64,64] = squeeze[dimensions=(0,)] emw
emy:f64[64,64] = mul emv emx
emz:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ekd
ena:i64[64,64] = squeeze[dimensions=(0,)] emz
enb:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ekd
enc:i64[64,64] = squeeze[dimensions=(0,)] enb
end:bool[64,64] = lt ena 0:i64[]
ene:i64[64,64] = add ena 64:i64[]
enf:i64[64,64] = select_n end ena ene
eng:bool[64,64] = lt enc 0:i64[]
enh:i64[64,64] = add enc 64:i64[]
eni:i64[64,64] = select_n eng enc enh
enj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] enf
enk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eni
enl:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] enj
enm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] enk
enn:i32[64,64,2] = concatenate[dimension=2] enl enm
eno:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ejo enn
enp:f64[64,64] = mul emy eno
enq:f64[64,64] = add emt enp
enr:f64[] = neg eb
ens:f64[] = convert_element_type[new_dtype=float64 weak_type=False] enr
ent:f64[2,64,64] = mul ens ebg
enu:f64[2,64,64] = add ea ent
env:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
enw:f64[2,64,64] = div enu env
enx:f64[2,64,64] = floor enw
eny:f64[2,64,64] = sub enw enx
enz:f64[2,64,64] = sub 1.0:f64[] eny
eoa:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] cc
eob:f64[2,64,64] = sub enw eny
eoc:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] eob
eod:i64[2,64,64] = jit[name=remainder jaxpr=remainder] eoc eoa
eoe:i64[2,64,64] = add eod 1:i64[]
eof:i64[2,64,64] = jit[name=remainder jaxpr=remainder] eoe eoa
eog:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] enz
eoh:f64[64,64] = squeeze[dimensions=(0,)] eog
eoi:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] enz
eoj:f64[64,64] = squeeze[dimensions=(0,)] eoi
eok:f64[64,64] = mul eoh eoj
eol:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] eod
eom:i64[64,64] = squeeze[dimensions=(0,)] eol
eon:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] eod
eoo:i64[64,64] = squeeze[dimensions=(0,)] eon
eop:bool[64,64] = lt eom 0:i64[]
eoq:i64[64,64] = add eom 64:i64[]
eor:i64[64,64] = select_n eop eom eoq
eos:bool[64,64] = lt eoo 0:i64[]
eot:i64[64,64] = add eoo 64:i64[]
eou:i64[64,64] = select_n eos eoo eot
eov:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eor
eow:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eou
eox:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] eov
eoy:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] eow
eoz:i32[64,64,2] = concatenate[dimension=2] eox eoy
epa:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] eav eoz
epb:f64[64,64] = mul eok epa
epc:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] enz
epd:f64[64,64] = squeeze[dimensions=(0,)] epc
epe:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] eny
epf:f64[64,64] = squeeze[dimensions=(0,)] epe
epg:f64[64,64] = mul epd epf
eph:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] eod
epi:i64[64,64] = squeeze[dimensions=(0,)] eph
epj:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] eof
epk:i64[64,64] = squeeze[dimensions=(0,)] epj
epl:bool[64,64] = lt epi 0:i64[]
epm:i64[64,64] = add epi 64:i64[]
epn:i64[64,64] = select_n epl epi epm
epo:bool[64,64] = lt epk 0:i64[]
epp:i64[64,64] = add epk 64:i64[]
epq:i64[64,64] = select_n epo epk epp
epr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] epn
eps:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] epq
ept:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] epr
epu:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] eps
epv:i32[64,64,2] = concatenate[dimension=2] ept epu
epw:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] eav epv
epx:f64[64,64] = mul epg epw
epy:f64[64,64] = add epb epx
epz:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] eny
eqa:f64[64,64] = squeeze[dimensions=(0,)] epz
eqb:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] enz
eqc:f64[64,64] = squeeze[dimensions=(0,)] eqb
eqd:f64[64,64] = mul eqa eqc
eqe:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] eof
eqf:i64[64,64] = squeeze[dimensions=(0,)] eqe
eqg:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] eod
eqh:i64[64,64] = squeeze[dimensions=(0,)] eqg
eqi:bool[64,64] = lt eqf 0:i64[]
eqj:i64[64,64] = add eqf 64:i64[]
eqk:i64[64,64] = select_n eqi eqf eqj
eql:bool[64,64] = lt eqh 0:i64[]
eqm:i64[64,64] = add eqh 64:i64[]
eqn:i64[64,64] = select_n eql eqh eqm
eqo:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eqk
eqp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eqn
eqq:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] eqo
eqr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] eqp
eqs:i32[64,64,2] = concatenate[dimension=2] eqq eqr
eqt:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] eav eqs
equ:f64[64,64] = mul eqd eqt
eqv:f64[64,64] = add epy equ
eqw:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] eny
eqx:f64[64,64] = squeeze[dimensions=(0,)] eqw
eqy:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] eny
eqz:f64[64,64] = squeeze[dimensions=(0,)] eqy
era:f64[64,64] = mul eqx eqz
erb:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] eof
erc:i64[64,64] = squeeze[dimensions=(0,)] erb
erd:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] eof
ere:i64[64,64] = squeeze[dimensions=(0,)] erd
erf:bool[64,64] = lt erc 0:i64[]
erg:i64[64,64] = add erc 64:i64[]
erh:i64[64,64] = select_n erf erc erg
eri:bool[64,64] = lt ere 0:i64[]
erj:i64[64,64] = add ere 64:i64[]
erk:i64[64,64] = select_n eri ere erj
erl:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] erh
erm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] erk
ern:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] erl
ero:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] erm
erp:i32[64,64,2] = concatenate[dimension=2] ern ero
erq:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] eav erp
err:f64[64,64] = mul era erq
ers:f64[64,64] = add eqv err
ert:f64[2,64,64] = neg ebg
eru:f64[] = neg eb
erv:f64[] = convert_element_type[new_dtype=float64 weak_type=False] eru
erw:f64[2,64,64] = mul erv ert
erx:f64[2,64,64] = add ea erw
ery:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
erz:f64[2,64,64] = div erx ery
esa:f64[2,64,64] = floor erz
esb:f64[2,64,64] = sub erz esa
esc:f64[2,64,64] = sub 1.0:f64[] esb
esd:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] cd
ese:f64[2,64,64] = sub erz esb
esf:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ese
esg:i64[2,64,64] = jit[name=remainder jaxpr=remainder] esf esd
esh:i64[2,64,64] = add esg 1:i64[]
esi:i64[2,64,64] = jit[name=remainder jaxpr=remainder] esh esd
esj:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] esc
esk:f64[64,64] = squeeze[dimensions=(0,)] esj
esl:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] esc
esm:f64[64,64] = squeeze[dimensions=(0,)] esl
esn:f64[64,64] = mul esk esm
eso:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] esg
esp:i64[64,64] = squeeze[dimensions=(0,)] eso
esq:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] esg
esr:i64[64,64] = squeeze[dimensions=(0,)] esq
ess:bool[64,64] = lt esp 0:i64[]
est:i64[64,64] = add esp 64:i64[]
esu:i64[64,64] = select_n ess esp est
esv:bool[64,64] = lt esr 0:i64[]
esw:i64[64,64] = add esr 64:i64[]
esx:i64[64,64] = select_n esv esr esw
esy:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] esu
esz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] esx
eta:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] esy
etb:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] esz
etc:i32[64,64,2] = concatenate[dimension=2] eta etb
etd:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ers etc
ete:f64[64,64] = mul esn etd
etf:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] esc
etg:f64[64,64] = squeeze[dimensions=(0,)] etf
eth:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] esb
eti:f64[64,64] = squeeze[dimensions=(0,)] eth
etj:f64[64,64] = mul etg eti
etk:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] esg
etl:i64[64,64] = squeeze[dimensions=(0,)] etk
etm:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] esi
etn:i64[64,64] = squeeze[dimensions=(0,)] etm
eto:bool[64,64] = lt etl 0:i64[]
etp:i64[64,64] = add etl 64:i64[]
etq:i64[64,64] = select_n eto etl etp
etr:bool[64,64] = lt etn 0:i64[]
ets:i64[64,64] = add etn 64:i64[]
ett:i64[64,64] = select_n etr etn ets
etu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] etq
etv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ett
etw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] etu
etx:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] etv
ety:i32[64,64,2] = concatenate[dimension=2] etw etx
etz:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ers ety
eua:f64[64,64] = mul etj etz
eub:f64[64,64] = add ete eua
euc:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] esb
eud:f64[64,64] = squeeze[dimensions=(0,)] euc
eue:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] esc
euf:f64[64,64] = squeeze[dimensions=(0,)] eue
eug:f64[64,64] = mul eud euf
euh:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] esi
eui:i64[64,64] = squeeze[dimensions=(0,)] euh
euj:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] esg
euk:i64[64,64] = squeeze[dimensions=(0,)] euj
eul:bool[64,64] = lt eui 0:i64[]
eum:i64[64,64] = add eui 64:i64[]
eun:i64[64,64] = select_n eul eui eum
euo:bool[64,64] = lt euk 0:i64[]
eup:i64[64,64] = add euk 64:i64[]
euq:i64[64,64] = select_n euo euk eup
eur:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eun
eus:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] euq
eut:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] eur
euu:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] eus
euv:i32[64,64,2] = concatenate[dimension=2] eut euu
euw:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ers euv
eux:f64[64,64] = mul eug euw
euy:f64[64,64] = add eub eux
euz:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] esb
eva:f64[64,64] = squeeze[dimensions=(0,)] euz
evb:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] esb
evc:f64[64,64] = squeeze[dimensions=(0,)] evb
evd:f64[64,64] = mul eva evc
eve:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] esi
evf:i64[64,64] = squeeze[dimensions=(0,)] eve
evg:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] esi
evh:i64[64,64] = squeeze[dimensions=(0,)] evg
evi:bool[64,64] = lt evf 0:i64[]
evj:i64[64,64] = add evf 64:i64[]
evk:i64[64,64] = select_n evi evf evj
evl:bool[64,64] = lt evh 0:i64[]
evm:i64[64,64] = add evh 64:i64[]
evn:i64[64,64] = select_n evl evh evm
evo:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] evk
evp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] evn
evq:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] evo
evr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] evp
evs:i32[64,64,2] = concatenate[dimension=2] evq evr
evt:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ers evs
evu:f64[64,64] = mul evd evt
evv:f64[64,64] = add euy evu
evw:f64[64,64] = sub eav evv
evx:f64[64,64] = div evw 2.0:f64[]
evy:f64[64,64] = add eav evx
evz:f64[] = neg eb
ewa:f64[] = convert_element_type[new_dtype=float64 weak_type=False] evz
ewb:f64[2,64,64] = mul ewa ebg
ewc:f64[2,64,64] = add ea ewb
ewd:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
ewe:f64[2,64,64] = div ewc ewd
ewf:f64[2,64,64] = floor ewe
ewg:f64[2,64,64] = sub ewe ewf
ewh:f64[2,64,64] = sub 1.0:f64[] ewg
ewi:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] ce
ewj:f64[2,64,64] = sub ewe ewg
ewk:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ewj
ewl:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ewk ewi
ewm:i64[2,64,64] = add ewl 1:i64[]
ewn:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ewm ewi
ewo:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ewh
ewp:f64[64,64] = squeeze[dimensions=(0,)] ewo
ewq:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ewh
ewr:f64[64,64] = squeeze[dimensions=(0,)] ewq
ews:f64[64,64] = mul ewp ewr
ewt:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ewl
ewu:i64[64,64] = squeeze[dimensions=(0,)] ewt
ewv:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ewl
eww:i64[64,64] = squeeze[dimensions=(0,)] ewv
ewx:bool[64,64] = lt ewu 0:i64[]
ewy:i64[64,64] = add ewu 64:i64[]
ewz:i64[64,64] = select_n ewx ewu ewy
exa:bool[64,64] = lt eww 0:i64[]
exb:i64[64,64] = add eww 64:i64[]
exc:i64[64,64] = select_n exa eww exb
exd:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ewz
exe:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] exc
exf:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] exd
exg:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] exe
exh:i32[64,64,2] = concatenate[dimension=2] exf exg
exi:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] evy exh
exj:f64[64,64] = mul ews exi
exk:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ewh
exl:f64[64,64] = squeeze[dimensions=(0,)] exk
exm:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ewg
exn:f64[64,64] = squeeze[dimensions=(0,)] exm
exo:f64[64,64] = mul exl exn
exp:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ewl
exq:i64[64,64] = squeeze[dimensions=(0,)] exp
exr:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ewn
exs:i64[64,64] = squeeze[dimensions=(0,)] exr
ext:bool[64,64] = lt exq 0:i64[]
exu:i64[64,64] = add exq 64:i64[]
exv:i64[64,64] = select_n ext exq exu
exw:bool[64,64] = lt exs 0:i64[]
exx:i64[64,64] = add exs 64:i64[]
exy:i64[64,64] = select_n exw exs exx
exz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] exv
eya:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] exy
eyb:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] exz
eyc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] eya
eyd:i32[64,64,2] = concatenate[dimension=2] eyb eyc
eye:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] evy eyd
eyf:f64[64,64] = mul exo eye
eyg:f64[64,64] = add exj eyf
eyh:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ewg
eyi:f64[64,64] = squeeze[dimensions=(0,)] eyh
eyj:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ewh
eyk:f64[64,64] = squeeze[dimensions=(0,)] eyj
eyl:f64[64,64] = mul eyi eyk
eym:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ewn
eyn:i64[64,64] = squeeze[dimensions=(0,)] eym
eyo:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ewl
eyp:i64[64,64] = squeeze[dimensions=(0,)] eyo
eyq:bool[64,64] = lt eyn 0:i64[]
eyr:i64[64,64] = add eyn 64:i64[]
eys:i64[64,64] = select_n eyq eyn eyr
eyt:bool[64,64] = lt eyp 0:i64[]
eyu:i64[64,64] = add eyp 64:i64[]
eyv:i64[64,64] = select_n eyt eyp eyu
eyw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eys
eyx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eyv
eyy:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] eyw
eyz:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] eyx
eza:i32[64,64,2] = concatenate[dimension=2] eyy eyz
ezb:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] evy eza
ezc:f64[64,64] = mul eyl ezb
ezd:f64[64,64] = add eyg ezc
eze:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ewg
ezf:f64[64,64] = squeeze[dimensions=(0,)] eze
ezg:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ewg
ezh:f64[64,64] = squeeze[dimensions=(0,)] ezg
ezi:f64[64,64] = mul ezf ezh
ezj:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ewn
ezk:i64[64,64] = squeeze[dimensions=(0,)] ezj
ezl:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ewn
ezm:i64[64,64] = squeeze[dimensions=(0,)] ezl
ezn:bool[64,64] = lt ezk 0:i64[]
ezo:i64[64,64] = add ezk 64:i64[]
ezp:i64[64,64] = select_n ezn ezk ezo
ezq:bool[64,64] = lt ezm 0:i64[]
ezr:i64[64,64] = add ezm 64:i64[]
ezs:i64[64,64] = select_n ezq ezm ezr
ezt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ezp
ezu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ezs
ezv:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ezt
ezw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ezu
ezx:i32[64,64,2] = concatenate[dimension=2] ezv ezw
ezy:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] evy ezx
ezz:f64[64,64] = mul ezi ezy
faa:f64[64,64] = add ezd ezz
fab:c128[64,33] = jit[name=fft jaxpr=fft] enq
fac:c128[] = reduce_prod[axes=(0,)] cf
fad:c128[] = sqrt fac
fae:c128[] = div (1+0j):c128[] fad
faf:c128[64,33] = mul fab fae
fag:c128[1,64,33] = broadcast_in_dim[
broadcast_dimensions=(1, 2)
shape=(1, 64, 33)
sharding=None
] faf
fah:c128[2,64,33] = mul dw fag
fai:f64[2,64,64] = jit[name=fft jaxpr=fft1] fah
faj:f64[] = reduce_prod[axes=(0,)] cg
fak:f64[] = sqrt faj
fal:f64[2,64,64] = mul fai fak
fam:f64[] = neg eb
fan:f64[] = convert_element_type[new_dtype=float64 weak_type=False] fam
fao:f64[2,64,64] = mul fan fal
fap:f64[2,64,64] = add ea fao
faq:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
far:f64[2,64,64] = div fap faq
fas:f64[2,64,64] = floor far
fat:f64[2,64,64] = sub far fas
fau:f64[2,64,64] = sub 1.0:f64[] fat
fav:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] ch
faw:f64[2,64,64] = sub far fat
fax:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] faw
fay:i64[2,64,64] = jit[name=remainder jaxpr=remainder] fax fav
faz:i64[2,64,64] = add fay 1:i64[]
fba:i64[2,64,64] = jit[name=remainder jaxpr=remainder] faz fav
fbb:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fau
fbc:f64[64,64] = squeeze[dimensions=(0,)] fbb
fbd:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fau
fbe:f64[64,64] = squeeze[dimensions=(0,)] fbd
fbf:f64[64,64] = mul fbc fbe
fbg:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fay
fbh:i64[64,64] = squeeze[dimensions=(0,)] fbg
fbi:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fay
fbj:i64[64,64] = squeeze[dimensions=(0,)] fbi
fbk:bool[64,64] = lt fbh 0:i64[]
fbl:i64[64,64] = add fbh 64:i64[]
fbm:i64[64,64] = select_n fbk fbh fbl
fbn:bool[64,64] = lt fbj 0:i64[]
fbo:i64[64,64] = add fbj 64:i64[]
fbp:i64[64,64] = select_n fbn fbj fbo
fbq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fbm
fbr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fbp
fbs:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fbq
fbt:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fbr
fbu:i32[64,64,2] = concatenate[dimension=2] fbs fbt
fbv:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] enq fbu
fbw:f64[64,64] = mul fbf fbv
fbx:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fau
fby:f64[64,64] = squeeze[dimensions=(0,)] fbx
fbz:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fat
fca:f64[64,64] = squeeze[dimensions=(0,)] fbz
fcb:f64[64,64] = mul fby fca
fcc:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fay
fcd:i64[64,64] = squeeze[dimensions=(0,)] fcc
fce:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fba
fcf:i64[64,64] = squeeze[dimensions=(0,)] fce
fcg:bool[64,64] = lt fcd 0:i64[]
fch:i64[64,64] = add fcd 64:i64[]
fci:i64[64,64] = select_n fcg fcd fch
fcj:bool[64,64] = lt fcf 0:i64[]
fck:i64[64,64] = add fcf 64:i64[]
fcl:i64[64,64] = select_n fcj fcf fck
fcm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fci
fcn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fcl
fco:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fcm
fcp:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fcn
fcq:i32[64,64,2] = concatenate[dimension=2] fco fcp
fcr:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] enq fcq
fcs:f64[64,64] = mul fcb fcr
fct:f64[64,64] = add fbw fcs
fcu:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fat
fcv:f64[64,64] = squeeze[dimensions=(0,)] fcu
fcw:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fau
fcx:f64[64,64] = squeeze[dimensions=(0,)] fcw
fcy:f64[64,64] = mul fcv fcx
fcz:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fba
fda:i64[64,64] = squeeze[dimensions=(0,)] fcz
fdb:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fay
fdc:i64[64,64] = squeeze[dimensions=(0,)] fdb
fdd:bool[64,64] = lt fda 0:i64[]
fde:i64[64,64] = add fda 64:i64[]
fdf:i64[64,64] = select_n fdd fda fde
fdg:bool[64,64] = lt fdc 0:i64[]
fdh:i64[64,64] = add fdc 64:i64[]
fdi:i64[64,64] = select_n fdg fdc fdh
fdj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fdf
fdk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fdi
fdl:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fdj
fdm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fdk
fdn:i32[64,64,2] = concatenate[dimension=2] fdl fdm
fdo:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] enq fdn
fdp:f64[64,64] = mul fcy fdo
fdq:f64[64,64] = add fct fdp
fdr:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fat
fds:f64[64,64] = squeeze[dimensions=(0,)] fdr
fdt:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fat
fdu:f64[64,64] = squeeze[dimensions=(0,)] fdt
fdv:f64[64,64] = mul fds fdu
fdw:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fba
fdx:i64[64,64] = squeeze[dimensions=(0,)] fdw
fdy:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fba
fdz:i64[64,64] = squeeze[dimensions=(0,)] fdy
fea:bool[64,64] = lt fdx 0:i64[]
feb:i64[64,64] = add fdx 64:i64[]
fec:i64[64,64] = select_n fea fdx feb
fed:bool[64,64] = lt fdz 0:i64[]
fee:i64[64,64] = add fdz 64:i64[]
fef:i64[64,64] = select_n fed fdz fee
feg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fec
feh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fef
fei:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] feg
fej:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] feh
fek:i32[64,64,2] = concatenate[dimension=2] fei fej
fel:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] enq fek
fem:f64[64,64] = mul fdv fel
fen:f64[64,64] = add fdq fem
feo:f64[2,64,64] = neg fal
fep:f64[] = neg eb
feq:f64[] = convert_element_type[new_dtype=float64 weak_type=False] fep
fer:f64[2,64,64] = mul feq feo
fes:f64[2,64,64] = add ea fer
fet:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
feu:f64[2,64,64] = div fes fet
fev:f64[2,64,64] = floor feu
few:f64[2,64,64] = sub feu fev
fex:f64[2,64,64] = sub 1.0:f64[] few
fey:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] ci
fez:f64[2,64,64] = sub feu few
ffa:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] fez
ffb:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ffa fey
ffc:i64[2,64,64] = add ffb 1:i64[]
ffd:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ffc fey
ffe:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fex
fff:f64[64,64] = squeeze[dimensions=(0,)] ffe
ffg:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fex
ffh:f64[64,64] = squeeze[dimensions=(0,)] ffg
ffi:f64[64,64] = mul fff ffh
ffj:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ffb
ffk:i64[64,64] = squeeze[dimensions=(0,)] ffj
ffl:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ffb
ffm:i64[64,64] = squeeze[dimensions=(0,)] ffl
ffn:bool[64,64] = lt ffk 0:i64[]
ffo:i64[64,64] = add ffk 64:i64[]
ffp:i64[64,64] = select_n ffn ffk ffo
ffq:bool[64,64] = lt ffm 0:i64[]
ffr:i64[64,64] = add ffm 64:i64[]
ffs:i64[64,64] = select_n ffq ffm ffr
ffu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ffp
ffv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ffs
ffw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ffu
ffx:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ffv
ffy:i32[64,64,2] = concatenate[dimension=2] ffw ffx
ffz:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fen ffy
fga:f64[64,64] = mul ffi ffz
fgb:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fex
fgc:f64[64,64] = squeeze[dimensions=(0,)] fgb
fgd:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] few
fge:f64[64,64] = squeeze[dimensions=(0,)] fgd
fgf:f64[64,64] = mul fgc fge
fgg:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ffb
fgh:i64[64,64] = squeeze[dimensions=(0,)] fgg
fgi:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ffd
fgj:i64[64,64] = squeeze[dimensions=(0,)] fgi
fgk:bool[64,64] = lt fgh 0:i64[]
fgl:i64[64,64] = add fgh 64:i64[]
fgm:i64[64,64] = select_n fgk fgh fgl
fgn:bool[64,64] = lt fgj 0:i64[]
fgo:i64[64,64] = add fgj 64:i64[]
fgp:i64[64,64] = select_n fgn fgj fgo
fgq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fgm
fgr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fgp
fgs:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fgq
fgt:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fgr
fgu:i32[64,64,2] = concatenate[dimension=2] fgs fgt
fgv:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fen fgu
fgw:f64[64,64] = mul fgf fgv
fgx:f64[64,64] = add fga fgw
fgy:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] few
fgz:f64[64,64] = squeeze[dimensions=(0,)] fgy
fha:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fex
fhb:f64[64,64] = squeeze[dimensions=(0,)] fha
fhc:f64[64,64] = mul fgz fhb
fhd:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ffd
fhe:i64[64,64] = squeeze[dimensions=(0,)] fhd
fhf:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ffb
fhg:i64[64,64] = squeeze[dimensions=(0,)] fhf
fhh:bool[64,64] = lt fhe 0:i64[]
fhi:i64[64,64] = add fhe 64:i64[]
fhj:i64[64,64] = select_n fhh fhe fhi
fhk:bool[64,64] = lt fhg 0:i64[]
fhl:i64[64,64] = add fhg 64:i64[]
fhm:i64[64,64] = select_n fhk fhg fhl
fhn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fhj
fho:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fhm
fhp:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fhn
fhq:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fho
fhr:i32[64,64,2] = concatenate[dimension=2] fhp fhq
fhs:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fen fhr
fht:f64[64,64] = mul fhc fhs
fhu:f64[64,64] = add fgx fht
fhv:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] few
fhw:f64[64,64] = squeeze[dimensions=(0,)] fhv
fhx:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] few
fhy:f64[64,64] = squeeze[dimensions=(0,)] fhx
fhz:f64[64,64] = mul fhw fhy
fia:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ffd
fib:i64[64,64] = squeeze[dimensions=(0,)] fia
fic:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ffd
fid:i64[64,64] = squeeze[dimensions=(0,)] fic
fie:bool[64,64] = lt fib 0:i64[]
fif:i64[64,64] = add fib 64:i64[]
fig:i64[64,64] = select_n fie fib fif
fih:bool[64,64] = lt fid 0:i64[]
fii:i64[64,64] = add fid 64:i64[]
fij:i64[64,64] = select_n fih fid fii
fik:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fig
fil:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fij
fim:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fik
fin:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fil
fio:i32[64,64,2] = concatenate[dimension=2] fim fin
fip:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fen fio
fiq:f64[64,64] = mul fhz fip
fir:f64[64,64] = add fhu fiq
fis:f64[64,64] = sub enq fir
fit:f64[64,64] = div fis 2.0:f64[]
fiu:f64[64,64] = add enq fit
fiv:f64[] = neg eb
fiw:f64[] = convert_element_type[new_dtype=float64 weak_type=False] fiv
fix:f64[2,64,64] = mul fiw fal
fiy:f64[2,64,64] = add ea fix
fiz:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
fja:f64[2,64,64] = div fiy fiz
fjb:f64[2,64,64] = floor fja
fjc:f64[2,64,64] = sub fja fjb
fjd:f64[2,64,64] = sub 1.0:f64[] fjc
fje:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] cj
fjf:f64[2,64,64] = sub fja fjc
fjg:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] fjf
fjh:i64[2,64,64] = jit[name=remainder jaxpr=remainder] fjg fje
fji:i64[2,64,64] = add fjh 1:i64[]
fjj:i64[2,64,64] = jit[name=remainder jaxpr=remainder] fji fje
fjk:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fjd
fjl:f64[64,64] = squeeze[dimensions=(0,)] fjk
fjm:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fjd
fjn:f64[64,64] = squeeze[dimensions=(0,)] fjm
fjo:f64[64,64] = mul fjl fjn
fjp:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fjh
fjq:i64[64,64] = squeeze[dimensions=(0,)] fjp
fjr:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fjh
fjs:i64[64,64] = squeeze[dimensions=(0,)] fjr
fjt:bool[64,64] = lt fjq 0:i64[]
fju:i64[64,64] = add fjq 64:i64[]
fjv:i64[64,64] = select_n fjt fjq fju
fjw:bool[64,64] = lt fjs 0:i64[]
fjx:i64[64,64] = add fjs 64:i64[]
fjy:i64[64,64] = select_n fjw fjs fjx
fjz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fjv
fka:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fjy
fkb:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fjz
fkc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fka
fkd:i32[64,64,2] = concatenate[dimension=2] fkb fkc
fke:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fiu fkd
fkf:f64[64,64] = mul fjo fke
fkg:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fjd
fkh:f64[64,64] = squeeze[dimensions=(0,)] fkg
fki:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fjc
fkj:f64[64,64] = squeeze[dimensions=(0,)] fki
fkk:f64[64,64] = mul fkh fkj
fkl:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fjh
fkm:i64[64,64] = squeeze[dimensions=(0,)] fkl
fkn:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fjj
fko:i64[64,64] = squeeze[dimensions=(0,)] fkn
fkp:bool[64,64] = lt fkm 0:i64[]
fkq:i64[64,64] = add fkm 64:i64[]
fkr:i64[64,64] = select_n fkp fkm fkq
fks:bool[64,64] = lt fko 0:i64[]
fkt:i64[64,64] = add fko 64:i64[]
fku:i64[64,64] = select_n fks fko fkt
fkv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fkr
fkw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fku
fkx:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fkv
fky:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fkw
fkz:i32[64,64,2] = concatenate[dimension=2] fkx fky
fla:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fiu fkz
flb:f64[64,64] = mul fkk fla
flc:f64[64,64] = add fkf flb
fld:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fjc
fle:f64[64,64] = squeeze[dimensions=(0,)] fld
flf:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fjd
flg:f64[64,64] = squeeze[dimensions=(0,)] flf
flh:f64[64,64] = mul fle flg
fli:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fjj
flj:i64[64,64] = squeeze[dimensions=(0,)] fli
flk:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fjh
fll:i64[64,64] = squeeze[dimensions=(0,)] flk
flm:bool[64,64] = lt flj 0:i64[]
fln:i64[64,64] = add flj 64:i64[]
flo:i64[64,64] = select_n flm flj fln
flp:bool[64,64] = lt fll 0:i64[]
flq:i64[64,64] = add fll 64:i64[]
flr:i64[64,64] = select_n flp fll flq
fls:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] flo
flt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] flr
flu:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fls
flv:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] flt
flw:i32[64,64,2] = concatenate[dimension=2] flu flv
flx:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fiu flw
fly:f64[64,64] = mul flh flx
flz:f64[64,64] = add flc fly
fma:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fjc
fmb:f64[64,64] = squeeze[dimensions=(0,)] fma
fmc:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fjc
fmd:f64[64,64] = squeeze[dimensions=(0,)] fmc
fme:f64[64,64] = mul fmb fmd
fmf:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fjj
fmg:i64[64,64] = squeeze[dimensions=(0,)] fmf
fmh:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fjj
fmi:i64[64,64] = squeeze[dimensions=(0,)] fmh
fmj:bool[64,64] = lt fmg 0:i64[]
fmk:i64[64,64] = add fmg 64:i64[]
fml:i64[64,64] = select_n fmj fmg fmk
fmm:bool[64,64] = lt fmi 0:i64[]
fmn:i64[64,64] = add fmi 64:i64[]
fmo:i64[64,64] = select_n fmm fmi fmn
fmp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fml
fmq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fmo
fmr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fmp
fms:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fmq
fmt:i32[64,64,2] = concatenate[dimension=2] fmr fms
fmu:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fiu fmt
fmv:f64[64,64] = mul fme fmu
fmw:f64[64,64] = add flz fmv
fmx:f64[] = neg eb
fmy:f64[] = convert_element_type[new_dtype=float64 weak_type=False] fmx
fmz:f64[2,64,64] = mul fmy fal
fna:f64[2,64,64] = add ea fmz
fnb:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
fnc:f64[2,64,64] = div fna fnb
fnd:f64[2,64,64] = floor fnc
fne:f64[2,64,64] = sub fnc fnd
fnf:f64[2,64,64] = sub 1.0:f64[] fne
fng:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] ck
fnh:f64[2,64,64] = sub fnc fne
fni:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] fnh
fnj:i64[2,64,64] = jit[name=remainder jaxpr=remainder] fni fng
fnk:i64[2,64,64] = add fnj 1:i64[]
fnl:i64[2,64,64] = jit[name=remainder jaxpr=remainder] fnk fng
fnm:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fnf
fnn:f64[64,64] = squeeze[dimensions=(0,)] fnm
fno:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fnf
fnp:f64[64,64] = squeeze[dimensions=(0,)] fno
fnq:f64[64,64] = mul fnn fnp
fnr:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fnj
fns:i64[64,64] = squeeze[dimensions=(0,)] fnr
fnt:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fnj
fnu:i64[64,64] = squeeze[dimensions=(0,)] fnt
fnv:bool[64,64] = lt fns 0:i64[]
fnw:i64[64,64] = add fns 64:i64[]
fnx:i64[64,64] = select_n fnv fns fnw
fny:bool[64,64] = lt fnu 0:i64[]
fnz:i64[64,64] = add fnu 64:i64[]
foa:i64[64,64] = select_n fny fnu fnz
fob:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fnx
foc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] foa
fod:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fob
foe:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] foc
fof:i32[64,64,2] = concatenate[dimension=2] fod foe
fog:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] faa fof
foh:f64[64,64] = mul fnq fog
foi:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fnf
foj:f64[64,64] = squeeze[dimensions=(0,)] foi
fok:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fne
fol:f64[64,64] = squeeze[dimensions=(0,)] fok
fom:f64[64,64] = mul foj fol
fon:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fnj
foo:i64[64,64] = squeeze[dimensions=(0,)] fon
fop:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fnl
foq:i64[64,64] = squeeze[dimensions=(0,)] fop
for:bool[64,64] = lt foo 0:i64[]
fos:i64[64,64] = add foo 64:i64[]
fot:i64[64,64] = select_n for foo fos
fou:bool[64,64] = lt foq 0:i64[]
fov:i64[64,64] = add foq 64:i64[]
fow:i64[64,64] = select_n fou foq fov
fox:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fot
foy:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fow
foz:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fox
fpa:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] foy
fpb:i32[64,64,2] = concatenate[dimension=2] foz fpa
fpc:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] faa fpb
fpd:f64[64,64] = mul fom fpc
fpe:f64[64,64] = add foh fpd
fpf:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fne
fpg:f64[64,64] = squeeze[dimensions=(0,)] fpf
fph:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fnf
fpi:f64[64,64] = squeeze[dimensions=(0,)] fph
fpj:f64[64,64] = mul fpg fpi
fpk:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fnl
fpl:i64[64,64] = squeeze[dimensions=(0,)] fpk
fpm:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fnj
fpn:i64[64,64] = squeeze[dimensions=(0,)] fpm
fpo:bool[64,64] = lt fpl 0:i64[]
fpp:i64[64,64] = add fpl 64:i64[]
fpq:i64[64,64] = select_n fpo fpl fpp
fpr:bool[64,64] = lt fpn 0:i64[]
fps:i64[64,64] = add fpn 64:i64[]
fpt:i64[64,64] = select_n fpr fpn fps
fpu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fpq
fpv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fpt
fpw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fpu
fpx:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fpv
fpy:i32[64,64,2] = concatenate[dimension=2] fpw fpx
fpz:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] faa fpy
fqa:f64[64,64] = mul fpj fpz
fqb:f64[64,64] = add fpe fqa
fqc:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fne
fqd:f64[64,64] = squeeze[dimensions=(0,)] fqc
fqe:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fne
fqf:f64[64,64] = squeeze[dimensions=(0,)] fqe
fqg:f64[64,64] = mul fqd fqf
fqh:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fnl
fqi:i64[64,64] = squeeze[dimensions=(0,)] fqh
fqj:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fnl
fqk:i64[64,64] = squeeze[dimensions=(0,)] fqj
fql:bool[64,64] = lt fqi 0:i64[]
fqm:i64[64,64] = add fqi 64:i64[]
fqn:i64[64,64] = select_n fql fqi fqm
fqo:bool[64,64] = lt fqk 0:i64[]
fqp:i64[64,64] = add fqk 64:i64[]
fqq:i64[64,64] = select_n fqo fqk fqp
fqr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fqn
fqs:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fqq
fqt:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fqr
fqu:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fqs
fqv:i32[64,64,2] = concatenate[dimension=2] fqt fqu
fqw:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] faa fqv
fqx:f64[64,64] = mul fqg fqw
fqy:f64[64,64] = add fqb fqx
fqz:f64[2,64,64] = neg fal
fra:f64[] = neg eb
frb:f64[] = convert_element_type[new_dtype=float64 weak_type=False] fra
frc:f64[2,64,64] = mul frb fqz
frd:f64[2,64,64] = add ea frc
fre:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
frf:f64[2,64,64] = div frd fre
frg:f64[2,64,64] = floor frf
frh:f64[2,64,64] = sub frf frg
fri:f64[2,64,64] = sub 1.0:f64[] frh
frj:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] cl
frk:f64[2,64,64] = sub frf frh
frl:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] frk
frm:i64[2,64,64] = jit[name=remainder jaxpr=remainder] frl frj
frn:i64[2,64,64] = add frm 1:i64[]
fro:i64[2,64,64] = jit[name=remainder jaxpr=remainder] frn frj
frp:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fri
frq:f64[64,64] = squeeze[dimensions=(0,)] frp
frr:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fri
frs:f64[64,64] = squeeze[dimensions=(0,)] frr
frt:f64[64,64] = mul frq frs
fru:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] frm
frv:i64[64,64] = squeeze[dimensions=(0,)] fru
frw:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] frm
frx:i64[64,64] = squeeze[dimensions=(0,)] frw
fry:bool[64,64] = lt frv 0:i64[]
frz:i64[64,64] = add frv 64:i64[]
fsa:i64[64,64] = select_n fry frv frz
fsb:bool[64,64] = lt frx 0:i64[]
fsc:i64[64,64] = add frx 64:i64[]
fsd:i64[64,64] = select_n fsb frx fsc
fse:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fsa
fsf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fsd
fsg:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fse
fsh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fsf
fsi:i32[64,64,2] = concatenate[dimension=2] fsg fsh
fsj:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fqy fsi
fsk:f64[64,64] = mul frt fsj
fsl:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fri
fsm:f64[64,64] = squeeze[dimensions=(0,)] fsl
fsn:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] frh
fso:f64[64,64] = squeeze[dimensions=(0,)] fsn
fsp:f64[64,64] = mul fsm fso
fsq:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] frm
fsr:i64[64,64] = squeeze[dimensions=(0,)] fsq
fss:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fro
fst:i64[64,64] = squeeze[dimensions=(0,)] fss
fsu:bool[64,64] = lt fsr 0:i64[]
fsv:i64[64,64] = add fsr 64:i64[]
fsw:i64[64,64] = select_n fsu fsr fsv
fsx:bool[64,64] = lt fst 0:i64[]
fsy:i64[64,64] = add fst 64:i64[]
fsz:i64[64,64] = select_n fsx fst fsy
fta:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fsw
ftb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fsz
ftc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fta
ftd:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ftb
fte:i32[64,64,2] = concatenate[dimension=2] ftc ftd
ftf:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fqy fte
ftg:f64[64,64] = mul fsp ftf
fth:f64[64,64] = add fsk ftg
fti:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] frh
ftj:f64[64,64] = squeeze[dimensions=(0,)] fti
ftk:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fri
ftl:f64[64,64] = squeeze[dimensions=(0,)] ftk
ftm:f64[64,64] = mul ftj ftl
ftn:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fro
fto:i64[64,64] = squeeze[dimensions=(0,)] ftn
ftp:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] frm
ftq:i64[64,64] = squeeze[dimensions=(0,)] ftp
ftr:bool[64,64] = lt fto 0:i64[]
fts:i64[64,64] = add fto 64:i64[]
ftt:i64[64,64] = select_n ftr fto fts
ftu:bool[64,64] = lt ftq 0:i64[]
ftv:i64[64,64] = add ftq 64:i64[]
ftw:i64[64,64] = select_n ftu ftq ftv
ftx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ftt
fty:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ftw
ftz:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ftx
fua:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fty
fub:i32[64,64,2] = concatenate[dimension=2] ftz fua
fuc:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fqy fub
fud:f64[64,64] = mul ftm fuc
fue:f64[64,64] = add fth fud
fuf:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] frh
fug:f64[64,64] = squeeze[dimensions=(0,)] fuf
fuh:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] frh
fui:f64[64,64] = squeeze[dimensions=(0,)] fuh
fuj:f64[64,64] = mul fug fui
fuk:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fro
ful:i64[64,64] = squeeze[dimensions=(0,)] fuk
fum:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fro
fun:i64[64,64] = squeeze[dimensions=(0,)] fum
fuo:bool[64,64] = lt ful 0:i64[]
fup:i64[64,64] = add ful 64:i64[]
fuq:i64[64,64] = select_n fuo ful fup
fur:bool[64,64] = lt fun 0:i64[]
fus:i64[64,64] = add fun 64:i64[]
fut:i64[64,64] = select_n fur fun fus
fuu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fuq
fuv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fut
fuw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fuu
fux:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fuv
fuy:i32[64,64,2] = concatenate[dimension=2] fuw fux
fuz:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fqy fuy
fva:f64[64,64] = mul fuj fuz
fvb:f64[64,64] = add fue fva
fvc:f64[64,64] = sub faa fvb
fvd:f64[64,64] = div fvc 2.0:f64[]
fve:f64[64,64] = add faa fvd
fvf:f64[] = neg eb
fvg:f64[] = convert_element_type[new_dtype=float64 weak_type=False] fvf
fvh:f64[2,64,64] = mul fvg fal
fvi:f64[2,64,64] = add ea fvh
fvj:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
fvk:f64[2,64,64] = div fvi fvj
fvl:f64[2,64,64] = floor fvk
fvm:f64[2,64,64] = sub fvk fvl
fvn:f64[2,64,64] = sub 1.0:f64[] fvm
fvo:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] cm
fvp:f64[2,64,64] = sub fvk fvm
fvq:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] fvp
fvr:i64[2,64,64] = jit[name=remainder jaxpr=remainder] fvq fvo
fvs:i64[2,64,64] = add fvr 1:i64[]
fvt:i64[2,64,64] = jit[name=remainder jaxpr=remainder] fvs fvo
fvu:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fvn
fvv:f64[64,64] = squeeze[dimensions=(0,)] fvu
fvw:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fvn
fvx:f64[64,64] = squeeze[dimensions=(0,)] fvw
fvy:f64[64,64] = mul fvv fvx
fvz:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fvr
fwa:i64[64,64] = squeeze[dimensions=(0,)] fvz
fwb:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fvr
fwc:i64[64,64] = squeeze[dimensions=(0,)] fwb
fwd:bool[64,64] = lt fwa 0:i64[]
fwe:i64[64,64] = add fwa 64:i64[]
fwf:i64[64,64] = select_n fwd fwa fwe
fwg:bool[64,64] = lt fwc 0:i64[]
fwh:i64[64,64] = add fwc 64:i64[]
fwi:i64[64,64] = select_n fwg fwc fwh
fwj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fwf
fwk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fwi
fwl:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fwj
fwm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fwk
fwn:i32[64,64,2] = concatenate[dimension=2] fwl fwm
fwo:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fve fwn
fwp:f64[64,64] = mul fvy fwo
fwq:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fvn
fwr:f64[64,64] = squeeze[dimensions=(0,)] fwq
fws:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fvm
fwt:f64[64,64] = squeeze[dimensions=(0,)] fws
fwu:f64[64,64] = mul fwr fwt
fwv:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fvr
fww:i64[64,64] = squeeze[dimensions=(0,)] fwv
fwx:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fvt
fwy:i64[64,64] = squeeze[dimensions=(0,)] fwx
fwz:bool[64,64] = lt fww 0:i64[]
fxa:i64[64,64] = add fww 64:i64[]
fxb:i64[64,64] = select_n fwz fww fxa
fxc:bool[64,64] = lt fwy 0:i64[]
fxd:i64[64,64] = add fwy 64:i64[]
fxe:i64[64,64] = select_n fxc fwy fxd
fxf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fxb
fxg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fxe
fxh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fxf
fxi:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fxg
fxj:i32[64,64,2] = concatenate[dimension=2] fxh fxi
fxk:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fve fxj
fxl:f64[64,64] = mul fwu fxk
fxm:f64[64,64] = add fwp fxl
fxn:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fvm
fxo:f64[64,64] = squeeze[dimensions=(0,)] fxn
fxp:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fvn
fxq:f64[64,64] = squeeze[dimensions=(0,)] fxp
fxr:f64[64,64] = mul fxo fxq
fxs:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fvt
fxt:i64[64,64] = squeeze[dimensions=(0,)] fxs
fxu:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fvr
fxv:i64[64,64] = squeeze[dimensions=(0,)] fxu
fxw:bool[64,64] = lt fxt 0:i64[]
fxx:i64[64,64] = add fxt 64:i64[]
fxy:i64[64,64] = select_n fxw fxt fxx
fxz:bool[64,64] = lt fxv 0:i64[]
fya:i64[64,64] = add fxv 64:i64[]
fyb:i64[64,64] = select_n fxz fxv fya
fyc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fxy
fyd:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fyb
fye:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fyc
fyf:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fyd
fyg:i32[64,64,2] = concatenate[dimension=2] fye fyf
fyh:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fve fyg
fyi:f64[64,64] = mul fxr fyh
fyj:f64[64,64] = add fxm fyi
fyk:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fvm
fyl:f64[64,64] = squeeze[dimensions=(0,)] fyk
fym:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fvm
fyn:f64[64,64] = squeeze[dimensions=(0,)] fym
fyo:f64[64,64] = mul fyl fyn
fyp:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fvt
fyq:i64[64,64] = squeeze[dimensions=(0,)] fyp
fyr:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fvt
fys:i64[64,64] = squeeze[dimensions=(0,)] fyr
fyt:bool[64,64] = lt fyq 0:i64[]
fyu:i64[64,64] = add fyq 64:i64[]
fyv:i64[64,64] = select_n fyt fyq fyu
fyw:bool[64,64] = lt fys 0:i64[]
fyx:i64[64,64] = add fys 64:i64[]
fyy:i64[64,64] = select_n fyw fys fyx
fyz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fyv
fza:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fyy
fzb:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fyz
fzc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] fza
fzd:i32[64,64,2] = concatenate[dimension=2] fzb fzc
fze:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fve fzd
fzf:f64[64,64] = mul fyo fze
fzg:f64[64,64] = add fyj fzf
fzh:c128[64,33] = jit[name=fft jaxpr=fft] fmw
fzi:c128[] = reduce_prod[axes=(0,)] cn
fzj:c128[] = sqrt fzi
fzk:c128[] = div (1+0j):c128[] fzj
fzl:c128[64,33] = mul fzh fzk
fzm:c128[1,64,33] = broadcast_in_dim[
broadcast_dimensions=(1, 2)
shape=(1, 64, 33)
sharding=None
] fzl
fzn:c128[2,64,33] = mul dw fzm
fzo:f64[2,64,64] = jit[name=fft jaxpr=fft1] fzn
fzp:f64[] = reduce_prod[axes=(0,)] co
fzq:f64[] = sqrt fzp
fzr:f64[2,64,64] = mul fzo fzq
fzs:f64[] = neg eb
fzt:f64[] = convert_element_type[new_dtype=float64 weak_type=False] fzs
fzu:f64[2,64,64] = mul fzt fzr
fzv:f64[2,64,64] = add ea fzu
fzw:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
fzx:f64[2,64,64] = div fzv fzw
fzy:f64[2,64,64] = floor fzx
fzz:f64[2,64,64] = sub fzx fzy
gaa:f64[2,64,64] = sub 1.0:f64[] fzz
gab:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] cp
gac:f64[2,64,64] = sub fzx fzz
gad:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] gac
gae:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gad gab
gaf:i64[2,64,64] = add gae 1:i64[]
gag:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gaf gab
gah:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gaa
gai:f64[64,64] = squeeze[dimensions=(0,)] gah
gaj:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gaa
gak:f64[64,64] = squeeze[dimensions=(0,)] gaj
gal:f64[64,64] = mul gai gak
gam:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gae
gan:i64[64,64] = squeeze[dimensions=(0,)] gam
gao:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gae
gap:i64[64,64] = squeeze[dimensions=(0,)] gao
gaq:bool[64,64] = lt gan 0:i64[]
gar:i64[64,64] = add gan 64:i64[]
gas:i64[64,64] = select_n gaq gan gar
gat:bool[64,64] = lt gap 0:i64[]
gau:i64[64,64] = add gap 64:i64[]
gav:i64[64,64] = select_n gat gap gau
gaw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gas
gax:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gav
gay:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gaw
gaz:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gax
gba:i32[64,64,2] = concatenate[dimension=2] gay gaz
gbb:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fmw gba
gbc:f64[64,64] = mul gal gbb
gbd:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gaa
gbe:f64[64,64] = squeeze[dimensions=(0,)] gbd
gbf:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fzz
gbg:f64[64,64] = squeeze[dimensions=(0,)] gbf
gbh:f64[64,64] = mul gbe gbg
gbi:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gae
gbj:i64[64,64] = squeeze[dimensions=(0,)] gbi
gbk:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gag
gbl:i64[64,64] = squeeze[dimensions=(0,)] gbk
gbm:bool[64,64] = lt gbj 0:i64[]
gbn:i64[64,64] = add gbj 64:i64[]
gbo:i64[64,64] = select_n gbm gbj gbn
gbp:bool[64,64] = lt gbl 0:i64[]
gbq:i64[64,64] = add gbl 64:i64[]
gbr:i64[64,64] = select_n gbp gbl gbq
gbs:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gbo
gbt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gbr
gbu:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gbs
gbv:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gbt
gbw:i32[64,64,2] = concatenate[dimension=2] gbu gbv
gbx:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fmw gbw
gby:f64[64,64] = mul gbh gbx
gbz:f64[64,64] = add gbc gby
gca:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fzz
gcb:f64[64,64] = squeeze[dimensions=(0,)] gca
gcc:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gaa
gcd:f64[64,64] = squeeze[dimensions=(0,)] gcc
gce:f64[64,64] = mul gcb gcd
gcf:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gag
gcg:i64[64,64] = squeeze[dimensions=(0,)] gcf
gch:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gae
gci:i64[64,64] = squeeze[dimensions=(0,)] gch
gcj:bool[64,64] = lt gcg 0:i64[]
gck:i64[64,64] = add gcg 64:i64[]
gcl:i64[64,64] = select_n gcj gcg gck
gcm:bool[64,64] = lt gci 0:i64[]
gcn:i64[64,64] = add gci 64:i64[]
gco:i64[64,64] = select_n gcm gci gcn
gcp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gcl
gcq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gco
gcr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gcp
gcs:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gcq
gct:i32[64,64,2] = concatenate[dimension=2] gcr gcs
gcu:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fmw gct
gcv:f64[64,64] = mul gce gcu
gcw:f64[64,64] = add gbz gcv
gcx:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] fzz
gcy:f64[64,64] = squeeze[dimensions=(0,)] gcx
gcz:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] fzz
gda:f64[64,64] = squeeze[dimensions=(0,)] gcz
gdb:f64[64,64] = mul gcy gda
gdc:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gag
gdd:i64[64,64] = squeeze[dimensions=(0,)] gdc
gde:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gag
gdf:i64[64,64] = squeeze[dimensions=(0,)] gde
gdg:bool[64,64] = lt gdd 0:i64[]
gdh:i64[64,64] = add gdd 64:i64[]
gdi:i64[64,64] = select_n gdg gdd gdh
gdj:bool[64,64] = lt gdf 0:i64[]
gdk:i64[64,64] = add gdf 64:i64[]
gdl:i64[64,64] = select_n gdj gdf gdk
gdm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gdi
gdn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gdl
gdo:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gdm
gdp:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gdn
gdq:i32[64,64,2] = concatenate[dimension=2] gdo gdp
gdr:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fmw gdq
gds:f64[64,64] = mul gdb gdr
gdt:f64[64,64] = add gcw gds
gdu:f64[2,64,64] = neg fzr
gdv:f64[] = neg eb
gdw:f64[] = convert_element_type[new_dtype=float64 weak_type=False] gdv
gdx:f64[2,64,64] = mul gdw gdu
gdy:f64[2,64,64] = add ea gdx
gdz:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
gea:f64[2,64,64] = div gdy gdz
geb:f64[2,64,64] = floor gea
gec:f64[2,64,64] = sub gea geb
ged:f64[2,64,64] = sub 1.0:f64[] gec
gee:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] cq
gef:f64[2,64,64] = sub gea gec
geg:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] gef
geh:i64[2,64,64] = jit[name=remainder jaxpr=remainder] geg gee
gei:i64[2,64,64] = add geh 1:i64[]
gej:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gei gee
gek:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ged
gel:f64[64,64] = squeeze[dimensions=(0,)] gek
gem:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ged
gen:f64[64,64] = squeeze[dimensions=(0,)] gem
geo:f64[64,64] = mul gel gen
gep:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] geh
geq:i64[64,64] = squeeze[dimensions=(0,)] gep
ger:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] geh
ges:i64[64,64] = squeeze[dimensions=(0,)] ger
get:bool[64,64] = lt geq 0:i64[]
geu:i64[64,64] = add geq 64:i64[]
gev:i64[64,64] = select_n get geq geu
gew:bool[64,64] = lt ges 0:i64[]
gex:i64[64,64] = add ges 64:i64[]
gey:i64[64,64] = select_n gew ges gex
gez:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gev
gfa:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gey
gfb:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gez
gfc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gfa
gfd:i32[64,64,2] = concatenate[dimension=2] gfb gfc
gfe:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] gdt gfd
gff:f64[64,64] = mul geo gfe
gfg:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ged
gfh:f64[64,64] = squeeze[dimensions=(0,)] gfg
gfi:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gec
gfj:f64[64,64] = squeeze[dimensions=(0,)] gfi
gfk:f64[64,64] = mul gfh gfj
gfl:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] geh
gfm:i64[64,64] = squeeze[dimensions=(0,)] gfl
gfn:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gej
gfo:i64[64,64] = squeeze[dimensions=(0,)] gfn
gfp:bool[64,64] = lt gfm 0:i64[]
gfq:i64[64,64] = add gfm 64:i64[]
gfr:i64[64,64] = select_n gfp gfm gfq
gfs:bool[64,64] = lt gfo 0:i64[]
gft:i64[64,64] = add gfo 64:i64[]
gfu:i64[64,64] = select_n gfs gfo gft
gfv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gfr
gfw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gfu
gfx:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gfv
gfy:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gfw
gfz:i32[64,64,2] = concatenate[dimension=2] gfx gfy
gga:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] gdt gfz
ggb:f64[64,64] = mul gfk gga
ggc:f64[64,64] = add gff ggb
ggd:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gec
gge:f64[64,64] = squeeze[dimensions=(0,)] ggd
ggf:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ged
ggg:f64[64,64] = squeeze[dimensions=(0,)] ggf
ggh:f64[64,64] = mul gge ggg
ggi:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gej
ggj:i64[64,64] = squeeze[dimensions=(0,)] ggi
ggk:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] geh
ggl:i64[64,64] = squeeze[dimensions=(0,)] ggk
ggm:bool[64,64] = lt ggj 0:i64[]
ggn:i64[64,64] = add ggj 64:i64[]
ggo:i64[64,64] = select_n ggm ggj ggn
ggp:bool[64,64] = lt ggl 0:i64[]
ggq:i64[64,64] = add ggl 64:i64[]
ggr:i64[64,64] = select_n ggp ggl ggq
ggs:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ggo
ggt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ggr
ggu:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ggs
ggv:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ggt
ggw:i32[64,64,2] = concatenate[dimension=2] ggu ggv
ggx:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] gdt ggw
ggy:f64[64,64] = mul ggh ggx
ggz:f64[64,64] = add ggc ggy
gha:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gec
ghb:f64[64,64] = squeeze[dimensions=(0,)] gha
ghc:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gec
ghd:f64[64,64] = squeeze[dimensions=(0,)] ghc
ghe:f64[64,64] = mul ghb ghd
ghf:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gej
ghg:i64[64,64] = squeeze[dimensions=(0,)] ghf
ghh:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gej
ghi:i64[64,64] = squeeze[dimensions=(0,)] ghh
ghj:bool[64,64] = lt ghg 0:i64[]
ghk:i64[64,64] = add ghg 64:i64[]
ghl:i64[64,64] = select_n ghj ghg ghk
ghm:bool[64,64] = lt ghi 0:i64[]
ghn:i64[64,64] = add ghi 64:i64[]
gho:i64[64,64] = select_n ghm ghi ghn
ghp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ghl
ghq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gho
ghr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ghp
ghs:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ghq
ght:i32[64,64,2] = concatenate[dimension=2] ghr ghs
ghu:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] gdt ght
ghv:f64[64,64] = mul ghe ghu
ghw:f64[64,64] = add ggz ghv
ghx:f64[64,64] = sub fmw ghw
ghy:f64[64,64] = div ghx 2.0:f64[]
ghz:f64[64,64] = add fmw ghy
gia:f64[] = neg eb
gib:f64[] = convert_element_type[new_dtype=float64 weak_type=False] gia
gic:f64[2,64,64] = mul gib fzr
gid:f64[2,64,64] = add ea gic
gie:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
gif:f64[2,64,64] = div gid gie
gig:f64[2,64,64] = floor gif
gih:f64[2,64,64] = sub gif gig
gii:f64[2,64,64] = sub 1.0:f64[] gih
gij:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] cr
gik:f64[2,64,64] = sub gif gih
gil:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] gik
gim:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gil gij
gin:i64[2,64,64] = add gim 1:i64[]
gio:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gin gij
gip:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gii
giq:f64[64,64] = squeeze[dimensions=(0,)] gip
gir:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gii
gis:f64[64,64] = squeeze[dimensions=(0,)] gir
git:f64[64,64] = mul giq gis
giu:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gim
giv:i64[64,64] = squeeze[dimensions=(0,)] giu
giw:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gim
gix:i64[64,64] = squeeze[dimensions=(0,)] giw
giy:bool[64,64] = lt giv 0:i64[]
giz:i64[64,64] = add giv 64:i64[]
gja:i64[64,64] = select_n giy giv giz
gjb:bool[64,64] = lt gix 0:i64[]
gjc:i64[64,64] = add gix 64:i64[]
gjd:i64[64,64] = select_n gjb gix gjc
gje:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gja
gjf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gjd
gjg:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gje
gjh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gjf
gji:i32[64,64,2] = concatenate[dimension=2] gjg gjh
gjj:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ghz gji
gjk:f64[64,64] = mul git gjj
gjl:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gii
gjm:f64[64,64] = squeeze[dimensions=(0,)] gjl
gjn:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gih
gjo:f64[64,64] = squeeze[dimensions=(0,)] gjn
gjp:f64[64,64] = mul gjm gjo
gjq:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gim
gjr:i64[64,64] = squeeze[dimensions=(0,)] gjq
gjs:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gio
gjt:i64[64,64] = squeeze[dimensions=(0,)] gjs
gju:bool[64,64] = lt gjr 0:i64[]
gjv:i64[64,64] = add gjr 64:i64[]
gjw:i64[64,64] = select_n gju gjr gjv
gjx:bool[64,64] = lt gjt 0:i64[]
gjy:i64[64,64] = add gjt 64:i64[]
gjz:i64[64,64] = select_n gjx gjt gjy
gka:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gjw
gkb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gjz
gkc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gka
gkd:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gkb
gke:i32[64,64,2] = concatenate[dimension=2] gkc gkd
gkf:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ghz gke
gkg:f64[64,64] = mul gjp gkf
gkh:f64[64,64] = add gjk gkg
gki:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gih
gkj:f64[64,64] = squeeze[dimensions=(0,)] gki
gkk:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gii
gkl:f64[64,64] = squeeze[dimensions=(0,)] gkk
gkm:f64[64,64] = mul gkj gkl
gkn:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gio
gko:i64[64,64] = squeeze[dimensions=(0,)] gkn
gkp:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gim
gkq:i64[64,64] = squeeze[dimensions=(0,)] gkp
gkr:bool[64,64] = lt gko 0:i64[]
gks:i64[64,64] = add gko 64:i64[]
gkt:i64[64,64] = select_n gkr gko gks
gku:bool[64,64] = lt gkq 0:i64[]
gkv:i64[64,64] = add gkq 64:i64[]
gkw:i64[64,64] = select_n gku gkq gkv
gkx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gkt
gky:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gkw
gkz:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gkx
gla:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gky
glb:i32[64,64,2] = concatenate[dimension=2] gkz gla
glc:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ghz glb
gld:f64[64,64] = mul gkm glc
gle:f64[64,64] = add gkh gld
glf:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gih
glg:f64[64,64] = squeeze[dimensions=(0,)] glf
glh:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gih
gli:f64[64,64] = squeeze[dimensions=(0,)] glh
glj:f64[64,64] = mul glg gli
glk:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gio
gll:i64[64,64] = squeeze[dimensions=(0,)] glk
glm:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gio
gln:i64[64,64] = squeeze[dimensions=(0,)] glm
glo:bool[64,64] = lt gll 0:i64[]
glp:i64[64,64] = add gll 64:i64[]
glq:i64[64,64] = select_n glo gll glp
glr:bool[64,64] = lt gln 0:i64[]
gls:i64[64,64] = add gln 64:i64[]
glt:i64[64,64] = select_n glr gln gls
glu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] glq
glv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] glt
glw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] glu
glx:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] glv
gly:i32[64,64,2] = concatenate[dimension=2] glw glx
glz:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ghz gly
gma:f64[64,64] = mul glj glz
gmb:f64[64,64] = add gle gma
gmc:f64[] = neg eb
gmd:f64[] = convert_element_type[new_dtype=float64 weak_type=False] gmc
gme:f64[2,64,64] = mul gmd fzr
gmf:f64[2,64,64] = add ea gme
gmg:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
gmh:f64[2,64,64] = div gmf gmg
gmi:f64[2,64,64] = floor gmh
gmj:f64[2,64,64] = sub gmh gmi
gmk:f64[2,64,64] = sub 1.0:f64[] gmj
gml:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] cs
gmm:f64[2,64,64] = sub gmh gmj
gmn:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] gmm
gmo:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gmn gml
gmp:i64[2,64,64] = add gmo 1:i64[]
gmq:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gmp gml
gmr:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gmk
gms:f64[64,64] = squeeze[dimensions=(0,)] gmr
gmt:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gmk
gmu:f64[64,64] = squeeze[dimensions=(0,)] gmt
gmv:f64[64,64] = mul gms gmu
gmw:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gmo
gmx:i64[64,64] = squeeze[dimensions=(0,)] gmw
gmy:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gmo
gmz:i64[64,64] = squeeze[dimensions=(0,)] gmy
gna:bool[64,64] = lt gmx 0:i64[]
gnb:i64[64,64] = add gmx 64:i64[]
gnc:i64[64,64] = select_n gna gmx gnb
gnd:bool[64,64] = lt gmz 0:i64[]
gne:i64[64,64] = add gmz 64:i64[]
gnf:i64[64,64] = select_n gnd gmz gne
gng:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gnc
gnh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gnf
gni:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gng
gnj:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gnh
gnk:i32[64,64,2] = concatenate[dimension=2] gni gnj
gnl:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fzg gnk
gnm:f64[64,64] = mul gmv gnl
gnn:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gmk
gno:f64[64,64] = squeeze[dimensions=(0,)] gnn
gnp:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gmj
gnq:f64[64,64] = squeeze[dimensions=(0,)] gnp
gnr:f64[64,64] = mul gno gnq
gns:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gmo
gnt:i64[64,64] = squeeze[dimensions=(0,)] gns
gnu:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gmq
gnv:i64[64,64] = squeeze[dimensions=(0,)] gnu
gnw:bool[64,64] = lt gnt 0:i64[]
gnx:i64[64,64] = add gnt 64:i64[]
gny:i64[64,64] = select_n gnw gnt gnx
gnz:bool[64,64] = lt gnv 0:i64[]
goa:i64[64,64] = add gnv 64:i64[]
gob:i64[64,64] = select_n gnz gnv goa
goc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gny
god:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gob
goe:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] goc
gof:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] god
gog:i32[64,64,2] = concatenate[dimension=2] goe gof
goh:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fzg gog
goi:f64[64,64] = mul gnr goh
goj:f64[64,64] = add gnm goi
gok:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gmj
gol:f64[64,64] = squeeze[dimensions=(0,)] gok
gom:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gmk
gon:f64[64,64] = squeeze[dimensions=(0,)] gom
goo:f64[64,64] = mul gol gon
gop:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gmq
goq:i64[64,64] = squeeze[dimensions=(0,)] gop
gor:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gmo
gos:i64[64,64] = squeeze[dimensions=(0,)] gor
got:bool[64,64] = lt goq 0:i64[]
gou:i64[64,64] = add goq 64:i64[]
gov:i64[64,64] = select_n got goq gou
gow:bool[64,64] = lt gos 0:i64[]
gox:i64[64,64] = add gos 64:i64[]
goy:i64[64,64] = select_n gow gos gox
goz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gov
gpa:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] goy
gpb:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] goz
gpc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gpa
gpd:i32[64,64,2] = concatenate[dimension=2] gpb gpc
gpe:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fzg gpd
gpf:f64[64,64] = mul goo gpe
gpg:f64[64,64] = add goj gpf
gph:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gmj
gpi:f64[64,64] = squeeze[dimensions=(0,)] gph
gpj:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gmj
gpk:f64[64,64] = squeeze[dimensions=(0,)] gpj
gpl:f64[64,64] = mul gpi gpk
gpm:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gmq
gpn:i64[64,64] = squeeze[dimensions=(0,)] gpm
gpo:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gmq
gpp:i64[64,64] = squeeze[dimensions=(0,)] gpo
gpq:bool[64,64] = lt gpn 0:i64[]
gpr:i64[64,64] = add gpn 64:i64[]
gps:i64[64,64] = select_n gpq gpn gpr
gpt:bool[64,64] = lt gpp 0:i64[]
gpu:i64[64,64] = add gpp 64:i64[]
gpv:i64[64,64] = select_n gpt gpp gpu
gpw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gps
gpx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gpv
gpy:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gpw
gpz:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gpx
gqa:i32[64,64,2] = concatenate[dimension=2] gpy gpz
gqb:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] fzg gqa
gqc:f64[64,64] = mul gpl gqb
gqd:f64[64,64] = add gpg gqc
gqe:f64[2,64,64] = neg fzr
gqf:f64[] = neg eb
gqg:f64[] = convert_element_type[new_dtype=float64 weak_type=False] gqf
gqh:f64[2,64,64] = mul gqg gqe
gqi:f64[2,64,64] = add ea gqh
gqj:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
gqk:f64[2,64,64] = div gqi gqj
gql:f64[2,64,64] = floor gqk
gqm:f64[2,64,64] = sub gqk gql
gqn:f64[2,64,64] = sub 1.0:f64[] gqm
gqo:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] ct
gqp:f64[2,64,64] = sub gqk gqm
gqq:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] gqp
gqr:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gqq gqo
gqs:i64[2,64,64] = add gqr 1:i64[]
gqt:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gqs gqo
gqu:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gqn
gqv:f64[64,64] = squeeze[dimensions=(0,)] gqu
gqw:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gqn
gqx:f64[64,64] = squeeze[dimensions=(0,)] gqw
gqy:f64[64,64] = mul gqv gqx
gqz:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gqr
gra:i64[64,64] = squeeze[dimensions=(0,)] gqz
grb:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gqr
grc:i64[64,64] = squeeze[dimensions=(0,)] grb
grd:bool[64,64] = lt gra 0:i64[]
gre:i64[64,64] = add gra 64:i64[]
grf:i64[64,64] = select_n grd gra gre
grg:bool[64,64] = lt grc 0:i64[]
grh:i64[64,64] = add grc 64:i64[]
gri:i64[64,64] = select_n grg grc grh
grj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] grf
grk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gri
grl:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] grj
grm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] grk
grn:i32[64,64,2] = concatenate[dimension=2] grl grm
gro:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] gqd grn
grp:f64[64,64] = mul gqy gro
grq:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gqn
grr:f64[64,64] = squeeze[dimensions=(0,)] grq
grs:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gqm
grt:f64[64,64] = squeeze[dimensions=(0,)] grs
gru:f64[64,64] = mul grr grt
grv:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gqr
grw:i64[64,64] = squeeze[dimensions=(0,)] grv
grx:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gqt
gry:i64[64,64] = squeeze[dimensions=(0,)] grx
grz:bool[64,64] = lt grw 0:i64[]
gsa:i64[64,64] = add grw 64:i64[]
gsb:i64[64,64] = select_n grz grw gsa
gsc:bool[64,64] = lt gry 0:i64[]
gsd:i64[64,64] = add gry 64:i64[]
gse:i64[64,64] = select_n gsc gry gsd
gsf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gsb
gsg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gse
gsh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gsf
gsi:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gsg
gsj:i32[64,64,2] = concatenate[dimension=2] gsh gsi
gsk:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] gqd gsj
gsl:f64[64,64] = mul gru gsk
gsm:f64[64,64] = add grp gsl
gsn:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gqm
gso:f64[64,64] = squeeze[dimensions=(0,)] gsn
gsp:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gqn
gsq:f64[64,64] = squeeze[dimensions=(0,)] gsp
gsr:f64[64,64] = mul gso gsq
gss:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gqt
gst:i64[64,64] = squeeze[dimensions=(0,)] gss
gsu:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gqr
gsv:i64[64,64] = squeeze[dimensions=(0,)] gsu
gsw:bool[64,64] = lt gst 0:i64[]
gsx:i64[64,64] = add gst 64:i64[]
gsy:i64[64,64] = select_n gsw gst gsx
gsz:bool[64,64] = lt gsv 0:i64[]
gta:i64[64,64] = add gsv 64:i64[]
gtb:i64[64,64] = select_n gsz gsv gta
gtc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gsy
gtd:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gtb
gte:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gtc
gtf:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gtd
gtg:i32[64,64,2] = concatenate[dimension=2] gte gtf
gth:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] gqd gtg
gti:f64[64,64] = mul gsr gth
gtj:f64[64,64] = add gsm gti
gtk:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gqm
gtl:f64[64,64] = squeeze[dimensions=(0,)] gtk
gtm:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gqm
gtn:f64[64,64] = squeeze[dimensions=(0,)] gtm
gto:f64[64,64] = mul gtl gtn
gtp:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gqt
gtq:i64[64,64] = squeeze[dimensions=(0,)] gtp
gtr:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gqt
gts:i64[64,64] = squeeze[dimensions=(0,)] gtr
gtt:bool[64,64] = lt gtq 0:i64[]
gtu:i64[64,64] = add gtq 64:i64[]
gtv:i64[64,64] = select_n gtt gtq gtu
gtw:bool[64,64] = lt gts 0:i64[]
gtx:i64[64,64] = add gts 64:i64[]
gty:i64[64,64] = select_n gtw gts gtx
gtz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gtv
gua:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gty
gub:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gtz
guc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gua
gud:i32[64,64,2] = concatenate[dimension=2] gub guc
gue:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] gqd gud
guf:f64[64,64] = mul gto gue
gug:f64[64,64] = add gtj guf
guh:f64[64,64] = sub fzg gug
gui:f64[64,64] = div guh 2.0:f64[]
guj:f64[64,64] = add fzg gui
guk:f64[] = neg eb
gul:f64[] = convert_element_type[new_dtype=float64 weak_type=False] guk
gum:f64[2,64,64] = mul gul fzr
gun:f64[2,64,64] = add ea gum
guo:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
gup:f64[2,64,64] = div gun guo
guq:f64[2,64,64] = floor gup
gur:f64[2,64,64] = sub gup guq
gus:f64[2,64,64] = sub 1.0:f64[] gur
gut:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] cu
guu:f64[2,64,64] = sub gup gur
guv:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] guu
guw:i64[2,64,64] = jit[name=remainder jaxpr=remainder] guv gut
gux:i64[2,64,64] = add guw 1:i64[]
guy:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gux gut
guz:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gus
gva:f64[64,64] = squeeze[dimensions=(0,)] guz
gvb:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gus
gvc:f64[64,64] = squeeze[dimensions=(0,)] gvb
gvd:f64[64,64] = mul gva gvc
gve:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] guw
gvf:i64[64,64] = squeeze[dimensions=(0,)] gve
gvg:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] guw
gvh:i64[64,64] = squeeze[dimensions=(0,)] gvg
gvi:bool[64,64] = lt gvf 0:i64[]
gvj:i64[64,64] = add gvf 64:i64[]
gvk:i64[64,64] = select_n gvi gvf gvj
gvl:bool[64,64] = lt gvh 0:i64[]
gvm:i64[64,64] = add gvh 64:i64[]
gvn:i64[64,64] = select_n gvl gvh gvm
gvo:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gvk
gvp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gvn
gvq:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gvo
gvr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gvp
gvs:i32[64,64,2] = concatenate[dimension=2] gvq gvr
gvt:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] guj gvs
gvu:f64[64,64] = mul gvd gvt
gvv:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gus
gvw:f64[64,64] = squeeze[dimensions=(0,)] gvv
gvx:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gur
gvy:f64[64,64] = squeeze[dimensions=(0,)] gvx
gvz:f64[64,64] = mul gvw gvy
gwa:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] guw
gwb:i64[64,64] = squeeze[dimensions=(0,)] gwa
gwc:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] guy
gwd:i64[64,64] = squeeze[dimensions=(0,)] gwc
gwe:bool[64,64] = lt gwb 0:i64[]
gwf:i64[64,64] = add gwb 64:i64[]
gwg:i64[64,64] = select_n gwe gwb gwf
gwh:bool[64,64] = lt gwd 0:i64[]
gwi:i64[64,64] = add gwd 64:i64[]
gwj:i64[64,64] = select_n gwh gwd gwi
gwk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gwg
gwl:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gwj
gwm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gwk
gwn:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gwl
gwo:i32[64,64,2] = concatenate[dimension=2] gwm gwn
gwp:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] guj gwo
gwq:f64[64,64] = mul gvz gwp
gwr:f64[64,64] = add gvu gwq
gws:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gur
gwt:f64[64,64] = squeeze[dimensions=(0,)] gws
gwu:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gus
gwv:f64[64,64] = squeeze[dimensions=(0,)] gwu
gww:f64[64,64] = mul gwt gwv
gwx:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] guy
gwy:i64[64,64] = squeeze[dimensions=(0,)] gwx
gwz:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] guw
gxa:i64[64,64] = squeeze[dimensions=(0,)] gwz
gxb:bool[64,64] = lt gwy 0:i64[]
gxc:i64[64,64] = add gwy 64:i64[]
gxd:i64[64,64] = select_n gxb gwy gxc
gxe:bool[64,64] = lt gxa 0:i64[]
gxf:i64[64,64] = add gxa 64:i64[]
gxg:i64[64,64] = select_n gxe gxa gxf
gxh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gxd
gxi:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gxg
gxj:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gxh
gxk:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gxi
gxl:i32[64,64,2] = concatenate[dimension=2] gxj gxk
gxm:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] guj gxl
gxn:f64[64,64] = mul gww gxm
gxo:f64[64,64] = add gwr gxn
gxp:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gur
gxq:f64[64,64] = squeeze[dimensions=(0,)] gxp
gxr:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gur
gxs:f64[64,64] = squeeze[dimensions=(0,)] gxr
gxt:f64[64,64] = mul gxq gxs
gxu:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] guy
gxv:i64[64,64] = squeeze[dimensions=(0,)] gxu
gxw:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] guy
gxx:i64[64,64] = squeeze[dimensions=(0,)] gxw
gxy:bool[64,64] = lt gxv 0:i64[]
gxz:i64[64,64] = add gxv 64:i64[]
gya:i64[64,64] = select_n gxy gxv gxz
gyb:bool[64,64] = lt gxx 0:i64[]
gyc:i64[64,64] = add gxx 64:i64[]
gyd:i64[64,64] = select_n gyb gxx gyc
gye:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gya
gyf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gyd
gyg:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gye
gyh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] gyf
gyi:i32[64,64,2] = concatenate[dimension=2] gyg gyh
gyj:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] guj gyi
gyk:f64[64,64] = mul gxt gyj
gyl:f64[64,64] = add gxo gyk
gym:c128[64,33] = jit[name=fft jaxpr=fft] gmb
gyn:c128[] = reduce_prod[axes=(0,)] cv
gyo:c128[] = sqrt gyn
gyp:c128[] = div (1+0j):c128[] gyo
gyq:c128[64,33] = mul gym gyp
gyr:c128[1,64,33] = broadcast_in_dim[
broadcast_dimensions=(1, 2)
shape=(1, 64, 33)
sharding=None
] gyq
gys:c128[2,64,33] = mul dw gyr
gyt:f64[2,64,64] = jit[name=fft jaxpr=fft1] gys
gyu:f64[] = reduce_prod[axes=(0,)] cw
gyv:f64[] = sqrt gyu
gyw:f64[2,64,64] = mul gyt gyv
gyx:f64[] = neg eb
gyy:f64[] = convert_element_type[new_dtype=float64 weak_type=False] gyx
gyz:f64[2,64,64] = mul gyy gyw
gza:f64[2,64,64] = add ea gyz
gzb:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
gzc:f64[2,64,64] = div gza gzb
gzd:f64[2,64,64] = floor gzc
gze:f64[2,64,64] = sub gzc gzd
gzf:f64[2,64,64] = sub 1.0:f64[] gze
gzg:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] cx
gzh:f64[2,64,64] = sub gzc gze
gzi:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] gzh
gzj:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gzi gzg
gzk:i64[2,64,64] = add gzj 1:i64[]
gzl:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gzk gzg
gzm:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gzf
gzn:f64[64,64] = squeeze[dimensions=(0,)] gzm
gzo:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gzf
gzp:f64[64,64] = squeeze[dimensions=(0,)] gzo
gzq:f64[64,64] = mul gzn gzp
gzr:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gzj
gzs:i64[64,64] = squeeze[dimensions=(0,)] gzr
gzt:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gzj
gzu:i64[64,64] = squeeze[dimensions=(0,)] gzt
gzv:bool[64,64] = lt gzs 0:i64[]
gzw:i64[64,64] = add gzs 64:i64[]
gzx:i64[64,64] = select_n gzv gzs gzw
gzy:bool[64,64] = lt gzu 0:i64[]
gzz:i64[64,64] = add gzu 64:i64[]
haa:i64[64,64] = select_n gzy gzu gzz
hab:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gzx
hac:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] haa
had:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hab
hae:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hac
haf:i32[64,64,2] = concatenate[dimension=2] had hae
hag:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] gmb haf
hah:f64[64,64] = mul gzq hag
hai:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gzf
haj:f64[64,64] = squeeze[dimensions=(0,)] hai
hak:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gze
hal:f64[64,64] = squeeze[dimensions=(0,)] hak
ham:f64[64,64] = mul haj hal
han:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gzj
hao:i64[64,64] = squeeze[dimensions=(0,)] han
hap:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gzl
haq:i64[64,64] = squeeze[dimensions=(0,)] hap
har:bool[64,64] = lt hao 0:i64[]
has:i64[64,64] = add hao 64:i64[]
hat:i64[64,64] = select_n har hao has
hau:bool[64,64] = lt haq 0:i64[]
hav:i64[64,64] = add haq 64:i64[]
haw:i64[64,64] = select_n hau haq hav
hax:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hat
hay:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] haw
haz:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hax
hba:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hay
hbb:i32[64,64,2] = concatenate[dimension=2] haz hba
hbc:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] gmb hbb
hbd:f64[64,64] = mul ham hbc
hbe:f64[64,64] = add hah hbd
hbf:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gze
hbg:f64[64,64] = squeeze[dimensions=(0,)] hbf
hbh:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gzf
hbi:f64[64,64] = squeeze[dimensions=(0,)] hbh
hbj:f64[64,64] = mul hbg hbi
hbk:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gzl
hbl:i64[64,64] = squeeze[dimensions=(0,)] hbk
hbm:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gzj
hbn:i64[64,64] = squeeze[dimensions=(0,)] hbm
hbo:bool[64,64] = lt hbl 0:i64[]
hbp:i64[64,64] = add hbl 64:i64[]
hbq:i64[64,64] = select_n hbo hbl hbp
hbr:bool[64,64] = lt hbn 0:i64[]
hbs:i64[64,64] = add hbn 64:i64[]
hbt:i64[64,64] = select_n hbr hbn hbs
hbu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hbq
hbv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hbt
hbw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hbu
hbx:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hbv
hby:i32[64,64,2] = concatenate[dimension=2] hbw hbx
hbz:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] gmb hby
hca:f64[64,64] = mul hbj hbz
hcb:f64[64,64] = add hbe hca
hcc:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gze
hcd:f64[64,64] = squeeze[dimensions=(0,)] hcc
hce:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gze
hcf:f64[64,64] = squeeze[dimensions=(0,)] hce
hcg:f64[64,64] = mul hcd hcf
hch:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] gzl
hci:i64[64,64] = squeeze[dimensions=(0,)] hch
hcj:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] gzl
hck:i64[64,64] = squeeze[dimensions=(0,)] hcj
hcl:bool[64,64] = lt hci 0:i64[]
hcm:i64[64,64] = add hci 64:i64[]
hcn:i64[64,64] = select_n hcl hci hcm
hco:bool[64,64] = lt hck 0:i64[]
hcp:i64[64,64] = add hck 64:i64[]
hcq:i64[64,64] = select_n hco hck hcp
hcr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hcn
hcs:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hcq
hct:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hcr
hcu:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hcs
hcv:i32[64,64,2] = concatenate[dimension=2] hct hcu
hcw:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] gmb hcv
hcx:f64[64,64] = mul hcg hcw
hcy:f64[64,64] = add hcb hcx
hcz:f64[2,64,64] = neg gyw
hda:f64[] = neg eb
hdb:f64[] = convert_element_type[new_dtype=float64 weak_type=False] hda
hdc:f64[2,64,64] = mul hdb hcz
hdd:f64[2,64,64] = add ea hdc
hde:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
hdf:f64[2,64,64] = div hdd hde
hdg:f64[2,64,64] = floor hdf
hdh:f64[2,64,64] = sub hdf hdg
hdi:f64[2,64,64] = sub 1.0:f64[] hdh
hdj:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] cy
hdk:f64[2,64,64] = sub hdf hdh
hdl:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] hdk
hdm:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hdl hdj
hdn:i64[2,64,64] = add hdm 1:i64[]
hdo:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hdn hdj
hdp:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hdi
hdq:f64[64,64] = squeeze[dimensions=(0,)] hdp
hdr:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hdi
hds:f64[64,64] = squeeze[dimensions=(0,)] hdr
hdt:f64[64,64] = mul hdq hds
hdu:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hdm
hdv:i64[64,64] = squeeze[dimensions=(0,)] hdu
hdw:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hdm
hdx:i64[64,64] = squeeze[dimensions=(0,)] hdw
hdy:bool[64,64] = lt hdv 0:i64[]
hdz:i64[64,64] = add hdv 64:i64[]
hea:i64[64,64] = select_n hdy hdv hdz
heb:bool[64,64] = lt hdx 0:i64[]
hec:i64[64,64] = add hdx 64:i64[]
hed:i64[64,64] = select_n heb hdx hec
hee:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hea
hef:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hed
heg:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hee
heh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hef
hei:i32[64,64,2] = concatenate[dimension=2] heg heh
hej:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hcy hei
hek:f64[64,64] = mul hdt hej
hel:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hdi
hem:f64[64,64] = squeeze[dimensions=(0,)] hel
hen:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hdh
heo:f64[64,64] = squeeze[dimensions=(0,)] hen
hep:f64[64,64] = mul hem heo
heq:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hdm
her:i64[64,64] = squeeze[dimensions=(0,)] heq
hes:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hdo
het:i64[64,64] = squeeze[dimensions=(0,)] hes
heu:bool[64,64] = lt her 0:i64[]
hev:i64[64,64] = add her 64:i64[]
hew:i64[64,64] = select_n heu her hev
hex:bool[64,64] = lt het 0:i64[]
hey:i64[64,64] = add het 64:i64[]
hez:i64[64,64] = select_n hex het hey
hfa:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hew
hfb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hez
hfc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hfa
hfd:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hfb
hfe:i32[64,64,2] = concatenate[dimension=2] hfc hfd
hff:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hcy hfe
hfg:f64[64,64] = mul hep hff
hfh:f64[64,64] = add hek hfg
hfi:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hdh
hfj:f64[64,64] = squeeze[dimensions=(0,)] hfi
hfk:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hdi
hfl:f64[64,64] = squeeze[dimensions=(0,)] hfk
hfm:f64[64,64] = mul hfj hfl
hfn:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hdo
hfo:i64[64,64] = squeeze[dimensions=(0,)] hfn
hfp:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hdm
hfq:i64[64,64] = squeeze[dimensions=(0,)] hfp
hfr:bool[64,64] = lt hfo 0:i64[]
hfs:i64[64,64] = add hfo 64:i64[]
hft:i64[64,64] = select_n hfr hfo hfs
hfu:bool[64,64] = lt hfq 0:i64[]
hfv:i64[64,64] = add hfq 64:i64[]
hfw:i64[64,64] = select_n hfu hfq hfv
hfx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hft
hfy:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hfw
hfz:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hfx
hga:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hfy
hgb:i32[64,64,2] = concatenate[dimension=2] hfz hga
hgc:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hcy hgb
hgd:f64[64,64] = mul hfm hgc
hge:f64[64,64] = add hfh hgd
hgf:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hdh
hgg:f64[64,64] = squeeze[dimensions=(0,)] hgf
hgh:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hdh
hgi:f64[64,64] = squeeze[dimensions=(0,)] hgh
hgj:f64[64,64] = mul hgg hgi
hgk:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hdo
hgl:i64[64,64] = squeeze[dimensions=(0,)] hgk
hgm:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hdo
hgn:i64[64,64] = squeeze[dimensions=(0,)] hgm
hgo:bool[64,64] = lt hgl 0:i64[]
hgp:i64[64,64] = add hgl 64:i64[]
hgq:i64[64,64] = select_n hgo hgl hgp
hgr:bool[64,64] = lt hgn 0:i64[]
hgs:i64[64,64] = add hgn 64:i64[]
hgt:i64[64,64] = select_n hgr hgn hgs
hgu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hgq
hgv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hgt
hgw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hgu
hgx:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hgv
hgy:i32[64,64,2] = concatenate[dimension=2] hgw hgx
hgz:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hcy hgy
hha:f64[64,64] = mul hgj hgz
hhb:f64[64,64] = add hge hha
hhc:f64[64,64] = sub gmb hhb
hhd:f64[64,64] = div hhc 2.0:f64[]
hhe:f64[64,64] = add gmb hhd
hhf:f64[] = neg eb
hhg:f64[] = convert_element_type[new_dtype=float64 weak_type=False] hhf
hhh:f64[2,64,64] = mul hhg gyw
hhi:f64[2,64,64] = add ea hhh
hhj:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
hhk:f64[2,64,64] = div hhi hhj
hhl:f64[2,64,64] = floor hhk
hhm:f64[2,64,64] = sub hhk hhl
hhn:f64[2,64,64] = sub 1.0:f64[] hhm
hho:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] cz
hhp:f64[2,64,64] = sub hhk hhm
hhq:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] hhp
hhr:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hhq hho
hhs:i64[2,64,64] = add hhr 1:i64[]
hht:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hhs hho
hhu:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hhn
hhv:f64[64,64] = squeeze[dimensions=(0,)] hhu
hhw:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hhn
hhx:f64[64,64] = squeeze[dimensions=(0,)] hhw
hhy:f64[64,64] = mul hhv hhx
hhz:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hhr
hia:i64[64,64] = squeeze[dimensions=(0,)] hhz
hib:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hhr
hic:i64[64,64] = squeeze[dimensions=(0,)] hib
hid:bool[64,64] = lt hia 0:i64[]
hie:i64[64,64] = add hia 64:i64[]
hif:i64[64,64] = select_n hid hia hie
hig:bool[64,64] = lt hic 0:i64[]
hih:i64[64,64] = add hic 64:i64[]
hii:i64[64,64] = select_n hig hic hih
hij:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hif
hik:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hii
hil:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hij
him:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hik
hin:i32[64,64,2] = concatenate[dimension=2] hil him
hio:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hhe hin
hip:f64[64,64] = mul hhy hio
hiq:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hhn
hir:f64[64,64] = squeeze[dimensions=(0,)] hiq
his:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hhm
hit:f64[64,64] = squeeze[dimensions=(0,)] his
hiu:f64[64,64] = mul hir hit
hiv:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hhr
hiw:i64[64,64] = squeeze[dimensions=(0,)] hiv
hix:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hht
hiy:i64[64,64] = squeeze[dimensions=(0,)] hix
hiz:bool[64,64] = lt hiw 0:i64[]
hja:i64[64,64] = add hiw 64:i64[]
hjb:i64[64,64] = select_n hiz hiw hja
hjc:bool[64,64] = lt hiy 0:i64[]
hjd:i64[64,64] = add hiy 64:i64[]
hje:i64[64,64] = select_n hjc hiy hjd
hjf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hjb
hjg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hje
hjh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hjf
hji:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hjg
hjj:i32[64,64,2] = concatenate[dimension=2] hjh hji
hjk:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hhe hjj
hjl:f64[64,64] = mul hiu hjk
hjm:f64[64,64] = add hip hjl
hjn:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hhm
hjo:f64[64,64] = squeeze[dimensions=(0,)] hjn
hjp:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hhn
hjq:f64[64,64] = squeeze[dimensions=(0,)] hjp
hjr:f64[64,64] = mul hjo hjq
hjs:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hht
hjt:i64[64,64] = squeeze[dimensions=(0,)] hjs
hju:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hhr
hjv:i64[64,64] = squeeze[dimensions=(0,)] hju
hjw:bool[64,64] = lt hjt 0:i64[]
hjx:i64[64,64] = add hjt 64:i64[]
hjy:i64[64,64] = select_n hjw hjt hjx
hjz:bool[64,64] = lt hjv 0:i64[]
hka:i64[64,64] = add hjv 64:i64[]
hkb:i64[64,64] = select_n hjz hjv hka
hkc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hjy
hkd:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hkb
hke:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hkc
hkf:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hkd
hkg:i32[64,64,2] = concatenate[dimension=2] hke hkf
hkh:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hhe hkg
hki:f64[64,64] = mul hjr hkh
hkj:f64[64,64] = add hjm hki
hkk:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hhm
hkl:f64[64,64] = squeeze[dimensions=(0,)] hkk
hkm:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hhm
hkn:f64[64,64] = squeeze[dimensions=(0,)] hkm
hko:f64[64,64] = mul hkl hkn
hkp:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hht
hkq:i64[64,64] = squeeze[dimensions=(0,)] hkp
hkr:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hht
hks:i64[64,64] = squeeze[dimensions=(0,)] hkr
hkt:bool[64,64] = lt hkq 0:i64[]
hku:i64[64,64] = add hkq 64:i64[]
hkv:i64[64,64] = select_n hkt hkq hku
hkw:bool[64,64] = lt hks 0:i64[]
hkx:i64[64,64] = add hks 64:i64[]
hky:i64[64,64] = select_n hkw hks hkx
hkz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hkv
hla:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hky
hlb:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hkz
hlc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hla
hld:i32[64,64,2] = concatenate[dimension=2] hlb hlc
hle:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hhe hld
hlf:f64[64,64] = mul hko hle
hlg:f64[64,64] = add hkj hlf
hlh:f64[] = neg eb
hli:f64[] = convert_element_type[new_dtype=float64 weak_type=False] hlh
hlj:f64[2,64,64] = mul hli gyw
hlk:f64[2,64,64] = add ea hlj
hll:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
hlm:f64[2,64,64] = div hlk hll
hln:f64[2,64,64] = floor hlm
hlo:f64[2,64,64] = sub hlm hln
hlp:f64[2,64,64] = sub 1.0:f64[] hlo
hlq:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] da
hlr:f64[2,64,64] = sub hlm hlo
hls:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] hlr
hlt:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hls hlq
hlu:i64[2,64,64] = add hlt 1:i64[]
hlv:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hlu hlq
hlw:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hlp
hlx:f64[64,64] = squeeze[dimensions=(0,)] hlw
hly:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hlp
hlz:f64[64,64] = squeeze[dimensions=(0,)] hly
hma:f64[64,64] = mul hlx hlz
hmb:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hlt
hmc:i64[64,64] = squeeze[dimensions=(0,)] hmb
hmd:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hlt
hme:i64[64,64] = squeeze[dimensions=(0,)] hmd
hmf:bool[64,64] = lt hmc 0:i64[]
hmg:i64[64,64] = add hmc 64:i64[]
hmh:i64[64,64] = select_n hmf hmc hmg
hmi:bool[64,64] = lt hme 0:i64[]
hmj:i64[64,64] = add hme 64:i64[]
hmk:i64[64,64] = select_n hmi hme hmj
hml:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hmh
hmm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hmk
hmn:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hml
hmo:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hmm
hmp:i32[64,64,2] = concatenate[dimension=2] hmn hmo
hmq:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] gyl hmp
hmr:f64[64,64] = mul hma hmq
hms:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hlp
hmt:f64[64,64] = squeeze[dimensions=(0,)] hms
hmu:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hlo
hmv:f64[64,64] = squeeze[dimensions=(0,)] hmu
hmw:f64[64,64] = mul hmt hmv
hmx:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hlt
hmy:i64[64,64] = squeeze[dimensions=(0,)] hmx
hmz:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hlv
hna:i64[64,64] = squeeze[dimensions=(0,)] hmz
hnb:bool[64,64] = lt hmy 0:i64[]
hnc:i64[64,64] = add hmy 64:i64[]
hnd:i64[64,64] = select_n hnb hmy hnc
hne:bool[64,64] = lt hna 0:i64[]
hnf:i64[64,64] = add hna 64:i64[]
hng:i64[64,64] = select_n hne hna hnf
hnh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hnd
hni:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hng
hnj:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hnh
hnk:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hni
hnl:i32[64,64,2] = concatenate[dimension=2] hnj hnk
hnm:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] gyl hnl
hnn:f64[64,64] = mul hmw hnm
hno:f64[64,64] = add hmr hnn
hnp:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hlo
hnq:f64[64,64] = squeeze[dimensions=(0,)] hnp
hnr:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hlp
hns:f64[64,64] = squeeze[dimensions=(0,)] hnr
hnt:f64[64,64] = mul hnq hns
hnu:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hlv
hnv:i64[64,64] = squeeze[dimensions=(0,)] hnu
hnw:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hlt
hnx:i64[64,64] = squeeze[dimensions=(0,)] hnw
hny:bool[64,64] = lt hnv 0:i64[]
hnz:i64[64,64] = add hnv 64:i64[]
hoa:i64[64,64] = select_n hny hnv hnz
hob:bool[64,64] = lt hnx 0:i64[]
hoc:i64[64,64] = add hnx 64:i64[]
hod:i64[64,64] = select_n hob hnx hoc
hoe:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hoa
hof:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hod
hog:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hoe
hoh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hof
hoi:i32[64,64,2] = concatenate[dimension=2] hog hoh
hoj:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] gyl hoi
hok:f64[64,64] = mul hnt hoj
hol:f64[64,64] = add hno hok
hom:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hlo
hon:f64[64,64] = squeeze[dimensions=(0,)] hom
hoo:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hlo
hop:f64[64,64] = squeeze[dimensions=(0,)] hoo
hoq:f64[64,64] = mul hon hop
hor:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hlv
hos:i64[64,64] = squeeze[dimensions=(0,)] hor
hot:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hlv
hou:i64[64,64] = squeeze[dimensions=(0,)] hot
hov:bool[64,64] = lt hos 0:i64[]
how:i64[64,64] = add hos 64:i64[]
hox:i64[64,64] = select_n hov hos how
hoy:bool[64,64] = lt hou 0:i64[]
hoz:i64[64,64] = add hou 64:i64[]
hpa:i64[64,64] = select_n hoy hou hoz
hpb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hox
hpc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hpa
hpd:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hpb
hpe:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hpc
hpf:i32[64,64,2] = concatenate[dimension=2] hpd hpe
hpg:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] gyl hpf
hph:f64[64,64] = mul hoq hpg
hpi:f64[64,64] = add hol hph
hpj:f64[2,64,64] = neg gyw
hpk:f64[] = neg eb
hpl:f64[] = convert_element_type[new_dtype=float64 weak_type=False] hpk
hpm:f64[2,64,64] = mul hpl hpj
hpn:f64[2,64,64] = add ea hpm
hpo:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
hpp:f64[2,64,64] = div hpn hpo
hpq:f64[2,64,64] = floor hpp
hpr:f64[2,64,64] = sub hpp hpq
hps:f64[2,64,64] = sub 1.0:f64[] hpr
hpt:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] db
hpu:f64[2,64,64] = sub hpp hpr
hpv:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] hpu
hpw:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hpv hpt
hpx:i64[2,64,64] = add hpw 1:i64[]
hpy:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hpx hpt
hpz:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hps
hqa:f64[64,64] = squeeze[dimensions=(0,)] hpz
hqb:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hps
hqc:f64[64,64] = squeeze[dimensions=(0,)] hqb
hqd:f64[64,64] = mul hqa hqc
hqe:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hpw
hqf:i64[64,64] = squeeze[dimensions=(0,)] hqe
hqg:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hpw
hqh:i64[64,64] = squeeze[dimensions=(0,)] hqg
hqi:bool[64,64] = lt hqf 0:i64[]
hqj:i64[64,64] = add hqf 64:i64[]
hqk:i64[64,64] = select_n hqi hqf hqj
hql:bool[64,64] = lt hqh 0:i64[]
hqm:i64[64,64] = add hqh 64:i64[]
hqn:i64[64,64] = select_n hql hqh hqm
hqo:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hqk
hqp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hqn
hqq:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hqo
hqr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hqp
hqs:i32[64,64,2] = concatenate[dimension=2] hqq hqr
hqt:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hpi hqs
hqu:f64[64,64] = mul hqd hqt
hqv:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hps
hqw:f64[64,64] = squeeze[dimensions=(0,)] hqv
hqx:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hpr
hqy:f64[64,64] = squeeze[dimensions=(0,)] hqx
hqz:f64[64,64] = mul hqw hqy
hra:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hpw
hrb:i64[64,64] = squeeze[dimensions=(0,)] hra
hrc:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hpy
hrd:i64[64,64] = squeeze[dimensions=(0,)] hrc
hre:bool[64,64] = lt hrb 0:i64[]
hrf:i64[64,64] = add hrb 64:i64[]
hrg:i64[64,64] = select_n hre hrb hrf
hrh:bool[64,64] = lt hrd 0:i64[]
hri:i64[64,64] = add hrd 64:i64[]
hrj:i64[64,64] = select_n hrh hrd hri
hrk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hrg
hrl:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hrj
hrm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hrk
hrn:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hrl
hro:i32[64,64,2] = concatenate[dimension=2] hrm hrn
hrp:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hpi hro
hrq:f64[64,64] = mul hqz hrp
hrr:f64[64,64] = add hqu hrq
hrs:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hpr
hrt:f64[64,64] = squeeze[dimensions=(0,)] hrs
hru:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hps
hrv:f64[64,64] = squeeze[dimensions=(0,)] hru
hrw:f64[64,64] = mul hrt hrv
hrx:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hpy
hry:i64[64,64] = squeeze[dimensions=(0,)] hrx
hrz:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hpw
hsa:i64[64,64] = squeeze[dimensions=(0,)] hrz
hsb:bool[64,64] = lt hry 0:i64[]
hsc:i64[64,64] = add hry 64:i64[]
hsd:i64[64,64] = select_n hsb hry hsc
hse:bool[64,64] = lt hsa 0:i64[]
hsf:i64[64,64] = add hsa 64:i64[]
hsg:i64[64,64] = select_n hse hsa hsf
hsh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hsd
hsi:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hsg
hsj:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hsh
hsk:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hsi
hsl:i32[64,64,2] = concatenate[dimension=2] hsj hsk
hsm:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hpi hsl
hsn:f64[64,64] = mul hrw hsm
hso:f64[64,64] = add hrr hsn
hsp:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hpr
hsq:f64[64,64] = squeeze[dimensions=(0,)] hsp
hsr:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hpr
hss:f64[64,64] = squeeze[dimensions=(0,)] hsr
hst:f64[64,64] = mul hsq hss
hsu:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hpy
hsv:i64[64,64] = squeeze[dimensions=(0,)] hsu
hsw:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hpy
hsx:i64[64,64] = squeeze[dimensions=(0,)] hsw
hsy:bool[64,64] = lt hsv 0:i64[]
hsz:i64[64,64] = add hsv 64:i64[]
hta:i64[64,64] = select_n hsy hsv hsz
htb:bool[64,64] = lt hsx 0:i64[]
htc:i64[64,64] = add hsx 64:i64[]
htd:i64[64,64] = select_n htb hsx htc
hte:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hta
htf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] htd
htg:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hte
hth:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] htf
hti:i32[64,64,2] = concatenate[dimension=2] htg hth
htj:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hpi hti
htk:f64[64,64] = mul hst htj
htl:f64[64,64] = add hso htk
htm:f64[64,64] = sub gyl htl
htn:f64[64,64] = div htm 2.0:f64[]
hto:f64[64,64] = add gyl htn
htp:f64[] = neg eb
htq:f64[] = convert_element_type[new_dtype=float64 weak_type=False] htp
htr:f64[2,64,64] = mul htq gyw
hts:f64[2,64,64] = add ea htr
htt:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
htu:f64[2,64,64] = div hts htt
htv:f64[2,64,64] = floor htu
htw:f64[2,64,64] = sub htu htv
htx:f64[2,64,64] = sub 1.0:f64[] htw
hty:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dc
htz:f64[2,64,64] = sub htu htw
hua:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] htz
hub:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hua hty
huc:i64[2,64,64] = add hub 1:i64[]
hud:i64[2,64,64] = jit[name=remainder jaxpr=remainder] huc hty
hue:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] htx
huf:f64[64,64] = squeeze[dimensions=(0,)] hue
hug:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] htx
huh:f64[64,64] = squeeze[dimensions=(0,)] hug
hui:f64[64,64] = mul huf huh
huj:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hub
huk:i64[64,64] = squeeze[dimensions=(0,)] huj
hul:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hub
hum:i64[64,64] = squeeze[dimensions=(0,)] hul
hun:bool[64,64] = lt huk 0:i64[]
huo:i64[64,64] = add huk 64:i64[]
hup:i64[64,64] = select_n hun huk huo
huq:bool[64,64] = lt hum 0:i64[]
hur:i64[64,64] = add hum 64:i64[]
hus:i64[64,64] = select_n huq hum hur
hut:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hup
huu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hus
huv:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hut
huw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] huu
hux:i32[64,64,2] = concatenate[dimension=2] huv huw
huy:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hto hux
huz:f64[64,64] = mul hui huy
hva:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] htx
hvb:f64[64,64] = squeeze[dimensions=(0,)] hva
hvc:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] htw
hvd:f64[64,64] = squeeze[dimensions=(0,)] hvc
hve:f64[64,64] = mul hvb hvd
hvf:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hub
hvg:i64[64,64] = squeeze[dimensions=(0,)] hvf
hvh:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hud
hvi:i64[64,64] = squeeze[dimensions=(0,)] hvh
hvj:bool[64,64] = lt hvg 0:i64[]
hvk:i64[64,64] = add hvg 64:i64[]
hvl:i64[64,64] = select_n hvj hvg hvk
hvm:bool[64,64] = lt hvi 0:i64[]
hvn:i64[64,64] = add hvi 64:i64[]
hvo:i64[64,64] = select_n hvm hvi hvn
hvp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hvl
hvq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hvo
hvr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hvp
hvs:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hvq
hvt:i32[64,64,2] = concatenate[dimension=2] hvr hvs
hvu:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hto hvt
hvv:f64[64,64] = mul hve hvu
hvw:f64[64,64] = add huz hvv
hvx:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] htw
hvy:f64[64,64] = squeeze[dimensions=(0,)] hvx
hvz:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] htx
hwa:f64[64,64] = squeeze[dimensions=(0,)] hvz
hwb:f64[64,64] = mul hvy hwa
hwc:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hud
hwd:i64[64,64] = squeeze[dimensions=(0,)] hwc
hwe:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hub
hwf:i64[64,64] = squeeze[dimensions=(0,)] hwe
hwg:bool[64,64] = lt hwd 0:i64[]
hwh:i64[64,64] = add hwd 64:i64[]
hwi:i64[64,64] = select_n hwg hwd hwh
hwj:bool[64,64] = lt hwf 0:i64[]
hwk:i64[64,64] = add hwf 64:i64[]
hwl:i64[64,64] = select_n hwj hwf hwk
hwm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hwi
hwn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hwl
hwo:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hwm
hwp:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hwn
hwq:i32[64,64,2] = concatenate[dimension=2] hwo hwp
hwr:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hto hwq
hws:f64[64,64] = mul hwb hwr
hwt:f64[64,64] = add hvw hws
hwu:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] htw
hwv:f64[64,64] = squeeze[dimensions=(0,)] hwu
hww:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] htw
hwx:f64[64,64] = squeeze[dimensions=(0,)] hww
hwy:f64[64,64] = mul hwv hwx
hwz:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hud
hxa:i64[64,64] = squeeze[dimensions=(0,)] hwz
hxb:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hud
hxc:i64[64,64] = squeeze[dimensions=(0,)] hxb
hxd:bool[64,64] = lt hxa 0:i64[]
hxe:i64[64,64] = add hxa 64:i64[]
hxf:i64[64,64] = select_n hxd hxa hxe
hxg:bool[64,64] = lt hxc 0:i64[]
hxh:i64[64,64] = add hxc 64:i64[]
hxi:i64[64,64] = select_n hxg hxc hxh
hxj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hxf
hxk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hxi
hxl:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hxj
hxm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hxk
hxn:i32[64,64,2] = concatenate[dimension=2] hxl hxm
hxo:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hto hxn
hxp:f64[64,64] = mul hwy hxo
hxq:f64[64,64] = add hwt hxp
hxr:c128[64,33] = jit[name=fft jaxpr=fft] hlg
hxs:c128[] = reduce_prod[axes=(0,)] dd
hxt:c128[] = sqrt hxs
hxu:c128[] = div (1+0j):c128[] hxt
hxv:c128[64,33] = mul hxr hxu
hxw:c128[1,64,33] = broadcast_in_dim[
broadcast_dimensions=(1, 2)
shape=(1, 64, 33)
sharding=None
] hxv
hxx:c128[2,64,33] = mul dw hxw
hxy:f64[2,64,64] = jit[name=fft jaxpr=fft1] hxx
hxz:f64[] = reduce_prod[axes=(0,)] de
hya:f64[] = sqrt hxz
hyb:f64[2,64,64] = mul hxy hya
hyc:f64[] = neg eb
hyd:f64[] = convert_element_type[new_dtype=float64 weak_type=False] hyc
hye:f64[2,64,64] = mul hyd hyb
hyf:f64[2,64,64] = add ea hye
hyg:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
hyh:f64[2,64,64] = div hyf hyg
hyi:f64[2,64,64] = floor hyh
hyj:f64[2,64,64] = sub hyh hyi
hyk:f64[2,64,64] = sub 1.0:f64[] hyj
hyl:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] df
hym:f64[2,64,64] = sub hyh hyj
hyn:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] hym
hyo:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hyn hyl
hyp:i64[2,64,64] = add hyo 1:i64[]
hyq:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hyp hyl
hyr:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hyk
hys:f64[64,64] = squeeze[dimensions=(0,)] hyr
hyt:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hyk
hyu:f64[64,64] = squeeze[dimensions=(0,)] hyt
hyv:f64[64,64] = mul hys hyu
hyw:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hyo
hyx:i64[64,64] = squeeze[dimensions=(0,)] hyw
hyy:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hyo
hyz:i64[64,64] = squeeze[dimensions=(0,)] hyy
hza:bool[64,64] = lt hyx 0:i64[]
hzb:i64[64,64] = add hyx 64:i64[]
hzc:i64[64,64] = select_n hza hyx hzb
hzd:bool[64,64] = lt hyz 0:i64[]
hze:i64[64,64] = add hyz 64:i64[]
hzf:i64[64,64] = select_n hzd hyz hze
hzg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hzc
hzh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hzf
hzi:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hzg
hzj:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] hzh
hzk:i32[64,64,2] = concatenate[dimension=2] hzi hzj
hzl:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hlg hzk
hzm:f64[64,64] = mul hyv hzl
hzn:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hyk
hzo:f64[64,64] = squeeze[dimensions=(0,)] hzn
hzp:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hyj
hzq:f64[64,64] = squeeze[dimensions=(0,)] hzp
hzr:f64[64,64] = mul hzo hzq
hzs:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hyo
hzt:i64[64,64] = squeeze[dimensions=(0,)] hzs
hzu:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hyq
hzv:i64[64,64] = squeeze[dimensions=(0,)] hzu
hzw:bool[64,64] = lt hzt 0:i64[]
hzx:i64[64,64] = add hzt 64:i64[]
hzy:i64[64,64] = select_n hzw hzt hzx
hzz:bool[64,64] = lt hzv 0:i64[]
iaa:i64[64,64] = add hzv 64:i64[]
iab:i64[64,64] = select_n hzz hzv iaa
iac:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hzy
iad:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iab
iae:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] iac
iaf:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] iad
iag:i32[64,64,2] = concatenate[dimension=2] iae iaf
iah:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hlg iag
iai:f64[64,64] = mul hzr iah
iaj:f64[64,64] = add hzm iai
iak:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hyj
ial:f64[64,64] = squeeze[dimensions=(0,)] iak
iam:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hyk
ian:f64[64,64] = squeeze[dimensions=(0,)] iam
iao:f64[64,64] = mul ial ian
iap:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hyq
iaq:i64[64,64] = squeeze[dimensions=(0,)] iap
iar:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hyo
ias:i64[64,64] = squeeze[dimensions=(0,)] iar
iat:bool[64,64] = lt iaq 0:i64[]
iau:i64[64,64] = add iaq 64:i64[]
iav:i64[64,64] = select_n iat iaq iau
iaw:bool[64,64] = lt ias 0:i64[]
iax:i64[64,64] = add ias 64:i64[]
iay:i64[64,64] = select_n iaw ias iax
iaz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iav
iba:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iay
ibb:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] iaz
ibc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] iba
ibd:i32[64,64,2] = concatenate[dimension=2] ibb ibc
ibe:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hlg ibd
ibf:f64[64,64] = mul iao ibe
ibg:f64[64,64] = add iaj ibf
ibh:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hyj
ibi:f64[64,64] = squeeze[dimensions=(0,)] ibh
ibj:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hyj
ibk:f64[64,64] = squeeze[dimensions=(0,)] ibj
ibl:f64[64,64] = mul ibi ibk
ibm:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] hyq
ibn:i64[64,64] = squeeze[dimensions=(0,)] ibm
ibo:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] hyq
ibp:i64[64,64] = squeeze[dimensions=(0,)] ibo
ibq:bool[64,64] = lt ibn 0:i64[]
ibr:i64[64,64] = add ibn 64:i64[]
ibs:i64[64,64] = select_n ibq ibn ibr
ibt:bool[64,64] = lt ibp 0:i64[]
ibu:i64[64,64] = add ibp 64:i64[]
ibv:i64[64,64] = select_n ibt ibp ibu
ibw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ibs
ibx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ibv
iby:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ibw
ibz:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ibx
ica:i32[64,64,2] = concatenate[dimension=2] iby ibz
icb:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hlg ica
icc:f64[64,64] = mul ibl icb
icd:f64[64,64] = add ibg icc
ice:f64[2,64,64] = neg hyb
icf:f64[] = neg eb
icg:f64[] = convert_element_type[new_dtype=float64 weak_type=False] icf
ich:f64[2,64,64] = mul icg ice
ici:f64[2,64,64] = add ea ich
icj:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
ick:f64[2,64,64] = div ici icj
icl:f64[2,64,64] = floor ick
icm:f64[2,64,64] = sub ick icl
icn:f64[2,64,64] = sub 1.0:f64[] icm
ico:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dg
icp:f64[2,64,64] = sub ick icm
icq:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] icp
icr:i64[2,64,64] = jit[name=remainder jaxpr=remainder] icq ico
ics:i64[2,64,64] = add icr 1:i64[]
ict:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ics ico
icu:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] icn
icv:f64[64,64] = squeeze[dimensions=(0,)] icu
icw:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] icn
icx:f64[64,64] = squeeze[dimensions=(0,)] icw
icy:f64[64,64] = mul icv icx
icz:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] icr
ida:i64[64,64] = squeeze[dimensions=(0,)] icz
idb:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] icr
idc:i64[64,64] = squeeze[dimensions=(0,)] idb
idd:bool[64,64] = lt ida 0:i64[]
ide:i64[64,64] = add ida 64:i64[]
idf:i64[64,64] = select_n idd ida ide
idg:bool[64,64] = lt idc 0:i64[]
idh:i64[64,64] = add idc 64:i64[]
idi:i64[64,64] = select_n idg idc idh
idj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] idf
idk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] idi
idl:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] idj
idm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] idk
idn:i32[64,64,2] = concatenate[dimension=2] idl idm
ido:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] icd idn
idp:f64[64,64] = mul icy ido
idq:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] icn
idr:f64[64,64] = squeeze[dimensions=(0,)] idq
ids:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] icm
idt:f64[64,64] = squeeze[dimensions=(0,)] ids
idu:f64[64,64] = mul idr idt
idv:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] icr
idw:i64[64,64] = squeeze[dimensions=(0,)] idv
idx:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ict
idy:i64[64,64] = squeeze[dimensions=(0,)] idx
idz:bool[64,64] = lt idw 0:i64[]
iea:i64[64,64] = add idw 64:i64[]
ieb:i64[64,64] = select_n idz idw iea
iec:bool[64,64] = lt idy 0:i64[]
ied:i64[64,64] = add idy 64:i64[]
iee:i64[64,64] = select_n iec idy ied
ief:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ieb
ieg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iee
ieh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ief
iei:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ieg
iej:i32[64,64,2] = concatenate[dimension=2] ieh iei
iek:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] icd iej
iel:f64[64,64] = mul idu iek
iem:f64[64,64] = add idp iel
ien:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] icm
ieo:f64[64,64] = squeeze[dimensions=(0,)] ien
iep:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] icn
ieq:f64[64,64] = squeeze[dimensions=(0,)] iep
ier:f64[64,64] = mul ieo ieq
ies:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ict
iet:i64[64,64] = squeeze[dimensions=(0,)] ies
ieu:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] icr
iev:i64[64,64] = squeeze[dimensions=(0,)] ieu
iew:bool[64,64] = lt iet 0:i64[]
iex:i64[64,64] = add iet 64:i64[]
iey:i64[64,64] = select_n iew iet iex
iez:bool[64,64] = lt iev 0:i64[]
ifa:i64[64,64] = add iev 64:i64[]
ifb:i64[64,64] = select_n iez iev ifa
ifc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iey
ifd:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ifb
ife:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ifc
iff:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ifd
ifg:i32[64,64,2] = concatenate[dimension=2] ife iff
ifh:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] icd ifg
ifi:f64[64,64] = mul ier ifh
ifj:f64[64,64] = add iem ifi
ifk:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] icm
ifl:f64[64,64] = squeeze[dimensions=(0,)] ifk
ifm:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] icm
ifn:f64[64,64] = squeeze[dimensions=(0,)] ifm
ifo:f64[64,64] = mul ifl ifn
ifp:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ict
ifq:i64[64,64] = squeeze[dimensions=(0,)] ifp
ifr:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ict
ifs:i64[64,64] = squeeze[dimensions=(0,)] ifr
ift:bool[64,64] = lt ifq 0:i64[]
ifu:i64[64,64] = add ifq 64:i64[]
ifv:i64[64,64] = select_n ift ifq ifu
ifw:bool[64,64] = lt ifs 0:i64[]
ifx:i64[64,64] = add ifs 64:i64[]
ify:i64[64,64] = select_n ifw ifs ifx
ifz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ifv
iga:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ify
igb:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ifz
igc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] iga
igd:i32[64,64,2] = concatenate[dimension=2] igb igc
ige:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] icd igd
igf:f64[64,64] = mul ifo ige
igg:f64[64,64] = add ifj igf
igh:f64[64,64] = sub hlg igg
igi:f64[64,64] = div igh 2.0:f64[]
igj:f64[64,64] = add hlg igi
igk:f64[] = neg eb
igl:f64[] = convert_element_type[new_dtype=float64 weak_type=False] igk
igm:f64[2,64,64] = mul igl hyb
ign:f64[2,64,64] = add ea igm
igo:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
igp:f64[2,64,64] = div ign igo
igq:f64[2,64,64] = floor igp
igr:f64[2,64,64] = sub igp igq
igs:f64[2,64,64] = sub 1.0:f64[] igr
igt:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dh
igu:f64[2,64,64] = sub igp igr
igv:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] igu
igw:i64[2,64,64] = jit[name=remainder jaxpr=remainder] igv igt
igx:i64[2,64,64] = add igw 1:i64[]
igy:i64[2,64,64] = jit[name=remainder jaxpr=remainder] igx igt
igz:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] igs
iha:f64[64,64] = squeeze[dimensions=(0,)] igz
ihb:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] igs
ihc:f64[64,64] = squeeze[dimensions=(0,)] ihb
ihd:f64[64,64] = mul iha ihc
ihe:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] igw
ihf:i64[64,64] = squeeze[dimensions=(0,)] ihe
ihg:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] igw
ihh:i64[64,64] = squeeze[dimensions=(0,)] ihg
ihi:bool[64,64] = lt ihf 0:i64[]
ihj:i64[64,64] = add ihf 64:i64[]
ihk:i64[64,64] = select_n ihi ihf ihj
ihl:bool[64,64] = lt ihh 0:i64[]
ihm:i64[64,64] = add ihh 64:i64[]
ihn:i64[64,64] = select_n ihl ihh ihm
iho:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ihk
ihp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ihn
ihq:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] iho
ihr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ihp
ihs:i32[64,64,2] = concatenate[dimension=2] ihq ihr
iht:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] igj ihs
ihu:f64[64,64] = mul ihd iht
ihv:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] igs
ihw:f64[64,64] = squeeze[dimensions=(0,)] ihv
ihx:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] igr
ihy:f64[64,64] = squeeze[dimensions=(0,)] ihx
ihz:f64[64,64] = mul ihw ihy
iia:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] igw
iib:i64[64,64] = squeeze[dimensions=(0,)] iia
iic:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] igy
iid:i64[64,64] = squeeze[dimensions=(0,)] iic
iie:bool[64,64] = lt iib 0:i64[]
iif:i64[64,64] = add iib 64:i64[]
iig:i64[64,64] = select_n iie iib iif
iih:bool[64,64] = lt iid 0:i64[]
iii:i64[64,64] = add iid 64:i64[]
iij:i64[64,64] = select_n iih iid iii
iik:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iig
iil:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iij
iim:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] iik
iin:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] iil
iio:i32[64,64,2] = concatenate[dimension=2] iim iin
iip:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] igj iio
iiq:f64[64,64] = mul ihz iip
iir:f64[64,64] = add ihu iiq
iis:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] igr
iit:f64[64,64] = squeeze[dimensions=(0,)] iis
iiu:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] igs
iiv:f64[64,64] = squeeze[dimensions=(0,)] iiu
iiw:f64[64,64] = mul iit iiv
iix:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] igy
iiy:i64[64,64] = squeeze[dimensions=(0,)] iix
iiz:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] igw
ija:i64[64,64] = squeeze[dimensions=(0,)] iiz
ijb:bool[64,64] = lt iiy 0:i64[]
ijc:i64[64,64] = add iiy 64:i64[]
ijd:i64[64,64] = select_n ijb iiy ijc
ije:bool[64,64] = lt ija 0:i64[]
ijf:i64[64,64] = add ija 64:i64[]
ijg:i64[64,64] = select_n ije ija ijf
ijh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ijd
iji:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ijg
ijj:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ijh
ijk:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] iji
ijl:i32[64,64,2] = concatenate[dimension=2] ijj ijk
ijm:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] igj ijl
ijn:f64[64,64] = mul iiw ijm
ijo:f64[64,64] = add iir ijn
ijp:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] igr
ijq:f64[64,64] = squeeze[dimensions=(0,)] ijp
ijr:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] igr
ijs:f64[64,64] = squeeze[dimensions=(0,)] ijr
ijt:f64[64,64] = mul ijq ijs
iju:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] igy
ijv:i64[64,64] = squeeze[dimensions=(0,)] iju
ijw:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] igy
ijx:i64[64,64] = squeeze[dimensions=(0,)] ijw
ijy:bool[64,64] = lt ijv 0:i64[]
ijz:i64[64,64] = add ijv 64:i64[]
ika:i64[64,64] = select_n ijy ijv ijz
ikb:bool[64,64] = lt ijx 0:i64[]
ikc:i64[64,64] = add ijx 64:i64[]
ikd:i64[64,64] = select_n ikb ijx ikc
ike:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ika
ikf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ikd
ikg:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ike
ikh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ikf
iki:i32[64,64,2] = concatenate[dimension=2] ikg ikh
ikj:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] igj iki
ikk:f64[64,64] = mul ijt ikj
ikl:f64[64,64] = add ijo ikk
ikm:f64[] = neg eb
ikn:f64[] = convert_element_type[new_dtype=float64 weak_type=False] ikm
iko:f64[2,64,64] = mul ikn hyb
ikp:f64[2,64,64] = add ea iko
ikq:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
ikr:f64[2,64,64] = div ikp ikq
iks:f64[2,64,64] = floor ikr
ikt:f64[2,64,64] = sub ikr iks
iku:f64[2,64,64] = sub 1.0:f64[] ikt
ikv:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] di
ikw:f64[2,64,64] = sub ikr ikt
ikx:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ikw
iky:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ikx ikv
ikz:i64[2,64,64] = add iky 1:i64[]
ila:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ikz ikv
ilb:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] iku
ilc:f64[64,64] = squeeze[dimensions=(0,)] ilb
ild:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] iku
ile:f64[64,64] = squeeze[dimensions=(0,)] ild
ilf:f64[64,64] = mul ilc ile
ilg:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] iky
ilh:i64[64,64] = squeeze[dimensions=(0,)] ilg
ili:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] iky
ilj:i64[64,64] = squeeze[dimensions=(0,)] ili
ilk:bool[64,64] = lt ilh 0:i64[]
ill:i64[64,64] = add ilh 64:i64[]
ilm:i64[64,64] = select_n ilk ilh ill
iln:bool[64,64] = lt ilj 0:i64[]
ilo:i64[64,64] = add ilj 64:i64[]
ilp:i64[64,64] = select_n iln ilj ilo
ilq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ilm
ilr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ilp
ils:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ilq
ilt:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ilr
ilu:i32[64,64,2] = concatenate[dimension=2] ils ilt
ilv:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hxq ilu
ilw:f64[64,64] = mul ilf ilv
ilx:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] iku
ily:f64[64,64] = squeeze[dimensions=(0,)] ilx
ilz:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ikt
ima:f64[64,64] = squeeze[dimensions=(0,)] ilz
imb:f64[64,64] = mul ily ima
imc:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] iky
imd:i64[64,64] = squeeze[dimensions=(0,)] imc
ime:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ila
imf:i64[64,64] = squeeze[dimensions=(0,)] ime
img:bool[64,64] = lt imd 0:i64[]
imh:i64[64,64] = add imd 64:i64[]
imi:i64[64,64] = select_n img imd imh
imj:bool[64,64] = lt imf 0:i64[]
imk:i64[64,64] = add imf 64:i64[]
iml:i64[64,64] = select_n imj imf imk
imm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] imi
imn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iml
imo:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] imm
imp:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] imn
imq:i32[64,64,2] = concatenate[dimension=2] imo imp
imr:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hxq imq
ims:f64[64,64] = mul imb imr
imt:f64[64,64] = add ilw ims
imu:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ikt
imv:f64[64,64] = squeeze[dimensions=(0,)] imu
imw:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] iku
imx:f64[64,64] = squeeze[dimensions=(0,)] imw
imy:f64[64,64] = mul imv imx
imz:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ila
ina:i64[64,64] = squeeze[dimensions=(0,)] imz
inb:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] iky
inc:i64[64,64] = squeeze[dimensions=(0,)] inb
ind:bool[64,64] = lt ina 0:i64[]
ine:i64[64,64] = add ina 64:i64[]
inf:i64[64,64] = select_n ind ina ine
ing:bool[64,64] = lt inc 0:i64[]
inh:i64[64,64] = add inc 64:i64[]
ini:i64[64,64] = select_n ing inc inh
inj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] inf
ink:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ini
inl:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] inj
inm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ink
inn:i32[64,64,2] = concatenate[dimension=2] inl inm
ino:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hxq inn
inp:f64[64,64] = mul imy ino
inq:f64[64,64] = add imt inp
inr:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ikt
ins:f64[64,64] = squeeze[dimensions=(0,)] inr
int:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ikt
inu:f64[64,64] = squeeze[dimensions=(0,)] int
inv:f64[64,64] = mul ins inu
inw:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ila
inx:i64[64,64] = squeeze[dimensions=(0,)] inw
iny:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ila
inz:i64[64,64] = squeeze[dimensions=(0,)] iny
ioa:bool[64,64] = lt inx 0:i64[]
iob:i64[64,64] = add inx 64:i64[]
ioc:i64[64,64] = select_n ioa inx iob
iod:bool[64,64] = lt inz 0:i64[]
ioe:i64[64,64] = add inz 64:i64[]
iof:i64[64,64] = select_n iod inz ioe
iog:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ioc
ioh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iof
ioi:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] iog
ioj:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ioh
iok:i32[64,64,2] = concatenate[dimension=2] ioi ioj
iol:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] hxq iok
iom:f64[64,64] = mul inv iol
ion:f64[64,64] = add inq iom
ioo:f64[2,64,64] = neg hyb
iop:f64[] = neg eb
ioq:f64[] = convert_element_type[new_dtype=float64 weak_type=False] iop
ior:f64[2,64,64] = mul ioq ioo
ios:f64[2,64,64] = add ea ior
iot:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
iou:f64[2,64,64] = div ios iot
iov:f64[2,64,64] = floor iou
iow:f64[2,64,64] = sub iou iov
iox:f64[2,64,64] = sub 1.0:f64[] iow
ioy:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dj
ioz:f64[2,64,64] = sub iou iow
ipa:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ioz
ipb:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ipa ioy
ipc:i64[2,64,64] = add ipb 1:i64[]
ipd:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ipc ioy
ipe:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] iox
ipf:f64[64,64] = squeeze[dimensions=(0,)] ipe
ipg:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] iox
iph:f64[64,64] = squeeze[dimensions=(0,)] ipg
ipi:f64[64,64] = mul ipf iph
ipj:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ipb
ipk:i64[64,64] = squeeze[dimensions=(0,)] ipj
ipl:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ipb
ipm:i64[64,64] = squeeze[dimensions=(0,)] ipl
ipn:bool[64,64] = lt ipk 0:i64[]
ipo:i64[64,64] = add ipk 64:i64[]
ipp:i64[64,64] = select_n ipn ipk ipo
ipq:bool[64,64] = lt ipm 0:i64[]
ipr:i64[64,64] = add ipm 64:i64[]
ips:i64[64,64] = select_n ipq ipm ipr
ipt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ipp
ipu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ips
ipv:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ipt
ipw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ipu
ipx:i32[64,64,2] = concatenate[dimension=2] ipv ipw
ipy:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ion ipx
ipz:f64[64,64] = mul ipi ipy
iqa:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] iox
iqb:f64[64,64] = squeeze[dimensions=(0,)] iqa
iqc:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] iow
iqd:f64[64,64] = squeeze[dimensions=(0,)] iqc
iqe:f64[64,64] = mul iqb iqd
iqf:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ipb
iqg:i64[64,64] = squeeze[dimensions=(0,)] iqf
iqh:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ipd
iqi:i64[64,64] = squeeze[dimensions=(0,)] iqh
iqj:bool[64,64] = lt iqg 0:i64[]
iqk:i64[64,64] = add iqg 64:i64[]
iql:i64[64,64] = select_n iqj iqg iqk
iqm:bool[64,64] = lt iqi 0:i64[]
iqn:i64[64,64] = add iqi 64:i64[]
iqo:i64[64,64] = select_n iqm iqi iqn
iqp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iql
iqq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iqo
iqr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] iqp
iqs:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] iqq
iqt:i32[64,64,2] = concatenate[dimension=2] iqr iqs
iqu:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ion iqt
iqv:f64[64,64] = mul iqe iqu
iqw:f64[64,64] = add ipz iqv
iqx:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] iow
iqy:f64[64,64] = squeeze[dimensions=(0,)] iqx
iqz:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] iox
ira:f64[64,64] = squeeze[dimensions=(0,)] iqz
irb:f64[64,64] = mul iqy ira
irc:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ipd
ird:i64[64,64] = squeeze[dimensions=(0,)] irc
ire:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ipb
irf:i64[64,64] = squeeze[dimensions=(0,)] ire
irg:bool[64,64] = lt ird 0:i64[]
irh:i64[64,64] = add ird 64:i64[]
iri:i64[64,64] = select_n irg ird irh
irj:bool[64,64] = lt irf 0:i64[]
irk:i64[64,64] = add irf 64:i64[]
irl:i64[64,64] = select_n irj irf irk
irm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iri
irn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] irl
iro:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] irm
irp:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] irn
irq:i32[64,64,2] = concatenate[dimension=2] iro irp
irr:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ion irq
irs:f64[64,64] = mul irb irr
irt:f64[64,64] = add iqw irs
iru:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] iow
irv:f64[64,64] = squeeze[dimensions=(0,)] iru
irw:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] iow
irx:f64[64,64] = squeeze[dimensions=(0,)] irw
iry:f64[64,64] = mul irv irx
irz:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ipd
isa:i64[64,64] = squeeze[dimensions=(0,)] irz
isb:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ipd
isc:i64[64,64] = squeeze[dimensions=(0,)] isb
isd:bool[64,64] = lt isa 0:i64[]
ise:i64[64,64] = add isa 64:i64[]
isf:i64[64,64] = select_n isd isa ise
isg:bool[64,64] = lt isc 0:i64[]
ish:i64[64,64] = add isc 64:i64[]
isi:i64[64,64] = select_n isg isc ish
isj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] isf
isk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] isi
isl:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] isj
ism:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] isk
isn:i32[64,64,2] = concatenate[dimension=2] isl ism
iso:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ion isn
isp:f64[64,64] = mul iry iso
isq:f64[64,64] = add irt isp
isr:f64[64,64] = sub hxq isq
iss:f64[64,64] = div isr 2.0:f64[]
ist:f64[64,64] = add hxq iss
isu:f64[] = neg eb
isv:f64[] = convert_element_type[new_dtype=float64 weak_type=False] isu
isw:f64[2,64,64] = mul isv hyb
isx:f64[2,64,64] = add ea isw
isy:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
isz:f64[2,64,64] = div isx isy
ita:f64[2,64,64] = floor isz
itb:f64[2,64,64] = sub isz ita
itc:f64[2,64,64] = sub 1.0:f64[] itb
itd:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dk
ite:f64[2,64,64] = sub isz itb
itf:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ite
itg:i64[2,64,64] = jit[name=remainder jaxpr=remainder] itf itd
ith:i64[2,64,64] = add itg 1:i64[]
iti:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ith itd
itj:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] itc
itk:f64[64,64] = squeeze[dimensions=(0,)] itj
itl:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] itc
itm:f64[64,64] = squeeze[dimensions=(0,)] itl
itn:f64[64,64] = mul itk itm
ito:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] itg
itp:i64[64,64] = squeeze[dimensions=(0,)] ito
itq:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] itg
itr:i64[64,64] = squeeze[dimensions=(0,)] itq
its:bool[64,64] = lt itp 0:i64[]
itt:i64[64,64] = add itp 64:i64[]
itu:i64[64,64] = select_n its itp itt
itv:bool[64,64] = lt itr 0:i64[]
itw:i64[64,64] = add itr 64:i64[]
itx:i64[64,64] = select_n itv itr itw
ity:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] itu
itz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] itx
iua:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ity
iub:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] itz
iuc:i32[64,64,2] = concatenate[dimension=2] iua iub
iud:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ist iuc
iue:f64[64,64] = mul itn iud
iuf:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] itc
iug:f64[64,64] = squeeze[dimensions=(0,)] iuf
iuh:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] itb
iui:f64[64,64] = squeeze[dimensions=(0,)] iuh
iuj:f64[64,64] = mul iug iui
iuk:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] itg
iul:i64[64,64] = squeeze[dimensions=(0,)] iuk
ium:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] iti
iun:i64[64,64] = squeeze[dimensions=(0,)] ium
iuo:bool[64,64] = lt iul 0:i64[]
iup:i64[64,64] = add iul 64:i64[]
iuq:i64[64,64] = select_n iuo iul iup
iur:bool[64,64] = lt iun 0:i64[]
ius:i64[64,64] = add iun 64:i64[]
iut:i64[64,64] = select_n iur iun ius
iuu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iuq
iuv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iut
iuw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] iuu
iux:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] iuv
iuy:i32[64,64,2] = concatenate[dimension=2] iuw iux
iuz:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ist iuy
iva:f64[64,64] = mul iuj iuz
ivb:f64[64,64] = add iue iva
ivc:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] itb
ivd:f64[64,64] = squeeze[dimensions=(0,)] ivc
ive:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] itc
ivf:f64[64,64] = squeeze[dimensions=(0,)] ive
ivg:f64[64,64] = mul ivd ivf
ivh:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] iti
ivi:i64[64,64] = squeeze[dimensions=(0,)] ivh
ivj:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] itg
ivk:i64[64,64] = squeeze[dimensions=(0,)] ivj
ivl:bool[64,64] = lt ivi 0:i64[]
ivm:i64[64,64] = add ivi 64:i64[]
ivn:i64[64,64] = select_n ivl ivi ivm
ivo:bool[64,64] = lt ivk 0:i64[]
ivp:i64[64,64] = add ivk 64:i64[]
ivq:i64[64,64] = select_n ivo ivk ivp
ivr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ivn
ivs:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ivq
ivt:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ivr
ivu:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] ivs
ivv:i32[64,64,2] = concatenate[dimension=2] ivt ivu
ivw:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ist ivv
ivx:f64[64,64] = mul ivg ivw
ivy:f64[64,64] = add ivb ivx
ivz:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] itb
iwa:f64[64,64] = squeeze[dimensions=(0,)] ivz
iwb:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] itb
iwc:f64[64,64] = squeeze[dimensions=(0,)] iwb
iwd:f64[64,64] = mul iwa iwc
iwe:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] iti
iwf:i64[64,64] = squeeze[dimensions=(0,)] iwe
iwg:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] iti
iwh:i64[64,64] = squeeze[dimensions=(0,)] iwg
iwi:bool[64,64] = lt iwf 0:i64[]
iwj:i64[64,64] = add iwf 64:i64[]
iwk:i64[64,64] = select_n iwi iwf iwj
iwl:bool[64,64] = lt iwh 0:i64[]
iwm:i64[64,64] = add iwh 64:i64[]
iwn:i64[64,64] = select_n iwl iwh iwm
iwo:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iwk
iwp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iwn
iwq:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] iwo
iwr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] iwp
iws:i32[64,64,2] = concatenate[dimension=2] iwq iwr
iwt:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ist iws
iwu:f64[64,64] = mul iwd iwt
iwv:f64[64,64] = add ivy iwu
iww:c128[64,33] = jit[name=fft jaxpr=fft] ikl
iwx:c128[] = reduce_prod[axes=(0,)] dl
iwy:c128[] = sqrt iwx
iwz:c128[] = div (1+0j):c128[] iwy
ixa:c128[64,33] = mul iww iwz
ixb:c128[1,64,33] = broadcast_in_dim[
broadcast_dimensions=(1, 2)
shape=(1, 64, 33)
sharding=None
] ixa
ixc:c128[2,64,33] = mul dw ixb
ixd:f64[2,64,64] = jit[name=fft jaxpr=fft1] ixc
ixe:f64[] = reduce_prod[axes=(0,)] dm
ixf:f64[] = sqrt ixe
ixg:f64[2,64,64] = mul ixd ixf
ixh:f64[] = neg eb
ixi:f64[] = convert_element_type[new_dtype=float64 weak_type=False] ixh
ixj:f64[2,64,64] = mul ixi ixg
ixk:f64[2,64,64] = add ea ixj
ixl:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
ixm:f64[2,64,64] = div ixk ixl
ixn:f64[2,64,64] = floor ixm
ixo:f64[2,64,64] = sub ixm ixn
ixp:f64[2,64,64] = sub 1.0:f64[] ixo
ixq:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dn
ixr:f64[2,64,64] = sub ixm ixo
ixs:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ixr
ixt:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ixs ixq
ixu:i64[2,64,64] = add ixt 1:i64[]
ixv:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ixu ixq
ixw:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ixp
ixx:f64[64,64] = squeeze[dimensions=(0,)] ixw
ixy:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ixp
ixz:f64[64,64] = squeeze[dimensions=(0,)] ixy
iya:f64[64,64] = mul ixx ixz
iyb:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ixt
iyc:i64[64,64] = squeeze[dimensions=(0,)] iyb
iyd:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ixt
iye:i64[64,64] = squeeze[dimensions=(0,)] iyd
iyf:bool[64,64] = lt iyc 0:i64[]
iyg:i64[64,64] = add iyc 64:i64[]
iyh:i64[64,64] = select_n iyf iyc iyg
iyi:bool[64,64] = lt iye 0:i64[]
iyj:i64[64,64] = add iye 64:i64[]
iyk:i64[64,64] = select_n iyi iye iyj
iyl:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iyh
iym:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iyk
iyn:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] iyl
iyo:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] iym
iyp:i32[64,64,2] = concatenate[dimension=2] iyn iyo
iyq:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ikl iyp
iyr:f64[64,64] = mul iya iyq
iys:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ixp
iyt:f64[64,64] = squeeze[dimensions=(0,)] iys
iyu:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ixo
iyv:f64[64,64] = squeeze[dimensions=(0,)] iyu
iyw:f64[64,64] = mul iyt iyv
iyx:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ixt
iyy:i64[64,64] = squeeze[dimensions=(0,)] iyx
iyz:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ixv
iza:i64[64,64] = squeeze[dimensions=(0,)] iyz
izb:bool[64,64] = lt iyy 0:i64[]
izc:i64[64,64] = add iyy 64:i64[]
izd:i64[64,64] = select_n izb iyy izc
ize:bool[64,64] = lt iza 0:i64[]
izf:i64[64,64] = add iza 64:i64[]
izg:i64[64,64] = select_n ize iza izf
izh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] izd
izi:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] izg
izj:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] izh
izk:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] izi
izl:i32[64,64,2] = concatenate[dimension=2] izj izk
izm:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ikl izl
izn:f64[64,64] = mul iyw izm
izo:f64[64,64] = add iyr izn
izp:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ixo
izq:f64[64,64] = squeeze[dimensions=(0,)] izp
izr:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ixp
izs:f64[64,64] = squeeze[dimensions=(0,)] izr
izt:f64[64,64] = mul izq izs
izu:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ixv
izv:i64[64,64] = squeeze[dimensions=(0,)] izu
izw:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ixt
izx:i64[64,64] = squeeze[dimensions=(0,)] izw
izy:bool[64,64] = lt izv 0:i64[]
izz:i64[64,64] = add izv 64:i64[]
jaa:i64[64,64] = select_n izy izv izz
jab:bool[64,64] = lt izx 0:i64[]
jac:i64[64,64] = add izx 64:i64[]
jad:i64[64,64] = select_n jab izx jac
jae:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jaa
jaf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jad
jag:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jae
jah:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jaf
jai:i32[64,64,2] = concatenate[dimension=2] jag jah
jaj:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ikl jai
jak:f64[64,64] = mul izt jaj
jal:f64[64,64] = add izo jak
jam:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ixo
jan:f64[64,64] = squeeze[dimensions=(0,)] jam
jao:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ixo
jap:f64[64,64] = squeeze[dimensions=(0,)] jao
jaq:f64[64,64] = mul jan jap
jar:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] ixv
jas:i64[64,64] = squeeze[dimensions=(0,)] jar
jat:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] ixv
jau:i64[64,64] = squeeze[dimensions=(0,)] jat
jav:bool[64,64] = lt jas 0:i64[]
jaw:i64[64,64] = add jas 64:i64[]
jax:i64[64,64] = select_n jav jas jaw
jay:bool[64,64] = lt jau 0:i64[]
jaz:i64[64,64] = add jau 64:i64[]
jba:i64[64,64] = select_n jay jau jaz
jbb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jax
jbc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jba
jbd:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jbb
jbe:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jbc
jbf:i32[64,64,2] = concatenate[dimension=2] jbd jbe
jbg:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] ikl jbf
jbh:f64[64,64] = mul jaq jbg
jbi:f64[64,64] = add jal jbh
jbj:f64[2,64,64] = neg ixg
jbk:f64[] = neg eb
jbl:f64[] = convert_element_type[new_dtype=float64 weak_type=False] jbk
jbm:f64[2,64,64] = mul jbl jbj
jbn:f64[2,64,64] = add ea jbm
jbo:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
jbp:f64[2,64,64] = div jbn jbo
jbq:f64[2,64,64] = floor jbp
jbr:f64[2,64,64] = sub jbp jbq
jbs:f64[2,64,64] = sub 1.0:f64[] jbr
jbt:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] do
jbu:f64[2,64,64] = sub jbp jbr
jbv:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] jbu
jbw:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jbv jbt
jbx:i64[2,64,64] = add jbw 1:i64[]
jby:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jbx jbt
jbz:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jbs
jca:f64[64,64] = squeeze[dimensions=(0,)] jbz
jcb:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jbs
jcc:f64[64,64] = squeeze[dimensions=(0,)] jcb
jcd:f64[64,64] = mul jca jcc
jce:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jbw
jcf:i64[64,64] = squeeze[dimensions=(0,)] jce
jcg:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jbw
jch:i64[64,64] = squeeze[dimensions=(0,)] jcg
jci:bool[64,64] = lt jcf 0:i64[]
jcj:i64[64,64] = add jcf 64:i64[]
jck:i64[64,64] = select_n jci jcf jcj
jcl:bool[64,64] = lt jch 0:i64[]
jcm:i64[64,64] = add jch 64:i64[]
jcn:i64[64,64] = select_n jcl jch jcm
jco:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jck
jcp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jcn
jcq:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jco
jcr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jcp
jcs:i32[64,64,2] = concatenate[dimension=2] jcq jcr
jct:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] jbi jcs
jcu:f64[64,64] = mul jcd jct
jcv:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jbs
jcw:f64[64,64] = squeeze[dimensions=(0,)] jcv
jcx:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jbr
jcy:f64[64,64] = squeeze[dimensions=(0,)] jcx
jcz:f64[64,64] = mul jcw jcy
jda:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jbw
jdb:i64[64,64] = squeeze[dimensions=(0,)] jda
jdc:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jby
jdd:i64[64,64] = squeeze[dimensions=(0,)] jdc
jde:bool[64,64] = lt jdb 0:i64[]
jdf:i64[64,64] = add jdb 64:i64[]
jdg:i64[64,64] = select_n jde jdb jdf
jdh:bool[64,64] = lt jdd 0:i64[]
jdi:i64[64,64] = add jdd 64:i64[]
jdj:i64[64,64] = select_n jdh jdd jdi
jdk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jdg
jdl:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jdj
jdm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jdk
jdn:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jdl
jdo:i32[64,64,2] = concatenate[dimension=2] jdm jdn
jdp:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] jbi jdo
jdq:f64[64,64] = mul jcz jdp
jdr:f64[64,64] = add jcu jdq
jds:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jbr
jdt:f64[64,64] = squeeze[dimensions=(0,)] jds
jdu:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jbs
jdv:f64[64,64] = squeeze[dimensions=(0,)] jdu
jdw:f64[64,64] = mul jdt jdv
jdx:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jby
jdy:i64[64,64] = squeeze[dimensions=(0,)] jdx
jdz:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jbw
jea:i64[64,64] = squeeze[dimensions=(0,)] jdz
jeb:bool[64,64] = lt jdy 0:i64[]
jec:i64[64,64] = add jdy 64:i64[]
jed:i64[64,64] = select_n jeb jdy jec
jee:bool[64,64] = lt jea 0:i64[]
jef:i64[64,64] = add jea 64:i64[]
jeg:i64[64,64] = select_n jee jea jef
jeh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jed
jei:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jeg
jej:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jeh
jek:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jei
jel:i32[64,64,2] = concatenate[dimension=2] jej jek
jem:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] jbi jel
jen:f64[64,64] = mul jdw jem
jeo:f64[64,64] = add jdr jen
jep:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jbr
jeq:f64[64,64] = squeeze[dimensions=(0,)] jep
jer:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jbr
jes:f64[64,64] = squeeze[dimensions=(0,)] jer
jet:f64[64,64] = mul jeq jes
jeu:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jby
jev:i64[64,64] = squeeze[dimensions=(0,)] jeu
jew:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jby
jex:i64[64,64] = squeeze[dimensions=(0,)] jew
jey:bool[64,64] = lt jev 0:i64[]
jez:i64[64,64] = add jev 64:i64[]
jfa:i64[64,64] = select_n jey jev jez
jfb:bool[64,64] = lt jex 0:i64[]
jfc:i64[64,64] = add jex 64:i64[]
jfd:i64[64,64] = select_n jfb jex jfc
jfe:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jfa
jff:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jfd
jfg:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jfe
jfh:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jff
jfi:i32[64,64,2] = concatenate[dimension=2] jfg jfh
jfj:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] jbi jfi
jfk:f64[64,64] = mul jet jfj
jfl:f64[64,64] = add jeo jfk
jfm:f64[64,64] = sub ikl jfl
jfn:f64[64,64] = div jfm 2.0:f64[]
jfo:f64[64,64] = add ikl jfn
jfp:f64[] = neg eb
jfq:f64[] = convert_element_type[new_dtype=float64 weak_type=False] jfp
jfr:f64[2,64,64] = mul jfq ixg
jfs:f64[2,64,64] = add ea jfr
jft:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
jfu:f64[2,64,64] = div jfs jft
jfv:f64[2,64,64] = floor jfu
jfw:f64[2,64,64] = sub jfu jfv
jfx:f64[2,64,64] = sub 1.0:f64[] jfw
jfy:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dp
jfz:f64[2,64,64] = sub jfu jfw
jga:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] jfz
jgb:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jga jfy
jgc:i64[2,64,64] = add jgb 1:i64[]
jgd:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jgc jfy
jge:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jfx
jgf:f64[64,64] = squeeze[dimensions=(0,)] jge
jgg:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jfx
jgh:f64[64,64] = squeeze[dimensions=(0,)] jgg
jgi:f64[64,64] = mul jgf jgh
jgj:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jgb
jgk:i64[64,64] = squeeze[dimensions=(0,)] jgj
jgl:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jgb
jgm:i64[64,64] = squeeze[dimensions=(0,)] jgl
jgn:bool[64,64] = lt jgk 0:i64[]
jgo:i64[64,64] = add jgk 64:i64[]
jgp:i64[64,64] = select_n jgn jgk jgo
jgq:bool[64,64] = lt jgm 0:i64[]
jgr:i64[64,64] = add jgm 64:i64[]
jgs:i64[64,64] = select_n jgq jgm jgr
jgt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jgp
jgu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jgs
jgv:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jgt
jgw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jgu
jgx:i32[64,64,2] = concatenate[dimension=2] jgv jgw
jgy:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] jfo jgx
jgz:f64[64,64] = mul jgi jgy
jha:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jfx
jhb:f64[64,64] = squeeze[dimensions=(0,)] jha
jhc:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jfw
jhd:f64[64,64] = squeeze[dimensions=(0,)] jhc
jhe:f64[64,64] = mul jhb jhd
jhf:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jgb
jhg:i64[64,64] = squeeze[dimensions=(0,)] jhf
jhh:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jgd
jhi:i64[64,64] = squeeze[dimensions=(0,)] jhh
jhj:bool[64,64] = lt jhg 0:i64[]
jhk:i64[64,64] = add jhg 64:i64[]
jhl:i64[64,64] = select_n jhj jhg jhk
jhm:bool[64,64] = lt jhi 0:i64[]
jhn:i64[64,64] = add jhi 64:i64[]
jho:i64[64,64] = select_n jhm jhi jhn
jhp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jhl
jhq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jho
jhr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jhp
jhs:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jhq
jht:i32[64,64,2] = concatenate[dimension=2] jhr jhs
jhu:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] jfo jht
jhv:f64[64,64] = mul jhe jhu
jhw:f64[64,64] = add jgz jhv
jhx:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jfw
jhy:f64[64,64] = squeeze[dimensions=(0,)] jhx
jhz:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jfx
jia:f64[64,64] = squeeze[dimensions=(0,)] jhz
jib:f64[64,64] = mul jhy jia
jic:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jgd
jid:i64[64,64] = squeeze[dimensions=(0,)] jic
jie:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jgb
jif:i64[64,64] = squeeze[dimensions=(0,)] jie
jig:bool[64,64] = lt jid 0:i64[]
jih:i64[64,64] = add jid 64:i64[]
jii:i64[64,64] = select_n jig jid jih
jij:bool[64,64] = lt jif 0:i64[]
jik:i64[64,64] = add jif 64:i64[]
jil:i64[64,64] = select_n jij jif jik
jim:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jii
jin:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jil
jio:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jim
jip:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jin
jiq:i32[64,64,2] = concatenate[dimension=2] jio jip
jir:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] jfo jiq
jis:f64[64,64] = mul jib jir
jit:f64[64,64] = add jhw jis
jiu:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jfw
jiv:f64[64,64] = squeeze[dimensions=(0,)] jiu
jiw:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jfw
jix:f64[64,64] = squeeze[dimensions=(0,)] jiw
jiy:f64[64,64] = mul jiv jix
jiz:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jgd
jja:i64[64,64] = squeeze[dimensions=(0,)] jiz
jjb:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jgd
jjc:i64[64,64] = squeeze[dimensions=(0,)] jjb
jjd:bool[64,64] = lt jja 0:i64[]
jje:i64[64,64] = add jja 64:i64[]
jjf:i64[64,64] = select_n jjd jja jje
jjg:bool[64,64] = lt jjc 0:i64[]
jjh:i64[64,64] = add jjc 64:i64[]
jji:i64[64,64] = select_n jjg jjc jjh
jjj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jjf
jjk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jji
jjl:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jjj
jjm:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jjk
jjn:i32[64,64,2] = concatenate[dimension=2] jjl jjm
jjo:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] jfo jjn
jjp:f64[64,64] = mul jiy jjo
jjq:f64[64,64] = add jit jjp
jjr:f64[] = neg eb
jjs:f64[] = convert_element_type[new_dtype=float64 weak_type=False] jjr
jjt:f64[2,64,64] = mul jjs ixg
jju:f64[2,64,64] = add ea jjt
jjv:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
jjw:f64[2,64,64] = div jju jjv
jjx:f64[2,64,64] = floor jjw
jjy:f64[2,64,64] = sub jjw jjx
jjz:f64[2,64,64] = sub 1.0:f64[] jjy
jka:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dq
jkb:f64[2,64,64] = sub jjw jjy
jkc:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] jkb
jkd:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jkc jka
jke:i64[2,64,64] = add jkd 1:i64[]
jkf:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jke jka
jkg:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jjz
jkh:f64[64,64] = squeeze[dimensions=(0,)] jkg
jki:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jjz
jkj:f64[64,64] = squeeze[dimensions=(0,)] jki
jkk:f64[64,64] = mul jkh jkj
jkl:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jkd
jkm:i64[64,64] = squeeze[dimensions=(0,)] jkl
jkn:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jkd
jko:i64[64,64] = squeeze[dimensions=(0,)] jkn
jkp:bool[64,64] = lt jkm 0:i64[]
jkq:i64[64,64] = add jkm 64:i64[]
jkr:i64[64,64] = select_n jkp jkm jkq
jks:bool[64,64] = lt jko 0:i64[]
jkt:i64[64,64] = add jko 64:i64[]
jku:i64[64,64] = select_n jks jko jkt
jkv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jkr
jkw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jku
jkx:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jkv
jky:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jkw
jkz:i32[64,64,2] = concatenate[dimension=2] jkx jky
jla:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] iwv jkz
jlb:f64[64,64] = mul jkk jla
jlc:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jjz
jld:f64[64,64] = squeeze[dimensions=(0,)] jlc
jle:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jjy
jlf:f64[64,64] = squeeze[dimensions=(0,)] jle
jlg:f64[64,64] = mul jld jlf
jlh:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jkd
jli:i64[64,64] = squeeze[dimensions=(0,)] jlh
jlj:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jkf
jlk:i64[64,64] = squeeze[dimensions=(0,)] jlj
jll:bool[64,64] = lt jli 0:i64[]
jlm:i64[64,64] = add jli 64:i64[]
jln:i64[64,64] = select_n jll jli jlm
jlo:bool[64,64] = lt jlk 0:i64[]
jlp:i64[64,64] = add jlk 64:i64[]
jlq:i64[64,64] = select_n jlo jlk jlp
jlr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jln
jls:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jlq
jlt:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jlr
jlu:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jls
jlv:i32[64,64,2] = concatenate[dimension=2] jlt jlu
jlw:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] iwv jlv
jlx:f64[64,64] = mul jlg jlw
jly:f64[64,64] = add jlb jlx
jlz:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jjy
jma:f64[64,64] = squeeze[dimensions=(0,)] jlz
jmb:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jjz
jmc:f64[64,64] = squeeze[dimensions=(0,)] jmb
jmd:f64[64,64] = mul jma jmc
jme:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jkf
jmf:i64[64,64] = squeeze[dimensions=(0,)] jme
jmg:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jkd
jmh:i64[64,64] = squeeze[dimensions=(0,)] jmg
jmi:bool[64,64] = lt jmf 0:i64[]
jmj:i64[64,64] = add jmf 64:i64[]
jmk:i64[64,64] = select_n jmi jmf jmj
jml:bool[64,64] = lt jmh 0:i64[]
jmm:i64[64,64] = add jmh 64:i64[]
jmn:i64[64,64] = select_n jml jmh jmm
jmo:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jmk
jmp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jmn
jmq:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jmo
jmr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jmp
jms:i32[64,64,2] = concatenate[dimension=2] jmq jmr
jmt:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] iwv jms
jmu:f64[64,64] = mul jmd jmt
jmv:f64[64,64] = add jly jmu
jmw:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jjy
jmx:f64[64,64] = squeeze[dimensions=(0,)] jmw
jmy:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jjy
jmz:f64[64,64] = squeeze[dimensions=(0,)] jmy
jna:f64[64,64] = mul jmx jmz
jnb:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jkf
jnc:i64[64,64] = squeeze[dimensions=(0,)] jnb
jnd:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jkf
jne:i64[64,64] = squeeze[dimensions=(0,)] jnd
jnf:bool[64,64] = lt jnc 0:i64[]
jng:i64[64,64] = add jnc 64:i64[]
jnh:i64[64,64] = select_n jnf jnc jng
jni:bool[64,64] = lt jne 0:i64[]
jnj:i64[64,64] = add jne 64:i64[]
jnk:i64[64,64] = select_n jni jne jnj
jnl:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jnh
jnm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jnk
jnn:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jnl
jno:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jnm
jnp:i32[64,64,2] = concatenate[dimension=2] jnn jno
jnq:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] iwv jnp
jnr:f64[64,64] = mul jna jnq
jns:f64[64,64] = add jmv jnr
jnt:f64[2,64,64] = neg ixg
jnu:f64[] = neg eb
jnv:f64[] = convert_element_type[new_dtype=float64 weak_type=False] jnu
jnw:f64[2,64,64] = mul jnv jnt
jnx:f64[2,64,64] = add ea jnw
jny:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
jnz:f64[2,64,64] = div jnx jny
joa:f64[2,64,64] = floor jnz
job:f64[2,64,64] = sub jnz joa
joc:f64[2,64,64] = sub 1.0:f64[] job
jod:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dr
joe:f64[2,64,64] = sub jnz job
jof:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] joe
jog:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jof jod
joh:i64[2,64,64] = add jog 1:i64[]
joi:i64[2,64,64] = jit[name=remainder jaxpr=remainder] joh jod
joj:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] joc
jok:f64[64,64] = squeeze[dimensions=(0,)] joj
jol:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] joc
jom:f64[64,64] = squeeze[dimensions=(0,)] jol
jon:f64[64,64] = mul jok jom
joo:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jog
jop:i64[64,64] = squeeze[dimensions=(0,)] joo
joq:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jog
jor:i64[64,64] = squeeze[dimensions=(0,)] joq
jos:bool[64,64] = lt jop 0:i64[]
jot:i64[64,64] = add jop 64:i64[]
jou:i64[64,64] = select_n jos jop jot
jov:bool[64,64] = lt jor 0:i64[]
jow:i64[64,64] = add jor 64:i64[]
jox:i64[64,64] = select_n jov jor jow
joy:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jou
joz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jox
jpa:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] joy
jpb:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] joz
jpc:i32[64,64,2] = concatenate[dimension=2] jpa jpb
jpd:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] jns jpc
jpe:f64[64,64] = mul jon jpd
jpf:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] joc
jpg:f64[64,64] = squeeze[dimensions=(0,)] jpf
jph:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] job
jpi:f64[64,64] = squeeze[dimensions=(0,)] jph
jpj:f64[64,64] = mul jpg jpi
jpk:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jog
jpl:i64[64,64] = squeeze[dimensions=(0,)] jpk
jpm:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] joi
jpn:i64[64,64] = squeeze[dimensions=(0,)] jpm
jpo:bool[64,64] = lt jpl 0:i64[]
jpp:i64[64,64] = add jpl 64:i64[]
jpq:i64[64,64] = select_n jpo jpl jpp
jpr:bool[64,64] = lt jpn 0:i64[]
jps:i64[64,64] = add jpn 64:i64[]
jpt:i64[64,64] = select_n jpr jpn jps
jpu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jpq
jpv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jpt
jpw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jpu
jpx:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jpv
jpy:i32[64,64,2] = concatenate[dimension=2] jpw jpx
jpz:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] jns jpy
jqa:f64[64,64] = mul jpj jpz
jqb:f64[64,64] = add jpe jqa
jqc:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] job
jqd:f64[64,64] = squeeze[dimensions=(0,)] jqc
jqe:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] joc
jqf:f64[64,64] = squeeze[dimensions=(0,)] jqe
jqg:f64[64,64] = mul jqd jqf
jqh:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] joi
jqi:i64[64,64] = squeeze[dimensions=(0,)] jqh
jqj:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jog
jqk:i64[64,64] = squeeze[dimensions=(0,)] jqj
jql:bool[64,64] = lt jqi 0:i64[]
jqm:i64[64,64] = add jqi 64:i64[]
jqn:i64[64,64] = select_n jql jqi jqm
jqo:bool[64,64] = lt jqk 0:i64[]
jqp:i64[64,64] = add jqk 64:i64[]
jqq:i64[64,64] = select_n jqo jqk jqp
jqr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jqn
jqs:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jqq
jqt:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jqr
jqu:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jqs
jqv:i32[64,64,2] = concatenate[dimension=2] jqt jqu
jqw:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] jns jqv
jqx:f64[64,64] = mul jqg jqw
jqy:f64[64,64] = add jqb jqx
jqz:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] job
jra:f64[64,64] = squeeze[dimensions=(0,)] jqz
jrb:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] job
jrc:f64[64,64] = squeeze[dimensions=(0,)] jrb
jrd:f64[64,64] = mul jra jrc
jre:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] joi
jrf:i64[64,64] = squeeze[dimensions=(0,)] jre
jrg:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] joi
jrh:i64[64,64] = squeeze[dimensions=(0,)] jrg
jri:bool[64,64] = lt jrf 0:i64[]
jrj:i64[64,64] = add jrf 64:i64[]
jrk:i64[64,64] = select_n jri jrf jrj
jrl:bool[64,64] = lt jrh 0:i64[]
jrm:i64[64,64] = add jrh 64:i64[]
jrn:i64[64,64] = select_n jrl jrh jrm
jro:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jrk
jrp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jrn
jrq:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jro
jrr:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jrp
jrs:i32[64,64,2] = concatenate[dimension=2] jrq jrr
jrt:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] jns jrs
jru:f64[64,64] = mul jrd jrt
jrv:f64[64,64] = add jqy jru
jrw:f64[64,64] = sub iwv jrv
jrx:f64[64,64] = div jrw 2.0:f64[]
jry:f64[64,64] = add iwv jrx
jrz:f64[] = neg eb
jsa:f64[] = convert_element_type[new_dtype=float64 weak_type=False] jrz
jsb:f64[2,64,64] = mul jsa ixg
jsc:f64[2,64,64] = add ea jsb
jsd:f64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] dz
jse:f64[2,64,64] = div jsc jsd
jsf:f64[2,64,64] = floor jse
jsg:f64[2,64,64] = sub jse jsf
jsh:f64[2,64,64] = sub 1.0:f64[] jsg
jsi:i64[2,1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1, 1)
sharding=None
] ds
jsj:f64[2,64,64] = sub jse jsg
jsk:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] jsj
jsl:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jsk jsi
jsm:i64[2,64,64] = add jsl 1:i64[]
jsn:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jsm jsi
jso:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jsh
jsp:f64[64,64] = squeeze[dimensions=(0,)] jso
jsq:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jsh
jsr:f64[64,64] = squeeze[dimensions=(0,)] jsq
jss:f64[64,64] = mul jsp jsr
jst:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jsl
jsu:i64[64,64] = squeeze[dimensions=(0,)] jst
jsv:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jsl
jsw:i64[64,64] = squeeze[dimensions=(0,)] jsv
jsx:bool[64,64] = lt jsu 0:i64[]
jsy:i64[64,64] = add jsu 64:i64[]
jsz:i64[64,64] = select_n jsx jsu jsy
jta:bool[64,64] = lt jsw 0:i64[]
jtb:i64[64,64] = add jsw 64:i64[]
jtc:i64[64,64] = select_n jta jsw jtb
jtd:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jsz
jte:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jtc
jtf:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jtd
jtg:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jte
jth:i32[64,64,2] = concatenate[dimension=2] jtf jtg
jti:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] jry jth
jtj:f64[64,64] = mul jss jti
jtk:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jsh
jtl:f64[64,64] = squeeze[dimensions=(0,)] jtk
jtm:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jsg
jtn:f64[64,64] = squeeze[dimensions=(0,)] jtm
jto:f64[64,64] = mul jtl jtn
jtp:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jsl
jtq:i64[64,64] = squeeze[dimensions=(0,)] jtp
jtr:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jsn
jts:i64[64,64] = squeeze[dimensions=(0,)] jtr
jtt:bool[64,64] = lt jtq 0:i64[]
jtu:i64[64,64] = add jtq 64:i64[]
jtv:i64[64,64] = select_n jtt jtq jtu
jtw:bool[64,64] = lt jts 0:i64[]
jtx:i64[64,64] = add jts 64:i64[]
jty:i64[64,64] = select_n jtw jts jtx
jtz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jtv
jua:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jty
jub:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jtz
juc:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jua
jud:i32[64,64,2] = concatenate[dimension=2] jub juc
jue:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] jry jud
juf:f64[64,64] = mul jto jue
jug:f64[64,64] = add jtj juf
juh:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jsg
jui:f64[64,64] = squeeze[dimensions=(0,)] juh
juj:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jsh
juk:f64[64,64] = squeeze[dimensions=(0,)] juj
jul:f64[64,64] = mul jui juk
jum:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jsn
jun:i64[64,64] = squeeze[dimensions=(0,)] jum
juo:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jsl
jup:i64[64,64] = squeeze[dimensions=(0,)] juo
juq:bool[64,64] = lt jun 0:i64[]
jur:i64[64,64] = add jun 64:i64[]
jus:i64[64,64] = select_n juq jun jur
jut:bool[64,64] = lt jup 0:i64[]
juu:i64[64,64] = add jup 64:i64[]
juv:i64[64,64] = select_n jut jup juu
juw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jus
jux:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] juv
juy:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] juw
juz:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jux
jva:i32[64,64,2] = concatenate[dimension=2] juy juz
jvb:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] jry jva
jvc:f64[64,64] = mul jul jvb
jvd:f64[64,64] = add jug jvc
jve:f64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jsg
jvf:f64[64,64] = squeeze[dimensions=(0,)] jve
jvg:f64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jsg
jvh:f64[64,64] = squeeze[dimensions=(0,)] jvg
jvi:f64[64,64] = mul jvf jvh
jvj:i64[1,64,64] = slice[
limit_indices=(1, 64, 64)
start_indices=(0, 0, 0)
strides=None
] jsn
jvk:i64[64,64] = squeeze[dimensions=(0,)] jvj
jvl:i64[1,64,64] = slice[
limit_indices=(2, 64, 64)
start_indices=(1, 0, 0)
strides=None
] jsn
jvm:i64[64,64] = squeeze[dimensions=(0,)] jvl
jvn:bool[64,64] = lt jvk 0:i64[]
jvo:i64[64,64] = add jvk 64:i64[]
jvp:i64[64,64] = select_n jvn jvk jvo
jvq:bool[64,64] = lt jvm 0:i64[]
jvr:i64[64,64] = add jvm 64:i64[]
jvs:i64[64,64] = select_n jvq jvm jvr
jvt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jvp
jvu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jvs
jvv:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jvt
jvw:i32[64,64,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(64, 64, 1)
sharding=None
] jvu
jvx:i32[64,64,2] = concatenate[dimension=2] jvv jvw
jvy:f64[64,64] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] jry jvx
jvz:f64[64,64] = mul jvi jvy
jwa:f64[64,64] = add jvd jvz
in (jjq, jwa) }
integrate
function
is
unrolled
on
tracing.
n_step
increases
with
a
(steep)
accompanying
increase
in
compile
time.
jax.lax
module.
lax.fori_loop
can
be
used
in
place
of
a
native
Python
for
loop.
Importantly the structured loop control flow primitives like lax.fori_loop are not unrolled on tracing and so compile time does not increase with number of iterations.
jit to integrationWe can now JIT compile this integrate_fori_loop function
jitted_integrateComparing performance we now see (JIT compiled) JAX is now exceeding NumPy’s performance by a reasonable margin:
NumPy
448 ms ± 63.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Jitted JAX
9.14 ms ± 1.41 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
lax.scanAs well as lax.fori_loop JAX provides other loop primitives such as lax.scan for sequence to sequence mapping.
In our case we can use lax.scan to define a variant of integrate which outputs sequences of simulated fields.
@jax.jit(static_argnames=("n_step",))
def integrate_scan(vorticity, tracer, kernels, mesh, time_step, n_step):
def scan_step(vorticity_tracer, _):
vorticity_tracer = step(*vorticity_tracer, kernels, mesh, time_step)
return vorticity_tracer, vorticity_tracer
_, (vorticity_sequence, tracer_sequence) = jax.lax.scan(
scan_step, (vorticity, tracer), length=n_step
)
return vorticity_sequence, tracer_sequencetime_step, total_time = 1., 100.
u = rng.standard_normal(mesh.shape[0] * mesh.shape[1])
vorticity = generate_vorticity(u, mesh, kernels)
tracer = generate_initial_tracer(mesh)
vorticity_seq, tracer_seq = integrate_scan(
vorticity, tracer, kernels, mesh, time_step, int(total_time / time_step)
)
animate_fields(mesh, kernels, vorticity, tracer, vorticity_seq, tracer_seq)@jax.jit(static_argnames=("time_step", "total_time"))
def objective(
u, initial_tracer, target_tracer, kernels, mesh, time_step, total_time
):
vorticity = generate_vorticity(u, mesh, kernels)
n_step = int(total_time / time_step)
_, final_tracer = integrate_fori_loop(
vorticity, initial_tracer, kernels, mesh, time_step, n_step
)
return ((final_tracer - target_tracer)**2).sum() + (u**2).sum()We use a target final tracer field loaded from an image.

We can then use the jax.value_and_grad transformation to compute a function which evaluates both the value of the function and its gradient, using reverse mode automatic differentiation
This can be composed with the jax.jit transformation to compile the gradient function.
We can evaluate the transformed function at a test input to check that it correctly gives the same value as the original (primal) function
u = rng.standard_normal(mesh.shape[0] * mesh.shape[1])
objective_kwargs = {
"initial_tracer": initial_tracer,
"target_tracer": target_tracer,
"kernels": kernels,
"mesh": mesh,
"time_step": time_step,
"total_time": total_time
}
value, gradient = value_and_grad_objective(u, **objective_kwargs)
assert np.allclose(value, objective(u, **objective_kwargs))Importantly reverse mode differentiation allows us to evaluate both the value and gradient of a scalar function with respect to all inputs at operation cost that is a small constant multiple of the operation cost of evaluating the function itself:
CPU times: user 210 ms, sys: 34.6 ms, total: 245 ms
Wall time: 107 ms
Array(5187.07212124, dtype=float64)
from jax.example_libraries import optimizers
def optimize(initial_u, objective_kwargs, n_steps=100, step_size=5e-2):
initialize, update, get_params = optimizers.adam(step_size)
def optimizer_step(i, state):
u = get_params(state)
value, grad = value_and_grad_objective(u, **objective_kwargs)
jax.debug.print("Iter {i:03}: obj = {value:.4g}", i=i, value=value)
return update(i, grad, state)
initial_state = initialize(initial_u)
final_state = jax.lax.fori_loop(0, n_steps + 1, optimizer_step, initial_state)
return get_params(final_state)Using the optimize helper function, we can minimize the objective function to find an initial vorticity which flows the tracer field to be close to the target.
Iter 000: obj = 1850
Iter 001: obj = 1462
Iter 002: obj = 1273
Iter 003: obj = 1131
Iter 004: obj = 993.7
Iter 005: obj = 925.6
Iter 006: obj = 848.9
Iter 007: obj = 787.8
Iter 008: obj = 758.2
Iter 009: obj = 718.3
Iter 010: obj = 687.6
Iter 011: obj = 656.3
Iter 012: obj = 627.6
Iter 013: obj = 600.3
Iter 014: obj = 570.7
Iter 015: obj = 555
Iter 016: obj = 530.6
Iter 017: obj = 512.2
Iter 018: obj = 500.9
Iter 019: obj = 480.7
Iter 020: obj = 460.9
Iter 021: obj = 449.9
Iter 022: obj = 438
Iter 023: obj = 424.8
Iter 024: obj = 410.9
Iter 025: obj = 402.5
Iter 026: obj = 393.1
Iter 027: obj = 382.8
Iter 028: obj = 374.8
Iter 029: obj = 364.6
Iter 030: obj = 356.4
Iter 031: obj = 350.6
Iter 032: obj = 343.2
Iter 033: obj = 335.5
Iter 034: obj = 328.8
Iter 035: obj = 322.6
Iter 036: obj = 316.8
Iter 037: obj = 311.8
Iter 038: obj = 306.3
Iter 039: obj = 300.5
Iter 040: obj = 297.2
Iter 041: obj = 293.2
Iter 042: obj = 289.6
Iter 043: obj = 286.6
Iter 044: obj = 283.5
Iter 045: obj = 280
Iter 046: obj = 277.2
Iter 047: obj = 275
Iter 048: obj = 273
Iter 049: obj = 271
Iter 050: obj = 269.3
Iter 051: obj = 267.3
Iter 052: obj = 265.5
Iter 053: obj = 264
Iter 054: obj = 262.4
Iter 055: obj = 260.6
Iter 056: obj = 259.1
Iter 057: obj = 257.7
Iter 058: obj = 256.4
Iter 059: obj = 255.1
Iter 060: obj = 254
Iter 061: obj = 252.7
Iter 062: obj = 251.3
Iter 063: obj = 249.9
Iter 064: obj = 248.7
Iter 065: obj = 247.6
Iter 066: obj = 246.5
Iter 067: obj = 245.5
Iter 068: obj = 244.5
Iter 069: obj = 243.5
Iter 070: obj = 242.6
Iter 071: obj = 241.7
Iter 072: obj = 240.8
Iter 073: obj = 240
Iter 074: obj = 239.1
Iter 075: obj = 238.3
Iter 076: obj = 237.5
Iter 077: obj = 236.8
Iter 078: obj = 236.1
Iter 079: obj = 235.4
Iter 080: obj = 234.8
Iter 081: obj = 234.2
Iter 082: obj = 233.6
Iter 083: obj = 233
Iter 084: obj = 232.4
Iter 085: obj = 232
Iter 086: obj = 231.6
Iter 087: obj = 231.3
Iter 088: obj = 231.4
Iter 089: obj = 231.9
Iter 090: obj = 232.6
Iter 091: obj = 232.2
Iter 092: obj = 230.4
Iter 093: obj = 228.3
Iter 094: obj = 228
Iter 095: obj = 228.8
Iter 096: obj = 228.9
Iter 097: obj = 227.8
Iter 098: obj = 226.4
Iter 099: obj = 226
Iter 100: obj = 226.5
Iter 101: obj = 226.5
Iter 102: obj = 225.7
Iter 103: obj = 224.7
Iter 104: obj = 224.3
Iter 105: obj = 224.4
Iter 106: obj = 224.5
Iter 107: obj = 224.1
Iter 108: obj = 223.4
Iter 109: obj = 222.7
Iter 110: obj = 222.4
Iter 111: obj = 222.2
Iter 112: obj = 222.2
Iter 113: obj = 222
Iter 114: obj = 221.6
Iter 115: obj = 221.1
Iter 116: obj = 220.6
Iter 117: obj = 220.2
Iter 118: obj = 220
Iter 119: obj = 219.8
Iter 120: obj = 219.7
Iter 121: obj = 219.6
Iter 122: obj = 219.5
Iter 123: obj = 219.4
Iter 124: obj = 219.3
Iter 125: obj = 219.1
Iter 126: obj = 218.9
Iter 127: obj = 218.7
Iter 128: obj = 218.5
Iter 129: obj = 218.2
Iter 130: obj = 217.9
Iter 131: obj = 217.6
Iter 132: obj = 217.3
Iter 133: obj = 217.1
Iter 134: obj = 216.8
Iter 135: obj = 216.6
Iter 136: obj = 216.4
Iter 137: obj = 216.1
Iter 138: obj = 215.9
Iter 139: obj = 215.8
Iter 140: obj = 215.6
Iter 141: obj = 215.5
Iter 142: obj = 215.4
Iter 143: obj = 215.5
Iter 144: obj = 215.7
Iter 145: obj = 216.1
Iter 146: obj = 216.8
Iter 147: obj = 217.4
Iter 148: obj = 217.9
Iter 149: obj = 217.3
Iter 150: obj = 215.9
Iter 151: obj = 214.2
Iter 152: obj = 213
Iter 153: obj = 212.7
Iter 154: obj = 213.1
Iter 155: obj = 213.8
Iter 156: obj = 214
Iter 157: obj = 213.6
Iter 158: obj = 212.7
Iter 159: obj = 211.8
Iter 160: obj = 211.3
Iter 161: obj = 211.2
Iter 162: obj = 211.4
Iter 163: obj = 211.7
Iter 164: obj = 211.8
Iter 165: obj = 211.6
Iter 166: obj = 211.2
Iter 167: obj = 210.7
Iter 168: obj = 210.3
Iter 169: obj = 210
Iter 170: obj = 209.8
Iter 171: obj = 209.8
Iter 172: obj = 209.8
Iter 173: obj = 209.8
Iter 174: obj = 209.8
Iter 175: obj = 209.8
Iter 176: obj = 209.7
Iter 177: obj = 209.6
Iter 178: obj = 209.5
Iter 179: obj = 209.4
Iter 180: obj = 209.2
Iter 181: obj = 209.1
Iter 182: obj = 208.9
Iter 183: obj = 208.8
Iter 184: obj = 208.6
Iter 185: obj = 208.5
Iter 186: obj = 208.4
Iter 187: obj = 208.3
Iter 188: obj = 208.2
Iter 189: obj = 208.1
Iter 190: obj = 208.1
Iter 191: obj = 208
Iter 192: obj = 208
Iter 193: obj = 208
Iter 194: obj = 208
Iter 195: obj = 208
Iter 196: obj = 208
Iter 197: obj = 208
Iter 198: obj = 208
Iter 199: obj = 208
Iter 200: obj = 208
We can then simulate trajectories of the fields from the optimized initial vorticity and animate:
vmap, shard_map.