Accelerated and differentiable
scientific computing with JAX

Matt Graham

UCL Centre for Advanced Research Computing

🤔 What is JAX?

JAX logo

JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.

✨ Key features

  • 🙌 Ease of use: Offers a NumPy-like API make it accessible to those already familiar with scientific Python ecosystem.
  • 🪄 Transformations: Provides composable function transformations for just-in-time compilation, batching, autodiff and parallelization.
  • 🚀 Acceleration: Allows easily executing the same code on CPUs and accelerator devices such as GPUs and TPUs.

🔢 Python array libraries

NumPy

JAX

PyTorch

TensorFlow

CuPy

Dask

Xarray

Sparse

🔀 NumPy API substitutes

The wide user familiarity with NumPy’s API and large amount of existing code using NumPy has led some libraries to providing NumPy ‘like’ APIs

import numpy as np
import jax.numpy as jnp
import cupy as cp

⚠️ NumPy API as common standard?

However NumPy API not designed for this purpose and has some shortcomings:

  • Complex datatype promotion semantics.
  • Copy-view behaviour.
  • Not designed for use with non-CPU devices.
  • Operations produce arrays with data-dependent shapes.

📐 Python data API standards

Python Data APIs logo

Array API standard defined by Consortium for Python Data API Standards (https://data-apis.org/).

💠Array API components

Array API components

💡 Array API example

Consider following implementation of

\[\text{LogSumExp}(x_{1:N}) = x^* + \log \sum_{n=1}^N \exp(x_n - x^*), \\ x^* = \max x_{1:N}\]

def log_sum_exp(x):
    xp = x.__array_namespace__()
    max_x = xp.max(x)
    return max_x + xp.log(xp.sum(xp.exp(x - max_x)))

log_sum_exp can then be called with array objects from any library supporting array API.

🧩 Array API compatibility library

array_api_compat acts a small wrapper around common array libraries to provide compatibility with array API standard.

Useful for libraries with only partial support for array API.

import array_api_compat

def log_sum_exp(x):
    xp = array_api_compat.array_namespace(x)
    max_x = xp.max(x)
    return max_x + xp.log(xp.sum(xp.exp(x - max_x)))

NumPy v2.1+ and JAX v0.4.32+ fully support array API so using array_api_compat not strictly necessary.

▶️ Demo: accelerated and differentiable fluid simulation with JAX

We will now demonstrate some of JAX and the Array API’s key features in an applied example.

🌍 Environment set up

We will leverage NumPy and JAX’s Array API support to write functions which can be run with either backend.
One important thing to note is that JAX defaults to using single-precision data types for many operations as this often gives better performance on devices such as GPUs.
The jax_enable_x64 configuration option can be used to use double-precision types by default, more closely matching NumPy’s behaviour.
import jax
import jax.numpy as jnp
import numpy as np

jax.config.update("jax_enable_x64", True)

🌊 Example task: fluid simulation

As an example of a scientific computing task, we will consider simulating motion of an incompressible inviscid fluid on a two-dimensional domain with periodic boundary conditions.

Mathematically such a flow can be described by

\[ \partial_t \omega = -(v \cdot \nabla) \omega, \quad \omega = \nabla \times v, \]

where \(\omega\) is the vorticity of a two-dimensional velocity field \(v\).

We will additionally consider the motion of a passive tracer field \(\tau\) advected by the velocity field \(v\)

\[ \partial_t \tau = -(v \cdot \nabla) \tau. \]

🌊 Example task: fluid simulation

Preview of example simulation output on a 64×46 mesh:

🔢 Numerical approach

On each step we

  1. Compute velocity (\(v\)) from vorticity (\(\omega\)) in spectral space.
  2. Advect vorticity and tracer (\(\tau\)) fields by computed velocity.
def step(
    vorticity: Array, tracer: Array, kernels: Kernels, mesh: Mesh, time_step: float
) -> Array:
    """Perform single time step update of vorticity and tracers."""
    velocity = velocity_from_vorticity(vorticity, kernels)
    new_vorticity = bfecc_advect(vorticity, velocity, mesh, time_step)
    new_tracer = bfecc_advect(tracer, velocity, mesh, time_step)
    return new_vorticity, new_tracer

🌀 Generating vorticity fields

To allow generating smooth vorticity fields we

  1. map a real-valued vector to complex spectral coefficients,
  2. multiply by a kernel corresponding to exponential squared covariance function,
  3. and then transform to grid space.
def generate_vorticity(u: Array, mesh: Mesh, kernels: Kernels) -> Array:
    """Generate vorticity field from standard normal random vector `u`."""
    xp = u.__array_namespace__()
    fft_coefficients = real_array_to_rfft2_coeff(u, mesh.coordinates.shape[1:])
    fft_vorticity = fft_coefficients * kernels.initialization
    return xp.fft.irfft2(fft_vorticity, norm="ortho")

✨ Generating with NumPy & JAX

We can generate initial vorticity and velocity fields with generate_vorticity and velocity_from_vorticity using either Numpy or JAX depending on array argument types:

mesh = generate_mesh(shape=(64, 64))
kernels = generate_kernels(mesh)
rng = np.random.default_rng(1234)
u = rng.standard_normal(size=mesh.shape[0] * mesh.shape[1])
vorticity_numpy = generate_vorticity(np.asarray(u), mesh, kernels)
velocity_numpy = velocity_from_vorticity(vorticity_numpy, kernels)
type(vorticity_numpy), type(velocity_numpy)
(numpy.ndarray, numpy.ndarray)
vorticity_jax = generate_vorticity(jnp.asarray(u), mesh, kernels)
velocity_jax = velocity_from_vorticity(vorticity_jax, kernels)
type(vorticity_jax), type(velocity_jax)
(jaxlib._jax.ArrayImpl, jaxlib._jax.ArrayImpl)

🟰 Equivalence of implementations

The output arrays are numerically equivalent:

assert np.allclose(vorticity_numpy, vorticity_jax)
assert np.allclose(velocity_numpy, velocity_jax)

We can also visualize the equivalence of the function outputs

🚀 Performance: NumPy vs. JAX

How does calling functions using NumPy and JAX arrays compare in terms of compute time?

⏱️ Comparing performance

We time calling a subset of the model functions for both NumPy and JAX arrays and visualize the results:

We can see that when evaluating on JAX arrays the functions are consistently and significantly slower!

🚚 JAX & NumPy dispatch models

This may at first seem surprising as JAX describes itself as a performance-oriented library. However,

  • NumPy operations are executed eagerly and synchronously.
  • JAX operations may be executed eagerly or lazily after compilation and they are dispatched asynchronously.

🎯 JAX & NumPy dispatch priorities

This has led to different priorities in the packages.

NumPy has put significant effort into decreasing the per-call dispatch overhead for individual array operations, because in NumPy’s computational model that overhead cannot be avoided.

JAX, on the other hand, has several ways to avoid dispatch overhead and so reducing per-call overhead has been less of a priority.

🙆 Reducing dispatch overhead

One particular way of reducing dispatch overhead offered by JAX is just-in-time (JIT) compilation.

As mentioned earlier JAX builds on top of the XLA compiler toolchain, which allows compiling code for a variety device backends (CPU, GPU, TPU).

Before we try out JIT compilation we will first take a look at how JAX’s function transformations work.

👣 JAX tracing and jaxprs

At a high level, JAX’s function transformations, operate by first tracing the Python function to be transformed to an intermediate representation (IR), with this IR then interpreted with a transformation specific interpreter.

Jaxprs are JAX’s internal IR of traced programs.

👀 Inspecting a jaxpr with make_jaxpr

Consider the advect_points function defined in the model:

def advect_points(points: Array, velocity: Array, time_step: float) -> Array:
    """Advect points along velocity with Euler method."""
    return points + time_step * velocity

JAX provides a special jax.make_jaxpr transformation that allows us to inspect the jaxpr representation of a function.

Applying to advect_points:

time_step = 1.
jax.make_jaxpr(advect_points)(mesh.coordinates, velocity_jax, time_step)
{ lambda ; a:f64[2,64,64] b:f64[2,64,64] c:f64[]. let
    d:f64[] = convert_element_type[new_dtype=float64 weak_type=False] c
    e:f64[2,64,64] = mul d b
    f:f64[2,64,64] = add a e
  in (f,) }

🔠 Abstract values

Internally JAX traces functions with abstract values. For most transformations the abstraction is at the level of the shape and datatype of arguments, but not their values.

We can use the jax.ShapedDtypeStruct to create an abstract array with just shape and dtype attributes (plus some additional attributes for JAX bookkeeping such as device layout and sharding).

abstract_coordinates = jax.ShapeDtypeStruct(
    mesh.coordinates.shape, mesh.coordinates.dtype
)
abstract_velocity = jax.ShapeDtypeStruct(
    velocity_jax.shape, velocity_jax.dtype
)

🔠 Tracing with abstract values

If we now call the jax.make_jaxpr transformed function with these abstract arguments we get exactly the same result as before

jax.make_jaxpr(advect_points)(abstract_coordinates, abstract_velocity, time_step)
{ lambda ; a:f64[2,64,64] b:f64[2,64,64] c:f64[]. let
    d:f64[] = convert_element_type[new_dtype=float64 weak_type=False] c
    e:f64[2,64,64] = mul d b
    f:f64[2,64,64] = add a e
  in (f,) }

🔀 Control flow and tracing

Importantly any control flow in the in traced function is transparent to the tracing.
Consider the semi_lagrangian_advect function which calls advect_points :
def semi_lagrangian_advect(
    field: Array, velocity: Array, mesh: Mesh, time_step: float
) -> Array:
    """Use semi-Lagrangian method to advect a given field a single step."""
    origin_points = advect_points(mesh.coordinates, velocity, -time_step)
    return bilinear_interpolate(field, origin_points / mesh.cell_size[:, None, None])

The jaxpr for this function corresponds to tracing through the full computation.

🔀 Control flow and tracing

jax.make_jaxpr(semi_lagrangian_advect)(
    vorticity_jax, velocity_jax, mesh, time_step
)
let remainder = { lambda ; a:i64[2,64,64] b:i64[2,1,1]. let
    c:bool[2,1,1] = eq b 0:i64[]
    d:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2, 1, 1)
      sharding=None
    ] 1:i64[]
    e:i64[2,1,1] = jit[
      name=_where
      jaxpr={ lambda ; c:bool[2,1,1] d:i64[2,1,1] b:i64[2,1,1]. let
          e:i64[2,1,1] = select_n c b d
        in (e,) }
    ] c d b
    f:i64[2,64,64] = rem a e
    g:bool[2,64,64] = ne f 0:i64[]
    h:bool[2,64,64] = lt f 0:i64[]
    i:bool[2,1,1] = lt e 0:i64[]
    j:bool[2,64,64] = ne h i
    k:bool[2,64,64] = and j g
    l:i64[2,64,64] = add f e
    m:i64[2,64,64] = select_n k f l
  in (m,) } in
let _where = { lambda ; c:bool[2,1,1] d:i64[2,1,1] b:i64[2,1,1]. let
    e:i64[2,1,1] = select_n c b d
  in (e,) } in
{ lambda n:i64[2]; o:f64[64,64] p:f64[2,64,64] q:i64[] r:i64[] s:f64[2] t:f64[2,64,64]
    u:f64[]. let
    v:f64[] = neg u
    w:f64[] = convert_element_type[new_dtype=float64 weak_type=False] v
    x:f64[2,64,64] = mul w p
    y:f64[2,64,64] = add t x
    z:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] s
    ba:f64[2,64,64] = div y z
    bb:f64[2,64,64] = floor ba
    bc:f64[2,64,64] = sub ba bb
    bd:f64[2,64,64] = sub 1.0:f64[] bc
    be:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] n
    bf:f64[2,64,64] = sub ba bc
    bg:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] bf
    bh:i64[2,64,64] = jit[name=remainder jaxpr=remainder] bg be
    bi:i64[2,64,64] = add bh 1:i64[]
    bj:i64[2,64,64] = jit[name=remainder jaxpr=remainder] bi be
    bk:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bd
    bl:f64[64,64] = squeeze[dimensions=(0,)] bk
    bm:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bd
    bn:f64[64,64] = squeeze[dimensions=(0,)] bm
    bo:f64[64,64] = mul bl bn
    bp:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bh
    bq:i64[64,64] = squeeze[dimensions=(0,)] bp
    br:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bh
    bs:i64[64,64] = squeeze[dimensions=(0,)] br
    bt:bool[64,64] = lt bq 0:i64[]
    bu:i64[64,64] = add bq 64:i64[]
    bv:i64[64,64] = select_n bt bq bu
    bw:bool[64,64] = lt bs 0:i64[]
    bx:i64[64,64] = add bs 64:i64[]
    by:i64[64,64] = select_n bw bs bx
    bz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bv
    ca:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] by
    cb:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bz
    cc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ca
    cd:i32[64,64,2] = concatenate[dimension=2] cb cc
    ce:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] o cd
    cf:f64[64,64] = mul bo ce
    cg:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bd
    ch:f64[64,64] = squeeze[dimensions=(0,)] cg
    ci:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bc
    cj:f64[64,64] = squeeze[dimensions=(0,)] ci
    ck:f64[64,64] = mul ch cj
    cl:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bh
    cm:i64[64,64] = squeeze[dimensions=(0,)] cl
    cn:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bj
    co:i64[64,64] = squeeze[dimensions=(0,)] cn
    cp:bool[64,64] = lt cm 0:i64[]
    cq:i64[64,64] = add cm 64:i64[]
    cr:i64[64,64] = select_n cp cm cq
    cs:bool[64,64] = lt co 0:i64[]
    ct:i64[64,64] = add co 64:i64[]
    cu:i64[64,64] = select_n cs co ct
    cv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cr
    cw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cu
    cx:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cv
    cy:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cw
    cz:i32[64,64,2] = concatenate[dimension=2] cx cy
    da:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] o cz
    db:f64[64,64] = mul ck da
    dc:f64[64,64] = add cf db
    dd:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bc
    de:f64[64,64] = squeeze[dimensions=(0,)] dd
    df:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bd
    dg:f64[64,64] = squeeze[dimensions=(0,)] df
    dh:f64[64,64] = mul de dg
    di:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bj
    dj:i64[64,64] = squeeze[dimensions=(0,)] di
    dk:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bh
    dl:i64[64,64] = squeeze[dimensions=(0,)] dk
    dm:bool[64,64] = lt dj 0:i64[]
    dn:i64[64,64] = add dj 64:i64[]
    do:i64[64,64] = select_n dm dj dn
    dp:bool[64,64] = lt dl 0:i64[]
    dq:i64[64,64] = add dl 64:i64[]
    dr:i64[64,64] = select_n dp dl dq
    ds:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] do
    dt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dr
    du:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ds
    dv:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dt
    dw:i32[64,64,2] = concatenate[dimension=2] du dv
    dx:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] o dw
    dy:f64[64,64] = mul dh dx
    dz:f64[64,64] = add dc dy
    ea:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bc
    eb:f64[64,64] = squeeze[dimensions=(0,)] ea
    ec:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bc
    ed:f64[64,64] = squeeze[dimensions=(0,)] ec
    ee:f64[64,64] = mul eb ed
    ef:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bj
    eg:i64[64,64] = squeeze[dimensions=(0,)] ef
    eh:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bj
    ei:i64[64,64] = squeeze[dimensions=(0,)] eh
    ej:bool[64,64] = lt eg 0:i64[]
    ek:i64[64,64] = add eg 64:i64[]
    el:i64[64,64] = select_n ej eg ek
    em:bool[64,64] = lt ei 0:i64[]
    en:i64[64,64] = add ei 64:i64[]
    eo:i64[64,64] = select_n em ei en
    ep:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] el
    eq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eo
    er:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ep
    es:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] eq
    et:i32[64,64,2] = concatenate[dimension=2] er es
    eu:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] o et
    ev:f64[64,64] = mul ee eu
    ew:f64[64,64] = add dz ev
  in (ew,) }

🖥️ Just-in-time compilation: jax.jit

One of the key JAX transformations is the jax.jit function which allows functions to be JIT compiled with XLA for execution on a particular device (CPU, GPU or TPU).

jitted_advect_points = jax.jit(advect_points)
jax.make_jaxpr(jitted_advect_points)(
    abstract_coordinates, abstract_velocity, time_step
)
{ lambda ; a:f64[2,64,64] b:f64[2,64,64] c:f64[]. let
    d:f64[2,64,64] = jit[
      name=advect_points
      jaxpr={ lambda ; a:f64[2,64,64] b:f64[2,64,64] c:f64[]. let
          e:f64[] = convert_element_type[new_dtype=float64 weak_type=False] c
          f:f64[2,64,64] = mul e b
          d:f64[2,64,64] = add a f
        in (d,) }
    ] a b c
  in (d,) }

🚀 Behaviour after jitting

The jitted and original functions give equivalent outputs to within floating point error

assert np.allclose(
    jitted_advect_points(mesh.coordinates, velocity_jax, time_step),
    advect_points(mesh.coordinates, velocity_jax, time_step)
)

The jitted function however is faster

print("Without jit")
%timeit advect_points(mesh.coordinates, velocity_jax, time_step).block_until_ready()
print("With jit")
%timeit jitted_advect_points(mesh.coordinates, velocity_jax, time_step).block_until_ready()
Without jit
59.9 μs ± 3.96 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
With jit
77.6 μs ± 2.44 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

🧐 jax.jit: under the hood

When calling a jitted function JAX does the following in order:

  1. Stage out a specialized version of the original Python callable to an internal representation (jaxpr).
  2. Lower this specialized, staged-out computation to the XLA compiler’s input language, StableHLO.
  3. Compile the lowered HLO program to produce an optimized executable for the target device (CPU, GPU, or TPU).
  4. Execute the compiled executable with the arrays as arguments.

📊 jit’s effect on performance

We now time jitted versions of each of the model functions:

We see that jitting gives a consistent significant performance improvement.
Performance is better than NumPy but gains are modest on CPU.
For this example NumPy version is already relatively efficient as much of time is spent in calls to compiled libraries.

💪 Going larger

JAX typically gives greater speed-ups when JIT compiling larger function traces, due to the decreased dispatch overhead.

In our case the model functions we have been looking at so far are only substeps of the overall simulation of interest.

We can simulate the fields forward an abitrary amount of time by iteratively applying the step function.

def integrate(vorticity, tracer, kernels, mesh, time_step, n_step):
    """Integrate fields forward in time a specified number of steps."""
    for s in range(n_step):
        vorticity, tracer = step(vorticity, tracer, kernels, mesh, time_step)
    return vorticity, tracer

▶️ Initializing tracer fields

We use a helper function to generate a vertically banded initial tracer field.

tracer = jnp.asarray(generate_initial_tracer(mesh))
_ = plot_fields(mesh, kernels, vorticity_jax, tracer)

➡️ Integrating fields forward

Using the integrate function we can then simulate forward:

n_step = 10
final_vorticity_jax, final_tracer = integrate(
    vorticity_jax, tracer, kernels, mesh, time_step, n_step
)
_ = plot_fields(mesh, kernels, final_vorticity_jax, final_tracer)

👣 Tracing integration

An important restriction is that as JAX traces computations using abstract values which have only shape and data type information.
By default all arguments to a function are traced using abstract values.
This is problematic when attempting to trace integrate as the number of iterations n_step is traced with an abstract value.
jax.make_jaxpr(integrate)(
    vorticity_jax, tracer, kernels, mesh, time_step, n_step
)
---------------------------------------------------------------------------
TracerIntegerConversionError              Traceback (most recent call last)
Cell In[36], line 1
----> 1 jax.make_jaxpr(integrate)(
      2     vorticity_jax, tracer, kernels, mesh, time_step, n_step
      3 )

    [... skipping hidden 14 frame]

Cell In[32], line 3, in integrate(vorticity, tracer, kernels, mesh, time_step, n_step)
      1 def integrate(vorticity, tracer, kernels, mesh, time_step, n_step):
      2     """Integrate fields forward in time a specified number of steps."""
----> 3     for s in range(n_step):
      4         vorticity, tracer = step(vorticity, tracer, kernels, mesh, time_step)
      5     return vorticity, tracer

    [... skipping hidden 1 frame]

File ~/projects/jax-sci-comp-demo/.venv/lib/python3.13t/site-packages/jax/_src/core.py:1834, in concretization_function_error.<locals>.error(self, arg)
   1833 def error(self, arg):
-> 1834   raise TracerIntegerConversionError(arg)

TracerIntegerConversionError: The __index__() method was called on traced array with shape int64[]
The error occurred while tracing the function integrate at /tmp/ipykernel_21229/42931077.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument n_step.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerIntegerConversionError

🏷️ Marking arguments as static

To side-step this issue we can mark arguments as static.

JAX will then re-trace the relevant function each time its static arguments are changed.

For example applying to the n_step argument in integrate:

jax.make_jaxpr(integrate, static_argnums=5)(
    vorticity_jax, tracer, kernels, mesh, time_step, n_step
)
let fft = { lambda ; a:f64[64,64]. let
    b:c128[64,33] = fft[fft_lengths=(64, 64) fft_type=2] a
  in (b,) } in
let fft1 = { lambda ; c:c128[2,64,33]. let
    d:f64[2,64,64] = fft[fft_lengths=(64, 64) fft_type=3] c
  in (d,) } in
let remainder = { lambda ; e:i64[2,64,64] f:i64[2,1,1]. let
    g:bool[2,1,1] = eq f 0:i64[]
    h:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2, 1, 1)
      sharding=None
    ] 1:i64[]
    i:i64[2,1,1] = jit[
      name=_where
      jaxpr={ lambda ; g:bool[2,1,1] h:i64[2,1,1] f:i64[2,1,1]. let
          i:i64[2,1,1] = select_n g f h
        in (i,) }
    ] g h f
    j:i64[2,64,64] = rem e i
    k:bool[2,64,64] = ne j 0:i64[]
    l:bool[2,64,64] = lt j 0:i64[]
    m:bool[2,1,1] = lt i 0:i64[]
    n:bool[2,64,64] = ne l m
    o:bool[2,64,64] = and n k
    p:i64[2,64,64] = add j i
    q:i64[2,64,64] = select_n o j p
  in (q,) } in
let _where = { lambda ; g:bool[2,1,1] h:i64[2,1,1] f:i64[2,1,1]. let
    i:i64[2,1,1] = select_n g f h
  in (i,) } in
{ lambda r:c128[2] s:f64[2] t:i64[2] u:i64[2] v:i64[2] w:i64[2] x:i64[2] y:i64[2]
    z:c128[2] ba:f64[2] bb:i64[2] bc:i64[2] bd:i64[2] be:i64[2] bf:i64[2] bg:i64[2]
    bh:c128[2] bi:f64[2] bj:i64[2] bk:i64[2] bl:i64[2] bm:i64[2] bn:i64[2] bo:i64[2]
    bp:c128[2] bq:f64[2] br:i64[2] bs:i64[2] bt:i64[2] bu:i64[2] bv:i64[2] bw:i64[2]
    bx:c128[2] by:f64[2] bz:i64[2] ca:i64[2] cb:i64[2] cc:i64[2] cd:i64[2] ce:i64[2]
    cf:c128[2] cg:f64[2] ch:i64[2] ci:i64[2] cj:i64[2] ck:i64[2] cl:i64[2] cm:i64[2]
    cn:c128[2] co:f64[2] cp:i64[2] cq:i64[2] cr:i64[2] cs:i64[2] ct:i64[2] cu:i64[2]
    cv:c128[2] cw:f64[2] cx:i64[2] cy:i64[2] cz:i64[2] da:i64[2] db:i64[2] dc:i64[2]
    dd:c128[2] de:f64[2] df:i64[2] dg:i64[2] dh:i64[2] di:i64[2] dj:i64[2] dk:i64[2]
    dl:c128[2] dm:f64[2] dn:i64[2] do:i64[2] dp:i64[2] dq:i64[2] dr:i64[2] ds:i64[2]; dt:f64[64,64]
    du:f64[64,64] dv:f64[64,33] dw:c128[2,64,33] dx:i64[] dy:i64[] dz:f64[2] ea:f64[2,64,64]
    eb:f64[]. let
    ec:c128[64,33] = jit[name=fft jaxpr=fft] dt
    ed:c128[] = reduce_prod[axes=(0,)] r
    ee:c128[] = sqrt ed
    ef:c128[] = div (1+0j):c128[] ee
    eg:c128[64,33] = mul ec ef
    eh:c128[1,64,33] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 64, 33)
      sharding=None
    ] eg
    ei:c128[2,64,33] = mul dw eh
    ej:f64[2,64,64] = jit[name=fft jaxpr=fft1] ei
    ek:f64[] = reduce_prod[axes=(0,)] s
    el:f64[] = sqrt ek
    em:f64[2,64,64] = mul ej el
    en:f64[] = neg eb
    eo:f64[] = convert_element_type[new_dtype=float64 weak_type=False] en
    ep:f64[2,64,64] = mul eo em
    eq:f64[2,64,64] = add ea ep
    er:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    es:f64[2,64,64] = div eq er
    et:f64[2,64,64] = floor es
    eu:f64[2,64,64] = sub es et
    ev:f64[2,64,64] = sub 1.0:f64[] eu
    ew:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] t
    ex:f64[2,64,64] = sub es eu
    ey:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ex
    ez:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ey ew
    fa:i64[2,64,64] = add ez 1:i64[]
    fb:i64[2,64,64] = jit[name=remainder jaxpr=remainder] fa ew
    fc:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ev
    fd:f64[64,64] = squeeze[dimensions=(0,)] fc
    fe:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ev
    ff:f64[64,64] = squeeze[dimensions=(0,)] fe
    fg:f64[64,64] = mul fd ff
    fh:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ez
    fi:i64[64,64] = squeeze[dimensions=(0,)] fh
    fj:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ez
    fk:i64[64,64] = squeeze[dimensions=(0,)] fj
    fl:bool[64,64] = lt fi 0:i64[]
    fm:i64[64,64] = add fi 64:i64[]
    fn:i64[64,64] = select_n fl fi fm
    fo:bool[64,64] = lt fk 0:i64[]
    fp:i64[64,64] = add fk 64:i64[]
    fq:i64[64,64] = select_n fo fk fp
    fr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fn
    fs:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fq
    ft:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fr
    fu:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fs
    fv:i32[64,64,2] = concatenate[dimension=2] ft fu
    fw:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dt fv
    fx:f64[64,64] = mul fg fw
    fy:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ev
    fz:f64[64,64] = squeeze[dimensions=(0,)] fy
    ga:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] eu
    gb:f64[64,64] = squeeze[dimensions=(0,)] ga
    gc:f64[64,64] = mul fz gb
    gd:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ez
    ge:i64[64,64] = squeeze[dimensions=(0,)] gd
    gf:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fb
    gg:i64[64,64] = squeeze[dimensions=(0,)] gf
    gh:bool[64,64] = lt ge 0:i64[]
    gi:i64[64,64] = add ge 64:i64[]
    gj:i64[64,64] = select_n gh ge gi
    gk:bool[64,64] = lt gg 0:i64[]
    gl:i64[64,64] = add gg 64:i64[]
    gm:i64[64,64] = select_n gk gg gl
    gn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gj
    go:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gm
    gp:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gn
    gq:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] go
    gr:i32[64,64,2] = concatenate[dimension=2] gp gq
    gs:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dt gr
    gt:f64[64,64] = mul gc gs
    gu:f64[64,64] = add fx gt
    gv:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] eu
    gw:f64[64,64] = squeeze[dimensions=(0,)] gv
    gx:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ev
    gy:f64[64,64] = squeeze[dimensions=(0,)] gx
    gz:f64[64,64] = mul gw gy
    ha:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fb
    hb:i64[64,64] = squeeze[dimensions=(0,)] ha
    hc:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ez
    hd:i64[64,64] = squeeze[dimensions=(0,)] hc
    he:bool[64,64] = lt hb 0:i64[]
    hf:i64[64,64] = add hb 64:i64[]
    hg:i64[64,64] = select_n he hb hf
    hh:bool[64,64] = lt hd 0:i64[]
    hi:i64[64,64] = add hd 64:i64[]
    hj:i64[64,64] = select_n hh hd hi
    hk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hg
    hl:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hj
    hm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hk
    hn:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hl
    ho:i32[64,64,2] = concatenate[dimension=2] hm hn
    hp:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dt ho
    hq:f64[64,64] = mul gz hp
    hr:f64[64,64] = add gu hq
    hs:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] eu
    ht:f64[64,64] = squeeze[dimensions=(0,)] hs
    hu:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] eu
    hv:f64[64,64] = squeeze[dimensions=(0,)] hu
    hw:f64[64,64] = mul ht hv
    hx:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fb
    hy:i64[64,64] = squeeze[dimensions=(0,)] hx
    hz:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fb
    ia:i64[64,64] = squeeze[dimensions=(0,)] hz
    ib:bool[64,64] = lt hy 0:i64[]
    ic:i64[64,64] = add hy 64:i64[]
    id:i64[64,64] = select_n ib hy ic
    ie:bool[64,64] = lt ia 0:i64[]
    if:i64[64,64] = add ia 64:i64[]
    ig:i64[64,64] = select_n ie ia if
    ih:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] id
    ii:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ig
    ij:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ih
    ik:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ii
    il:i32[64,64,2] = concatenate[dimension=2] ij ik
    im:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dt il
    in:f64[64,64] = mul hw im
    io:f64[64,64] = add hr in
    ip:f64[2,64,64] = neg em
    iq:f64[] = neg eb
    ir:f64[] = convert_element_type[new_dtype=float64 weak_type=False] iq
    is:f64[2,64,64] = mul ir ip
    it:f64[2,64,64] = add ea is
    iu:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    iv:f64[2,64,64] = div it iu
    iw:f64[2,64,64] = floor iv
    ix:f64[2,64,64] = sub iv iw
    iy:f64[2,64,64] = sub 1.0:f64[] ix
    iz:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] u
    ja:f64[2,64,64] = sub iv ix
    jb:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ja
    jc:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jb iz
    jd:i64[2,64,64] = add jc 1:i64[]
    je:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jd iz
    jf:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] iy
    jg:f64[64,64] = squeeze[dimensions=(0,)] jf
    jh:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] iy
    ji:f64[64,64] = squeeze[dimensions=(0,)] jh
    jj:f64[64,64] = mul jg ji
    jk:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jc
    jl:i64[64,64] = squeeze[dimensions=(0,)] jk
    jm:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jc
    jn:i64[64,64] = squeeze[dimensions=(0,)] jm
    jo:bool[64,64] = lt jl 0:i64[]
    jp:i64[64,64] = add jl 64:i64[]
    jq:i64[64,64] = select_n jo jl jp
    jr:bool[64,64] = lt jn 0:i64[]
    js:i64[64,64] = add jn 64:i64[]
    jt:i64[64,64] = select_n jr jn js
    ju:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jq
    jv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jt
    jw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ju
    jx:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jv
    jy:i32[64,64,2] = concatenate[dimension=2] jw jx
    jz:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] io jy
    ka:f64[64,64] = mul jj jz
    kb:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] iy
    kc:f64[64,64] = squeeze[dimensions=(0,)] kb
    kd:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ix
    ke:f64[64,64] = squeeze[dimensions=(0,)] kd
    kf:f64[64,64] = mul kc ke
    kg:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jc
    kh:i64[64,64] = squeeze[dimensions=(0,)] kg
    ki:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] je
    kj:i64[64,64] = squeeze[dimensions=(0,)] ki
    kk:bool[64,64] = lt kh 0:i64[]
    kl:i64[64,64] = add kh 64:i64[]
    km:i64[64,64] = select_n kk kh kl
    kn:bool[64,64] = lt kj 0:i64[]
    ko:i64[64,64] = add kj 64:i64[]
    kp:i64[64,64] = select_n kn kj ko
    kq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] km
    kr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] kp
    ks:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] kq
    kt:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] kr
    ku:i32[64,64,2] = concatenate[dimension=2] ks kt
    kv:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] io ku
    kw:f64[64,64] = mul kf kv
    kx:f64[64,64] = add ka kw
    ky:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ix
    kz:f64[64,64] = squeeze[dimensions=(0,)] ky
    la:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] iy
    lb:f64[64,64] = squeeze[dimensions=(0,)] la
    lc:f64[64,64] = mul kz lb
    ld:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] je
    le:i64[64,64] = squeeze[dimensions=(0,)] ld
    lf:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jc
    lg:i64[64,64] = squeeze[dimensions=(0,)] lf
    lh:bool[64,64] = lt le 0:i64[]
    li:i64[64,64] = add le 64:i64[]
    lj:i64[64,64] = select_n lh le li
    lk:bool[64,64] = lt lg 0:i64[]
    ll:i64[64,64] = add lg 64:i64[]
    lm:i64[64,64] = select_n lk lg ll
    ln:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] lj
    lo:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] lm
    lp:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ln
    lq:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] lo
    lr:i32[64,64,2] = concatenate[dimension=2] lp lq
    ls:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] io lr
    lt:f64[64,64] = mul lc ls
    lu:f64[64,64] = add kx lt
    lv:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ix
    lw:f64[64,64] = squeeze[dimensions=(0,)] lv
    lx:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ix
    ly:f64[64,64] = squeeze[dimensions=(0,)] lx
    lz:f64[64,64] = mul lw ly
    ma:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] je
    mb:i64[64,64] = squeeze[dimensions=(0,)] ma
    mc:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] je
    md:i64[64,64] = squeeze[dimensions=(0,)] mc
    me:bool[64,64] = lt mb 0:i64[]
    mf:i64[64,64] = add mb 64:i64[]
    mg:i64[64,64] = select_n me mb mf
    mh:bool[64,64] = lt md 0:i64[]
    mi:i64[64,64] = add md 64:i64[]
    mj:i64[64,64] = select_n mh md mi
    mk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] mg
    ml:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] mj
    mm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] mk
    mn:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ml
    mo:i32[64,64,2] = concatenate[dimension=2] mm mn
    mp:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] io mo
    mq:f64[64,64] = mul lz mp
    mr:f64[64,64] = add lu mq
    ms:f64[64,64] = sub dt mr
    mt:f64[64,64] = div ms 2.0:f64[]
    mu:f64[64,64] = add dt mt
    mv:f64[] = neg eb
    mw:f64[] = convert_element_type[new_dtype=float64 weak_type=False] mv
    mx:f64[2,64,64] = mul mw em
    my:f64[2,64,64] = add ea mx
    mz:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    na:f64[2,64,64] = div my mz
    nb:f64[2,64,64] = floor na
    nc:f64[2,64,64] = sub na nb
    nd:f64[2,64,64] = sub 1.0:f64[] nc
    ne:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] v
    nf:f64[2,64,64] = sub na nc
    ng:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] nf
    nh:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ng ne
    ni:i64[2,64,64] = add nh 1:i64[]
    nj:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ni ne
    nk:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] nd
    nl:f64[64,64] = squeeze[dimensions=(0,)] nk
    nm:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] nd
    nn:f64[64,64] = squeeze[dimensions=(0,)] nm
    no:f64[64,64] = mul nl nn
    np:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] nh
    nq:i64[64,64] = squeeze[dimensions=(0,)] np
    nr:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] nh
    ns:i64[64,64] = squeeze[dimensions=(0,)] nr
    nt:bool[64,64] = lt nq 0:i64[]
    nu:i64[64,64] = add nq 64:i64[]
    nv:i64[64,64] = select_n nt nq nu
    nw:bool[64,64] = lt ns 0:i64[]
    nx:i64[64,64] = add ns 64:i64[]
    ny:i64[64,64] = select_n nw ns nx
    nz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] nv
    oa:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ny
    ob:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] nz
    oc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] oa
    od:i32[64,64,2] = concatenate[dimension=2] ob oc
    oe:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] mu od
    of:f64[64,64] = mul no oe
    og:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] nd
    oh:f64[64,64] = squeeze[dimensions=(0,)] og
    oi:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] nc
    oj:f64[64,64] = squeeze[dimensions=(0,)] oi
    ok:f64[64,64] = mul oh oj
    ol:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] nh
    om:i64[64,64] = squeeze[dimensions=(0,)] ol
    on:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] nj
    oo:i64[64,64] = squeeze[dimensions=(0,)] on
    op:bool[64,64] = lt om 0:i64[]
    oq:i64[64,64] = add om 64:i64[]
    or:i64[64,64] = select_n op om oq
    os:bool[64,64] = lt oo 0:i64[]
    ot:i64[64,64] = add oo 64:i64[]
    ou:i64[64,64] = select_n os oo ot
    ov:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] or
    ow:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ou
    ox:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ov
    oy:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ow
    oz:i32[64,64,2] = concatenate[dimension=2] ox oy
    pa:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] mu oz
    pb:f64[64,64] = mul ok pa
    pc:f64[64,64] = add of pb
    pd:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] nc
    pe:f64[64,64] = squeeze[dimensions=(0,)] pd
    pf:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] nd
    pg:f64[64,64] = squeeze[dimensions=(0,)] pf
    ph:f64[64,64] = mul pe pg
    pi:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] nj
    pj:i64[64,64] = squeeze[dimensions=(0,)] pi
    pk:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] nh
    pl:i64[64,64] = squeeze[dimensions=(0,)] pk
    pm:bool[64,64] = lt pj 0:i64[]
    pn:i64[64,64] = add pj 64:i64[]
    po:i64[64,64] = select_n pm pj pn
    pp:bool[64,64] = lt pl 0:i64[]
    pq:i64[64,64] = add pl 64:i64[]
    pr:i64[64,64] = select_n pp pl pq
    ps:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] po
    pt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] pr
    pu:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ps
    pv:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] pt
    pw:i32[64,64,2] = concatenate[dimension=2] pu pv
    px:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] mu pw
    py:f64[64,64] = mul ph px
    pz:f64[64,64] = add pc py
    qa:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] nc
    qb:f64[64,64] = squeeze[dimensions=(0,)] qa
    qc:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] nc
    qd:f64[64,64] = squeeze[dimensions=(0,)] qc
    qe:f64[64,64] = mul qb qd
    qf:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] nj
    qg:i64[64,64] = squeeze[dimensions=(0,)] qf
    qh:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] nj
    qi:i64[64,64] = squeeze[dimensions=(0,)] qh
    qj:bool[64,64] = lt qg 0:i64[]
    qk:i64[64,64] = add qg 64:i64[]
    ql:i64[64,64] = select_n qj qg qk
    qm:bool[64,64] = lt qi 0:i64[]
    qn:i64[64,64] = add qi 64:i64[]
    qo:i64[64,64] = select_n qm qi qn
    qp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ql
    qq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] qo
    qr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] qp
    qs:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] qq
    qt:i32[64,64,2] = concatenate[dimension=2] qr qs
    qu:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] mu qt
    qv:f64[64,64] = mul qe qu
    qw:f64[64,64] = add pz qv
    qx:f64[] = neg eb
    qy:f64[] = convert_element_type[new_dtype=float64 weak_type=False] qx
    qz:f64[2,64,64] = mul qy em
    ra:f64[2,64,64] = add ea qz
    rb:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    rc:f64[2,64,64] = div ra rb
    rd:f64[2,64,64] = floor rc
    re:f64[2,64,64] = sub rc rd
    rf:f64[2,64,64] = sub 1.0:f64[] re
    rg:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] w
    rh:f64[2,64,64] = sub rc re
    ri:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] rh
    rj:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ri rg
    rk:i64[2,64,64] = add rj 1:i64[]
    rl:i64[2,64,64] = jit[name=remainder jaxpr=remainder] rk rg
    rm:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] rf
    rn:f64[64,64] = squeeze[dimensions=(0,)] rm
    ro:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] rf
    rp:f64[64,64] = squeeze[dimensions=(0,)] ro
    rq:f64[64,64] = mul rn rp
    rr:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] rj
    rs:i64[64,64] = squeeze[dimensions=(0,)] rr
    rt:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] rj
    ru:i64[64,64] = squeeze[dimensions=(0,)] rt
    rv:bool[64,64] = lt rs 0:i64[]
    rw:i64[64,64] = add rs 64:i64[]
    rx:i64[64,64] = select_n rv rs rw
    ry:bool[64,64] = lt ru 0:i64[]
    rz:i64[64,64] = add ru 64:i64[]
    sa:i64[64,64] = select_n ry ru rz
    sb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] rx
    sc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] sa
    sd:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] sb
    se:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] sc
    sf:i32[64,64,2] = concatenate[dimension=2] sd se
    sg:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] du sf
    sh:f64[64,64] = mul rq sg
    si:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] rf
    sj:f64[64,64] = squeeze[dimensions=(0,)] si
    sk:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] re
    sl:f64[64,64] = squeeze[dimensions=(0,)] sk
    sm:f64[64,64] = mul sj sl
    sn:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] rj
    so:i64[64,64] = squeeze[dimensions=(0,)] sn
    sp:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] rl
    sq:i64[64,64] = squeeze[dimensions=(0,)] sp
    sr:bool[64,64] = lt so 0:i64[]
    ss:i64[64,64] = add so 64:i64[]
    st:i64[64,64] = select_n sr so ss
    su:bool[64,64] = lt sq 0:i64[]
    sv:i64[64,64] = add sq 64:i64[]
    sw:i64[64,64] = select_n su sq sv
    sx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] st
    sy:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] sw
    sz:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] sx
    ta:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] sy
    tb:i32[64,64,2] = concatenate[dimension=2] sz ta
    tc:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] du tb
    td:f64[64,64] = mul sm tc
    te:f64[64,64] = add sh td
    tf:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] re
    tg:f64[64,64] = squeeze[dimensions=(0,)] tf
    th:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] rf
    ti:f64[64,64] = squeeze[dimensions=(0,)] th
    tj:f64[64,64] = mul tg ti
    tk:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] rl
    tl:i64[64,64] = squeeze[dimensions=(0,)] tk
    tm:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] rj
    tn:i64[64,64] = squeeze[dimensions=(0,)] tm
    to:bool[64,64] = lt tl 0:i64[]
    tp:i64[64,64] = add tl 64:i64[]
    tq:i64[64,64] = select_n to tl tp
    tr:bool[64,64] = lt tn 0:i64[]
    ts:i64[64,64] = add tn 64:i64[]
    tt:i64[64,64] = select_n tr tn ts
    tu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] tq
    tv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] tt
    tw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] tu
    tx:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] tv
    ty:i32[64,64,2] = concatenate[dimension=2] tw tx
    tz:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] du ty
    ua:f64[64,64] = mul tj tz
    ub:f64[64,64] = add te ua
    uc:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] re
    ud:f64[64,64] = squeeze[dimensions=(0,)] uc
    ue:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] re
    uf:f64[64,64] = squeeze[dimensions=(0,)] ue
    ug:f64[64,64] = mul ud uf
    uh:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] rl
    ui:i64[64,64] = squeeze[dimensions=(0,)] uh
    uj:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] rl
    uk:i64[64,64] = squeeze[dimensions=(0,)] uj
    ul:bool[64,64] = lt ui 0:i64[]
    um:i64[64,64] = add ui 64:i64[]
    un:i64[64,64] = select_n ul ui um
    uo:bool[64,64] = lt uk 0:i64[]
    up:i64[64,64] = add uk 64:i64[]
    uq:i64[64,64] = select_n uo uk up
    ur:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] un
    us:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] uq
    ut:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ur
    uu:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] us
    uv:i32[64,64,2] = concatenate[dimension=2] ut uu
    uw:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] du uv
    ux:f64[64,64] = mul ug uw
    uy:f64[64,64] = add ub ux
    uz:f64[2,64,64] = neg em
    va:f64[] = neg eb
    vb:f64[] = convert_element_type[new_dtype=float64 weak_type=False] va
    vc:f64[2,64,64] = mul vb uz
    vd:f64[2,64,64] = add ea vc
    ve:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    vf:f64[2,64,64] = div vd ve
    vg:f64[2,64,64] = floor vf
    vh:f64[2,64,64] = sub vf vg
    vi:f64[2,64,64] = sub 1.0:f64[] vh
    vj:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] x
    vk:f64[2,64,64] = sub vf vh
    vl:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] vk
    vm:i64[2,64,64] = jit[name=remainder jaxpr=remainder] vl vj
    vn:i64[2,64,64] = add vm 1:i64[]
    vo:i64[2,64,64] = jit[name=remainder jaxpr=remainder] vn vj
    vp:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] vi
    vq:f64[64,64] = squeeze[dimensions=(0,)] vp
    vr:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] vi
    vs:f64[64,64] = squeeze[dimensions=(0,)] vr
    vt:f64[64,64] = mul vq vs
    vu:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] vm
    vv:i64[64,64] = squeeze[dimensions=(0,)] vu
    vw:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] vm
    vx:i64[64,64] = squeeze[dimensions=(0,)] vw
    vy:bool[64,64] = lt vv 0:i64[]
    vz:i64[64,64] = add vv 64:i64[]
    wa:i64[64,64] = select_n vy vv vz
    wb:bool[64,64] = lt vx 0:i64[]
    wc:i64[64,64] = add vx 64:i64[]
    wd:i64[64,64] = select_n wb vx wc
    we:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] wa
    wf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] wd
    wg:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] we
    wh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] wf
    wi:i32[64,64,2] = concatenate[dimension=2] wg wh
    wj:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] uy wi
    wk:f64[64,64] = mul vt wj
    wl:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] vi
    wm:f64[64,64] = squeeze[dimensions=(0,)] wl
    wn:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] vh
    wo:f64[64,64] = squeeze[dimensions=(0,)] wn
    wp:f64[64,64] = mul wm wo
    wq:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] vm
    wr:i64[64,64] = squeeze[dimensions=(0,)] wq
    ws:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] vo
    wt:i64[64,64] = squeeze[dimensions=(0,)] ws
    wu:bool[64,64] = lt wr 0:i64[]
    wv:i64[64,64] = add wr 64:i64[]
    ww:i64[64,64] = select_n wu wr wv
    wx:bool[64,64] = lt wt 0:i64[]
    wy:i64[64,64] = add wt 64:i64[]
    wz:i64[64,64] = select_n wx wt wy
    xa:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ww
    xb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] wz
    xc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] xa
    xd:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] xb
    xe:i32[64,64,2] = concatenate[dimension=2] xc xd
    xf:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] uy xe
    xg:f64[64,64] = mul wp xf
    xh:f64[64,64] = add wk xg
    xi:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] vh
    xj:f64[64,64] = squeeze[dimensions=(0,)] xi
    xk:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] vi
    xl:f64[64,64] = squeeze[dimensions=(0,)] xk
    xm:f64[64,64] = mul xj xl
    xn:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] vo
    xo:i64[64,64] = squeeze[dimensions=(0,)] xn
    xp:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] vm
    xq:i64[64,64] = squeeze[dimensions=(0,)] xp
    xr:bool[64,64] = lt xo 0:i64[]
    xs:i64[64,64] = add xo 64:i64[]
    xt:i64[64,64] = select_n xr xo xs
    xu:bool[64,64] = lt xq 0:i64[]
    xv:i64[64,64] = add xq 64:i64[]
    xw:i64[64,64] = select_n xu xq xv
    xx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] xt
    xy:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] xw
    xz:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] xx
    ya:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] xy
    yb:i32[64,64,2] = concatenate[dimension=2] xz ya
    yc:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] uy yb
    yd:f64[64,64] = mul xm yc
    ye:f64[64,64] = add xh yd
    yf:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] vh
    yg:f64[64,64] = squeeze[dimensions=(0,)] yf
    yh:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] vh
    yi:f64[64,64] = squeeze[dimensions=(0,)] yh
    yj:f64[64,64] = mul yg yi
    yk:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] vo
    yl:i64[64,64] = squeeze[dimensions=(0,)] yk
    ym:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] vo
    yn:i64[64,64] = squeeze[dimensions=(0,)] ym
    yo:bool[64,64] = lt yl 0:i64[]
    yp:i64[64,64] = add yl 64:i64[]
    yq:i64[64,64] = select_n yo yl yp
    yr:bool[64,64] = lt yn 0:i64[]
    ys:i64[64,64] = add yn 64:i64[]
    yt:i64[64,64] = select_n yr yn ys
    yu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] yq
    yv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] yt
    yw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] yu
    yx:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] yv
    yy:i32[64,64,2] = concatenate[dimension=2] yw yx
    yz:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] uy yy
    za:f64[64,64] = mul yj yz
    zb:f64[64,64] = add ye za
    zc:f64[64,64] = sub du zb
    zd:f64[64,64] = div zc 2.0:f64[]
    ze:f64[64,64] = add du zd
    zf:f64[] = neg eb
    zg:f64[] = convert_element_type[new_dtype=float64 weak_type=False] zf
    zh:f64[2,64,64] = mul zg em
    zi:f64[2,64,64] = add ea zh
    zj:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    zk:f64[2,64,64] = div zi zj
    zl:f64[2,64,64] = floor zk
    zm:f64[2,64,64] = sub zk zl
    zn:f64[2,64,64] = sub 1.0:f64[] zm
    zo:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] y
    zp:f64[2,64,64] = sub zk zm
    zq:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] zp
    zr:i64[2,64,64] = jit[name=remainder jaxpr=remainder] zq zo
    zs:i64[2,64,64] = add zr 1:i64[]
    zt:i64[2,64,64] = jit[name=remainder jaxpr=remainder] zs zo
    zu:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] zn
    zv:f64[64,64] = squeeze[dimensions=(0,)] zu
    zw:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] zn
    zx:f64[64,64] = squeeze[dimensions=(0,)] zw
    zy:f64[64,64] = mul zv zx
    zz:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] zr
    baa:i64[64,64] = squeeze[dimensions=(0,)] zz
    bab:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] zr
    bac:i64[64,64] = squeeze[dimensions=(0,)] bab
    bad:bool[64,64] = lt baa 0:i64[]
    bae:i64[64,64] = add baa 64:i64[]
    baf:i64[64,64] = select_n bad baa bae
    bag:bool[64,64] = lt bac 0:i64[]
    bah:i64[64,64] = add bac 64:i64[]
    bai:i64[64,64] = select_n bag bac bah
    baj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] baf
    bak:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bai
    bal:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] baj
    bam:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bak
    ban:i32[64,64,2] = concatenate[dimension=2] bal bam
    bao:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ze ban
    bap:f64[64,64] = mul zy bao
    baq:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] zn
    bar:f64[64,64] = squeeze[dimensions=(0,)] baq
    bas:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] zm
    bat:f64[64,64] = squeeze[dimensions=(0,)] bas
    bau:f64[64,64] = mul bar bat
    bav:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] zr
    baw:i64[64,64] = squeeze[dimensions=(0,)] bav
    bax:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] zt
    bay:i64[64,64] = squeeze[dimensions=(0,)] bax
    baz:bool[64,64] = lt baw 0:i64[]
    bba:i64[64,64] = add baw 64:i64[]
    bbb:i64[64,64] = select_n baz baw bba
    bbc:bool[64,64] = lt bay 0:i64[]
    bbd:i64[64,64] = add bay 64:i64[]
    bbe:i64[64,64] = select_n bbc bay bbd
    bbf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bbb
    bbg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bbe
    bbh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bbf
    bbi:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bbg
    bbj:i32[64,64,2] = concatenate[dimension=2] bbh bbi
    bbk:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ze bbj
    bbl:f64[64,64] = mul bau bbk
    bbm:f64[64,64] = add bap bbl
    bbn:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] zm
    bbo:f64[64,64] = squeeze[dimensions=(0,)] bbn
    bbp:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] zn
    bbq:f64[64,64] = squeeze[dimensions=(0,)] bbp
    bbr:f64[64,64] = mul bbo bbq
    bbs:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] zt
    bbt:i64[64,64] = squeeze[dimensions=(0,)] bbs
    bbu:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] zr
    bbv:i64[64,64] = squeeze[dimensions=(0,)] bbu
    bbw:bool[64,64] = lt bbt 0:i64[]
    bbx:i64[64,64] = add bbt 64:i64[]
    bby:i64[64,64] = select_n bbw bbt bbx
    bbz:bool[64,64] = lt bbv 0:i64[]
    bca:i64[64,64] = add bbv 64:i64[]
    bcb:i64[64,64] = select_n bbz bbv bca
    bcc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bby
    bcd:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bcb
    bce:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bcc
    bcf:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bcd
    bcg:i32[64,64,2] = concatenate[dimension=2] bce bcf
    bch:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ze bcg
    bci:f64[64,64] = mul bbr bch
    bcj:f64[64,64] = add bbm bci
    bck:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] zm
    bcl:f64[64,64] = squeeze[dimensions=(0,)] bck
    bcm:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] zm
    bcn:f64[64,64] = squeeze[dimensions=(0,)] bcm
    bco:f64[64,64] = mul bcl bcn
    bcp:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] zt
    bcq:i64[64,64] = squeeze[dimensions=(0,)] bcp
    bcr:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] zt
    bcs:i64[64,64] = squeeze[dimensions=(0,)] bcr
    bct:bool[64,64] = lt bcq 0:i64[]
    bcu:i64[64,64] = add bcq 64:i64[]
    bcv:i64[64,64] = select_n bct bcq bcu
    bcw:bool[64,64] = lt bcs 0:i64[]
    bcx:i64[64,64] = add bcs 64:i64[]
    bcy:i64[64,64] = select_n bcw bcs bcx
    bcz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bcv
    bda:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bcy
    bdb:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bcz
    bdc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bda
    bdd:i32[64,64,2] = concatenate[dimension=2] bdb bdc
    bde:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ze bdd
    bdf:f64[64,64] = mul bco bde
    bdg:f64[64,64] = add bcj bdf
    bdh:c128[64,33] = jit[name=fft jaxpr=fft] qw
    bdi:c128[] = reduce_prod[axes=(0,)] z
    bdj:c128[] = sqrt bdi
    bdk:c128[] = div (1+0j):c128[] bdj
    bdl:c128[64,33] = mul bdh bdk
    bdm:c128[1,64,33] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 64, 33)
      sharding=None
    ] bdl
    bdn:c128[2,64,33] = mul dw bdm
    bdo:f64[2,64,64] = jit[name=fft jaxpr=fft1] bdn
    bdp:f64[] = reduce_prod[axes=(0,)] ba
    bdq:f64[] = sqrt bdp
    bdr:f64[2,64,64] = mul bdo bdq
    bds:f64[] = neg eb
    bdt:f64[] = convert_element_type[new_dtype=float64 weak_type=False] bds
    bdu:f64[2,64,64] = mul bdt bdr
    bdv:f64[2,64,64] = add ea bdu
    bdw:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    bdx:f64[2,64,64] = div bdv bdw
    bdy:f64[2,64,64] = floor bdx
    bdz:f64[2,64,64] = sub bdx bdy
    bea:f64[2,64,64] = sub 1.0:f64[] bdz
    beb:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] bb
    bec:f64[2,64,64] = sub bdx bdz
    bed:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] bec
    bee:i64[2,64,64] = jit[name=remainder jaxpr=remainder] bed beb
    bef:i64[2,64,64] = add bee 1:i64[]
    beg:i64[2,64,64] = jit[name=remainder jaxpr=remainder] bef beb
    beh:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bea
    bei:f64[64,64] = squeeze[dimensions=(0,)] beh
    bej:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bea
    bek:f64[64,64] = squeeze[dimensions=(0,)] bej
    bel:f64[64,64] = mul bei bek
    bem:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bee
    ben:i64[64,64] = squeeze[dimensions=(0,)] bem
    beo:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bee
    bep:i64[64,64] = squeeze[dimensions=(0,)] beo
    beq:bool[64,64] = lt ben 0:i64[]
    ber:i64[64,64] = add ben 64:i64[]
    bes:i64[64,64] = select_n beq ben ber
    bet:bool[64,64] = lt bep 0:i64[]
    beu:i64[64,64] = add bep 64:i64[]
    bev:i64[64,64] = select_n bet bep beu
    bew:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bes
    bex:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bev
    bey:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bew
    bez:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bex
    bfa:i32[64,64,2] = concatenate[dimension=2] bey bez
    bfb:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] qw bfa
    bfc:f64[64,64] = mul bel bfb
    bfd:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bea
    bfe:f64[64,64] = squeeze[dimensions=(0,)] bfd
    bff:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bdz
    bfg:f64[64,64] = squeeze[dimensions=(0,)] bff
    bfh:f64[64,64] = mul bfe bfg
    bfi:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bee
    bfj:i64[64,64] = squeeze[dimensions=(0,)] bfi
    bfk:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] beg
    bfl:i64[64,64] = squeeze[dimensions=(0,)] bfk
    bfm:bool[64,64] = lt bfj 0:i64[]
    bfn:i64[64,64] = add bfj 64:i64[]
    bfo:i64[64,64] = select_n bfm bfj bfn
    bfp:bool[64,64] = lt bfl 0:i64[]
    bfq:i64[64,64] = add bfl 64:i64[]
    bfr:i64[64,64] = select_n bfp bfl bfq
    bfs:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bfo
    bft:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bfr
    bfu:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bfs
    bfv:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bft
    bfw:i32[64,64,2] = concatenate[dimension=2] bfu bfv
    bfx:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] qw bfw
    bfy:f64[64,64] = mul bfh bfx
    bfz:f64[64,64] = add bfc bfy
    bga:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bdz
    bgb:f64[64,64] = squeeze[dimensions=(0,)] bga
    bgc:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bea
    bgd:f64[64,64] = squeeze[dimensions=(0,)] bgc
    bge:f64[64,64] = mul bgb bgd
    bgf:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] beg
    bgg:i64[64,64] = squeeze[dimensions=(0,)] bgf
    bgh:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bee
    bgi:i64[64,64] = squeeze[dimensions=(0,)] bgh
    bgj:bool[64,64] = lt bgg 0:i64[]
    bgk:i64[64,64] = add bgg 64:i64[]
    bgl:i64[64,64] = select_n bgj bgg bgk
    bgm:bool[64,64] = lt bgi 0:i64[]
    bgn:i64[64,64] = add bgi 64:i64[]
    bgo:i64[64,64] = select_n bgm bgi bgn
    bgp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bgl
    bgq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bgo
    bgr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bgp
    bgs:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bgq
    bgt:i32[64,64,2] = concatenate[dimension=2] bgr bgs
    bgu:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] qw bgt
    bgv:f64[64,64] = mul bge bgu
    bgw:f64[64,64] = add bfz bgv
    bgx:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bdz
    bgy:f64[64,64] = squeeze[dimensions=(0,)] bgx
    bgz:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bdz
    bha:f64[64,64] = squeeze[dimensions=(0,)] bgz
    bhb:f64[64,64] = mul bgy bha
    bhc:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] beg
    bhd:i64[64,64] = squeeze[dimensions=(0,)] bhc
    bhe:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] beg
    bhf:i64[64,64] = squeeze[dimensions=(0,)] bhe
    bhg:bool[64,64] = lt bhd 0:i64[]
    bhh:i64[64,64] = add bhd 64:i64[]
    bhi:i64[64,64] = select_n bhg bhd bhh
    bhj:bool[64,64] = lt bhf 0:i64[]
    bhk:i64[64,64] = add bhf 64:i64[]
    bhl:i64[64,64] = select_n bhj bhf bhk
    bhm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bhi
    bhn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bhl
    bho:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bhm
    bhp:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bhn
    bhq:i32[64,64,2] = concatenate[dimension=2] bho bhp
    bhr:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] qw bhq
    bhs:f64[64,64] = mul bhb bhr
    bht:f64[64,64] = add bgw bhs
    bhu:f64[2,64,64] = neg bdr
    bhv:f64[] = neg eb
    bhw:f64[] = convert_element_type[new_dtype=float64 weak_type=False] bhv
    bhx:f64[2,64,64] = mul bhw bhu
    bhy:f64[2,64,64] = add ea bhx
    bhz:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    bia:f64[2,64,64] = div bhy bhz
    bib:f64[2,64,64] = floor bia
    bic:f64[2,64,64] = sub bia bib
    bid:f64[2,64,64] = sub 1.0:f64[] bic
    bie:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] bc
    bif:f64[2,64,64] = sub bia bic
    big:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] bif
    bih:i64[2,64,64] = jit[name=remainder jaxpr=remainder] big bie
    bii:i64[2,64,64] = add bih 1:i64[]
    bij:i64[2,64,64] = jit[name=remainder jaxpr=remainder] bii bie
    bik:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bid
    bil:f64[64,64] = squeeze[dimensions=(0,)] bik
    bim:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bid
    bin:f64[64,64] = squeeze[dimensions=(0,)] bim
    bio:f64[64,64] = mul bil bin
    bip:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bih
    biq:i64[64,64] = squeeze[dimensions=(0,)] bip
    bir:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bih
    bis:i64[64,64] = squeeze[dimensions=(0,)] bir
    bit:bool[64,64] = lt biq 0:i64[]
    biu:i64[64,64] = add biq 64:i64[]
    biv:i64[64,64] = select_n bit biq biu
    biw:bool[64,64] = lt bis 0:i64[]
    bix:i64[64,64] = add bis 64:i64[]
    biy:i64[64,64] = select_n biw bis bix
    biz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] biv
    bja:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] biy
    bjb:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] biz
    bjc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bja
    bjd:i32[64,64,2] = concatenate[dimension=2] bjb bjc
    bje:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] bht bjd
    bjf:f64[64,64] = mul bio bje
    bjg:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bid
    bjh:f64[64,64] = squeeze[dimensions=(0,)] bjg
    bji:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bic
    bjj:f64[64,64] = squeeze[dimensions=(0,)] bji
    bjk:f64[64,64] = mul bjh bjj
    bjl:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bih
    bjm:i64[64,64] = squeeze[dimensions=(0,)] bjl
    bjn:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bij
    bjo:i64[64,64] = squeeze[dimensions=(0,)] bjn
    bjp:bool[64,64] = lt bjm 0:i64[]
    bjq:i64[64,64] = add bjm 64:i64[]
    bjr:i64[64,64] = select_n bjp bjm bjq
    bjs:bool[64,64] = lt bjo 0:i64[]
    bjt:i64[64,64] = add bjo 64:i64[]
    bju:i64[64,64] = select_n bjs bjo bjt
    bjv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bjr
    bjw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bju
    bjx:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bjv
    bjy:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bjw
    bjz:i32[64,64,2] = concatenate[dimension=2] bjx bjy
    bka:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] bht bjz
    bkb:f64[64,64] = mul bjk bka
    bkc:f64[64,64] = add bjf bkb
    bkd:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bic
    bke:f64[64,64] = squeeze[dimensions=(0,)] bkd
    bkf:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bid
    bkg:f64[64,64] = squeeze[dimensions=(0,)] bkf
    bkh:f64[64,64] = mul bke bkg
    bki:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bij
    bkj:i64[64,64] = squeeze[dimensions=(0,)] bki
    bkk:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bih
    bkl:i64[64,64] = squeeze[dimensions=(0,)] bkk
    bkm:bool[64,64] = lt bkj 0:i64[]
    bkn:i64[64,64] = add bkj 64:i64[]
    bko:i64[64,64] = select_n bkm bkj bkn
    bkp:bool[64,64] = lt bkl 0:i64[]
    bkq:i64[64,64] = add bkl 64:i64[]
    bkr:i64[64,64] = select_n bkp bkl bkq
    bks:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bko
    bkt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bkr
    bku:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bks
    bkv:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bkt
    bkw:i32[64,64,2] = concatenate[dimension=2] bku bkv
    bkx:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] bht bkw
    bky:f64[64,64] = mul bkh bkx
    bkz:f64[64,64] = add bkc bky
    bla:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bic
    blb:f64[64,64] = squeeze[dimensions=(0,)] bla
    blc:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bic
    bld:f64[64,64] = squeeze[dimensions=(0,)] blc
    ble:f64[64,64] = mul blb bld
    blf:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bij
    blg:i64[64,64] = squeeze[dimensions=(0,)] blf
    blh:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bij
    bli:i64[64,64] = squeeze[dimensions=(0,)] blh
    blj:bool[64,64] = lt blg 0:i64[]
    blk:i64[64,64] = add blg 64:i64[]
    bll:i64[64,64] = select_n blj blg blk
    blm:bool[64,64] = lt bli 0:i64[]
    bln:i64[64,64] = add bli 64:i64[]
    blo:i64[64,64] = select_n blm bli bln
    blp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bll
    blq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] blo
    blr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] blp
    bls:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] blq
    blt:i32[64,64,2] = concatenate[dimension=2] blr bls
    blu:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] bht blt
    blv:f64[64,64] = mul ble blu
    blw:f64[64,64] = add bkz blv
    blx:f64[64,64] = sub qw blw
    bly:f64[64,64] = div blx 2.0:f64[]
    blz:f64[64,64] = add qw bly
    bma:f64[] = neg eb
    bmb:f64[] = convert_element_type[new_dtype=float64 weak_type=False] bma
    bmc:f64[2,64,64] = mul bmb bdr
    bmd:f64[2,64,64] = add ea bmc
    bme:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    bmf:f64[2,64,64] = div bmd bme
    bmg:f64[2,64,64] = floor bmf
    bmh:f64[2,64,64] = sub bmf bmg
    bmi:f64[2,64,64] = sub 1.0:f64[] bmh
    bmj:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] bd
    bmk:f64[2,64,64] = sub bmf bmh
    bml:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] bmk
    bmm:i64[2,64,64] = jit[name=remainder jaxpr=remainder] bml bmj
    bmn:i64[2,64,64] = add bmm 1:i64[]
    bmo:i64[2,64,64] = jit[name=remainder jaxpr=remainder] bmn bmj
    bmp:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bmi
    bmq:f64[64,64] = squeeze[dimensions=(0,)] bmp
    bmr:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bmi
    bms:f64[64,64] = squeeze[dimensions=(0,)] bmr
    bmt:f64[64,64] = mul bmq bms
    bmu:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bmm
    bmv:i64[64,64] = squeeze[dimensions=(0,)] bmu
    bmw:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bmm
    bmx:i64[64,64] = squeeze[dimensions=(0,)] bmw
    bmy:bool[64,64] = lt bmv 0:i64[]
    bmz:i64[64,64] = add bmv 64:i64[]
    bna:i64[64,64] = select_n bmy bmv bmz
    bnb:bool[64,64] = lt bmx 0:i64[]
    bnc:i64[64,64] = add bmx 64:i64[]
    bnd:i64[64,64] = select_n bnb bmx bnc
    bne:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bna
    bnf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bnd
    bng:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bne
    bnh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bnf
    bni:i32[64,64,2] = concatenate[dimension=2] bng bnh
    bnj:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] blz bni
    bnk:f64[64,64] = mul bmt bnj
    bnl:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bmi
    bnm:f64[64,64] = squeeze[dimensions=(0,)] bnl
    bnn:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bmh
    bno:f64[64,64] = squeeze[dimensions=(0,)] bnn
    bnp:f64[64,64] = mul bnm bno
    bnq:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bmm
    bnr:i64[64,64] = squeeze[dimensions=(0,)] bnq
    bns:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bmo
    bnt:i64[64,64] = squeeze[dimensions=(0,)] bns
    bnu:bool[64,64] = lt bnr 0:i64[]
    bnv:i64[64,64] = add bnr 64:i64[]
    bnw:i64[64,64] = select_n bnu bnr bnv
    bnx:bool[64,64] = lt bnt 0:i64[]
    bny:i64[64,64] = add bnt 64:i64[]
    bnz:i64[64,64] = select_n bnx bnt bny
    boa:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bnw
    bob:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bnz
    boc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] boa
    bod:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bob
    boe:i32[64,64,2] = concatenate[dimension=2] boc bod
    bof:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] blz boe
    bog:f64[64,64] = mul bnp bof
    boh:f64[64,64] = add bnk bog
    boi:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bmh
    boj:f64[64,64] = squeeze[dimensions=(0,)] boi
    bok:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bmi
    bol:f64[64,64] = squeeze[dimensions=(0,)] bok
    bom:f64[64,64] = mul boj bol
    bon:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bmo
    boo:i64[64,64] = squeeze[dimensions=(0,)] bon
    bop:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bmm
    boq:i64[64,64] = squeeze[dimensions=(0,)] bop
    bor:bool[64,64] = lt boo 0:i64[]
    bos:i64[64,64] = add boo 64:i64[]
    bot:i64[64,64] = select_n bor boo bos
    bou:bool[64,64] = lt boq 0:i64[]
    bov:i64[64,64] = add boq 64:i64[]
    bow:i64[64,64] = select_n bou boq bov
    box:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bot
    boy:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bow
    boz:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] box
    bpa:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] boy
    bpb:i32[64,64,2] = concatenate[dimension=2] boz bpa
    bpc:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] blz bpb
    bpd:f64[64,64] = mul bom bpc
    bpe:f64[64,64] = add boh bpd
    bpf:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bmh
    bpg:f64[64,64] = squeeze[dimensions=(0,)] bpf
    bph:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bmh
    bpi:f64[64,64] = squeeze[dimensions=(0,)] bph
    bpj:f64[64,64] = mul bpg bpi
    bpk:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bmo
    bpl:i64[64,64] = squeeze[dimensions=(0,)] bpk
    bpm:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bmo
    bpn:i64[64,64] = squeeze[dimensions=(0,)] bpm
    bpo:bool[64,64] = lt bpl 0:i64[]
    bpp:i64[64,64] = add bpl 64:i64[]
    bpq:i64[64,64] = select_n bpo bpl bpp
    bpr:bool[64,64] = lt bpn 0:i64[]
    bps:i64[64,64] = add bpn 64:i64[]
    bpt:i64[64,64] = select_n bpr bpn bps
    bpu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bpq
    bpv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bpt
    bpw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bpu
    bpx:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bpv
    bpy:i32[64,64,2] = concatenate[dimension=2] bpw bpx
    bpz:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] blz bpy
    bqa:f64[64,64] = mul bpj bpz
    bqb:f64[64,64] = add bpe bqa
    bqc:f64[] = neg eb
    bqd:f64[] = convert_element_type[new_dtype=float64 weak_type=False] bqc
    bqe:f64[2,64,64] = mul bqd bdr
    bqf:f64[2,64,64] = add ea bqe
    bqg:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    bqh:f64[2,64,64] = div bqf bqg
    bqi:f64[2,64,64] = floor bqh
    bqj:f64[2,64,64] = sub bqh bqi
    bqk:f64[2,64,64] = sub 1.0:f64[] bqj
    bql:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] be
    bqm:f64[2,64,64] = sub bqh bqj
    bqn:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] bqm
    bqo:i64[2,64,64] = jit[name=remainder jaxpr=remainder] bqn bql
    bqp:i64[2,64,64] = add bqo 1:i64[]
    bqq:i64[2,64,64] = jit[name=remainder jaxpr=remainder] bqp bql
    bqr:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bqk
    bqs:f64[64,64] = squeeze[dimensions=(0,)] bqr
    bqt:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bqk
    bqu:f64[64,64] = squeeze[dimensions=(0,)] bqt
    bqv:f64[64,64] = mul bqs bqu
    bqw:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bqo
    bqx:i64[64,64] = squeeze[dimensions=(0,)] bqw
    bqy:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bqo
    bqz:i64[64,64] = squeeze[dimensions=(0,)] bqy
    bra:bool[64,64] = lt bqx 0:i64[]
    brb:i64[64,64] = add bqx 64:i64[]
    brc:i64[64,64] = select_n bra bqx brb
    brd:bool[64,64] = lt bqz 0:i64[]
    bre:i64[64,64] = add bqz 64:i64[]
    brf:i64[64,64] = select_n brd bqz bre
    brg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] brc
    brh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] brf
    bri:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] brg
    brj:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] brh
    brk:i32[64,64,2] = concatenate[dimension=2] bri brj
    brl:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] bdg brk
    brm:f64[64,64] = mul bqv brl
    brn:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bqk
    bro:f64[64,64] = squeeze[dimensions=(0,)] brn
    brp:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bqj
    brq:f64[64,64] = squeeze[dimensions=(0,)] brp
    brr:f64[64,64] = mul bro brq
    brs:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bqo
    brt:i64[64,64] = squeeze[dimensions=(0,)] brs
    bru:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bqq
    brv:i64[64,64] = squeeze[dimensions=(0,)] bru
    brw:bool[64,64] = lt brt 0:i64[]
    brx:i64[64,64] = add brt 64:i64[]
    bry:i64[64,64] = select_n brw brt brx
    brz:bool[64,64] = lt brv 0:i64[]
    bsa:i64[64,64] = add brv 64:i64[]
    bsb:i64[64,64] = select_n brz brv bsa
    bsc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bry
    bsd:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bsb
    bse:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bsc
    bsf:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bsd
    bsg:i32[64,64,2] = concatenate[dimension=2] bse bsf
    bsh:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] bdg bsg
    bsi:f64[64,64] = mul brr bsh
    bsj:f64[64,64] = add brm bsi
    bsk:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bqj
    bsl:f64[64,64] = squeeze[dimensions=(0,)] bsk
    bsm:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bqk
    bsn:f64[64,64] = squeeze[dimensions=(0,)] bsm
    bso:f64[64,64] = mul bsl bsn
    bsp:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bqq
    bsq:i64[64,64] = squeeze[dimensions=(0,)] bsp
    bsr:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bqo
    bss:i64[64,64] = squeeze[dimensions=(0,)] bsr
    bst:bool[64,64] = lt bsq 0:i64[]
    bsu:i64[64,64] = add bsq 64:i64[]
    bsv:i64[64,64] = select_n bst bsq bsu
    bsw:bool[64,64] = lt bss 0:i64[]
    bsx:i64[64,64] = add bss 64:i64[]
    bsy:i64[64,64] = select_n bsw bss bsx
    bsz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bsv
    bta:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bsy
    btb:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bsz
    btc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bta
    btd:i32[64,64,2] = concatenate[dimension=2] btb btc
    bte:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] bdg btd
    btf:f64[64,64] = mul bso bte
    btg:f64[64,64] = add bsj btf
    bth:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bqj
    bti:f64[64,64] = squeeze[dimensions=(0,)] bth
    btj:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bqj
    btk:f64[64,64] = squeeze[dimensions=(0,)] btj
    btl:f64[64,64] = mul bti btk
    btm:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bqq
    btn:i64[64,64] = squeeze[dimensions=(0,)] btm
    bto:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bqq
    btp:i64[64,64] = squeeze[dimensions=(0,)] bto
    btq:bool[64,64] = lt btn 0:i64[]
    btr:i64[64,64] = add btn 64:i64[]
    bts:i64[64,64] = select_n btq btn btr
    btt:bool[64,64] = lt btp 0:i64[]
    btu:i64[64,64] = add btp 64:i64[]
    btv:i64[64,64] = select_n btt btp btu
    btw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bts
    btx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] btv
    bty:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] btw
    btz:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] btx
    bua:i32[64,64,2] = concatenate[dimension=2] bty btz
    bub:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] bdg bua
    buc:f64[64,64] = mul btl bub
    bud:f64[64,64] = add btg buc
    bue:f64[2,64,64] = neg bdr
    buf:f64[] = neg eb
    bug:f64[] = convert_element_type[new_dtype=float64 weak_type=False] buf
    buh:f64[2,64,64] = mul bug bue
    bui:f64[2,64,64] = add ea buh
    buj:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    buk:f64[2,64,64] = div bui buj
    bul:f64[2,64,64] = floor buk
    bum:f64[2,64,64] = sub buk bul
    bun:f64[2,64,64] = sub 1.0:f64[] bum
    buo:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] bf
    bup:f64[2,64,64] = sub buk bum
    buq:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] bup
    bur:i64[2,64,64] = jit[name=remainder jaxpr=remainder] buq buo
    bus:i64[2,64,64] = add bur 1:i64[]
    but:i64[2,64,64] = jit[name=remainder jaxpr=remainder] bus buo
    buu:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bun
    buv:f64[64,64] = squeeze[dimensions=(0,)] buu
    buw:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bun
    bux:f64[64,64] = squeeze[dimensions=(0,)] buw
    buy:f64[64,64] = mul buv bux
    buz:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bur
    bva:i64[64,64] = squeeze[dimensions=(0,)] buz
    bvb:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bur
    bvc:i64[64,64] = squeeze[dimensions=(0,)] bvb
    bvd:bool[64,64] = lt bva 0:i64[]
    bve:i64[64,64] = add bva 64:i64[]
    bvf:i64[64,64] = select_n bvd bva bve
    bvg:bool[64,64] = lt bvc 0:i64[]
    bvh:i64[64,64] = add bvc 64:i64[]
    bvi:i64[64,64] = select_n bvg bvc bvh
    bvj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bvf
    bvk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bvi
    bvl:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bvj
    bvm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bvk
    bvn:i32[64,64,2] = concatenate[dimension=2] bvl bvm
    bvo:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] bud bvn
    bvp:f64[64,64] = mul buy bvo
    bvq:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bun
    bvr:f64[64,64] = squeeze[dimensions=(0,)] bvq
    bvs:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bum
    bvt:f64[64,64] = squeeze[dimensions=(0,)] bvs
    bvu:f64[64,64] = mul bvr bvt
    bvv:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bur
    bvw:i64[64,64] = squeeze[dimensions=(0,)] bvv
    bvx:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] but
    bvy:i64[64,64] = squeeze[dimensions=(0,)] bvx
    bvz:bool[64,64] = lt bvw 0:i64[]
    bwa:i64[64,64] = add bvw 64:i64[]
    bwb:i64[64,64] = select_n bvz bvw bwa
    bwc:bool[64,64] = lt bvy 0:i64[]
    bwd:i64[64,64] = add bvy 64:i64[]
    bwe:i64[64,64] = select_n bwc bvy bwd
    bwf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bwb
    bwg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bwe
    bwh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bwf
    bwi:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bwg
    bwj:i32[64,64,2] = concatenate[dimension=2] bwh bwi
    bwk:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] bud bwj
    bwl:f64[64,64] = mul bvu bwk
    bwm:f64[64,64] = add bvp bwl
    bwn:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bum
    bwo:f64[64,64] = squeeze[dimensions=(0,)] bwn
    bwp:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bun
    bwq:f64[64,64] = squeeze[dimensions=(0,)] bwp
    bwr:f64[64,64] = mul bwo bwq
    bws:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] but
    bwt:i64[64,64] = squeeze[dimensions=(0,)] bws
    bwu:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bur
    bwv:i64[64,64] = squeeze[dimensions=(0,)] bwu
    bww:bool[64,64] = lt bwt 0:i64[]
    bwx:i64[64,64] = add bwt 64:i64[]
    bwy:i64[64,64] = select_n bww bwt bwx
    bwz:bool[64,64] = lt bwv 0:i64[]
    bxa:i64[64,64] = add bwv 64:i64[]
    bxb:i64[64,64] = select_n bwz bwv bxa
    bxc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bwy
    bxd:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bxb
    bxe:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bxc
    bxf:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bxd
    bxg:i32[64,64,2] = concatenate[dimension=2] bxe bxf
    bxh:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] bud bxg
    bxi:f64[64,64] = mul bwr bxh
    bxj:f64[64,64] = add bwm bxi
    bxk:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bum
    bxl:f64[64,64] = squeeze[dimensions=(0,)] bxk
    bxm:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bum
    bxn:f64[64,64] = squeeze[dimensions=(0,)] bxm
    bxo:f64[64,64] = mul bxl bxn
    bxp:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] but
    bxq:i64[64,64] = squeeze[dimensions=(0,)] bxp
    bxr:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] but
    bxs:i64[64,64] = squeeze[dimensions=(0,)] bxr
    bxt:bool[64,64] = lt bxq 0:i64[]
    bxu:i64[64,64] = add bxq 64:i64[]
    bxv:i64[64,64] = select_n bxt bxq bxu
    bxw:bool[64,64] = lt bxs 0:i64[]
    bxx:i64[64,64] = add bxs 64:i64[]
    bxy:i64[64,64] = select_n bxw bxs bxx
    bxz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bxv
    bya:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bxy
    byb:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bxz
    byc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bya
    byd:i32[64,64,2] = concatenate[dimension=2] byb byc
    bye:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] bud byd
    byf:f64[64,64] = mul bxo bye
    byg:f64[64,64] = add bxj byf
    byh:f64[64,64] = sub bdg byg
    byi:f64[64,64] = div byh 2.0:f64[]
    byj:f64[64,64] = add bdg byi
    byk:f64[] = neg eb
    byl:f64[] = convert_element_type[new_dtype=float64 weak_type=False] byk
    bym:f64[2,64,64] = mul byl bdr
    byn:f64[2,64,64] = add ea bym
    byo:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    byp:f64[2,64,64] = div byn byo
    byq:f64[2,64,64] = floor byp
    byr:f64[2,64,64] = sub byp byq
    bys:f64[2,64,64] = sub 1.0:f64[] byr
    byt:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] bg
    byu:f64[2,64,64] = sub byp byr
    byv:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] byu
    byw:i64[2,64,64] = jit[name=remainder jaxpr=remainder] byv byt
    byx:i64[2,64,64] = add byw 1:i64[]
    byy:i64[2,64,64] = jit[name=remainder jaxpr=remainder] byx byt
    byz:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bys
    bza:f64[64,64] = squeeze[dimensions=(0,)] byz
    bzb:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bys
    bzc:f64[64,64] = squeeze[dimensions=(0,)] bzb
    bzd:f64[64,64] = mul bza bzc
    bze:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] byw
    bzf:i64[64,64] = squeeze[dimensions=(0,)] bze
    bzg:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] byw
    bzh:i64[64,64] = squeeze[dimensions=(0,)] bzg
    bzi:bool[64,64] = lt bzf 0:i64[]
    bzj:i64[64,64] = add bzf 64:i64[]
    bzk:i64[64,64] = select_n bzi bzf bzj
    bzl:bool[64,64] = lt bzh 0:i64[]
    bzm:i64[64,64] = add bzh 64:i64[]
    bzn:i64[64,64] = select_n bzl bzh bzm
    bzo:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bzk
    bzp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] bzn
    bzq:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bzo
    bzr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] bzp
    bzs:i32[64,64,2] = concatenate[dimension=2] bzq bzr
    bzt:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] byj bzs
    bzu:f64[64,64] = mul bzd bzt
    bzv:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] bys
    bzw:f64[64,64] = squeeze[dimensions=(0,)] bzv
    bzx:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] byr
    bzy:f64[64,64] = squeeze[dimensions=(0,)] bzx
    bzz:f64[64,64] = mul bzw bzy
    caa:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] byw
    cab:i64[64,64] = squeeze[dimensions=(0,)] caa
    cac:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] byy
    cad:i64[64,64] = squeeze[dimensions=(0,)] cac
    cae:bool[64,64] = lt cab 0:i64[]
    caf:i64[64,64] = add cab 64:i64[]
    cag:i64[64,64] = select_n cae cab caf
    cah:bool[64,64] = lt cad 0:i64[]
    cai:i64[64,64] = add cad 64:i64[]
    caj:i64[64,64] = select_n cah cad cai
    cak:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cag
    cal:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] caj
    cam:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cak
    can:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cal
    cao:i32[64,64,2] = concatenate[dimension=2] cam can
    cap:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] byj cao
    caq:f64[64,64] = mul bzz cap
    car:f64[64,64] = add bzu caq
    cas:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] byr
    cat:f64[64,64] = squeeze[dimensions=(0,)] cas
    cau:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] bys
    cav:f64[64,64] = squeeze[dimensions=(0,)] cau
    caw:f64[64,64] = mul cat cav
    cax:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] byy
    cay:i64[64,64] = squeeze[dimensions=(0,)] cax
    caz:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] byw
    cba:i64[64,64] = squeeze[dimensions=(0,)] caz
    cbb:bool[64,64] = lt cay 0:i64[]
    cbc:i64[64,64] = add cay 64:i64[]
    cbd:i64[64,64] = select_n cbb cay cbc
    cbe:bool[64,64] = lt cba 0:i64[]
    cbf:i64[64,64] = add cba 64:i64[]
    cbg:i64[64,64] = select_n cbe cba cbf
    cbh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cbd
    cbi:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cbg
    cbj:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cbh
    cbk:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cbi
    cbl:i32[64,64,2] = concatenate[dimension=2] cbj cbk
    cbm:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] byj cbl
    cbn:f64[64,64] = mul caw cbm
    cbo:f64[64,64] = add car cbn
    cbp:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] byr
    cbq:f64[64,64] = squeeze[dimensions=(0,)] cbp
    cbr:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] byr
    cbs:f64[64,64] = squeeze[dimensions=(0,)] cbr
    cbt:f64[64,64] = mul cbq cbs
    cbu:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] byy
    cbv:i64[64,64] = squeeze[dimensions=(0,)] cbu
    cbw:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] byy
    cbx:i64[64,64] = squeeze[dimensions=(0,)] cbw
    cby:bool[64,64] = lt cbv 0:i64[]
    cbz:i64[64,64] = add cbv 64:i64[]
    cca:i64[64,64] = select_n cby cbv cbz
    ccb:bool[64,64] = lt cbx 0:i64[]
    ccc:i64[64,64] = add cbx 64:i64[]
    ccd:i64[64,64] = select_n ccb cbx ccc
    cce:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cca
    ccf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ccd
    ccg:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cce
    cch:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ccf
    cci:i32[64,64,2] = concatenate[dimension=2] ccg cch
    ccj:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] byj cci
    cck:f64[64,64] = mul cbt ccj
    ccl:f64[64,64] = add cbo cck
    ccm:c128[64,33] = jit[name=fft jaxpr=fft] bqb
    ccn:c128[] = reduce_prod[axes=(0,)] bh
    cco:c128[] = sqrt ccn
    ccp:c128[] = div (1+0j):c128[] cco
    ccq:c128[64,33] = mul ccm ccp
    ccr:c128[1,64,33] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 64, 33)
      sharding=None
    ] ccq
    ccs:c128[2,64,33] = mul dw ccr
    cct:f64[2,64,64] = jit[name=fft jaxpr=fft1] ccs
    ccu:f64[] = reduce_prod[axes=(0,)] bi
    ccv:f64[] = sqrt ccu
    ccw:f64[2,64,64] = mul cct ccv
    ccx:f64[] = neg eb
    ccy:f64[] = convert_element_type[new_dtype=float64 weak_type=False] ccx
    ccz:f64[2,64,64] = mul ccy ccw
    cda:f64[2,64,64] = add ea ccz
    cdb:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    cdc:f64[2,64,64] = div cda cdb
    cdd:f64[2,64,64] = floor cdc
    cde:f64[2,64,64] = sub cdc cdd
    cdf:f64[2,64,64] = sub 1.0:f64[] cde
    cdg:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] bj
    cdh:f64[2,64,64] = sub cdc cde
    cdi:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] cdh
    cdj:i64[2,64,64] = jit[name=remainder jaxpr=remainder] cdi cdg
    cdk:i64[2,64,64] = add cdj 1:i64[]
    cdl:i64[2,64,64] = jit[name=remainder jaxpr=remainder] cdk cdg
    cdm:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cdf
    cdn:f64[64,64] = squeeze[dimensions=(0,)] cdm
    cdo:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cdf
    cdp:f64[64,64] = squeeze[dimensions=(0,)] cdo
    cdq:f64[64,64] = mul cdn cdp
    cdr:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cdj
    cds:i64[64,64] = squeeze[dimensions=(0,)] cdr
    cdt:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cdj
    cdu:i64[64,64] = squeeze[dimensions=(0,)] cdt
    cdv:bool[64,64] = lt cds 0:i64[]
    cdw:i64[64,64] = add cds 64:i64[]
    cdx:i64[64,64] = select_n cdv cds cdw
    cdy:bool[64,64] = lt cdu 0:i64[]
    cdz:i64[64,64] = add cdu 64:i64[]
    cea:i64[64,64] = select_n cdy cdu cdz
    ceb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cdx
    cec:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cea
    ced:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ceb
    cee:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cec
    cef:i32[64,64,2] = concatenate[dimension=2] ced cee
    ceg:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] bqb cef
    ceh:f64[64,64] = mul cdq ceg
    cei:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cdf
    cej:f64[64,64] = squeeze[dimensions=(0,)] cei
    cek:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cde
    cel:f64[64,64] = squeeze[dimensions=(0,)] cek
    cem:f64[64,64] = mul cej cel
    cen:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cdj
    ceo:i64[64,64] = squeeze[dimensions=(0,)] cen
    cep:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cdl
    ceq:i64[64,64] = squeeze[dimensions=(0,)] cep
    cer:bool[64,64] = lt ceo 0:i64[]
    ces:i64[64,64] = add ceo 64:i64[]
    cet:i64[64,64] = select_n cer ceo ces
    ceu:bool[64,64] = lt ceq 0:i64[]
    cev:i64[64,64] = add ceq 64:i64[]
    cew:i64[64,64] = select_n ceu ceq cev
    cex:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cet
    cey:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cew
    cez:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cex
    cfa:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cey
    cfb:i32[64,64,2] = concatenate[dimension=2] cez cfa
    cfc:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] bqb cfb
    cfd:f64[64,64] = mul cem cfc
    cfe:f64[64,64] = add ceh cfd
    cff:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cde
    cfg:f64[64,64] = squeeze[dimensions=(0,)] cff
    cfh:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cdf
    cfi:f64[64,64] = squeeze[dimensions=(0,)] cfh
    cfj:f64[64,64] = mul cfg cfi
    cfk:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cdl
    cfl:i64[64,64] = squeeze[dimensions=(0,)] cfk
    cfm:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cdj
    cfn:i64[64,64] = squeeze[dimensions=(0,)] cfm
    cfo:bool[64,64] = lt cfl 0:i64[]
    cfp:i64[64,64] = add cfl 64:i64[]
    cfq:i64[64,64] = select_n cfo cfl cfp
    cfr:bool[64,64] = lt cfn 0:i64[]
    cfs:i64[64,64] = add cfn 64:i64[]
    cft:i64[64,64] = select_n cfr cfn cfs
    cfu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cfq
    cfv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cft
    cfw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cfu
    cfx:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cfv
    cfy:i32[64,64,2] = concatenate[dimension=2] cfw cfx
    cfz:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] bqb cfy
    cga:f64[64,64] = mul cfj cfz
    cgb:f64[64,64] = add cfe cga
    cgc:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cde
    cgd:f64[64,64] = squeeze[dimensions=(0,)] cgc
    cge:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cde
    cgf:f64[64,64] = squeeze[dimensions=(0,)] cge
    cgg:f64[64,64] = mul cgd cgf
    cgh:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cdl
    cgi:i64[64,64] = squeeze[dimensions=(0,)] cgh
    cgj:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cdl
    cgk:i64[64,64] = squeeze[dimensions=(0,)] cgj
    cgl:bool[64,64] = lt cgi 0:i64[]
    cgm:i64[64,64] = add cgi 64:i64[]
    cgn:i64[64,64] = select_n cgl cgi cgm
    cgo:bool[64,64] = lt cgk 0:i64[]
    cgp:i64[64,64] = add cgk 64:i64[]
    cgq:i64[64,64] = select_n cgo cgk cgp
    cgr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cgn
    cgs:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cgq
    cgt:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cgr
    cgu:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cgs
    cgv:i32[64,64,2] = concatenate[dimension=2] cgt cgu
    cgw:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] bqb cgv
    cgx:f64[64,64] = mul cgg cgw
    cgy:f64[64,64] = add cgb cgx
    cgz:f64[2,64,64] = neg ccw
    cha:f64[] = neg eb
    chb:f64[] = convert_element_type[new_dtype=float64 weak_type=False] cha
    chc:f64[2,64,64] = mul chb cgz
    chd:f64[2,64,64] = add ea chc
    che:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    chf:f64[2,64,64] = div chd che
    chg:f64[2,64,64] = floor chf
    chh:f64[2,64,64] = sub chf chg
    chi:f64[2,64,64] = sub 1.0:f64[] chh
    chj:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] bk
    chk:f64[2,64,64] = sub chf chh
    chl:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] chk
    chm:i64[2,64,64] = jit[name=remainder jaxpr=remainder] chl chj
    chn:i64[2,64,64] = add chm 1:i64[]
    cho:i64[2,64,64] = jit[name=remainder jaxpr=remainder] chn chj
    chp:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] chi
    chq:f64[64,64] = squeeze[dimensions=(0,)] chp
    chr:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] chi
    chs:f64[64,64] = squeeze[dimensions=(0,)] chr
    cht:f64[64,64] = mul chq chs
    chu:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] chm
    chv:i64[64,64] = squeeze[dimensions=(0,)] chu
    chw:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] chm
    chx:i64[64,64] = squeeze[dimensions=(0,)] chw
    chy:bool[64,64] = lt chv 0:i64[]
    chz:i64[64,64] = add chv 64:i64[]
    cia:i64[64,64] = select_n chy chv chz
    cib:bool[64,64] = lt chx 0:i64[]
    cic:i64[64,64] = add chx 64:i64[]
    cid:i64[64,64] = select_n cib chx cic
    cie:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cia
    cif:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cid
    cig:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cie
    cih:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cif
    cii:i32[64,64,2] = concatenate[dimension=2] cig cih
    cij:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] cgy cii
    cik:f64[64,64] = mul cht cij
    cil:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] chi
    cim:f64[64,64] = squeeze[dimensions=(0,)] cil
    cin:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] chh
    cio:f64[64,64] = squeeze[dimensions=(0,)] cin
    cip:f64[64,64] = mul cim cio
    ciq:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] chm
    cir:i64[64,64] = squeeze[dimensions=(0,)] ciq
    cis:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cho
    cit:i64[64,64] = squeeze[dimensions=(0,)] cis
    ciu:bool[64,64] = lt cir 0:i64[]
    civ:i64[64,64] = add cir 64:i64[]
    ciw:i64[64,64] = select_n ciu cir civ
    cix:bool[64,64] = lt cit 0:i64[]
    ciy:i64[64,64] = add cit 64:i64[]
    ciz:i64[64,64] = select_n cix cit ciy
    cja:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ciw
    cjb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ciz
    cjc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cja
    cjd:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cjb
    cje:i32[64,64,2] = concatenate[dimension=2] cjc cjd
    cjf:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] cgy cje
    cjg:f64[64,64] = mul cip cjf
    cjh:f64[64,64] = add cik cjg
    cji:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] chh
    cjj:f64[64,64] = squeeze[dimensions=(0,)] cji
    cjk:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] chi
    cjl:f64[64,64] = squeeze[dimensions=(0,)] cjk
    cjm:f64[64,64] = mul cjj cjl
    cjn:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cho
    cjo:i64[64,64] = squeeze[dimensions=(0,)] cjn
    cjp:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] chm
    cjq:i64[64,64] = squeeze[dimensions=(0,)] cjp
    cjr:bool[64,64] = lt cjo 0:i64[]
    cjs:i64[64,64] = add cjo 64:i64[]
    cjt:i64[64,64] = select_n cjr cjo cjs
    cju:bool[64,64] = lt cjq 0:i64[]
    cjv:i64[64,64] = add cjq 64:i64[]
    cjw:i64[64,64] = select_n cju cjq cjv
    cjx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cjt
    cjy:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cjw
    cjz:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cjx
    cka:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cjy
    ckb:i32[64,64,2] = concatenate[dimension=2] cjz cka
    ckc:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] cgy ckb
    ckd:f64[64,64] = mul cjm ckc
    cke:f64[64,64] = add cjh ckd
    ckf:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] chh
    ckg:f64[64,64] = squeeze[dimensions=(0,)] ckf
    ckh:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] chh
    cki:f64[64,64] = squeeze[dimensions=(0,)] ckh
    ckj:f64[64,64] = mul ckg cki
    ckk:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cho
    ckl:i64[64,64] = squeeze[dimensions=(0,)] ckk
    ckm:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cho
    ckn:i64[64,64] = squeeze[dimensions=(0,)] ckm
    cko:bool[64,64] = lt ckl 0:i64[]
    ckp:i64[64,64] = add ckl 64:i64[]
    ckq:i64[64,64] = select_n cko ckl ckp
    ckr:bool[64,64] = lt ckn 0:i64[]
    cks:i64[64,64] = add ckn 64:i64[]
    ckt:i64[64,64] = select_n ckr ckn cks
    cku:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ckq
    ckv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ckt
    ckw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cku
    ckx:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ckv
    cky:i32[64,64,2] = concatenate[dimension=2] ckw ckx
    ckz:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] cgy cky
    cla:f64[64,64] = mul ckj ckz
    clb:f64[64,64] = add cke cla
    clc:f64[64,64] = sub bqb clb
    cld:f64[64,64] = div clc 2.0:f64[]
    cle:f64[64,64] = add bqb cld
    clf:f64[] = neg eb
    clg:f64[] = convert_element_type[new_dtype=float64 weak_type=False] clf
    clh:f64[2,64,64] = mul clg ccw
    cli:f64[2,64,64] = add ea clh
    clj:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    clk:f64[2,64,64] = div cli clj
    cll:f64[2,64,64] = floor clk
    clm:f64[2,64,64] = sub clk cll
    cln:f64[2,64,64] = sub 1.0:f64[] clm
    clo:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] bl
    clp:f64[2,64,64] = sub clk clm
    clq:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] clp
    clr:i64[2,64,64] = jit[name=remainder jaxpr=remainder] clq clo
    cls:i64[2,64,64] = add clr 1:i64[]
    clt:i64[2,64,64] = jit[name=remainder jaxpr=remainder] cls clo
    clu:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cln
    clv:f64[64,64] = squeeze[dimensions=(0,)] clu
    clw:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cln
    clx:f64[64,64] = squeeze[dimensions=(0,)] clw
    cly:f64[64,64] = mul clv clx
    clz:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] clr
    cma:i64[64,64] = squeeze[dimensions=(0,)] clz
    cmb:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] clr
    cmc:i64[64,64] = squeeze[dimensions=(0,)] cmb
    cmd:bool[64,64] = lt cma 0:i64[]
    cme:i64[64,64] = add cma 64:i64[]
    cmf:i64[64,64] = select_n cmd cma cme
    cmg:bool[64,64] = lt cmc 0:i64[]
    cmh:i64[64,64] = add cmc 64:i64[]
    cmi:i64[64,64] = select_n cmg cmc cmh
    cmj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cmf
    cmk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cmi
    cml:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cmj
    cmm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cmk
    cmn:i32[64,64,2] = concatenate[dimension=2] cml cmm
    cmo:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] cle cmn
    cmp:f64[64,64] = mul cly cmo
    cmq:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cln
    cmr:f64[64,64] = squeeze[dimensions=(0,)] cmq
    cms:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] clm
    cmt:f64[64,64] = squeeze[dimensions=(0,)] cms
    cmu:f64[64,64] = mul cmr cmt
    cmv:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] clr
    cmw:i64[64,64] = squeeze[dimensions=(0,)] cmv
    cmx:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] clt
    cmy:i64[64,64] = squeeze[dimensions=(0,)] cmx
    cmz:bool[64,64] = lt cmw 0:i64[]
    cna:i64[64,64] = add cmw 64:i64[]
    cnb:i64[64,64] = select_n cmz cmw cna
    cnc:bool[64,64] = lt cmy 0:i64[]
    cnd:i64[64,64] = add cmy 64:i64[]
    cne:i64[64,64] = select_n cnc cmy cnd
    cnf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cnb
    cng:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cne
    cnh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cnf
    cni:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cng
    cnj:i32[64,64,2] = concatenate[dimension=2] cnh cni
    cnk:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] cle cnj
    cnl:f64[64,64] = mul cmu cnk
    cnm:f64[64,64] = add cmp cnl
    cnn:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] clm
    cno:f64[64,64] = squeeze[dimensions=(0,)] cnn
    cnp:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cln
    cnq:f64[64,64] = squeeze[dimensions=(0,)] cnp
    cnr:f64[64,64] = mul cno cnq
    cns:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] clt
    cnt:i64[64,64] = squeeze[dimensions=(0,)] cns
    cnu:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] clr
    cnv:i64[64,64] = squeeze[dimensions=(0,)] cnu
    cnw:bool[64,64] = lt cnt 0:i64[]
    cnx:i64[64,64] = add cnt 64:i64[]
    cny:i64[64,64] = select_n cnw cnt cnx
    cnz:bool[64,64] = lt cnv 0:i64[]
    coa:i64[64,64] = add cnv 64:i64[]
    cob:i64[64,64] = select_n cnz cnv coa
    coc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cny
    cod:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cob
    coe:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] coc
    cof:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cod
    cog:i32[64,64,2] = concatenate[dimension=2] coe cof
    coh:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] cle cog
    coi:f64[64,64] = mul cnr coh
    coj:f64[64,64] = add cnm coi
    cok:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] clm
    col:f64[64,64] = squeeze[dimensions=(0,)] cok
    com:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] clm
    con:f64[64,64] = squeeze[dimensions=(0,)] com
    coo:f64[64,64] = mul col con
    cop:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] clt
    coq:i64[64,64] = squeeze[dimensions=(0,)] cop
    cor:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] clt
    cos:i64[64,64] = squeeze[dimensions=(0,)] cor
    cot:bool[64,64] = lt coq 0:i64[]
    cou:i64[64,64] = add coq 64:i64[]
    cov:i64[64,64] = select_n cot coq cou
    cow:bool[64,64] = lt cos 0:i64[]
    cox:i64[64,64] = add cos 64:i64[]
    coy:i64[64,64] = select_n cow cos cox
    coz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cov
    cpa:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] coy
    cpb:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] coz
    cpc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cpa
    cpd:i32[64,64,2] = concatenate[dimension=2] cpb cpc
    cpe:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] cle cpd
    cpf:f64[64,64] = mul coo cpe
    cpg:f64[64,64] = add coj cpf
    cph:f64[] = neg eb
    cpi:f64[] = convert_element_type[new_dtype=float64 weak_type=False] cph
    cpj:f64[2,64,64] = mul cpi ccw
    cpk:f64[2,64,64] = add ea cpj
    cpl:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    cpm:f64[2,64,64] = div cpk cpl
    cpn:f64[2,64,64] = floor cpm
    cpo:f64[2,64,64] = sub cpm cpn
    cpp:f64[2,64,64] = sub 1.0:f64[] cpo
    cpq:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] bm
    cpr:f64[2,64,64] = sub cpm cpo
    cps:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] cpr
    cpt:i64[2,64,64] = jit[name=remainder jaxpr=remainder] cps cpq
    cpu:i64[2,64,64] = add cpt 1:i64[]
    cpv:i64[2,64,64] = jit[name=remainder jaxpr=remainder] cpu cpq
    cpw:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cpp
    cpx:f64[64,64] = squeeze[dimensions=(0,)] cpw
    cpy:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cpp
    cpz:f64[64,64] = squeeze[dimensions=(0,)] cpy
    cqa:f64[64,64] = mul cpx cpz
    cqb:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cpt
    cqc:i64[64,64] = squeeze[dimensions=(0,)] cqb
    cqd:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cpt
    cqe:i64[64,64] = squeeze[dimensions=(0,)] cqd
    cqf:bool[64,64] = lt cqc 0:i64[]
    cqg:i64[64,64] = add cqc 64:i64[]
    cqh:i64[64,64] = select_n cqf cqc cqg
    cqi:bool[64,64] = lt cqe 0:i64[]
    cqj:i64[64,64] = add cqe 64:i64[]
    cqk:i64[64,64] = select_n cqi cqe cqj
    cql:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cqh
    cqm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cqk
    cqn:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cql
    cqo:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cqm
    cqp:i32[64,64,2] = concatenate[dimension=2] cqn cqo
    cqq:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ccl cqp
    cqr:f64[64,64] = mul cqa cqq
    cqs:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cpp
    cqt:f64[64,64] = squeeze[dimensions=(0,)] cqs
    cqu:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cpo
    cqv:f64[64,64] = squeeze[dimensions=(0,)] cqu
    cqw:f64[64,64] = mul cqt cqv
    cqx:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cpt
    cqy:i64[64,64] = squeeze[dimensions=(0,)] cqx
    cqz:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cpv
    cra:i64[64,64] = squeeze[dimensions=(0,)] cqz
    crb:bool[64,64] = lt cqy 0:i64[]
    crc:i64[64,64] = add cqy 64:i64[]
    crd:i64[64,64] = select_n crb cqy crc
    cre:bool[64,64] = lt cra 0:i64[]
    crf:i64[64,64] = add cra 64:i64[]
    crg:i64[64,64] = select_n cre cra crf
    crh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] crd
    cri:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] crg
    crj:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] crh
    crk:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cri
    crl:i32[64,64,2] = concatenate[dimension=2] crj crk
    crm:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ccl crl
    crn:f64[64,64] = mul cqw crm
    cro:f64[64,64] = add cqr crn
    crp:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cpo
    crq:f64[64,64] = squeeze[dimensions=(0,)] crp
    crr:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cpp
    crs:f64[64,64] = squeeze[dimensions=(0,)] crr
    crt:f64[64,64] = mul crq crs
    cru:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cpv
    crv:i64[64,64] = squeeze[dimensions=(0,)] cru
    crw:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cpt
    crx:i64[64,64] = squeeze[dimensions=(0,)] crw
    cry:bool[64,64] = lt crv 0:i64[]
    crz:i64[64,64] = add crv 64:i64[]
    csa:i64[64,64] = select_n cry crv crz
    csb:bool[64,64] = lt crx 0:i64[]
    csc:i64[64,64] = add crx 64:i64[]
    csd:i64[64,64] = select_n csb crx csc
    cse:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] csa
    csf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] csd
    csg:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cse
    csh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] csf
    csi:i32[64,64,2] = concatenate[dimension=2] csg csh
    csj:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ccl csi
    csk:f64[64,64] = mul crt csj
    csl:f64[64,64] = add cro csk
    csm:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cpo
    csn:f64[64,64] = squeeze[dimensions=(0,)] csm
    cso:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cpo
    csp:f64[64,64] = squeeze[dimensions=(0,)] cso
    csq:f64[64,64] = mul csn csp
    csr:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cpv
    css:i64[64,64] = squeeze[dimensions=(0,)] csr
    cst:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cpv
    csu:i64[64,64] = squeeze[dimensions=(0,)] cst
    csv:bool[64,64] = lt css 0:i64[]
    csw:i64[64,64] = add css 64:i64[]
    csx:i64[64,64] = select_n csv css csw
    csy:bool[64,64] = lt csu 0:i64[]
    csz:i64[64,64] = add csu 64:i64[]
    cta:i64[64,64] = select_n csy csu csz
    ctb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] csx
    ctc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cta
    ctd:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ctb
    cte:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ctc
    ctf:i32[64,64,2] = concatenate[dimension=2] ctd cte
    ctg:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ccl ctf
    cth:f64[64,64] = mul csq ctg
    cti:f64[64,64] = add csl cth
    ctj:f64[2,64,64] = neg ccw
    ctk:f64[] = neg eb
    ctl:f64[] = convert_element_type[new_dtype=float64 weak_type=False] ctk
    ctm:f64[2,64,64] = mul ctl ctj
    ctn:f64[2,64,64] = add ea ctm
    cto:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    ctp:f64[2,64,64] = div ctn cto
    ctq:f64[2,64,64] = floor ctp
    ctr:f64[2,64,64] = sub ctp ctq
    cts:f64[2,64,64] = sub 1.0:f64[] ctr
    ctt:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] bn
    ctu:f64[2,64,64] = sub ctp ctr
    ctv:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ctu
    ctw:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ctv ctt
    ctx:i64[2,64,64] = add ctw 1:i64[]
    cty:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ctx ctt
    ctz:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cts
    cua:f64[64,64] = squeeze[dimensions=(0,)] ctz
    cub:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cts
    cuc:f64[64,64] = squeeze[dimensions=(0,)] cub
    cud:f64[64,64] = mul cua cuc
    cue:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ctw
    cuf:i64[64,64] = squeeze[dimensions=(0,)] cue
    cug:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ctw
    cuh:i64[64,64] = squeeze[dimensions=(0,)] cug
    cui:bool[64,64] = lt cuf 0:i64[]
    cuj:i64[64,64] = add cuf 64:i64[]
    cuk:i64[64,64] = select_n cui cuf cuj
    cul:bool[64,64] = lt cuh 0:i64[]
    cum:i64[64,64] = add cuh 64:i64[]
    cun:i64[64,64] = select_n cul cuh cum
    cuo:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cuk
    cup:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cun
    cuq:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cuo
    cur:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cup
    cus:i32[64,64,2] = concatenate[dimension=2] cuq cur
    cut:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] cti cus
    cuu:f64[64,64] = mul cud cut
    cuv:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cts
    cuw:f64[64,64] = squeeze[dimensions=(0,)] cuv
    cux:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ctr
    cuy:f64[64,64] = squeeze[dimensions=(0,)] cux
    cuz:f64[64,64] = mul cuw cuy
    cva:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ctw
    cvb:i64[64,64] = squeeze[dimensions=(0,)] cva
    cvc:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cty
    cvd:i64[64,64] = squeeze[dimensions=(0,)] cvc
    cve:bool[64,64] = lt cvb 0:i64[]
    cvf:i64[64,64] = add cvb 64:i64[]
    cvg:i64[64,64] = select_n cve cvb cvf
    cvh:bool[64,64] = lt cvd 0:i64[]
    cvi:i64[64,64] = add cvd 64:i64[]
    cvj:i64[64,64] = select_n cvh cvd cvi
    cvk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cvg
    cvl:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cvj
    cvm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cvk
    cvn:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cvl
    cvo:i32[64,64,2] = concatenate[dimension=2] cvm cvn
    cvp:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] cti cvo
    cvq:f64[64,64] = mul cuz cvp
    cvr:f64[64,64] = add cuu cvq
    cvs:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ctr
    cvt:f64[64,64] = squeeze[dimensions=(0,)] cvs
    cvu:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cts
    cvv:f64[64,64] = squeeze[dimensions=(0,)] cvu
    cvw:f64[64,64] = mul cvt cvv
    cvx:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cty
    cvy:i64[64,64] = squeeze[dimensions=(0,)] cvx
    cvz:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ctw
    cwa:i64[64,64] = squeeze[dimensions=(0,)] cvz
    cwb:bool[64,64] = lt cvy 0:i64[]
    cwc:i64[64,64] = add cvy 64:i64[]
    cwd:i64[64,64] = select_n cwb cvy cwc
    cwe:bool[64,64] = lt cwa 0:i64[]
    cwf:i64[64,64] = add cwa 64:i64[]
    cwg:i64[64,64] = select_n cwe cwa cwf
    cwh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cwd
    cwi:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cwg
    cwj:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cwh
    cwk:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cwi
    cwl:i32[64,64,2] = concatenate[dimension=2] cwj cwk
    cwm:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] cti cwl
    cwn:f64[64,64] = mul cvw cwm
    cwo:f64[64,64] = add cvr cwn
    cwp:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ctr
    cwq:f64[64,64] = squeeze[dimensions=(0,)] cwp
    cwr:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ctr
    cws:f64[64,64] = squeeze[dimensions=(0,)] cwr
    cwt:f64[64,64] = mul cwq cws
    cwu:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cty
    cwv:i64[64,64] = squeeze[dimensions=(0,)] cwu
    cww:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cty
    cwx:i64[64,64] = squeeze[dimensions=(0,)] cww
    cwy:bool[64,64] = lt cwv 0:i64[]
    cwz:i64[64,64] = add cwv 64:i64[]
    cxa:i64[64,64] = select_n cwy cwv cwz
    cxb:bool[64,64] = lt cwx 0:i64[]
    cxc:i64[64,64] = add cwx 64:i64[]
    cxd:i64[64,64] = select_n cxb cwx cxc
    cxe:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cxa
    cxf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cxd
    cxg:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cxe
    cxh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cxf
    cxi:i32[64,64,2] = concatenate[dimension=2] cxg cxh
    cxj:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] cti cxi
    cxk:f64[64,64] = mul cwt cxj
    cxl:f64[64,64] = add cwo cxk
    cxm:f64[64,64] = sub ccl cxl
    cxn:f64[64,64] = div cxm 2.0:f64[]
    cxo:f64[64,64] = add ccl cxn
    cxp:f64[] = neg eb
    cxq:f64[] = convert_element_type[new_dtype=float64 weak_type=False] cxp
    cxr:f64[2,64,64] = mul cxq ccw
    cxs:f64[2,64,64] = add ea cxr
    cxt:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    cxu:f64[2,64,64] = div cxs cxt
    cxv:f64[2,64,64] = floor cxu
    cxw:f64[2,64,64] = sub cxu cxv
    cxx:f64[2,64,64] = sub 1.0:f64[] cxw
    cxy:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] bo
    cxz:f64[2,64,64] = sub cxu cxw
    cya:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] cxz
    cyb:i64[2,64,64] = jit[name=remainder jaxpr=remainder] cya cxy
    cyc:i64[2,64,64] = add cyb 1:i64[]
    cyd:i64[2,64,64] = jit[name=remainder jaxpr=remainder] cyc cxy
    cye:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cxx
    cyf:f64[64,64] = squeeze[dimensions=(0,)] cye
    cyg:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cxx
    cyh:f64[64,64] = squeeze[dimensions=(0,)] cyg
    cyi:f64[64,64] = mul cyf cyh
    cyj:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cyb
    cyk:i64[64,64] = squeeze[dimensions=(0,)] cyj
    cyl:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cyb
    cym:i64[64,64] = squeeze[dimensions=(0,)] cyl
    cyn:bool[64,64] = lt cyk 0:i64[]
    cyo:i64[64,64] = add cyk 64:i64[]
    cyp:i64[64,64] = select_n cyn cyk cyo
    cyq:bool[64,64] = lt cym 0:i64[]
    cyr:i64[64,64] = add cym 64:i64[]
    cys:i64[64,64] = select_n cyq cym cyr
    cyt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cyp
    cyu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] cys
    cyv:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cyt
    cyw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] cyu
    cyx:i32[64,64,2] = concatenate[dimension=2] cyv cyw
    cyy:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] cxo cyx
    cyz:f64[64,64] = mul cyi cyy
    cza:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cxx
    czb:f64[64,64] = squeeze[dimensions=(0,)] cza
    czc:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cxw
    czd:f64[64,64] = squeeze[dimensions=(0,)] czc
    cze:f64[64,64] = mul czb czd
    czf:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cyb
    czg:i64[64,64] = squeeze[dimensions=(0,)] czf
    czh:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cyd
    czi:i64[64,64] = squeeze[dimensions=(0,)] czh
    czj:bool[64,64] = lt czg 0:i64[]
    czk:i64[64,64] = add czg 64:i64[]
    czl:i64[64,64] = select_n czj czg czk
    czm:bool[64,64] = lt czi 0:i64[]
    czn:i64[64,64] = add czi 64:i64[]
    czo:i64[64,64] = select_n czm czi czn
    czp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] czl
    czq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] czo
    czr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] czp
    czs:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] czq
    czt:i32[64,64,2] = concatenate[dimension=2] czr czs
    czu:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] cxo czt
    czv:f64[64,64] = mul cze czu
    czw:f64[64,64] = add cyz czv
    czx:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cxw
    czy:f64[64,64] = squeeze[dimensions=(0,)] czx
    czz:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cxx
    daa:f64[64,64] = squeeze[dimensions=(0,)] czz
    dab:f64[64,64] = mul czy daa
    dac:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cyd
    dad:i64[64,64] = squeeze[dimensions=(0,)] dac
    dae:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cyb
    daf:i64[64,64] = squeeze[dimensions=(0,)] dae
    dag:bool[64,64] = lt dad 0:i64[]
    dah:i64[64,64] = add dad 64:i64[]
    dai:i64[64,64] = select_n dag dad dah
    daj:bool[64,64] = lt daf 0:i64[]
    dak:i64[64,64] = add daf 64:i64[]
    dal:i64[64,64] = select_n daj daf dak
    dam:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dai
    dan:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dal
    dao:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dam
    dap:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dan
    daq:i32[64,64,2] = concatenate[dimension=2] dao dap
    dar:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] cxo daq
    das:f64[64,64] = mul dab dar
    dat:f64[64,64] = add czw das
    dau:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cxw
    dav:f64[64,64] = squeeze[dimensions=(0,)] dau
    daw:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cxw
    dax:f64[64,64] = squeeze[dimensions=(0,)] daw
    day:f64[64,64] = mul dav dax
    daz:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] cyd
    dba:i64[64,64] = squeeze[dimensions=(0,)] daz
    dbb:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] cyd
    dbc:i64[64,64] = squeeze[dimensions=(0,)] dbb
    dbd:bool[64,64] = lt dba 0:i64[]
    dbe:i64[64,64] = add dba 64:i64[]
    dbf:i64[64,64] = select_n dbd dba dbe
    dbg:bool[64,64] = lt dbc 0:i64[]
    dbh:i64[64,64] = add dbc 64:i64[]
    dbi:i64[64,64] = select_n dbg dbc dbh
    dbj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dbf
    dbk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dbi
    dbl:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dbj
    dbm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dbk
    dbn:i32[64,64,2] = concatenate[dimension=2] dbl dbm
    dbo:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] cxo dbn
    dbp:f64[64,64] = mul day dbo
    dbq:f64[64,64] = add dat dbp
    dbr:c128[64,33] = jit[name=fft jaxpr=fft] cpg
    dbs:c128[] = reduce_prod[axes=(0,)] bp
    dbt:c128[] = sqrt dbs
    dbu:c128[] = div (1+0j):c128[] dbt
    dbv:c128[64,33] = mul dbr dbu
    dbw:c128[1,64,33] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 64, 33)
      sharding=None
    ] dbv
    dbx:c128[2,64,33] = mul dw dbw
    dby:f64[2,64,64] = jit[name=fft jaxpr=fft1] dbx
    dbz:f64[] = reduce_prod[axes=(0,)] bq
    dca:f64[] = sqrt dbz
    dcb:f64[2,64,64] = mul dby dca
    dcc:f64[] = neg eb
    dcd:f64[] = convert_element_type[new_dtype=float64 weak_type=False] dcc
    dce:f64[2,64,64] = mul dcd dcb
    dcf:f64[2,64,64] = add ea dce
    dcg:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    dch:f64[2,64,64] = div dcf dcg
    dci:f64[2,64,64] = floor dch
    dcj:f64[2,64,64] = sub dch dci
    dck:f64[2,64,64] = sub 1.0:f64[] dcj
    dcl:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] br
    dcm:f64[2,64,64] = sub dch dcj
    dcn:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] dcm
    dco:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dcn dcl
    dcp:i64[2,64,64] = add dco 1:i64[]
    dcq:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dcp dcl
    dcr:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dck
    dcs:f64[64,64] = squeeze[dimensions=(0,)] dcr
    dct:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dck
    dcu:f64[64,64] = squeeze[dimensions=(0,)] dct
    dcv:f64[64,64] = mul dcs dcu
    dcw:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dco
    dcx:i64[64,64] = squeeze[dimensions=(0,)] dcw
    dcy:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dco
    dcz:i64[64,64] = squeeze[dimensions=(0,)] dcy
    dda:bool[64,64] = lt dcx 0:i64[]
    ddb:i64[64,64] = add dcx 64:i64[]
    ddc:i64[64,64] = select_n dda dcx ddb
    ddd:bool[64,64] = lt dcz 0:i64[]
    dde:i64[64,64] = add dcz 64:i64[]
    ddf:i64[64,64] = select_n ddd dcz dde
    ddg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ddc
    ddh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ddf
    ddi:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ddg
    ddj:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ddh
    ddk:i32[64,64,2] = concatenate[dimension=2] ddi ddj
    ddl:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] cpg ddk
    ddm:f64[64,64] = mul dcv ddl
    ddn:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dck
    ddo:f64[64,64] = squeeze[dimensions=(0,)] ddn
    ddp:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dcj
    ddq:f64[64,64] = squeeze[dimensions=(0,)] ddp
    ddr:f64[64,64] = mul ddo ddq
    dds:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dco
    ddt:i64[64,64] = squeeze[dimensions=(0,)] dds
    ddu:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dcq
    ddv:i64[64,64] = squeeze[dimensions=(0,)] ddu
    ddw:bool[64,64] = lt ddt 0:i64[]
    ddx:i64[64,64] = add ddt 64:i64[]
    ddy:i64[64,64] = select_n ddw ddt ddx
    ddz:bool[64,64] = lt ddv 0:i64[]
    dea:i64[64,64] = add ddv 64:i64[]
    deb:i64[64,64] = select_n ddz ddv dea
    dec:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ddy
    ded:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] deb
    dee:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dec
    def:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ded
    deg:i32[64,64,2] = concatenate[dimension=2] dee def
    deh:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] cpg deg
    dei:f64[64,64] = mul ddr deh
    dej:f64[64,64] = add ddm dei
    dek:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dcj
    del:f64[64,64] = squeeze[dimensions=(0,)] dek
    dem:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dck
    den:f64[64,64] = squeeze[dimensions=(0,)] dem
    deo:f64[64,64] = mul del den
    dep:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dcq
    deq:i64[64,64] = squeeze[dimensions=(0,)] dep
    der:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dco
    des:i64[64,64] = squeeze[dimensions=(0,)] der
    det:bool[64,64] = lt deq 0:i64[]
    deu:i64[64,64] = add deq 64:i64[]
    dev:i64[64,64] = select_n det deq deu
    dew:bool[64,64] = lt des 0:i64[]
    dex:i64[64,64] = add des 64:i64[]
    dey:i64[64,64] = select_n dew des dex
    dez:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dev
    dfa:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dey
    dfb:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dez
    dfc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dfa
    dfd:i32[64,64,2] = concatenate[dimension=2] dfb dfc
    dfe:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] cpg dfd
    dff:f64[64,64] = mul deo dfe
    dfg:f64[64,64] = add dej dff
    dfh:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dcj
    dfi:f64[64,64] = squeeze[dimensions=(0,)] dfh
    dfj:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dcj
    dfk:f64[64,64] = squeeze[dimensions=(0,)] dfj
    dfl:f64[64,64] = mul dfi dfk
    dfm:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dcq
    dfn:i64[64,64] = squeeze[dimensions=(0,)] dfm
    dfo:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dcq
    dfp:i64[64,64] = squeeze[dimensions=(0,)] dfo
    dfq:bool[64,64] = lt dfn 0:i64[]
    dfr:i64[64,64] = add dfn 64:i64[]
    dfs:i64[64,64] = select_n dfq dfn dfr
    dft:bool[64,64] = lt dfp 0:i64[]
    dfu:i64[64,64] = add dfp 64:i64[]
    dfv:i64[64,64] = select_n dft dfp dfu
    dfw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dfs
    dfx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dfv
    dfy:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dfw
    dfz:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dfx
    dga:i32[64,64,2] = concatenate[dimension=2] dfy dfz
    dgb:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] cpg dga
    dgc:f64[64,64] = mul dfl dgb
    dgd:f64[64,64] = add dfg dgc
    dge:f64[2,64,64] = neg dcb
    dgf:f64[] = neg eb
    dgg:f64[] = convert_element_type[new_dtype=float64 weak_type=False] dgf
    dgh:f64[2,64,64] = mul dgg dge
    dgi:f64[2,64,64] = add ea dgh
    dgj:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    dgk:f64[2,64,64] = div dgi dgj
    dgl:f64[2,64,64] = floor dgk
    dgm:f64[2,64,64] = sub dgk dgl
    dgn:f64[2,64,64] = sub 1.0:f64[] dgm
    dgo:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] bs
    dgp:f64[2,64,64] = sub dgk dgm
    dgq:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] dgp
    dgr:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dgq dgo
    dgs:i64[2,64,64] = add dgr 1:i64[]
    dgt:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dgs dgo
    dgu:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dgn
    dgv:f64[64,64] = squeeze[dimensions=(0,)] dgu
    dgw:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dgn
    dgx:f64[64,64] = squeeze[dimensions=(0,)] dgw
    dgy:f64[64,64] = mul dgv dgx
    dgz:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dgr
    dha:i64[64,64] = squeeze[dimensions=(0,)] dgz
    dhb:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dgr
    dhc:i64[64,64] = squeeze[dimensions=(0,)] dhb
    dhd:bool[64,64] = lt dha 0:i64[]
    dhe:i64[64,64] = add dha 64:i64[]
    dhf:i64[64,64] = select_n dhd dha dhe
    dhg:bool[64,64] = lt dhc 0:i64[]
    dhh:i64[64,64] = add dhc 64:i64[]
    dhi:i64[64,64] = select_n dhg dhc dhh
    dhj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dhf
    dhk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dhi
    dhl:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dhj
    dhm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dhk
    dhn:i32[64,64,2] = concatenate[dimension=2] dhl dhm
    dho:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dgd dhn
    dhp:f64[64,64] = mul dgy dho
    dhq:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dgn
    dhr:f64[64,64] = squeeze[dimensions=(0,)] dhq
    dhs:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dgm
    dht:f64[64,64] = squeeze[dimensions=(0,)] dhs
    dhu:f64[64,64] = mul dhr dht
    dhv:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dgr
    dhw:i64[64,64] = squeeze[dimensions=(0,)] dhv
    dhx:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dgt
    dhy:i64[64,64] = squeeze[dimensions=(0,)] dhx
    dhz:bool[64,64] = lt dhw 0:i64[]
    dia:i64[64,64] = add dhw 64:i64[]
    dib:i64[64,64] = select_n dhz dhw dia
    dic:bool[64,64] = lt dhy 0:i64[]
    did:i64[64,64] = add dhy 64:i64[]
    die:i64[64,64] = select_n dic dhy did
    dif:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dib
    dig:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] die
    dih:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dif
    dii:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dig
    dij:i32[64,64,2] = concatenate[dimension=2] dih dii
    dik:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dgd dij
    dil:f64[64,64] = mul dhu dik
    dim:f64[64,64] = add dhp dil
    din:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dgm
    dio:f64[64,64] = squeeze[dimensions=(0,)] din
    dip:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dgn
    diq:f64[64,64] = squeeze[dimensions=(0,)] dip
    dir:f64[64,64] = mul dio diq
    dis:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dgt
    dit:i64[64,64] = squeeze[dimensions=(0,)] dis
    diu:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dgr
    div:i64[64,64] = squeeze[dimensions=(0,)] diu
    diw:bool[64,64] = lt dit 0:i64[]
    dix:i64[64,64] = add dit 64:i64[]
    diy:i64[64,64] = select_n diw dit dix
    diz:bool[64,64] = lt div 0:i64[]
    dja:i64[64,64] = add div 64:i64[]
    djb:i64[64,64] = select_n diz div dja
    djc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] diy
    djd:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] djb
    dje:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] djc
    djf:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] djd
    djg:i32[64,64,2] = concatenate[dimension=2] dje djf
    djh:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dgd djg
    dji:f64[64,64] = mul dir djh
    djj:f64[64,64] = add dim dji
    djk:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dgm
    djl:f64[64,64] = squeeze[dimensions=(0,)] djk
    djm:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dgm
    djn:f64[64,64] = squeeze[dimensions=(0,)] djm
    djo:f64[64,64] = mul djl djn
    djp:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dgt
    djq:i64[64,64] = squeeze[dimensions=(0,)] djp
    djr:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dgt
    djs:i64[64,64] = squeeze[dimensions=(0,)] djr
    djt:bool[64,64] = lt djq 0:i64[]
    dju:i64[64,64] = add djq 64:i64[]
    djv:i64[64,64] = select_n djt djq dju
    djw:bool[64,64] = lt djs 0:i64[]
    djx:i64[64,64] = add djs 64:i64[]
    djy:i64[64,64] = select_n djw djs djx
    djz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] djv
    dka:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] djy
    dkb:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] djz
    dkc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dka
    dkd:i32[64,64,2] = concatenate[dimension=2] dkb dkc
    dke:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dgd dkd
    dkf:f64[64,64] = mul djo dke
    dkg:f64[64,64] = add djj dkf
    dkh:f64[64,64] = sub cpg dkg
    dki:f64[64,64] = div dkh 2.0:f64[]
    dkj:f64[64,64] = add cpg dki
    dkk:f64[] = neg eb
    dkl:f64[] = convert_element_type[new_dtype=float64 weak_type=False] dkk
    dkm:f64[2,64,64] = mul dkl dcb
    dkn:f64[2,64,64] = add ea dkm
    dko:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    dkp:f64[2,64,64] = div dkn dko
    dkq:f64[2,64,64] = floor dkp
    dkr:f64[2,64,64] = sub dkp dkq
    dks:f64[2,64,64] = sub 1.0:f64[] dkr
    dkt:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] bt
    dku:f64[2,64,64] = sub dkp dkr
    dkv:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] dku
    dkw:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dkv dkt
    dkx:i64[2,64,64] = add dkw 1:i64[]
    dky:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dkx dkt
    dkz:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dks
    dla:f64[64,64] = squeeze[dimensions=(0,)] dkz
    dlb:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dks
    dlc:f64[64,64] = squeeze[dimensions=(0,)] dlb
    dld:f64[64,64] = mul dla dlc
    dle:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dkw
    dlf:i64[64,64] = squeeze[dimensions=(0,)] dle
    dlg:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dkw
    dlh:i64[64,64] = squeeze[dimensions=(0,)] dlg
    dli:bool[64,64] = lt dlf 0:i64[]
    dlj:i64[64,64] = add dlf 64:i64[]
    dlk:i64[64,64] = select_n dli dlf dlj
    dll:bool[64,64] = lt dlh 0:i64[]
    dlm:i64[64,64] = add dlh 64:i64[]
    dln:i64[64,64] = select_n dll dlh dlm
    dlo:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dlk
    dlp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dln
    dlq:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dlo
    dlr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dlp
    dls:i32[64,64,2] = concatenate[dimension=2] dlq dlr
    dlt:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dkj dls
    dlu:f64[64,64] = mul dld dlt
    dlv:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dks
    dlw:f64[64,64] = squeeze[dimensions=(0,)] dlv
    dlx:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dkr
    dly:f64[64,64] = squeeze[dimensions=(0,)] dlx
    dlz:f64[64,64] = mul dlw dly
    dma:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dkw
    dmb:i64[64,64] = squeeze[dimensions=(0,)] dma
    dmc:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dky
    dmd:i64[64,64] = squeeze[dimensions=(0,)] dmc
    dme:bool[64,64] = lt dmb 0:i64[]
    dmf:i64[64,64] = add dmb 64:i64[]
    dmg:i64[64,64] = select_n dme dmb dmf
    dmh:bool[64,64] = lt dmd 0:i64[]
    dmi:i64[64,64] = add dmd 64:i64[]
    dmj:i64[64,64] = select_n dmh dmd dmi
    dmk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dmg
    dml:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dmj
    dmm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dmk
    dmn:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dml
    dmo:i32[64,64,2] = concatenate[dimension=2] dmm dmn
    dmp:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dkj dmo
    dmq:f64[64,64] = mul dlz dmp
    dmr:f64[64,64] = add dlu dmq
    dms:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dkr
    dmt:f64[64,64] = squeeze[dimensions=(0,)] dms
    dmu:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dks
    dmv:f64[64,64] = squeeze[dimensions=(0,)] dmu
    dmw:f64[64,64] = mul dmt dmv
    dmx:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dky
    dmy:i64[64,64] = squeeze[dimensions=(0,)] dmx
    dmz:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dkw
    dna:i64[64,64] = squeeze[dimensions=(0,)] dmz
    dnb:bool[64,64] = lt dmy 0:i64[]
    dnc:i64[64,64] = add dmy 64:i64[]
    dnd:i64[64,64] = select_n dnb dmy dnc
    dne:bool[64,64] = lt dna 0:i64[]
    dnf:i64[64,64] = add dna 64:i64[]
    dng:i64[64,64] = select_n dne dna dnf
    dnh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dnd
    dni:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dng
    dnj:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dnh
    dnk:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dni
    dnl:i32[64,64,2] = concatenate[dimension=2] dnj dnk
    dnm:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dkj dnl
    dnn:f64[64,64] = mul dmw dnm
    dno:f64[64,64] = add dmr dnn
    dnp:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dkr
    dnq:f64[64,64] = squeeze[dimensions=(0,)] dnp
    dnr:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dkr
    dns:f64[64,64] = squeeze[dimensions=(0,)] dnr
    dnt:f64[64,64] = mul dnq dns
    dnu:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dky
    dnv:i64[64,64] = squeeze[dimensions=(0,)] dnu
    dnw:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dky
    dnx:i64[64,64] = squeeze[dimensions=(0,)] dnw
    dny:bool[64,64] = lt dnv 0:i64[]
    dnz:i64[64,64] = add dnv 64:i64[]
    doa:i64[64,64] = select_n dny dnv dnz
    dob:bool[64,64] = lt dnx 0:i64[]
    doc:i64[64,64] = add dnx 64:i64[]
    dod:i64[64,64] = select_n dob dnx doc
    doe:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] doa
    dof:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dod
    dog:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] doe
    doh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dof
    doi:i32[64,64,2] = concatenate[dimension=2] dog doh
    doj:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dkj doi
    dok:f64[64,64] = mul dnt doj
    dol:f64[64,64] = add dno dok
    dom:f64[] = neg eb
    don:f64[] = convert_element_type[new_dtype=float64 weak_type=False] dom
    doo:f64[2,64,64] = mul don dcb
    dop:f64[2,64,64] = add ea doo
    doq:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    dor:f64[2,64,64] = div dop doq
    dos:f64[2,64,64] = floor dor
    dot:f64[2,64,64] = sub dor dos
    dou:f64[2,64,64] = sub 1.0:f64[] dot
    dov:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] bu
    dow:f64[2,64,64] = sub dor dot
    dox:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] dow
    doy:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dox dov
    doz:i64[2,64,64] = add doy 1:i64[]
    dpa:i64[2,64,64] = jit[name=remainder jaxpr=remainder] doz dov
    dpb:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dou
    dpc:f64[64,64] = squeeze[dimensions=(0,)] dpb
    dpd:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dou
    dpe:f64[64,64] = squeeze[dimensions=(0,)] dpd
    dpf:f64[64,64] = mul dpc dpe
    dpg:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] doy
    dph:i64[64,64] = squeeze[dimensions=(0,)] dpg
    dpi:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] doy
    dpj:i64[64,64] = squeeze[dimensions=(0,)] dpi
    dpk:bool[64,64] = lt dph 0:i64[]
    dpl:i64[64,64] = add dph 64:i64[]
    dpm:i64[64,64] = select_n dpk dph dpl
    dpn:bool[64,64] = lt dpj 0:i64[]
    dpo:i64[64,64] = add dpj 64:i64[]
    dpp:i64[64,64] = select_n dpn dpj dpo
    dpq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dpm
    dpr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dpp
    dps:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dpq
    dpt:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dpr
    dpu:i32[64,64,2] = concatenate[dimension=2] dps dpt
    dpv:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dbq dpu
    dpw:f64[64,64] = mul dpf dpv
    dpx:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dou
    dpy:f64[64,64] = squeeze[dimensions=(0,)] dpx
    dpz:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dot
    dqa:f64[64,64] = squeeze[dimensions=(0,)] dpz
    dqb:f64[64,64] = mul dpy dqa
    dqc:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] doy
    dqd:i64[64,64] = squeeze[dimensions=(0,)] dqc
    dqe:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dpa
    dqf:i64[64,64] = squeeze[dimensions=(0,)] dqe
    dqg:bool[64,64] = lt dqd 0:i64[]
    dqh:i64[64,64] = add dqd 64:i64[]
    dqi:i64[64,64] = select_n dqg dqd dqh
    dqj:bool[64,64] = lt dqf 0:i64[]
    dqk:i64[64,64] = add dqf 64:i64[]
    dql:i64[64,64] = select_n dqj dqf dqk
    dqm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dqi
    dqn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dql
    dqo:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dqm
    dqp:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dqn
    dqq:i32[64,64,2] = concatenate[dimension=2] dqo dqp
    dqr:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dbq dqq
    dqs:f64[64,64] = mul dqb dqr
    dqt:f64[64,64] = add dpw dqs
    dqu:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dot
    dqv:f64[64,64] = squeeze[dimensions=(0,)] dqu
    dqw:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dou
    dqx:f64[64,64] = squeeze[dimensions=(0,)] dqw
    dqy:f64[64,64] = mul dqv dqx
    dqz:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dpa
    dra:i64[64,64] = squeeze[dimensions=(0,)] dqz
    drb:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] doy
    drc:i64[64,64] = squeeze[dimensions=(0,)] drb
    drd:bool[64,64] = lt dra 0:i64[]
    dre:i64[64,64] = add dra 64:i64[]
    drf:i64[64,64] = select_n drd dra dre
    drg:bool[64,64] = lt drc 0:i64[]
    drh:i64[64,64] = add drc 64:i64[]
    dri:i64[64,64] = select_n drg drc drh
    drj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] drf
    drk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dri
    drl:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] drj
    drm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] drk
    drn:i32[64,64,2] = concatenate[dimension=2] drl drm
    dro:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dbq drn
    drp:f64[64,64] = mul dqy dro
    drq:f64[64,64] = add dqt drp
    drr:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dot
    drs:f64[64,64] = squeeze[dimensions=(0,)] drr
    drt:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dot
    dru:f64[64,64] = squeeze[dimensions=(0,)] drt
    drv:f64[64,64] = mul drs dru
    drw:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dpa
    drx:i64[64,64] = squeeze[dimensions=(0,)] drw
    dry:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dpa
    drz:i64[64,64] = squeeze[dimensions=(0,)] dry
    dsa:bool[64,64] = lt drx 0:i64[]
    dsb:i64[64,64] = add drx 64:i64[]
    dsc:i64[64,64] = select_n dsa drx dsb
    dsd:bool[64,64] = lt drz 0:i64[]
    dse:i64[64,64] = add drz 64:i64[]
    dsf:i64[64,64] = select_n dsd drz dse
    dsg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dsc
    dsh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dsf
    dsi:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dsg
    dsj:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dsh
    dsk:i32[64,64,2] = concatenate[dimension=2] dsi dsj
    dsl:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dbq dsk
    dsm:f64[64,64] = mul drv dsl
    dsn:f64[64,64] = add drq dsm
    dso:f64[2,64,64] = neg dcb
    dsp:f64[] = neg eb
    dsq:f64[] = convert_element_type[new_dtype=float64 weak_type=False] dsp
    dsr:f64[2,64,64] = mul dsq dso
    dss:f64[2,64,64] = add ea dsr
    dst:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    dsu:f64[2,64,64] = div dss dst
    dsv:f64[2,64,64] = floor dsu
    dsw:f64[2,64,64] = sub dsu dsv
    dsx:f64[2,64,64] = sub 1.0:f64[] dsw
    dsy:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] bv
    dsz:f64[2,64,64] = sub dsu dsw
    dta:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] dsz
    dtb:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dta dsy
    dtc:i64[2,64,64] = add dtb 1:i64[]
    dtd:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dtc dsy
    dte:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dsx
    dtf:f64[64,64] = squeeze[dimensions=(0,)] dte
    dtg:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dsx
    dth:f64[64,64] = squeeze[dimensions=(0,)] dtg
    dti:f64[64,64] = mul dtf dth
    dtj:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dtb
    dtk:i64[64,64] = squeeze[dimensions=(0,)] dtj
    dtl:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dtb
    dtm:i64[64,64] = squeeze[dimensions=(0,)] dtl
    dtn:bool[64,64] = lt dtk 0:i64[]
    dto:i64[64,64] = add dtk 64:i64[]
    dtp:i64[64,64] = select_n dtn dtk dto
    dtq:bool[64,64] = lt dtm 0:i64[]
    dtr:i64[64,64] = add dtm 64:i64[]
    dts:i64[64,64] = select_n dtq dtm dtr
    dtt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dtp
    dtu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dts
    dtv:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dtt
    dtw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dtu
    dtx:i32[64,64,2] = concatenate[dimension=2] dtv dtw
    dty:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dsn dtx
    dtz:f64[64,64] = mul dti dty
    dua:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dsx
    dub:f64[64,64] = squeeze[dimensions=(0,)] dua
    duc:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dsw
    dud:f64[64,64] = squeeze[dimensions=(0,)] duc
    due:f64[64,64] = mul dub dud
    duf:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dtb
    dug:i64[64,64] = squeeze[dimensions=(0,)] duf
    duh:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dtd
    dui:i64[64,64] = squeeze[dimensions=(0,)] duh
    duj:bool[64,64] = lt dug 0:i64[]
    duk:i64[64,64] = add dug 64:i64[]
    dul:i64[64,64] = select_n duj dug duk
    dum:bool[64,64] = lt dui 0:i64[]
    dun:i64[64,64] = add dui 64:i64[]
    duo:i64[64,64] = select_n dum dui dun
    dup:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dul
    duq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] duo
    dur:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dup
    dus:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] duq
    dut:i32[64,64,2] = concatenate[dimension=2] dur dus
    duu:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dsn dut
    duv:f64[64,64] = mul due duu
    duw:f64[64,64] = add dtz duv
    dux:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dsw
    duy:f64[64,64] = squeeze[dimensions=(0,)] dux
    duz:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dsx
    dva:f64[64,64] = squeeze[dimensions=(0,)] duz
    dvb:f64[64,64] = mul duy dva
    dvc:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dtd
    dvd:i64[64,64] = squeeze[dimensions=(0,)] dvc
    dve:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dtb
    dvf:i64[64,64] = squeeze[dimensions=(0,)] dve
    dvg:bool[64,64] = lt dvd 0:i64[]
    dvh:i64[64,64] = add dvd 64:i64[]
    dvi:i64[64,64] = select_n dvg dvd dvh
    dvj:bool[64,64] = lt dvf 0:i64[]
    dvk:i64[64,64] = add dvf 64:i64[]
    dvl:i64[64,64] = select_n dvj dvf dvk
    dvm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dvi
    dvn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dvl
    dvo:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dvm
    dvp:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dvn
    dvq:i32[64,64,2] = concatenate[dimension=2] dvo dvp
    dvr:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dsn dvq
    dvs:f64[64,64] = mul dvb dvr
    dvt:f64[64,64] = add duw dvs
    dvu:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dsw
    dvv:f64[64,64] = squeeze[dimensions=(0,)] dvu
    dvw:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dsw
    dvx:f64[64,64] = squeeze[dimensions=(0,)] dvw
    dvy:f64[64,64] = mul dvv dvx
    dvz:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dtd
    dwa:i64[64,64] = squeeze[dimensions=(0,)] dvz
    dwb:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dtd
    dwc:i64[64,64] = squeeze[dimensions=(0,)] dwb
    dwd:bool[64,64] = lt dwa 0:i64[]
    dwe:i64[64,64] = add dwa 64:i64[]
    dwf:i64[64,64] = select_n dwd dwa dwe
    dwg:bool[64,64] = lt dwc 0:i64[]
    dwh:i64[64,64] = add dwc 64:i64[]
    dwi:i64[64,64] = select_n dwg dwc dwh
    dwj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dwf
    dwk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dwi
    dwl:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dwj
    dwm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dwk
    dwn:i32[64,64,2] = concatenate[dimension=2] dwl dwm
    dwo:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dsn dwn
    dwp:f64[64,64] = mul dvy dwo
    dwq:f64[64,64] = add dvt dwp
    dwr:f64[64,64] = sub dbq dwq
    dws:f64[64,64] = div dwr 2.0:f64[]
    dwt:f64[64,64] = add dbq dws
    dwu:f64[] = neg eb
    dwv:f64[] = convert_element_type[new_dtype=float64 weak_type=False] dwu
    dww:f64[2,64,64] = mul dwv dcb
    dwx:f64[2,64,64] = add ea dww
    dwy:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    dwz:f64[2,64,64] = div dwx dwy
    dxa:f64[2,64,64] = floor dwz
    dxb:f64[2,64,64] = sub dwz dxa
    dxc:f64[2,64,64] = sub 1.0:f64[] dxb
    dxd:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] bw
    dxe:f64[2,64,64] = sub dwz dxb
    dxf:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] dxe
    dxg:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dxf dxd
    dxh:i64[2,64,64] = add dxg 1:i64[]
    dxi:i64[2,64,64] = jit[name=remainder jaxpr=remainder] dxh dxd
    dxj:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dxc
    dxk:f64[64,64] = squeeze[dimensions=(0,)] dxj
    dxl:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dxc
    dxm:f64[64,64] = squeeze[dimensions=(0,)] dxl
    dxn:f64[64,64] = mul dxk dxm
    dxo:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dxg
    dxp:i64[64,64] = squeeze[dimensions=(0,)] dxo
    dxq:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dxg
    dxr:i64[64,64] = squeeze[dimensions=(0,)] dxq
    dxs:bool[64,64] = lt dxp 0:i64[]
    dxt:i64[64,64] = add dxp 64:i64[]
    dxu:i64[64,64] = select_n dxs dxp dxt
    dxv:bool[64,64] = lt dxr 0:i64[]
    dxw:i64[64,64] = add dxr 64:i64[]
    dxx:i64[64,64] = select_n dxv dxr dxw
    dxy:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dxu
    dxz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dxx
    dya:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dxy
    dyb:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dxz
    dyc:i32[64,64,2] = concatenate[dimension=2] dya dyb
    dyd:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dwt dyc
    dye:f64[64,64] = mul dxn dyd
    dyf:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dxc
    dyg:f64[64,64] = squeeze[dimensions=(0,)] dyf
    dyh:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dxb
    dyi:f64[64,64] = squeeze[dimensions=(0,)] dyh
    dyj:f64[64,64] = mul dyg dyi
    dyk:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dxg
    dyl:i64[64,64] = squeeze[dimensions=(0,)] dyk
    dym:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dxi
    dyn:i64[64,64] = squeeze[dimensions=(0,)] dym
    dyo:bool[64,64] = lt dyl 0:i64[]
    dyp:i64[64,64] = add dyl 64:i64[]
    dyq:i64[64,64] = select_n dyo dyl dyp
    dyr:bool[64,64] = lt dyn 0:i64[]
    dys:i64[64,64] = add dyn 64:i64[]
    dyt:i64[64,64] = select_n dyr dyn dys
    dyu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dyq
    dyv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dyt
    dyw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dyu
    dyx:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dyv
    dyy:i32[64,64,2] = concatenate[dimension=2] dyw dyx
    dyz:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dwt dyy
    dza:f64[64,64] = mul dyj dyz
    dzb:f64[64,64] = add dye dza
    dzc:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dxb
    dzd:f64[64,64] = squeeze[dimensions=(0,)] dzc
    dze:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dxc
    dzf:f64[64,64] = squeeze[dimensions=(0,)] dze
    dzg:f64[64,64] = mul dzd dzf
    dzh:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dxi
    dzi:i64[64,64] = squeeze[dimensions=(0,)] dzh
    dzj:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dxg
    dzk:i64[64,64] = squeeze[dimensions=(0,)] dzj
    dzl:bool[64,64] = lt dzi 0:i64[]
    dzm:i64[64,64] = add dzi 64:i64[]
    dzn:i64[64,64] = select_n dzl dzi dzm
    dzo:bool[64,64] = lt dzk 0:i64[]
    dzp:i64[64,64] = add dzk 64:i64[]
    dzq:i64[64,64] = select_n dzo dzk dzp
    dzr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dzn
    dzs:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] dzq
    dzt:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dzr
    dzu:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] dzs
    dzv:i32[64,64,2] = concatenate[dimension=2] dzt dzu
    dzw:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dwt dzv
    dzx:f64[64,64] = mul dzg dzw
    dzy:f64[64,64] = add dzb dzx
    dzz:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dxb
    eaa:f64[64,64] = squeeze[dimensions=(0,)] dzz
    eab:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dxb
    eac:f64[64,64] = squeeze[dimensions=(0,)] eab
    ead:f64[64,64] = mul eaa eac
    eae:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] dxi
    eaf:i64[64,64] = squeeze[dimensions=(0,)] eae
    eag:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] dxi
    eah:i64[64,64] = squeeze[dimensions=(0,)] eag
    eai:bool[64,64] = lt eaf 0:i64[]
    eaj:i64[64,64] = add eaf 64:i64[]
    eak:i64[64,64] = select_n eai eaf eaj
    eal:bool[64,64] = lt eah 0:i64[]
    eam:i64[64,64] = add eah 64:i64[]
    ean:i64[64,64] = select_n eal eah eam
    eao:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eak
    eap:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ean
    eaq:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] eao
    ear:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] eap
    eas:i32[64,64,2] = concatenate[dimension=2] eaq ear
    eat:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dwt eas
    eau:f64[64,64] = mul ead eat
    eav:f64[64,64] = add dzy eau
    eaw:c128[64,33] = jit[name=fft jaxpr=fft] dol
    eax:c128[] = reduce_prod[axes=(0,)] bx
    eay:c128[] = sqrt eax
    eaz:c128[] = div (1+0j):c128[] eay
    eba:c128[64,33] = mul eaw eaz
    ebb:c128[1,64,33] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 64, 33)
      sharding=None
    ] eba
    ebc:c128[2,64,33] = mul dw ebb
    ebd:f64[2,64,64] = jit[name=fft jaxpr=fft1] ebc
    ebe:f64[] = reduce_prod[axes=(0,)] by
    ebf:f64[] = sqrt ebe
    ebg:f64[2,64,64] = mul ebd ebf
    ebh:f64[] = neg eb
    ebi:f64[] = convert_element_type[new_dtype=float64 weak_type=False] ebh
    ebj:f64[2,64,64] = mul ebi ebg
    ebk:f64[2,64,64] = add ea ebj
    ebl:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    ebm:f64[2,64,64] = div ebk ebl
    ebn:f64[2,64,64] = floor ebm
    ebo:f64[2,64,64] = sub ebm ebn
    ebp:f64[2,64,64] = sub 1.0:f64[] ebo
    ebq:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] bz
    ebr:f64[2,64,64] = sub ebm ebo
    ebs:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ebr
    ebt:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ebs ebq
    ebu:i64[2,64,64] = add ebt 1:i64[]
    ebv:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ebu ebq
    ebw:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ebp
    ebx:f64[64,64] = squeeze[dimensions=(0,)] ebw
    eby:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ebp
    ebz:f64[64,64] = squeeze[dimensions=(0,)] eby
    eca:f64[64,64] = mul ebx ebz
    ecb:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ebt
    ecc:i64[64,64] = squeeze[dimensions=(0,)] ecb
    ecd:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ebt
    ece:i64[64,64] = squeeze[dimensions=(0,)] ecd
    ecf:bool[64,64] = lt ecc 0:i64[]
    ecg:i64[64,64] = add ecc 64:i64[]
    ech:i64[64,64] = select_n ecf ecc ecg
    eci:bool[64,64] = lt ece 0:i64[]
    ecj:i64[64,64] = add ece 64:i64[]
    eck:i64[64,64] = select_n eci ece ecj
    ecl:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ech
    ecm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eck
    ecn:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ecl
    eco:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ecm
    ecp:i32[64,64,2] = concatenate[dimension=2] ecn eco
    ecq:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dol ecp
    ecr:f64[64,64] = mul eca ecq
    ecs:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ebp
    ect:f64[64,64] = squeeze[dimensions=(0,)] ecs
    ecu:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ebo
    ecv:f64[64,64] = squeeze[dimensions=(0,)] ecu
    ecw:f64[64,64] = mul ect ecv
    ecx:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ebt
    ecy:i64[64,64] = squeeze[dimensions=(0,)] ecx
    ecz:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ebv
    eda:i64[64,64] = squeeze[dimensions=(0,)] ecz
    edb:bool[64,64] = lt ecy 0:i64[]
    edc:i64[64,64] = add ecy 64:i64[]
    edd:i64[64,64] = select_n edb ecy edc
    ede:bool[64,64] = lt eda 0:i64[]
    edf:i64[64,64] = add eda 64:i64[]
    edg:i64[64,64] = select_n ede eda edf
    edh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] edd
    edi:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] edg
    edj:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] edh
    edk:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] edi
    edl:i32[64,64,2] = concatenate[dimension=2] edj edk
    edm:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dol edl
    edn:f64[64,64] = mul ecw edm
    edo:f64[64,64] = add ecr edn
    edp:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ebo
    edq:f64[64,64] = squeeze[dimensions=(0,)] edp
    edr:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ebp
    eds:f64[64,64] = squeeze[dimensions=(0,)] edr
    edt:f64[64,64] = mul edq eds
    edu:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ebv
    edv:i64[64,64] = squeeze[dimensions=(0,)] edu
    edw:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ebt
    edx:i64[64,64] = squeeze[dimensions=(0,)] edw
    edy:bool[64,64] = lt edv 0:i64[]
    edz:i64[64,64] = add edv 64:i64[]
    eea:i64[64,64] = select_n edy edv edz
    eeb:bool[64,64] = lt edx 0:i64[]
    eec:i64[64,64] = add edx 64:i64[]
    eed:i64[64,64] = select_n eeb edx eec
    eee:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eea
    eef:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eed
    eeg:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] eee
    eeh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] eef
    eei:i32[64,64,2] = concatenate[dimension=2] eeg eeh
    eej:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dol eei
    eek:f64[64,64] = mul edt eej
    eel:f64[64,64] = add edo eek
    eem:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ebo
    een:f64[64,64] = squeeze[dimensions=(0,)] eem
    eeo:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ebo
    eep:f64[64,64] = squeeze[dimensions=(0,)] eeo
    eeq:f64[64,64] = mul een eep
    eer:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ebv
    ees:i64[64,64] = squeeze[dimensions=(0,)] eer
    eet:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ebv
    eeu:i64[64,64] = squeeze[dimensions=(0,)] eet
    eev:bool[64,64] = lt ees 0:i64[]
    eew:i64[64,64] = add ees 64:i64[]
    eex:i64[64,64] = select_n eev ees eew
    eey:bool[64,64] = lt eeu 0:i64[]
    eez:i64[64,64] = add eeu 64:i64[]
    efa:i64[64,64] = select_n eey eeu eez
    efb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eex
    efc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] efa
    efd:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] efb
    efe:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] efc
    eff:i32[64,64,2] = concatenate[dimension=2] efd efe
    efg:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] dol eff
    efh:f64[64,64] = mul eeq efg
    efi:f64[64,64] = add eel efh
    efj:f64[2,64,64] = neg ebg
    efk:f64[] = neg eb
    efl:f64[] = convert_element_type[new_dtype=float64 weak_type=False] efk
    efm:f64[2,64,64] = mul efl efj
    efn:f64[2,64,64] = add ea efm
    efo:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    efp:f64[2,64,64] = div efn efo
    efq:f64[2,64,64] = floor efp
    efr:f64[2,64,64] = sub efp efq
    efs:f64[2,64,64] = sub 1.0:f64[] efr
    eft:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] ca
    efu:f64[2,64,64] = sub efp efr
    efv:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] efu
    efw:i64[2,64,64] = jit[name=remainder jaxpr=remainder] efv eft
    efx:i64[2,64,64] = add efw 1:i64[]
    efy:i64[2,64,64] = jit[name=remainder jaxpr=remainder] efx eft
    efz:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] efs
    ega:f64[64,64] = squeeze[dimensions=(0,)] efz
    egb:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] efs
    egc:f64[64,64] = squeeze[dimensions=(0,)] egb
    egd:f64[64,64] = mul ega egc
    ege:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] efw
    egf:i64[64,64] = squeeze[dimensions=(0,)] ege
    egg:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] efw
    egh:i64[64,64] = squeeze[dimensions=(0,)] egg
    egi:bool[64,64] = lt egf 0:i64[]
    egj:i64[64,64] = add egf 64:i64[]
    egk:i64[64,64] = select_n egi egf egj
    egl:bool[64,64] = lt egh 0:i64[]
    egm:i64[64,64] = add egh 64:i64[]
    egn:i64[64,64] = select_n egl egh egm
    ego:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] egk
    egp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] egn
    egq:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ego
    egr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] egp
    egs:i32[64,64,2] = concatenate[dimension=2] egq egr
    egt:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] efi egs
    egu:f64[64,64] = mul egd egt
    egv:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] efs
    egw:f64[64,64] = squeeze[dimensions=(0,)] egv
    egx:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] efr
    egy:f64[64,64] = squeeze[dimensions=(0,)] egx
    egz:f64[64,64] = mul egw egy
    eha:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] efw
    ehb:i64[64,64] = squeeze[dimensions=(0,)] eha
    ehc:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] efy
    ehd:i64[64,64] = squeeze[dimensions=(0,)] ehc
    ehe:bool[64,64] = lt ehb 0:i64[]
    ehf:i64[64,64] = add ehb 64:i64[]
    ehg:i64[64,64] = select_n ehe ehb ehf
    ehh:bool[64,64] = lt ehd 0:i64[]
    ehi:i64[64,64] = add ehd 64:i64[]
    ehj:i64[64,64] = select_n ehh ehd ehi
    ehk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ehg
    ehl:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ehj
    ehm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ehk
    ehn:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ehl
    eho:i32[64,64,2] = concatenate[dimension=2] ehm ehn
    ehp:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] efi eho
    ehq:f64[64,64] = mul egz ehp
    ehr:f64[64,64] = add egu ehq
    ehs:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] efr
    eht:f64[64,64] = squeeze[dimensions=(0,)] ehs
    ehu:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] efs
    ehv:f64[64,64] = squeeze[dimensions=(0,)] ehu
    ehw:f64[64,64] = mul eht ehv
    ehx:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] efy
    ehy:i64[64,64] = squeeze[dimensions=(0,)] ehx
    ehz:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] efw
    eia:i64[64,64] = squeeze[dimensions=(0,)] ehz
    eib:bool[64,64] = lt ehy 0:i64[]
    eic:i64[64,64] = add ehy 64:i64[]
    eid:i64[64,64] = select_n eib ehy eic
    eie:bool[64,64] = lt eia 0:i64[]
    eif:i64[64,64] = add eia 64:i64[]
    eig:i64[64,64] = select_n eie eia eif
    eih:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eid
    eii:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eig
    eij:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] eih
    eik:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] eii
    eil:i32[64,64,2] = concatenate[dimension=2] eij eik
    eim:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] efi eil
    ein:f64[64,64] = mul ehw eim
    eio:f64[64,64] = add ehr ein
    eip:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] efr
    eiq:f64[64,64] = squeeze[dimensions=(0,)] eip
    eir:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] efr
    eis:f64[64,64] = squeeze[dimensions=(0,)] eir
    eit:f64[64,64] = mul eiq eis
    eiu:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] efy
    eiv:i64[64,64] = squeeze[dimensions=(0,)] eiu
    eiw:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] efy
    eix:i64[64,64] = squeeze[dimensions=(0,)] eiw
    eiy:bool[64,64] = lt eiv 0:i64[]
    eiz:i64[64,64] = add eiv 64:i64[]
    eja:i64[64,64] = select_n eiy eiv eiz
    ejb:bool[64,64] = lt eix 0:i64[]
    ejc:i64[64,64] = add eix 64:i64[]
    ejd:i64[64,64] = select_n ejb eix ejc
    eje:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eja
    ejf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ejd
    ejg:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] eje
    ejh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ejf
    eji:i32[64,64,2] = concatenate[dimension=2] ejg ejh
    ejj:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] efi eji
    ejk:f64[64,64] = mul eit ejj
    ejl:f64[64,64] = add eio ejk
    ejm:f64[64,64] = sub dol ejl
    ejn:f64[64,64] = div ejm 2.0:f64[]
    ejo:f64[64,64] = add dol ejn
    ejp:f64[] = neg eb
    ejq:f64[] = convert_element_type[new_dtype=float64 weak_type=False] ejp
    ejr:f64[2,64,64] = mul ejq ebg
    ejs:f64[2,64,64] = add ea ejr
    ejt:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    eju:f64[2,64,64] = div ejs ejt
    ejv:f64[2,64,64] = floor eju
    ejw:f64[2,64,64] = sub eju ejv
    ejx:f64[2,64,64] = sub 1.0:f64[] ejw
    ejy:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] cb
    ejz:f64[2,64,64] = sub eju ejw
    eka:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ejz
    ekb:i64[2,64,64] = jit[name=remainder jaxpr=remainder] eka ejy
    ekc:i64[2,64,64] = add ekb 1:i64[]
    ekd:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ekc ejy
    eke:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ejx
    ekf:f64[64,64] = squeeze[dimensions=(0,)] eke
    ekg:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ejx
    ekh:f64[64,64] = squeeze[dimensions=(0,)] ekg
    eki:f64[64,64] = mul ekf ekh
    ekj:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ekb
    ekk:i64[64,64] = squeeze[dimensions=(0,)] ekj
    ekl:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ekb
    ekm:i64[64,64] = squeeze[dimensions=(0,)] ekl
    ekn:bool[64,64] = lt ekk 0:i64[]
    eko:i64[64,64] = add ekk 64:i64[]
    ekp:i64[64,64] = select_n ekn ekk eko
    ekq:bool[64,64] = lt ekm 0:i64[]
    ekr:i64[64,64] = add ekm 64:i64[]
    eks:i64[64,64] = select_n ekq ekm ekr
    ekt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ekp
    eku:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eks
    ekv:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ekt
    ekw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] eku
    ekx:i32[64,64,2] = concatenate[dimension=2] ekv ekw
    eky:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ejo ekx
    ekz:f64[64,64] = mul eki eky
    ela:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ejx
    elb:f64[64,64] = squeeze[dimensions=(0,)] ela
    elc:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ejw
    eld:f64[64,64] = squeeze[dimensions=(0,)] elc
    ele:f64[64,64] = mul elb eld
    elf:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ekb
    elg:i64[64,64] = squeeze[dimensions=(0,)] elf
    elh:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ekd
    eli:i64[64,64] = squeeze[dimensions=(0,)] elh
    elj:bool[64,64] = lt elg 0:i64[]
    elk:i64[64,64] = add elg 64:i64[]
    ell:i64[64,64] = select_n elj elg elk
    elm:bool[64,64] = lt eli 0:i64[]
    eln:i64[64,64] = add eli 64:i64[]
    elo:i64[64,64] = select_n elm eli eln
    elp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ell
    elq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] elo
    elr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] elp
    els:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] elq
    elt:i32[64,64,2] = concatenate[dimension=2] elr els
    elu:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ejo elt
    elv:f64[64,64] = mul ele elu
    elw:f64[64,64] = add ekz elv
    elx:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ejw
    ely:f64[64,64] = squeeze[dimensions=(0,)] elx
    elz:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ejx
    ema:f64[64,64] = squeeze[dimensions=(0,)] elz
    emb:f64[64,64] = mul ely ema
    emc:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ekd
    emd:i64[64,64] = squeeze[dimensions=(0,)] emc
    eme:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ekb
    emf:i64[64,64] = squeeze[dimensions=(0,)] eme
    emg:bool[64,64] = lt emd 0:i64[]
    emh:i64[64,64] = add emd 64:i64[]
    emi:i64[64,64] = select_n emg emd emh
    emj:bool[64,64] = lt emf 0:i64[]
    emk:i64[64,64] = add emf 64:i64[]
    eml:i64[64,64] = select_n emj emf emk
    emm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] emi
    emn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eml
    emo:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] emm
    emp:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] emn
    emq:i32[64,64,2] = concatenate[dimension=2] emo emp
    emr:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ejo emq
    ems:f64[64,64] = mul emb emr
    emt:f64[64,64] = add elw ems
    emu:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ejw
    emv:f64[64,64] = squeeze[dimensions=(0,)] emu
    emw:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ejw
    emx:f64[64,64] = squeeze[dimensions=(0,)] emw
    emy:f64[64,64] = mul emv emx
    emz:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ekd
    ena:i64[64,64] = squeeze[dimensions=(0,)] emz
    enb:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ekd
    enc:i64[64,64] = squeeze[dimensions=(0,)] enb
    end:bool[64,64] = lt ena 0:i64[]
    ene:i64[64,64] = add ena 64:i64[]
    enf:i64[64,64] = select_n end ena ene
    eng:bool[64,64] = lt enc 0:i64[]
    enh:i64[64,64] = add enc 64:i64[]
    eni:i64[64,64] = select_n eng enc enh
    enj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] enf
    enk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eni
    enl:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] enj
    enm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] enk
    enn:i32[64,64,2] = concatenate[dimension=2] enl enm
    eno:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ejo enn
    enp:f64[64,64] = mul emy eno
    enq:f64[64,64] = add emt enp
    enr:f64[] = neg eb
    ens:f64[] = convert_element_type[new_dtype=float64 weak_type=False] enr
    ent:f64[2,64,64] = mul ens ebg
    enu:f64[2,64,64] = add ea ent
    env:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    enw:f64[2,64,64] = div enu env
    enx:f64[2,64,64] = floor enw
    eny:f64[2,64,64] = sub enw enx
    enz:f64[2,64,64] = sub 1.0:f64[] eny
    eoa:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] cc
    eob:f64[2,64,64] = sub enw eny
    eoc:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] eob
    eod:i64[2,64,64] = jit[name=remainder jaxpr=remainder] eoc eoa
    eoe:i64[2,64,64] = add eod 1:i64[]
    eof:i64[2,64,64] = jit[name=remainder jaxpr=remainder] eoe eoa
    eog:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] enz
    eoh:f64[64,64] = squeeze[dimensions=(0,)] eog
    eoi:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] enz
    eoj:f64[64,64] = squeeze[dimensions=(0,)] eoi
    eok:f64[64,64] = mul eoh eoj
    eol:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] eod
    eom:i64[64,64] = squeeze[dimensions=(0,)] eol
    eon:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] eod
    eoo:i64[64,64] = squeeze[dimensions=(0,)] eon
    eop:bool[64,64] = lt eom 0:i64[]
    eoq:i64[64,64] = add eom 64:i64[]
    eor:i64[64,64] = select_n eop eom eoq
    eos:bool[64,64] = lt eoo 0:i64[]
    eot:i64[64,64] = add eoo 64:i64[]
    eou:i64[64,64] = select_n eos eoo eot
    eov:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eor
    eow:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eou
    eox:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] eov
    eoy:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] eow
    eoz:i32[64,64,2] = concatenate[dimension=2] eox eoy
    epa:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] eav eoz
    epb:f64[64,64] = mul eok epa
    epc:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] enz
    epd:f64[64,64] = squeeze[dimensions=(0,)] epc
    epe:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] eny
    epf:f64[64,64] = squeeze[dimensions=(0,)] epe
    epg:f64[64,64] = mul epd epf
    eph:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] eod
    epi:i64[64,64] = squeeze[dimensions=(0,)] eph
    epj:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] eof
    epk:i64[64,64] = squeeze[dimensions=(0,)] epj
    epl:bool[64,64] = lt epi 0:i64[]
    epm:i64[64,64] = add epi 64:i64[]
    epn:i64[64,64] = select_n epl epi epm
    epo:bool[64,64] = lt epk 0:i64[]
    epp:i64[64,64] = add epk 64:i64[]
    epq:i64[64,64] = select_n epo epk epp
    epr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] epn
    eps:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] epq
    ept:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] epr
    epu:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] eps
    epv:i32[64,64,2] = concatenate[dimension=2] ept epu
    epw:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] eav epv
    epx:f64[64,64] = mul epg epw
    epy:f64[64,64] = add epb epx
    epz:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] eny
    eqa:f64[64,64] = squeeze[dimensions=(0,)] epz
    eqb:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] enz
    eqc:f64[64,64] = squeeze[dimensions=(0,)] eqb
    eqd:f64[64,64] = mul eqa eqc
    eqe:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] eof
    eqf:i64[64,64] = squeeze[dimensions=(0,)] eqe
    eqg:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] eod
    eqh:i64[64,64] = squeeze[dimensions=(0,)] eqg
    eqi:bool[64,64] = lt eqf 0:i64[]
    eqj:i64[64,64] = add eqf 64:i64[]
    eqk:i64[64,64] = select_n eqi eqf eqj
    eql:bool[64,64] = lt eqh 0:i64[]
    eqm:i64[64,64] = add eqh 64:i64[]
    eqn:i64[64,64] = select_n eql eqh eqm
    eqo:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eqk
    eqp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eqn
    eqq:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] eqo
    eqr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] eqp
    eqs:i32[64,64,2] = concatenate[dimension=2] eqq eqr
    eqt:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] eav eqs
    equ:f64[64,64] = mul eqd eqt
    eqv:f64[64,64] = add epy equ
    eqw:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] eny
    eqx:f64[64,64] = squeeze[dimensions=(0,)] eqw
    eqy:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] eny
    eqz:f64[64,64] = squeeze[dimensions=(0,)] eqy
    era:f64[64,64] = mul eqx eqz
    erb:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] eof
    erc:i64[64,64] = squeeze[dimensions=(0,)] erb
    erd:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] eof
    ere:i64[64,64] = squeeze[dimensions=(0,)] erd
    erf:bool[64,64] = lt erc 0:i64[]
    erg:i64[64,64] = add erc 64:i64[]
    erh:i64[64,64] = select_n erf erc erg
    eri:bool[64,64] = lt ere 0:i64[]
    erj:i64[64,64] = add ere 64:i64[]
    erk:i64[64,64] = select_n eri ere erj
    erl:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] erh
    erm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] erk
    ern:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] erl
    ero:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] erm
    erp:i32[64,64,2] = concatenate[dimension=2] ern ero
    erq:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] eav erp
    err:f64[64,64] = mul era erq
    ers:f64[64,64] = add eqv err
    ert:f64[2,64,64] = neg ebg
    eru:f64[] = neg eb
    erv:f64[] = convert_element_type[new_dtype=float64 weak_type=False] eru
    erw:f64[2,64,64] = mul erv ert
    erx:f64[2,64,64] = add ea erw
    ery:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    erz:f64[2,64,64] = div erx ery
    esa:f64[2,64,64] = floor erz
    esb:f64[2,64,64] = sub erz esa
    esc:f64[2,64,64] = sub 1.0:f64[] esb
    esd:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] cd
    ese:f64[2,64,64] = sub erz esb
    esf:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ese
    esg:i64[2,64,64] = jit[name=remainder jaxpr=remainder] esf esd
    esh:i64[2,64,64] = add esg 1:i64[]
    esi:i64[2,64,64] = jit[name=remainder jaxpr=remainder] esh esd
    esj:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] esc
    esk:f64[64,64] = squeeze[dimensions=(0,)] esj
    esl:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] esc
    esm:f64[64,64] = squeeze[dimensions=(0,)] esl
    esn:f64[64,64] = mul esk esm
    eso:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] esg
    esp:i64[64,64] = squeeze[dimensions=(0,)] eso
    esq:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] esg
    esr:i64[64,64] = squeeze[dimensions=(0,)] esq
    ess:bool[64,64] = lt esp 0:i64[]
    est:i64[64,64] = add esp 64:i64[]
    esu:i64[64,64] = select_n ess esp est
    esv:bool[64,64] = lt esr 0:i64[]
    esw:i64[64,64] = add esr 64:i64[]
    esx:i64[64,64] = select_n esv esr esw
    esy:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] esu
    esz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] esx
    eta:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] esy
    etb:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] esz
    etc:i32[64,64,2] = concatenate[dimension=2] eta etb
    etd:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ers etc
    ete:f64[64,64] = mul esn etd
    etf:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] esc
    etg:f64[64,64] = squeeze[dimensions=(0,)] etf
    eth:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] esb
    eti:f64[64,64] = squeeze[dimensions=(0,)] eth
    etj:f64[64,64] = mul etg eti
    etk:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] esg
    etl:i64[64,64] = squeeze[dimensions=(0,)] etk
    etm:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] esi
    etn:i64[64,64] = squeeze[dimensions=(0,)] etm
    eto:bool[64,64] = lt etl 0:i64[]
    etp:i64[64,64] = add etl 64:i64[]
    etq:i64[64,64] = select_n eto etl etp
    etr:bool[64,64] = lt etn 0:i64[]
    ets:i64[64,64] = add etn 64:i64[]
    ett:i64[64,64] = select_n etr etn ets
    etu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] etq
    etv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ett
    etw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] etu
    etx:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] etv
    ety:i32[64,64,2] = concatenate[dimension=2] etw etx
    etz:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ers ety
    eua:f64[64,64] = mul etj etz
    eub:f64[64,64] = add ete eua
    euc:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] esb
    eud:f64[64,64] = squeeze[dimensions=(0,)] euc
    eue:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] esc
    euf:f64[64,64] = squeeze[dimensions=(0,)] eue
    eug:f64[64,64] = mul eud euf
    euh:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] esi
    eui:i64[64,64] = squeeze[dimensions=(0,)] euh
    euj:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] esg
    euk:i64[64,64] = squeeze[dimensions=(0,)] euj
    eul:bool[64,64] = lt eui 0:i64[]
    eum:i64[64,64] = add eui 64:i64[]
    eun:i64[64,64] = select_n eul eui eum
    euo:bool[64,64] = lt euk 0:i64[]
    eup:i64[64,64] = add euk 64:i64[]
    euq:i64[64,64] = select_n euo euk eup
    eur:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eun
    eus:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] euq
    eut:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] eur
    euu:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] eus
    euv:i32[64,64,2] = concatenate[dimension=2] eut euu
    euw:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ers euv
    eux:f64[64,64] = mul eug euw
    euy:f64[64,64] = add eub eux
    euz:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] esb
    eva:f64[64,64] = squeeze[dimensions=(0,)] euz
    evb:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] esb
    evc:f64[64,64] = squeeze[dimensions=(0,)] evb
    evd:f64[64,64] = mul eva evc
    eve:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] esi
    evf:i64[64,64] = squeeze[dimensions=(0,)] eve
    evg:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] esi
    evh:i64[64,64] = squeeze[dimensions=(0,)] evg
    evi:bool[64,64] = lt evf 0:i64[]
    evj:i64[64,64] = add evf 64:i64[]
    evk:i64[64,64] = select_n evi evf evj
    evl:bool[64,64] = lt evh 0:i64[]
    evm:i64[64,64] = add evh 64:i64[]
    evn:i64[64,64] = select_n evl evh evm
    evo:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] evk
    evp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] evn
    evq:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] evo
    evr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] evp
    evs:i32[64,64,2] = concatenate[dimension=2] evq evr
    evt:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ers evs
    evu:f64[64,64] = mul evd evt
    evv:f64[64,64] = add euy evu
    evw:f64[64,64] = sub eav evv
    evx:f64[64,64] = div evw 2.0:f64[]
    evy:f64[64,64] = add eav evx
    evz:f64[] = neg eb
    ewa:f64[] = convert_element_type[new_dtype=float64 weak_type=False] evz
    ewb:f64[2,64,64] = mul ewa ebg
    ewc:f64[2,64,64] = add ea ewb
    ewd:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    ewe:f64[2,64,64] = div ewc ewd
    ewf:f64[2,64,64] = floor ewe
    ewg:f64[2,64,64] = sub ewe ewf
    ewh:f64[2,64,64] = sub 1.0:f64[] ewg
    ewi:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] ce
    ewj:f64[2,64,64] = sub ewe ewg
    ewk:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ewj
    ewl:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ewk ewi
    ewm:i64[2,64,64] = add ewl 1:i64[]
    ewn:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ewm ewi
    ewo:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ewh
    ewp:f64[64,64] = squeeze[dimensions=(0,)] ewo
    ewq:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ewh
    ewr:f64[64,64] = squeeze[dimensions=(0,)] ewq
    ews:f64[64,64] = mul ewp ewr
    ewt:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ewl
    ewu:i64[64,64] = squeeze[dimensions=(0,)] ewt
    ewv:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ewl
    eww:i64[64,64] = squeeze[dimensions=(0,)] ewv
    ewx:bool[64,64] = lt ewu 0:i64[]
    ewy:i64[64,64] = add ewu 64:i64[]
    ewz:i64[64,64] = select_n ewx ewu ewy
    exa:bool[64,64] = lt eww 0:i64[]
    exb:i64[64,64] = add eww 64:i64[]
    exc:i64[64,64] = select_n exa eww exb
    exd:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ewz
    exe:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] exc
    exf:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] exd
    exg:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] exe
    exh:i32[64,64,2] = concatenate[dimension=2] exf exg
    exi:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] evy exh
    exj:f64[64,64] = mul ews exi
    exk:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ewh
    exl:f64[64,64] = squeeze[dimensions=(0,)] exk
    exm:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ewg
    exn:f64[64,64] = squeeze[dimensions=(0,)] exm
    exo:f64[64,64] = mul exl exn
    exp:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ewl
    exq:i64[64,64] = squeeze[dimensions=(0,)] exp
    exr:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ewn
    exs:i64[64,64] = squeeze[dimensions=(0,)] exr
    ext:bool[64,64] = lt exq 0:i64[]
    exu:i64[64,64] = add exq 64:i64[]
    exv:i64[64,64] = select_n ext exq exu
    exw:bool[64,64] = lt exs 0:i64[]
    exx:i64[64,64] = add exs 64:i64[]
    exy:i64[64,64] = select_n exw exs exx
    exz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] exv
    eya:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] exy
    eyb:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] exz
    eyc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] eya
    eyd:i32[64,64,2] = concatenate[dimension=2] eyb eyc
    eye:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] evy eyd
    eyf:f64[64,64] = mul exo eye
    eyg:f64[64,64] = add exj eyf
    eyh:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ewg
    eyi:f64[64,64] = squeeze[dimensions=(0,)] eyh
    eyj:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ewh
    eyk:f64[64,64] = squeeze[dimensions=(0,)] eyj
    eyl:f64[64,64] = mul eyi eyk
    eym:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ewn
    eyn:i64[64,64] = squeeze[dimensions=(0,)] eym
    eyo:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ewl
    eyp:i64[64,64] = squeeze[dimensions=(0,)] eyo
    eyq:bool[64,64] = lt eyn 0:i64[]
    eyr:i64[64,64] = add eyn 64:i64[]
    eys:i64[64,64] = select_n eyq eyn eyr
    eyt:bool[64,64] = lt eyp 0:i64[]
    eyu:i64[64,64] = add eyp 64:i64[]
    eyv:i64[64,64] = select_n eyt eyp eyu
    eyw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eys
    eyx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] eyv
    eyy:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] eyw
    eyz:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] eyx
    eza:i32[64,64,2] = concatenate[dimension=2] eyy eyz
    ezb:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] evy eza
    ezc:f64[64,64] = mul eyl ezb
    ezd:f64[64,64] = add eyg ezc
    eze:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ewg
    ezf:f64[64,64] = squeeze[dimensions=(0,)] eze
    ezg:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ewg
    ezh:f64[64,64] = squeeze[dimensions=(0,)] ezg
    ezi:f64[64,64] = mul ezf ezh
    ezj:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ewn
    ezk:i64[64,64] = squeeze[dimensions=(0,)] ezj
    ezl:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ewn
    ezm:i64[64,64] = squeeze[dimensions=(0,)] ezl
    ezn:bool[64,64] = lt ezk 0:i64[]
    ezo:i64[64,64] = add ezk 64:i64[]
    ezp:i64[64,64] = select_n ezn ezk ezo
    ezq:bool[64,64] = lt ezm 0:i64[]
    ezr:i64[64,64] = add ezm 64:i64[]
    ezs:i64[64,64] = select_n ezq ezm ezr
    ezt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ezp
    ezu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ezs
    ezv:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ezt
    ezw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ezu
    ezx:i32[64,64,2] = concatenate[dimension=2] ezv ezw
    ezy:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] evy ezx
    ezz:f64[64,64] = mul ezi ezy
    faa:f64[64,64] = add ezd ezz
    fab:c128[64,33] = jit[name=fft jaxpr=fft] enq
    fac:c128[] = reduce_prod[axes=(0,)] cf
    fad:c128[] = sqrt fac
    fae:c128[] = div (1+0j):c128[] fad
    faf:c128[64,33] = mul fab fae
    fag:c128[1,64,33] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 64, 33)
      sharding=None
    ] faf
    fah:c128[2,64,33] = mul dw fag
    fai:f64[2,64,64] = jit[name=fft jaxpr=fft1] fah
    faj:f64[] = reduce_prod[axes=(0,)] cg
    fak:f64[] = sqrt faj
    fal:f64[2,64,64] = mul fai fak
    fam:f64[] = neg eb
    fan:f64[] = convert_element_type[new_dtype=float64 weak_type=False] fam
    fao:f64[2,64,64] = mul fan fal
    fap:f64[2,64,64] = add ea fao
    faq:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    far:f64[2,64,64] = div fap faq
    fas:f64[2,64,64] = floor far
    fat:f64[2,64,64] = sub far fas
    fau:f64[2,64,64] = sub 1.0:f64[] fat
    fav:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] ch
    faw:f64[2,64,64] = sub far fat
    fax:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] faw
    fay:i64[2,64,64] = jit[name=remainder jaxpr=remainder] fax fav
    faz:i64[2,64,64] = add fay 1:i64[]
    fba:i64[2,64,64] = jit[name=remainder jaxpr=remainder] faz fav
    fbb:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fau
    fbc:f64[64,64] = squeeze[dimensions=(0,)] fbb
    fbd:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fau
    fbe:f64[64,64] = squeeze[dimensions=(0,)] fbd
    fbf:f64[64,64] = mul fbc fbe
    fbg:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fay
    fbh:i64[64,64] = squeeze[dimensions=(0,)] fbg
    fbi:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fay
    fbj:i64[64,64] = squeeze[dimensions=(0,)] fbi
    fbk:bool[64,64] = lt fbh 0:i64[]
    fbl:i64[64,64] = add fbh 64:i64[]
    fbm:i64[64,64] = select_n fbk fbh fbl
    fbn:bool[64,64] = lt fbj 0:i64[]
    fbo:i64[64,64] = add fbj 64:i64[]
    fbp:i64[64,64] = select_n fbn fbj fbo
    fbq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fbm
    fbr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fbp
    fbs:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fbq
    fbt:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fbr
    fbu:i32[64,64,2] = concatenate[dimension=2] fbs fbt
    fbv:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] enq fbu
    fbw:f64[64,64] = mul fbf fbv
    fbx:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fau
    fby:f64[64,64] = squeeze[dimensions=(0,)] fbx
    fbz:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fat
    fca:f64[64,64] = squeeze[dimensions=(0,)] fbz
    fcb:f64[64,64] = mul fby fca
    fcc:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fay
    fcd:i64[64,64] = squeeze[dimensions=(0,)] fcc
    fce:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fba
    fcf:i64[64,64] = squeeze[dimensions=(0,)] fce
    fcg:bool[64,64] = lt fcd 0:i64[]
    fch:i64[64,64] = add fcd 64:i64[]
    fci:i64[64,64] = select_n fcg fcd fch
    fcj:bool[64,64] = lt fcf 0:i64[]
    fck:i64[64,64] = add fcf 64:i64[]
    fcl:i64[64,64] = select_n fcj fcf fck
    fcm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fci
    fcn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fcl
    fco:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fcm
    fcp:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fcn
    fcq:i32[64,64,2] = concatenate[dimension=2] fco fcp
    fcr:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] enq fcq
    fcs:f64[64,64] = mul fcb fcr
    fct:f64[64,64] = add fbw fcs
    fcu:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fat
    fcv:f64[64,64] = squeeze[dimensions=(0,)] fcu
    fcw:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fau
    fcx:f64[64,64] = squeeze[dimensions=(0,)] fcw
    fcy:f64[64,64] = mul fcv fcx
    fcz:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fba
    fda:i64[64,64] = squeeze[dimensions=(0,)] fcz
    fdb:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fay
    fdc:i64[64,64] = squeeze[dimensions=(0,)] fdb
    fdd:bool[64,64] = lt fda 0:i64[]
    fde:i64[64,64] = add fda 64:i64[]
    fdf:i64[64,64] = select_n fdd fda fde
    fdg:bool[64,64] = lt fdc 0:i64[]
    fdh:i64[64,64] = add fdc 64:i64[]
    fdi:i64[64,64] = select_n fdg fdc fdh
    fdj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fdf
    fdk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fdi
    fdl:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fdj
    fdm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fdk
    fdn:i32[64,64,2] = concatenate[dimension=2] fdl fdm
    fdo:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] enq fdn
    fdp:f64[64,64] = mul fcy fdo
    fdq:f64[64,64] = add fct fdp
    fdr:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fat
    fds:f64[64,64] = squeeze[dimensions=(0,)] fdr
    fdt:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fat
    fdu:f64[64,64] = squeeze[dimensions=(0,)] fdt
    fdv:f64[64,64] = mul fds fdu
    fdw:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fba
    fdx:i64[64,64] = squeeze[dimensions=(0,)] fdw
    fdy:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fba
    fdz:i64[64,64] = squeeze[dimensions=(0,)] fdy
    fea:bool[64,64] = lt fdx 0:i64[]
    feb:i64[64,64] = add fdx 64:i64[]
    fec:i64[64,64] = select_n fea fdx feb
    fed:bool[64,64] = lt fdz 0:i64[]
    fee:i64[64,64] = add fdz 64:i64[]
    fef:i64[64,64] = select_n fed fdz fee
    feg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fec
    feh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fef
    fei:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] feg
    fej:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] feh
    fek:i32[64,64,2] = concatenate[dimension=2] fei fej
    fel:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] enq fek
    fem:f64[64,64] = mul fdv fel
    fen:f64[64,64] = add fdq fem
    feo:f64[2,64,64] = neg fal
    fep:f64[] = neg eb
    feq:f64[] = convert_element_type[new_dtype=float64 weak_type=False] fep
    fer:f64[2,64,64] = mul feq feo
    fes:f64[2,64,64] = add ea fer
    fet:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    feu:f64[2,64,64] = div fes fet
    fev:f64[2,64,64] = floor feu
    few:f64[2,64,64] = sub feu fev
    fex:f64[2,64,64] = sub 1.0:f64[] few
    fey:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] ci
    fez:f64[2,64,64] = sub feu few
    ffa:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] fez
    ffb:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ffa fey
    ffc:i64[2,64,64] = add ffb 1:i64[]
    ffd:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ffc fey
    ffe:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fex
    fff:f64[64,64] = squeeze[dimensions=(0,)] ffe
    ffg:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fex
    ffh:f64[64,64] = squeeze[dimensions=(0,)] ffg
    ffi:f64[64,64] = mul fff ffh
    ffj:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ffb
    ffk:i64[64,64] = squeeze[dimensions=(0,)] ffj
    ffl:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ffb
    ffm:i64[64,64] = squeeze[dimensions=(0,)] ffl
    ffn:bool[64,64] = lt ffk 0:i64[]
    ffo:i64[64,64] = add ffk 64:i64[]
    ffp:i64[64,64] = select_n ffn ffk ffo
    ffq:bool[64,64] = lt ffm 0:i64[]
    ffr:i64[64,64] = add ffm 64:i64[]
    ffs:i64[64,64] = select_n ffq ffm ffr
    ffu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ffp
    ffv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ffs
    ffw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ffu
    ffx:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ffv
    ffy:i32[64,64,2] = concatenate[dimension=2] ffw ffx
    ffz:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fen ffy
    fga:f64[64,64] = mul ffi ffz
    fgb:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fex
    fgc:f64[64,64] = squeeze[dimensions=(0,)] fgb
    fgd:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] few
    fge:f64[64,64] = squeeze[dimensions=(0,)] fgd
    fgf:f64[64,64] = mul fgc fge
    fgg:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ffb
    fgh:i64[64,64] = squeeze[dimensions=(0,)] fgg
    fgi:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ffd
    fgj:i64[64,64] = squeeze[dimensions=(0,)] fgi
    fgk:bool[64,64] = lt fgh 0:i64[]
    fgl:i64[64,64] = add fgh 64:i64[]
    fgm:i64[64,64] = select_n fgk fgh fgl
    fgn:bool[64,64] = lt fgj 0:i64[]
    fgo:i64[64,64] = add fgj 64:i64[]
    fgp:i64[64,64] = select_n fgn fgj fgo
    fgq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fgm
    fgr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fgp
    fgs:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fgq
    fgt:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fgr
    fgu:i32[64,64,2] = concatenate[dimension=2] fgs fgt
    fgv:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fen fgu
    fgw:f64[64,64] = mul fgf fgv
    fgx:f64[64,64] = add fga fgw
    fgy:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] few
    fgz:f64[64,64] = squeeze[dimensions=(0,)] fgy
    fha:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fex
    fhb:f64[64,64] = squeeze[dimensions=(0,)] fha
    fhc:f64[64,64] = mul fgz fhb
    fhd:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ffd
    fhe:i64[64,64] = squeeze[dimensions=(0,)] fhd
    fhf:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ffb
    fhg:i64[64,64] = squeeze[dimensions=(0,)] fhf
    fhh:bool[64,64] = lt fhe 0:i64[]
    fhi:i64[64,64] = add fhe 64:i64[]
    fhj:i64[64,64] = select_n fhh fhe fhi
    fhk:bool[64,64] = lt fhg 0:i64[]
    fhl:i64[64,64] = add fhg 64:i64[]
    fhm:i64[64,64] = select_n fhk fhg fhl
    fhn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fhj
    fho:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fhm
    fhp:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fhn
    fhq:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fho
    fhr:i32[64,64,2] = concatenate[dimension=2] fhp fhq
    fhs:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fen fhr
    fht:f64[64,64] = mul fhc fhs
    fhu:f64[64,64] = add fgx fht
    fhv:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] few
    fhw:f64[64,64] = squeeze[dimensions=(0,)] fhv
    fhx:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] few
    fhy:f64[64,64] = squeeze[dimensions=(0,)] fhx
    fhz:f64[64,64] = mul fhw fhy
    fia:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ffd
    fib:i64[64,64] = squeeze[dimensions=(0,)] fia
    fic:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ffd
    fid:i64[64,64] = squeeze[dimensions=(0,)] fic
    fie:bool[64,64] = lt fib 0:i64[]
    fif:i64[64,64] = add fib 64:i64[]
    fig:i64[64,64] = select_n fie fib fif
    fih:bool[64,64] = lt fid 0:i64[]
    fii:i64[64,64] = add fid 64:i64[]
    fij:i64[64,64] = select_n fih fid fii
    fik:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fig
    fil:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fij
    fim:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fik
    fin:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fil
    fio:i32[64,64,2] = concatenate[dimension=2] fim fin
    fip:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fen fio
    fiq:f64[64,64] = mul fhz fip
    fir:f64[64,64] = add fhu fiq
    fis:f64[64,64] = sub enq fir
    fit:f64[64,64] = div fis 2.0:f64[]
    fiu:f64[64,64] = add enq fit
    fiv:f64[] = neg eb
    fiw:f64[] = convert_element_type[new_dtype=float64 weak_type=False] fiv
    fix:f64[2,64,64] = mul fiw fal
    fiy:f64[2,64,64] = add ea fix
    fiz:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    fja:f64[2,64,64] = div fiy fiz
    fjb:f64[2,64,64] = floor fja
    fjc:f64[2,64,64] = sub fja fjb
    fjd:f64[2,64,64] = sub 1.0:f64[] fjc
    fje:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] cj
    fjf:f64[2,64,64] = sub fja fjc
    fjg:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] fjf
    fjh:i64[2,64,64] = jit[name=remainder jaxpr=remainder] fjg fje
    fji:i64[2,64,64] = add fjh 1:i64[]
    fjj:i64[2,64,64] = jit[name=remainder jaxpr=remainder] fji fje
    fjk:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fjd
    fjl:f64[64,64] = squeeze[dimensions=(0,)] fjk
    fjm:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fjd
    fjn:f64[64,64] = squeeze[dimensions=(0,)] fjm
    fjo:f64[64,64] = mul fjl fjn
    fjp:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fjh
    fjq:i64[64,64] = squeeze[dimensions=(0,)] fjp
    fjr:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fjh
    fjs:i64[64,64] = squeeze[dimensions=(0,)] fjr
    fjt:bool[64,64] = lt fjq 0:i64[]
    fju:i64[64,64] = add fjq 64:i64[]
    fjv:i64[64,64] = select_n fjt fjq fju
    fjw:bool[64,64] = lt fjs 0:i64[]
    fjx:i64[64,64] = add fjs 64:i64[]
    fjy:i64[64,64] = select_n fjw fjs fjx
    fjz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fjv
    fka:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fjy
    fkb:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fjz
    fkc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fka
    fkd:i32[64,64,2] = concatenate[dimension=2] fkb fkc
    fke:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fiu fkd
    fkf:f64[64,64] = mul fjo fke
    fkg:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fjd
    fkh:f64[64,64] = squeeze[dimensions=(0,)] fkg
    fki:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fjc
    fkj:f64[64,64] = squeeze[dimensions=(0,)] fki
    fkk:f64[64,64] = mul fkh fkj
    fkl:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fjh
    fkm:i64[64,64] = squeeze[dimensions=(0,)] fkl
    fkn:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fjj
    fko:i64[64,64] = squeeze[dimensions=(0,)] fkn
    fkp:bool[64,64] = lt fkm 0:i64[]
    fkq:i64[64,64] = add fkm 64:i64[]
    fkr:i64[64,64] = select_n fkp fkm fkq
    fks:bool[64,64] = lt fko 0:i64[]
    fkt:i64[64,64] = add fko 64:i64[]
    fku:i64[64,64] = select_n fks fko fkt
    fkv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fkr
    fkw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fku
    fkx:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fkv
    fky:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fkw
    fkz:i32[64,64,2] = concatenate[dimension=2] fkx fky
    fla:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fiu fkz
    flb:f64[64,64] = mul fkk fla
    flc:f64[64,64] = add fkf flb
    fld:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fjc
    fle:f64[64,64] = squeeze[dimensions=(0,)] fld
    flf:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fjd
    flg:f64[64,64] = squeeze[dimensions=(0,)] flf
    flh:f64[64,64] = mul fle flg
    fli:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fjj
    flj:i64[64,64] = squeeze[dimensions=(0,)] fli
    flk:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fjh
    fll:i64[64,64] = squeeze[dimensions=(0,)] flk
    flm:bool[64,64] = lt flj 0:i64[]
    fln:i64[64,64] = add flj 64:i64[]
    flo:i64[64,64] = select_n flm flj fln
    flp:bool[64,64] = lt fll 0:i64[]
    flq:i64[64,64] = add fll 64:i64[]
    flr:i64[64,64] = select_n flp fll flq
    fls:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] flo
    flt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] flr
    flu:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fls
    flv:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] flt
    flw:i32[64,64,2] = concatenate[dimension=2] flu flv
    flx:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fiu flw
    fly:f64[64,64] = mul flh flx
    flz:f64[64,64] = add flc fly
    fma:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fjc
    fmb:f64[64,64] = squeeze[dimensions=(0,)] fma
    fmc:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fjc
    fmd:f64[64,64] = squeeze[dimensions=(0,)] fmc
    fme:f64[64,64] = mul fmb fmd
    fmf:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fjj
    fmg:i64[64,64] = squeeze[dimensions=(0,)] fmf
    fmh:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fjj
    fmi:i64[64,64] = squeeze[dimensions=(0,)] fmh
    fmj:bool[64,64] = lt fmg 0:i64[]
    fmk:i64[64,64] = add fmg 64:i64[]
    fml:i64[64,64] = select_n fmj fmg fmk
    fmm:bool[64,64] = lt fmi 0:i64[]
    fmn:i64[64,64] = add fmi 64:i64[]
    fmo:i64[64,64] = select_n fmm fmi fmn
    fmp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fml
    fmq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fmo
    fmr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fmp
    fms:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fmq
    fmt:i32[64,64,2] = concatenate[dimension=2] fmr fms
    fmu:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fiu fmt
    fmv:f64[64,64] = mul fme fmu
    fmw:f64[64,64] = add flz fmv
    fmx:f64[] = neg eb
    fmy:f64[] = convert_element_type[new_dtype=float64 weak_type=False] fmx
    fmz:f64[2,64,64] = mul fmy fal
    fna:f64[2,64,64] = add ea fmz
    fnb:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    fnc:f64[2,64,64] = div fna fnb
    fnd:f64[2,64,64] = floor fnc
    fne:f64[2,64,64] = sub fnc fnd
    fnf:f64[2,64,64] = sub 1.0:f64[] fne
    fng:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] ck
    fnh:f64[2,64,64] = sub fnc fne
    fni:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] fnh
    fnj:i64[2,64,64] = jit[name=remainder jaxpr=remainder] fni fng
    fnk:i64[2,64,64] = add fnj 1:i64[]
    fnl:i64[2,64,64] = jit[name=remainder jaxpr=remainder] fnk fng
    fnm:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fnf
    fnn:f64[64,64] = squeeze[dimensions=(0,)] fnm
    fno:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fnf
    fnp:f64[64,64] = squeeze[dimensions=(0,)] fno
    fnq:f64[64,64] = mul fnn fnp
    fnr:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fnj
    fns:i64[64,64] = squeeze[dimensions=(0,)] fnr
    fnt:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fnj
    fnu:i64[64,64] = squeeze[dimensions=(0,)] fnt
    fnv:bool[64,64] = lt fns 0:i64[]
    fnw:i64[64,64] = add fns 64:i64[]
    fnx:i64[64,64] = select_n fnv fns fnw
    fny:bool[64,64] = lt fnu 0:i64[]
    fnz:i64[64,64] = add fnu 64:i64[]
    foa:i64[64,64] = select_n fny fnu fnz
    fob:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fnx
    foc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] foa
    fod:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fob
    foe:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] foc
    fof:i32[64,64,2] = concatenate[dimension=2] fod foe
    fog:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] faa fof
    foh:f64[64,64] = mul fnq fog
    foi:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fnf
    foj:f64[64,64] = squeeze[dimensions=(0,)] foi
    fok:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fne
    fol:f64[64,64] = squeeze[dimensions=(0,)] fok
    fom:f64[64,64] = mul foj fol
    fon:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fnj
    foo:i64[64,64] = squeeze[dimensions=(0,)] fon
    fop:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fnl
    foq:i64[64,64] = squeeze[dimensions=(0,)] fop
    for:bool[64,64] = lt foo 0:i64[]
    fos:i64[64,64] = add foo 64:i64[]
    fot:i64[64,64] = select_n for foo fos
    fou:bool[64,64] = lt foq 0:i64[]
    fov:i64[64,64] = add foq 64:i64[]
    fow:i64[64,64] = select_n fou foq fov
    fox:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fot
    foy:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fow
    foz:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fox
    fpa:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] foy
    fpb:i32[64,64,2] = concatenate[dimension=2] foz fpa
    fpc:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] faa fpb
    fpd:f64[64,64] = mul fom fpc
    fpe:f64[64,64] = add foh fpd
    fpf:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fne
    fpg:f64[64,64] = squeeze[dimensions=(0,)] fpf
    fph:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fnf
    fpi:f64[64,64] = squeeze[dimensions=(0,)] fph
    fpj:f64[64,64] = mul fpg fpi
    fpk:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fnl
    fpl:i64[64,64] = squeeze[dimensions=(0,)] fpk
    fpm:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fnj
    fpn:i64[64,64] = squeeze[dimensions=(0,)] fpm
    fpo:bool[64,64] = lt fpl 0:i64[]
    fpp:i64[64,64] = add fpl 64:i64[]
    fpq:i64[64,64] = select_n fpo fpl fpp
    fpr:bool[64,64] = lt fpn 0:i64[]
    fps:i64[64,64] = add fpn 64:i64[]
    fpt:i64[64,64] = select_n fpr fpn fps
    fpu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fpq
    fpv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fpt
    fpw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fpu
    fpx:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fpv
    fpy:i32[64,64,2] = concatenate[dimension=2] fpw fpx
    fpz:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] faa fpy
    fqa:f64[64,64] = mul fpj fpz
    fqb:f64[64,64] = add fpe fqa
    fqc:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fne
    fqd:f64[64,64] = squeeze[dimensions=(0,)] fqc
    fqe:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fne
    fqf:f64[64,64] = squeeze[dimensions=(0,)] fqe
    fqg:f64[64,64] = mul fqd fqf
    fqh:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fnl
    fqi:i64[64,64] = squeeze[dimensions=(0,)] fqh
    fqj:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fnl
    fqk:i64[64,64] = squeeze[dimensions=(0,)] fqj
    fql:bool[64,64] = lt fqi 0:i64[]
    fqm:i64[64,64] = add fqi 64:i64[]
    fqn:i64[64,64] = select_n fql fqi fqm
    fqo:bool[64,64] = lt fqk 0:i64[]
    fqp:i64[64,64] = add fqk 64:i64[]
    fqq:i64[64,64] = select_n fqo fqk fqp
    fqr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fqn
    fqs:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fqq
    fqt:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fqr
    fqu:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fqs
    fqv:i32[64,64,2] = concatenate[dimension=2] fqt fqu
    fqw:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] faa fqv
    fqx:f64[64,64] = mul fqg fqw
    fqy:f64[64,64] = add fqb fqx
    fqz:f64[2,64,64] = neg fal
    fra:f64[] = neg eb
    frb:f64[] = convert_element_type[new_dtype=float64 weak_type=False] fra
    frc:f64[2,64,64] = mul frb fqz
    frd:f64[2,64,64] = add ea frc
    fre:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    frf:f64[2,64,64] = div frd fre
    frg:f64[2,64,64] = floor frf
    frh:f64[2,64,64] = sub frf frg
    fri:f64[2,64,64] = sub 1.0:f64[] frh
    frj:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] cl
    frk:f64[2,64,64] = sub frf frh
    frl:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] frk
    frm:i64[2,64,64] = jit[name=remainder jaxpr=remainder] frl frj
    frn:i64[2,64,64] = add frm 1:i64[]
    fro:i64[2,64,64] = jit[name=remainder jaxpr=remainder] frn frj
    frp:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fri
    frq:f64[64,64] = squeeze[dimensions=(0,)] frp
    frr:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fri
    frs:f64[64,64] = squeeze[dimensions=(0,)] frr
    frt:f64[64,64] = mul frq frs
    fru:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] frm
    frv:i64[64,64] = squeeze[dimensions=(0,)] fru
    frw:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] frm
    frx:i64[64,64] = squeeze[dimensions=(0,)] frw
    fry:bool[64,64] = lt frv 0:i64[]
    frz:i64[64,64] = add frv 64:i64[]
    fsa:i64[64,64] = select_n fry frv frz
    fsb:bool[64,64] = lt frx 0:i64[]
    fsc:i64[64,64] = add frx 64:i64[]
    fsd:i64[64,64] = select_n fsb frx fsc
    fse:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fsa
    fsf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fsd
    fsg:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fse
    fsh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fsf
    fsi:i32[64,64,2] = concatenate[dimension=2] fsg fsh
    fsj:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fqy fsi
    fsk:f64[64,64] = mul frt fsj
    fsl:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fri
    fsm:f64[64,64] = squeeze[dimensions=(0,)] fsl
    fsn:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] frh
    fso:f64[64,64] = squeeze[dimensions=(0,)] fsn
    fsp:f64[64,64] = mul fsm fso
    fsq:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] frm
    fsr:i64[64,64] = squeeze[dimensions=(0,)] fsq
    fss:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fro
    fst:i64[64,64] = squeeze[dimensions=(0,)] fss
    fsu:bool[64,64] = lt fsr 0:i64[]
    fsv:i64[64,64] = add fsr 64:i64[]
    fsw:i64[64,64] = select_n fsu fsr fsv
    fsx:bool[64,64] = lt fst 0:i64[]
    fsy:i64[64,64] = add fst 64:i64[]
    fsz:i64[64,64] = select_n fsx fst fsy
    fta:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fsw
    ftb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fsz
    ftc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fta
    ftd:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ftb
    fte:i32[64,64,2] = concatenate[dimension=2] ftc ftd
    ftf:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fqy fte
    ftg:f64[64,64] = mul fsp ftf
    fth:f64[64,64] = add fsk ftg
    fti:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] frh
    ftj:f64[64,64] = squeeze[dimensions=(0,)] fti
    ftk:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fri
    ftl:f64[64,64] = squeeze[dimensions=(0,)] ftk
    ftm:f64[64,64] = mul ftj ftl
    ftn:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fro
    fto:i64[64,64] = squeeze[dimensions=(0,)] ftn
    ftp:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] frm
    ftq:i64[64,64] = squeeze[dimensions=(0,)] ftp
    ftr:bool[64,64] = lt fto 0:i64[]
    fts:i64[64,64] = add fto 64:i64[]
    ftt:i64[64,64] = select_n ftr fto fts
    ftu:bool[64,64] = lt ftq 0:i64[]
    ftv:i64[64,64] = add ftq 64:i64[]
    ftw:i64[64,64] = select_n ftu ftq ftv
    ftx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ftt
    fty:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ftw
    ftz:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ftx
    fua:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fty
    fub:i32[64,64,2] = concatenate[dimension=2] ftz fua
    fuc:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fqy fub
    fud:f64[64,64] = mul ftm fuc
    fue:f64[64,64] = add fth fud
    fuf:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] frh
    fug:f64[64,64] = squeeze[dimensions=(0,)] fuf
    fuh:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] frh
    fui:f64[64,64] = squeeze[dimensions=(0,)] fuh
    fuj:f64[64,64] = mul fug fui
    fuk:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fro
    ful:i64[64,64] = squeeze[dimensions=(0,)] fuk
    fum:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fro
    fun:i64[64,64] = squeeze[dimensions=(0,)] fum
    fuo:bool[64,64] = lt ful 0:i64[]
    fup:i64[64,64] = add ful 64:i64[]
    fuq:i64[64,64] = select_n fuo ful fup
    fur:bool[64,64] = lt fun 0:i64[]
    fus:i64[64,64] = add fun 64:i64[]
    fut:i64[64,64] = select_n fur fun fus
    fuu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fuq
    fuv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fut
    fuw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fuu
    fux:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fuv
    fuy:i32[64,64,2] = concatenate[dimension=2] fuw fux
    fuz:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fqy fuy
    fva:f64[64,64] = mul fuj fuz
    fvb:f64[64,64] = add fue fva
    fvc:f64[64,64] = sub faa fvb
    fvd:f64[64,64] = div fvc 2.0:f64[]
    fve:f64[64,64] = add faa fvd
    fvf:f64[] = neg eb
    fvg:f64[] = convert_element_type[new_dtype=float64 weak_type=False] fvf
    fvh:f64[2,64,64] = mul fvg fal
    fvi:f64[2,64,64] = add ea fvh
    fvj:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    fvk:f64[2,64,64] = div fvi fvj
    fvl:f64[2,64,64] = floor fvk
    fvm:f64[2,64,64] = sub fvk fvl
    fvn:f64[2,64,64] = sub 1.0:f64[] fvm
    fvo:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] cm
    fvp:f64[2,64,64] = sub fvk fvm
    fvq:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] fvp
    fvr:i64[2,64,64] = jit[name=remainder jaxpr=remainder] fvq fvo
    fvs:i64[2,64,64] = add fvr 1:i64[]
    fvt:i64[2,64,64] = jit[name=remainder jaxpr=remainder] fvs fvo
    fvu:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fvn
    fvv:f64[64,64] = squeeze[dimensions=(0,)] fvu
    fvw:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fvn
    fvx:f64[64,64] = squeeze[dimensions=(0,)] fvw
    fvy:f64[64,64] = mul fvv fvx
    fvz:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fvr
    fwa:i64[64,64] = squeeze[dimensions=(0,)] fvz
    fwb:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fvr
    fwc:i64[64,64] = squeeze[dimensions=(0,)] fwb
    fwd:bool[64,64] = lt fwa 0:i64[]
    fwe:i64[64,64] = add fwa 64:i64[]
    fwf:i64[64,64] = select_n fwd fwa fwe
    fwg:bool[64,64] = lt fwc 0:i64[]
    fwh:i64[64,64] = add fwc 64:i64[]
    fwi:i64[64,64] = select_n fwg fwc fwh
    fwj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fwf
    fwk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fwi
    fwl:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fwj
    fwm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fwk
    fwn:i32[64,64,2] = concatenate[dimension=2] fwl fwm
    fwo:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fve fwn
    fwp:f64[64,64] = mul fvy fwo
    fwq:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fvn
    fwr:f64[64,64] = squeeze[dimensions=(0,)] fwq
    fws:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fvm
    fwt:f64[64,64] = squeeze[dimensions=(0,)] fws
    fwu:f64[64,64] = mul fwr fwt
    fwv:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fvr
    fww:i64[64,64] = squeeze[dimensions=(0,)] fwv
    fwx:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fvt
    fwy:i64[64,64] = squeeze[dimensions=(0,)] fwx
    fwz:bool[64,64] = lt fww 0:i64[]
    fxa:i64[64,64] = add fww 64:i64[]
    fxb:i64[64,64] = select_n fwz fww fxa
    fxc:bool[64,64] = lt fwy 0:i64[]
    fxd:i64[64,64] = add fwy 64:i64[]
    fxe:i64[64,64] = select_n fxc fwy fxd
    fxf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fxb
    fxg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fxe
    fxh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fxf
    fxi:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fxg
    fxj:i32[64,64,2] = concatenate[dimension=2] fxh fxi
    fxk:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fve fxj
    fxl:f64[64,64] = mul fwu fxk
    fxm:f64[64,64] = add fwp fxl
    fxn:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fvm
    fxo:f64[64,64] = squeeze[dimensions=(0,)] fxn
    fxp:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fvn
    fxq:f64[64,64] = squeeze[dimensions=(0,)] fxp
    fxr:f64[64,64] = mul fxo fxq
    fxs:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fvt
    fxt:i64[64,64] = squeeze[dimensions=(0,)] fxs
    fxu:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fvr
    fxv:i64[64,64] = squeeze[dimensions=(0,)] fxu
    fxw:bool[64,64] = lt fxt 0:i64[]
    fxx:i64[64,64] = add fxt 64:i64[]
    fxy:i64[64,64] = select_n fxw fxt fxx
    fxz:bool[64,64] = lt fxv 0:i64[]
    fya:i64[64,64] = add fxv 64:i64[]
    fyb:i64[64,64] = select_n fxz fxv fya
    fyc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fxy
    fyd:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fyb
    fye:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fyc
    fyf:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fyd
    fyg:i32[64,64,2] = concatenate[dimension=2] fye fyf
    fyh:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fve fyg
    fyi:f64[64,64] = mul fxr fyh
    fyj:f64[64,64] = add fxm fyi
    fyk:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fvm
    fyl:f64[64,64] = squeeze[dimensions=(0,)] fyk
    fym:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fvm
    fyn:f64[64,64] = squeeze[dimensions=(0,)] fym
    fyo:f64[64,64] = mul fyl fyn
    fyp:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fvt
    fyq:i64[64,64] = squeeze[dimensions=(0,)] fyp
    fyr:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fvt
    fys:i64[64,64] = squeeze[dimensions=(0,)] fyr
    fyt:bool[64,64] = lt fyq 0:i64[]
    fyu:i64[64,64] = add fyq 64:i64[]
    fyv:i64[64,64] = select_n fyt fyq fyu
    fyw:bool[64,64] = lt fys 0:i64[]
    fyx:i64[64,64] = add fys 64:i64[]
    fyy:i64[64,64] = select_n fyw fys fyx
    fyz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fyv
    fza:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] fyy
    fzb:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fyz
    fzc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] fza
    fzd:i32[64,64,2] = concatenate[dimension=2] fzb fzc
    fze:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fve fzd
    fzf:f64[64,64] = mul fyo fze
    fzg:f64[64,64] = add fyj fzf
    fzh:c128[64,33] = jit[name=fft jaxpr=fft] fmw
    fzi:c128[] = reduce_prod[axes=(0,)] cn
    fzj:c128[] = sqrt fzi
    fzk:c128[] = div (1+0j):c128[] fzj
    fzl:c128[64,33] = mul fzh fzk
    fzm:c128[1,64,33] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 64, 33)
      sharding=None
    ] fzl
    fzn:c128[2,64,33] = mul dw fzm
    fzo:f64[2,64,64] = jit[name=fft jaxpr=fft1] fzn
    fzp:f64[] = reduce_prod[axes=(0,)] co
    fzq:f64[] = sqrt fzp
    fzr:f64[2,64,64] = mul fzo fzq
    fzs:f64[] = neg eb
    fzt:f64[] = convert_element_type[new_dtype=float64 weak_type=False] fzs
    fzu:f64[2,64,64] = mul fzt fzr
    fzv:f64[2,64,64] = add ea fzu
    fzw:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    fzx:f64[2,64,64] = div fzv fzw
    fzy:f64[2,64,64] = floor fzx
    fzz:f64[2,64,64] = sub fzx fzy
    gaa:f64[2,64,64] = sub 1.0:f64[] fzz
    gab:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] cp
    gac:f64[2,64,64] = sub fzx fzz
    gad:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] gac
    gae:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gad gab
    gaf:i64[2,64,64] = add gae 1:i64[]
    gag:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gaf gab
    gah:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gaa
    gai:f64[64,64] = squeeze[dimensions=(0,)] gah
    gaj:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gaa
    gak:f64[64,64] = squeeze[dimensions=(0,)] gaj
    gal:f64[64,64] = mul gai gak
    gam:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gae
    gan:i64[64,64] = squeeze[dimensions=(0,)] gam
    gao:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gae
    gap:i64[64,64] = squeeze[dimensions=(0,)] gao
    gaq:bool[64,64] = lt gan 0:i64[]
    gar:i64[64,64] = add gan 64:i64[]
    gas:i64[64,64] = select_n gaq gan gar
    gat:bool[64,64] = lt gap 0:i64[]
    gau:i64[64,64] = add gap 64:i64[]
    gav:i64[64,64] = select_n gat gap gau
    gaw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gas
    gax:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gav
    gay:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gaw
    gaz:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gax
    gba:i32[64,64,2] = concatenate[dimension=2] gay gaz
    gbb:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fmw gba
    gbc:f64[64,64] = mul gal gbb
    gbd:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gaa
    gbe:f64[64,64] = squeeze[dimensions=(0,)] gbd
    gbf:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fzz
    gbg:f64[64,64] = squeeze[dimensions=(0,)] gbf
    gbh:f64[64,64] = mul gbe gbg
    gbi:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gae
    gbj:i64[64,64] = squeeze[dimensions=(0,)] gbi
    gbk:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gag
    gbl:i64[64,64] = squeeze[dimensions=(0,)] gbk
    gbm:bool[64,64] = lt gbj 0:i64[]
    gbn:i64[64,64] = add gbj 64:i64[]
    gbo:i64[64,64] = select_n gbm gbj gbn
    gbp:bool[64,64] = lt gbl 0:i64[]
    gbq:i64[64,64] = add gbl 64:i64[]
    gbr:i64[64,64] = select_n gbp gbl gbq
    gbs:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gbo
    gbt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gbr
    gbu:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gbs
    gbv:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gbt
    gbw:i32[64,64,2] = concatenate[dimension=2] gbu gbv
    gbx:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fmw gbw
    gby:f64[64,64] = mul gbh gbx
    gbz:f64[64,64] = add gbc gby
    gca:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fzz
    gcb:f64[64,64] = squeeze[dimensions=(0,)] gca
    gcc:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gaa
    gcd:f64[64,64] = squeeze[dimensions=(0,)] gcc
    gce:f64[64,64] = mul gcb gcd
    gcf:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gag
    gcg:i64[64,64] = squeeze[dimensions=(0,)] gcf
    gch:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gae
    gci:i64[64,64] = squeeze[dimensions=(0,)] gch
    gcj:bool[64,64] = lt gcg 0:i64[]
    gck:i64[64,64] = add gcg 64:i64[]
    gcl:i64[64,64] = select_n gcj gcg gck
    gcm:bool[64,64] = lt gci 0:i64[]
    gcn:i64[64,64] = add gci 64:i64[]
    gco:i64[64,64] = select_n gcm gci gcn
    gcp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gcl
    gcq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gco
    gcr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gcp
    gcs:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gcq
    gct:i32[64,64,2] = concatenate[dimension=2] gcr gcs
    gcu:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fmw gct
    gcv:f64[64,64] = mul gce gcu
    gcw:f64[64,64] = add gbz gcv
    gcx:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] fzz
    gcy:f64[64,64] = squeeze[dimensions=(0,)] gcx
    gcz:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] fzz
    gda:f64[64,64] = squeeze[dimensions=(0,)] gcz
    gdb:f64[64,64] = mul gcy gda
    gdc:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gag
    gdd:i64[64,64] = squeeze[dimensions=(0,)] gdc
    gde:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gag
    gdf:i64[64,64] = squeeze[dimensions=(0,)] gde
    gdg:bool[64,64] = lt gdd 0:i64[]
    gdh:i64[64,64] = add gdd 64:i64[]
    gdi:i64[64,64] = select_n gdg gdd gdh
    gdj:bool[64,64] = lt gdf 0:i64[]
    gdk:i64[64,64] = add gdf 64:i64[]
    gdl:i64[64,64] = select_n gdj gdf gdk
    gdm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gdi
    gdn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gdl
    gdo:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gdm
    gdp:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gdn
    gdq:i32[64,64,2] = concatenate[dimension=2] gdo gdp
    gdr:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fmw gdq
    gds:f64[64,64] = mul gdb gdr
    gdt:f64[64,64] = add gcw gds
    gdu:f64[2,64,64] = neg fzr
    gdv:f64[] = neg eb
    gdw:f64[] = convert_element_type[new_dtype=float64 weak_type=False] gdv
    gdx:f64[2,64,64] = mul gdw gdu
    gdy:f64[2,64,64] = add ea gdx
    gdz:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    gea:f64[2,64,64] = div gdy gdz
    geb:f64[2,64,64] = floor gea
    gec:f64[2,64,64] = sub gea geb
    ged:f64[2,64,64] = sub 1.0:f64[] gec
    gee:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] cq
    gef:f64[2,64,64] = sub gea gec
    geg:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] gef
    geh:i64[2,64,64] = jit[name=remainder jaxpr=remainder] geg gee
    gei:i64[2,64,64] = add geh 1:i64[]
    gej:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gei gee
    gek:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ged
    gel:f64[64,64] = squeeze[dimensions=(0,)] gek
    gem:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ged
    gen:f64[64,64] = squeeze[dimensions=(0,)] gem
    geo:f64[64,64] = mul gel gen
    gep:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] geh
    geq:i64[64,64] = squeeze[dimensions=(0,)] gep
    ger:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] geh
    ges:i64[64,64] = squeeze[dimensions=(0,)] ger
    get:bool[64,64] = lt geq 0:i64[]
    geu:i64[64,64] = add geq 64:i64[]
    gev:i64[64,64] = select_n get geq geu
    gew:bool[64,64] = lt ges 0:i64[]
    gex:i64[64,64] = add ges 64:i64[]
    gey:i64[64,64] = select_n gew ges gex
    gez:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gev
    gfa:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gey
    gfb:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gez
    gfc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gfa
    gfd:i32[64,64,2] = concatenate[dimension=2] gfb gfc
    gfe:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] gdt gfd
    gff:f64[64,64] = mul geo gfe
    gfg:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ged
    gfh:f64[64,64] = squeeze[dimensions=(0,)] gfg
    gfi:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gec
    gfj:f64[64,64] = squeeze[dimensions=(0,)] gfi
    gfk:f64[64,64] = mul gfh gfj
    gfl:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] geh
    gfm:i64[64,64] = squeeze[dimensions=(0,)] gfl
    gfn:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gej
    gfo:i64[64,64] = squeeze[dimensions=(0,)] gfn
    gfp:bool[64,64] = lt gfm 0:i64[]
    gfq:i64[64,64] = add gfm 64:i64[]
    gfr:i64[64,64] = select_n gfp gfm gfq
    gfs:bool[64,64] = lt gfo 0:i64[]
    gft:i64[64,64] = add gfo 64:i64[]
    gfu:i64[64,64] = select_n gfs gfo gft
    gfv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gfr
    gfw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gfu
    gfx:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gfv
    gfy:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gfw
    gfz:i32[64,64,2] = concatenate[dimension=2] gfx gfy
    gga:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] gdt gfz
    ggb:f64[64,64] = mul gfk gga
    ggc:f64[64,64] = add gff ggb
    ggd:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gec
    gge:f64[64,64] = squeeze[dimensions=(0,)] ggd
    ggf:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ged
    ggg:f64[64,64] = squeeze[dimensions=(0,)] ggf
    ggh:f64[64,64] = mul gge ggg
    ggi:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gej
    ggj:i64[64,64] = squeeze[dimensions=(0,)] ggi
    ggk:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] geh
    ggl:i64[64,64] = squeeze[dimensions=(0,)] ggk
    ggm:bool[64,64] = lt ggj 0:i64[]
    ggn:i64[64,64] = add ggj 64:i64[]
    ggo:i64[64,64] = select_n ggm ggj ggn
    ggp:bool[64,64] = lt ggl 0:i64[]
    ggq:i64[64,64] = add ggl 64:i64[]
    ggr:i64[64,64] = select_n ggp ggl ggq
    ggs:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ggo
    ggt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ggr
    ggu:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ggs
    ggv:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ggt
    ggw:i32[64,64,2] = concatenate[dimension=2] ggu ggv
    ggx:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] gdt ggw
    ggy:f64[64,64] = mul ggh ggx
    ggz:f64[64,64] = add ggc ggy
    gha:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gec
    ghb:f64[64,64] = squeeze[dimensions=(0,)] gha
    ghc:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gec
    ghd:f64[64,64] = squeeze[dimensions=(0,)] ghc
    ghe:f64[64,64] = mul ghb ghd
    ghf:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gej
    ghg:i64[64,64] = squeeze[dimensions=(0,)] ghf
    ghh:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gej
    ghi:i64[64,64] = squeeze[dimensions=(0,)] ghh
    ghj:bool[64,64] = lt ghg 0:i64[]
    ghk:i64[64,64] = add ghg 64:i64[]
    ghl:i64[64,64] = select_n ghj ghg ghk
    ghm:bool[64,64] = lt ghi 0:i64[]
    ghn:i64[64,64] = add ghi 64:i64[]
    gho:i64[64,64] = select_n ghm ghi ghn
    ghp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ghl
    ghq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gho
    ghr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ghp
    ghs:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ghq
    ght:i32[64,64,2] = concatenate[dimension=2] ghr ghs
    ghu:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] gdt ght
    ghv:f64[64,64] = mul ghe ghu
    ghw:f64[64,64] = add ggz ghv
    ghx:f64[64,64] = sub fmw ghw
    ghy:f64[64,64] = div ghx 2.0:f64[]
    ghz:f64[64,64] = add fmw ghy
    gia:f64[] = neg eb
    gib:f64[] = convert_element_type[new_dtype=float64 weak_type=False] gia
    gic:f64[2,64,64] = mul gib fzr
    gid:f64[2,64,64] = add ea gic
    gie:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    gif:f64[2,64,64] = div gid gie
    gig:f64[2,64,64] = floor gif
    gih:f64[2,64,64] = sub gif gig
    gii:f64[2,64,64] = sub 1.0:f64[] gih
    gij:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] cr
    gik:f64[2,64,64] = sub gif gih
    gil:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] gik
    gim:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gil gij
    gin:i64[2,64,64] = add gim 1:i64[]
    gio:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gin gij
    gip:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gii
    giq:f64[64,64] = squeeze[dimensions=(0,)] gip
    gir:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gii
    gis:f64[64,64] = squeeze[dimensions=(0,)] gir
    git:f64[64,64] = mul giq gis
    giu:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gim
    giv:i64[64,64] = squeeze[dimensions=(0,)] giu
    giw:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gim
    gix:i64[64,64] = squeeze[dimensions=(0,)] giw
    giy:bool[64,64] = lt giv 0:i64[]
    giz:i64[64,64] = add giv 64:i64[]
    gja:i64[64,64] = select_n giy giv giz
    gjb:bool[64,64] = lt gix 0:i64[]
    gjc:i64[64,64] = add gix 64:i64[]
    gjd:i64[64,64] = select_n gjb gix gjc
    gje:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gja
    gjf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gjd
    gjg:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gje
    gjh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gjf
    gji:i32[64,64,2] = concatenate[dimension=2] gjg gjh
    gjj:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ghz gji
    gjk:f64[64,64] = mul git gjj
    gjl:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gii
    gjm:f64[64,64] = squeeze[dimensions=(0,)] gjl
    gjn:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gih
    gjo:f64[64,64] = squeeze[dimensions=(0,)] gjn
    gjp:f64[64,64] = mul gjm gjo
    gjq:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gim
    gjr:i64[64,64] = squeeze[dimensions=(0,)] gjq
    gjs:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gio
    gjt:i64[64,64] = squeeze[dimensions=(0,)] gjs
    gju:bool[64,64] = lt gjr 0:i64[]
    gjv:i64[64,64] = add gjr 64:i64[]
    gjw:i64[64,64] = select_n gju gjr gjv
    gjx:bool[64,64] = lt gjt 0:i64[]
    gjy:i64[64,64] = add gjt 64:i64[]
    gjz:i64[64,64] = select_n gjx gjt gjy
    gka:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gjw
    gkb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gjz
    gkc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gka
    gkd:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gkb
    gke:i32[64,64,2] = concatenate[dimension=2] gkc gkd
    gkf:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ghz gke
    gkg:f64[64,64] = mul gjp gkf
    gkh:f64[64,64] = add gjk gkg
    gki:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gih
    gkj:f64[64,64] = squeeze[dimensions=(0,)] gki
    gkk:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gii
    gkl:f64[64,64] = squeeze[dimensions=(0,)] gkk
    gkm:f64[64,64] = mul gkj gkl
    gkn:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gio
    gko:i64[64,64] = squeeze[dimensions=(0,)] gkn
    gkp:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gim
    gkq:i64[64,64] = squeeze[dimensions=(0,)] gkp
    gkr:bool[64,64] = lt gko 0:i64[]
    gks:i64[64,64] = add gko 64:i64[]
    gkt:i64[64,64] = select_n gkr gko gks
    gku:bool[64,64] = lt gkq 0:i64[]
    gkv:i64[64,64] = add gkq 64:i64[]
    gkw:i64[64,64] = select_n gku gkq gkv
    gkx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gkt
    gky:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gkw
    gkz:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gkx
    gla:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gky
    glb:i32[64,64,2] = concatenate[dimension=2] gkz gla
    glc:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ghz glb
    gld:f64[64,64] = mul gkm glc
    gle:f64[64,64] = add gkh gld
    glf:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gih
    glg:f64[64,64] = squeeze[dimensions=(0,)] glf
    glh:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gih
    gli:f64[64,64] = squeeze[dimensions=(0,)] glh
    glj:f64[64,64] = mul glg gli
    glk:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gio
    gll:i64[64,64] = squeeze[dimensions=(0,)] glk
    glm:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gio
    gln:i64[64,64] = squeeze[dimensions=(0,)] glm
    glo:bool[64,64] = lt gll 0:i64[]
    glp:i64[64,64] = add gll 64:i64[]
    glq:i64[64,64] = select_n glo gll glp
    glr:bool[64,64] = lt gln 0:i64[]
    gls:i64[64,64] = add gln 64:i64[]
    glt:i64[64,64] = select_n glr gln gls
    glu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] glq
    glv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] glt
    glw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] glu
    glx:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] glv
    gly:i32[64,64,2] = concatenate[dimension=2] glw glx
    glz:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ghz gly
    gma:f64[64,64] = mul glj glz
    gmb:f64[64,64] = add gle gma
    gmc:f64[] = neg eb
    gmd:f64[] = convert_element_type[new_dtype=float64 weak_type=False] gmc
    gme:f64[2,64,64] = mul gmd fzr
    gmf:f64[2,64,64] = add ea gme
    gmg:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    gmh:f64[2,64,64] = div gmf gmg
    gmi:f64[2,64,64] = floor gmh
    gmj:f64[2,64,64] = sub gmh gmi
    gmk:f64[2,64,64] = sub 1.0:f64[] gmj
    gml:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] cs
    gmm:f64[2,64,64] = sub gmh gmj
    gmn:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] gmm
    gmo:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gmn gml
    gmp:i64[2,64,64] = add gmo 1:i64[]
    gmq:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gmp gml
    gmr:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gmk
    gms:f64[64,64] = squeeze[dimensions=(0,)] gmr
    gmt:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gmk
    gmu:f64[64,64] = squeeze[dimensions=(0,)] gmt
    gmv:f64[64,64] = mul gms gmu
    gmw:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gmo
    gmx:i64[64,64] = squeeze[dimensions=(0,)] gmw
    gmy:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gmo
    gmz:i64[64,64] = squeeze[dimensions=(0,)] gmy
    gna:bool[64,64] = lt gmx 0:i64[]
    gnb:i64[64,64] = add gmx 64:i64[]
    gnc:i64[64,64] = select_n gna gmx gnb
    gnd:bool[64,64] = lt gmz 0:i64[]
    gne:i64[64,64] = add gmz 64:i64[]
    gnf:i64[64,64] = select_n gnd gmz gne
    gng:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gnc
    gnh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gnf
    gni:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gng
    gnj:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gnh
    gnk:i32[64,64,2] = concatenate[dimension=2] gni gnj
    gnl:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fzg gnk
    gnm:f64[64,64] = mul gmv gnl
    gnn:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gmk
    gno:f64[64,64] = squeeze[dimensions=(0,)] gnn
    gnp:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gmj
    gnq:f64[64,64] = squeeze[dimensions=(0,)] gnp
    gnr:f64[64,64] = mul gno gnq
    gns:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gmo
    gnt:i64[64,64] = squeeze[dimensions=(0,)] gns
    gnu:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gmq
    gnv:i64[64,64] = squeeze[dimensions=(0,)] gnu
    gnw:bool[64,64] = lt gnt 0:i64[]
    gnx:i64[64,64] = add gnt 64:i64[]
    gny:i64[64,64] = select_n gnw gnt gnx
    gnz:bool[64,64] = lt gnv 0:i64[]
    goa:i64[64,64] = add gnv 64:i64[]
    gob:i64[64,64] = select_n gnz gnv goa
    goc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gny
    god:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gob
    goe:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] goc
    gof:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] god
    gog:i32[64,64,2] = concatenate[dimension=2] goe gof
    goh:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fzg gog
    goi:f64[64,64] = mul gnr goh
    goj:f64[64,64] = add gnm goi
    gok:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gmj
    gol:f64[64,64] = squeeze[dimensions=(0,)] gok
    gom:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gmk
    gon:f64[64,64] = squeeze[dimensions=(0,)] gom
    goo:f64[64,64] = mul gol gon
    gop:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gmq
    goq:i64[64,64] = squeeze[dimensions=(0,)] gop
    gor:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gmo
    gos:i64[64,64] = squeeze[dimensions=(0,)] gor
    got:bool[64,64] = lt goq 0:i64[]
    gou:i64[64,64] = add goq 64:i64[]
    gov:i64[64,64] = select_n got goq gou
    gow:bool[64,64] = lt gos 0:i64[]
    gox:i64[64,64] = add gos 64:i64[]
    goy:i64[64,64] = select_n gow gos gox
    goz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gov
    gpa:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] goy
    gpb:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] goz
    gpc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gpa
    gpd:i32[64,64,2] = concatenate[dimension=2] gpb gpc
    gpe:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fzg gpd
    gpf:f64[64,64] = mul goo gpe
    gpg:f64[64,64] = add goj gpf
    gph:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gmj
    gpi:f64[64,64] = squeeze[dimensions=(0,)] gph
    gpj:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gmj
    gpk:f64[64,64] = squeeze[dimensions=(0,)] gpj
    gpl:f64[64,64] = mul gpi gpk
    gpm:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gmq
    gpn:i64[64,64] = squeeze[dimensions=(0,)] gpm
    gpo:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gmq
    gpp:i64[64,64] = squeeze[dimensions=(0,)] gpo
    gpq:bool[64,64] = lt gpn 0:i64[]
    gpr:i64[64,64] = add gpn 64:i64[]
    gps:i64[64,64] = select_n gpq gpn gpr
    gpt:bool[64,64] = lt gpp 0:i64[]
    gpu:i64[64,64] = add gpp 64:i64[]
    gpv:i64[64,64] = select_n gpt gpp gpu
    gpw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gps
    gpx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gpv
    gpy:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gpw
    gpz:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gpx
    gqa:i32[64,64,2] = concatenate[dimension=2] gpy gpz
    gqb:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] fzg gqa
    gqc:f64[64,64] = mul gpl gqb
    gqd:f64[64,64] = add gpg gqc
    gqe:f64[2,64,64] = neg fzr
    gqf:f64[] = neg eb
    gqg:f64[] = convert_element_type[new_dtype=float64 weak_type=False] gqf
    gqh:f64[2,64,64] = mul gqg gqe
    gqi:f64[2,64,64] = add ea gqh
    gqj:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    gqk:f64[2,64,64] = div gqi gqj
    gql:f64[2,64,64] = floor gqk
    gqm:f64[2,64,64] = sub gqk gql
    gqn:f64[2,64,64] = sub 1.0:f64[] gqm
    gqo:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] ct
    gqp:f64[2,64,64] = sub gqk gqm
    gqq:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] gqp
    gqr:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gqq gqo
    gqs:i64[2,64,64] = add gqr 1:i64[]
    gqt:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gqs gqo
    gqu:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gqn
    gqv:f64[64,64] = squeeze[dimensions=(0,)] gqu
    gqw:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gqn
    gqx:f64[64,64] = squeeze[dimensions=(0,)] gqw
    gqy:f64[64,64] = mul gqv gqx
    gqz:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gqr
    gra:i64[64,64] = squeeze[dimensions=(0,)] gqz
    grb:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gqr
    grc:i64[64,64] = squeeze[dimensions=(0,)] grb
    grd:bool[64,64] = lt gra 0:i64[]
    gre:i64[64,64] = add gra 64:i64[]
    grf:i64[64,64] = select_n grd gra gre
    grg:bool[64,64] = lt grc 0:i64[]
    grh:i64[64,64] = add grc 64:i64[]
    gri:i64[64,64] = select_n grg grc grh
    grj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] grf
    grk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gri
    grl:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] grj
    grm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] grk
    grn:i32[64,64,2] = concatenate[dimension=2] grl grm
    gro:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] gqd grn
    grp:f64[64,64] = mul gqy gro
    grq:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gqn
    grr:f64[64,64] = squeeze[dimensions=(0,)] grq
    grs:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gqm
    grt:f64[64,64] = squeeze[dimensions=(0,)] grs
    gru:f64[64,64] = mul grr grt
    grv:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gqr
    grw:i64[64,64] = squeeze[dimensions=(0,)] grv
    grx:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gqt
    gry:i64[64,64] = squeeze[dimensions=(0,)] grx
    grz:bool[64,64] = lt grw 0:i64[]
    gsa:i64[64,64] = add grw 64:i64[]
    gsb:i64[64,64] = select_n grz grw gsa
    gsc:bool[64,64] = lt gry 0:i64[]
    gsd:i64[64,64] = add gry 64:i64[]
    gse:i64[64,64] = select_n gsc gry gsd
    gsf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gsb
    gsg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gse
    gsh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gsf
    gsi:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gsg
    gsj:i32[64,64,2] = concatenate[dimension=2] gsh gsi
    gsk:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] gqd gsj
    gsl:f64[64,64] = mul gru gsk
    gsm:f64[64,64] = add grp gsl
    gsn:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gqm
    gso:f64[64,64] = squeeze[dimensions=(0,)] gsn
    gsp:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gqn
    gsq:f64[64,64] = squeeze[dimensions=(0,)] gsp
    gsr:f64[64,64] = mul gso gsq
    gss:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gqt
    gst:i64[64,64] = squeeze[dimensions=(0,)] gss
    gsu:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gqr
    gsv:i64[64,64] = squeeze[dimensions=(0,)] gsu
    gsw:bool[64,64] = lt gst 0:i64[]
    gsx:i64[64,64] = add gst 64:i64[]
    gsy:i64[64,64] = select_n gsw gst gsx
    gsz:bool[64,64] = lt gsv 0:i64[]
    gta:i64[64,64] = add gsv 64:i64[]
    gtb:i64[64,64] = select_n gsz gsv gta
    gtc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gsy
    gtd:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gtb
    gte:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gtc
    gtf:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gtd
    gtg:i32[64,64,2] = concatenate[dimension=2] gte gtf
    gth:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] gqd gtg
    gti:f64[64,64] = mul gsr gth
    gtj:f64[64,64] = add gsm gti
    gtk:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gqm
    gtl:f64[64,64] = squeeze[dimensions=(0,)] gtk
    gtm:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gqm
    gtn:f64[64,64] = squeeze[dimensions=(0,)] gtm
    gto:f64[64,64] = mul gtl gtn
    gtp:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gqt
    gtq:i64[64,64] = squeeze[dimensions=(0,)] gtp
    gtr:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gqt
    gts:i64[64,64] = squeeze[dimensions=(0,)] gtr
    gtt:bool[64,64] = lt gtq 0:i64[]
    gtu:i64[64,64] = add gtq 64:i64[]
    gtv:i64[64,64] = select_n gtt gtq gtu
    gtw:bool[64,64] = lt gts 0:i64[]
    gtx:i64[64,64] = add gts 64:i64[]
    gty:i64[64,64] = select_n gtw gts gtx
    gtz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gtv
    gua:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gty
    gub:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gtz
    guc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gua
    gud:i32[64,64,2] = concatenate[dimension=2] gub guc
    gue:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] gqd gud
    guf:f64[64,64] = mul gto gue
    gug:f64[64,64] = add gtj guf
    guh:f64[64,64] = sub fzg gug
    gui:f64[64,64] = div guh 2.0:f64[]
    guj:f64[64,64] = add fzg gui
    guk:f64[] = neg eb
    gul:f64[] = convert_element_type[new_dtype=float64 weak_type=False] guk
    gum:f64[2,64,64] = mul gul fzr
    gun:f64[2,64,64] = add ea gum
    guo:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    gup:f64[2,64,64] = div gun guo
    guq:f64[2,64,64] = floor gup
    gur:f64[2,64,64] = sub gup guq
    gus:f64[2,64,64] = sub 1.0:f64[] gur
    gut:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] cu
    guu:f64[2,64,64] = sub gup gur
    guv:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] guu
    guw:i64[2,64,64] = jit[name=remainder jaxpr=remainder] guv gut
    gux:i64[2,64,64] = add guw 1:i64[]
    guy:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gux gut
    guz:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gus
    gva:f64[64,64] = squeeze[dimensions=(0,)] guz
    gvb:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gus
    gvc:f64[64,64] = squeeze[dimensions=(0,)] gvb
    gvd:f64[64,64] = mul gva gvc
    gve:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] guw
    gvf:i64[64,64] = squeeze[dimensions=(0,)] gve
    gvg:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] guw
    gvh:i64[64,64] = squeeze[dimensions=(0,)] gvg
    gvi:bool[64,64] = lt gvf 0:i64[]
    gvj:i64[64,64] = add gvf 64:i64[]
    gvk:i64[64,64] = select_n gvi gvf gvj
    gvl:bool[64,64] = lt gvh 0:i64[]
    gvm:i64[64,64] = add gvh 64:i64[]
    gvn:i64[64,64] = select_n gvl gvh gvm
    gvo:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gvk
    gvp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gvn
    gvq:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gvo
    gvr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gvp
    gvs:i32[64,64,2] = concatenate[dimension=2] gvq gvr
    gvt:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] guj gvs
    gvu:f64[64,64] = mul gvd gvt
    gvv:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gus
    gvw:f64[64,64] = squeeze[dimensions=(0,)] gvv
    gvx:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gur
    gvy:f64[64,64] = squeeze[dimensions=(0,)] gvx
    gvz:f64[64,64] = mul gvw gvy
    gwa:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] guw
    gwb:i64[64,64] = squeeze[dimensions=(0,)] gwa
    gwc:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] guy
    gwd:i64[64,64] = squeeze[dimensions=(0,)] gwc
    gwe:bool[64,64] = lt gwb 0:i64[]
    gwf:i64[64,64] = add gwb 64:i64[]
    gwg:i64[64,64] = select_n gwe gwb gwf
    gwh:bool[64,64] = lt gwd 0:i64[]
    gwi:i64[64,64] = add gwd 64:i64[]
    gwj:i64[64,64] = select_n gwh gwd gwi
    gwk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gwg
    gwl:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gwj
    gwm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gwk
    gwn:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gwl
    gwo:i32[64,64,2] = concatenate[dimension=2] gwm gwn
    gwp:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] guj gwo
    gwq:f64[64,64] = mul gvz gwp
    gwr:f64[64,64] = add gvu gwq
    gws:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gur
    gwt:f64[64,64] = squeeze[dimensions=(0,)] gws
    gwu:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gus
    gwv:f64[64,64] = squeeze[dimensions=(0,)] gwu
    gww:f64[64,64] = mul gwt gwv
    gwx:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] guy
    gwy:i64[64,64] = squeeze[dimensions=(0,)] gwx
    gwz:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] guw
    gxa:i64[64,64] = squeeze[dimensions=(0,)] gwz
    gxb:bool[64,64] = lt gwy 0:i64[]
    gxc:i64[64,64] = add gwy 64:i64[]
    gxd:i64[64,64] = select_n gxb gwy gxc
    gxe:bool[64,64] = lt gxa 0:i64[]
    gxf:i64[64,64] = add gxa 64:i64[]
    gxg:i64[64,64] = select_n gxe gxa gxf
    gxh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gxd
    gxi:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gxg
    gxj:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gxh
    gxk:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gxi
    gxl:i32[64,64,2] = concatenate[dimension=2] gxj gxk
    gxm:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] guj gxl
    gxn:f64[64,64] = mul gww gxm
    gxo:f64[64,64] = add gwr gxn
    gxp:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gur
    gxq:f64[64,64] = squeeze[dimensions=(0,)] gxp
    gxr:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gur
    gxs:f64[64,64] = squeeze[dimensions=(0,)] gxr
    gxt:f64[64,64] = mul gxq gxs
    gxu:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] guy
    gxv:i64[64,64] = squeeze[dimensions=(0,)] gxu
    gxw:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] guy
    gxx:i64[64,64] = squeeze[dimensions=(0,)] gxw
    gxy:bool[64,64] = lt gxv 0:i64[]
    gxz:i64[64,64] = add gxv 64:i64[]
    gya:i64[64,64] = select_n gxy gxv gxz
    gyb:bool[64,64] = lt gxx 0:i64[]
    gyc:i64[64,64] = add gxx 64:i64[]
    gyd:i64[64,64] = select_n gyb gxx gyc
    gye:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gya
    gyf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gyd
    gyg:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gye
    gyh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] gyf
    gyi:i32[64,64,2] = concatenate[dimension=2] gyg gyh
    gyj:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] guj gyi
    gyk:f64[64,64] = mul gxt gyj
    gyl:f64[64,64] = add gxo gyk
    gym:c128[64,33] = jit[name=fft jaxpr=fft] gmb
    gyn:c128[] = reduce_prod[axes=(0,)] cv
    gyo:c128[] = sqrt gyn
    gyp:c128[] = div (1+0j):c128[] gyo
    gyq:c128[64,33] = mul gym gyp
    gyr:c128[1,64,33] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 64, 33)
      sharding=None
    ] gyq
    gys:c128[2,64,33] = mul dw gyr
    gyt:f64[2,64,64] = jit[name=fft jaxpr=fft1] gys
    gyu:f64[] = reduce_prod[axes=(0,)] cw
    gyv:f64[] = sqrt gyu
    gyw:f64[2,64,64] = mul gyt gyv
    gyx:f64[] = neg eb
    gyy:f64[] = convert_element_type[new_dtype=float64 weak_type=False] gyx
    gyz:f64[2,64,64] = mul gyy gyw
    gza:f64[2,64,64] = add ea gyz
    gzb:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    gzc:f64[2,64,64] = div gza gzb
    gzd:f64[2,64,64] = floor gzc
    gze:f64[2,64,64] = sub gzc gzd
    gzf:f64[2,64,64] = sub 1.0:f64[] gze
    gzg:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] cx
    gzh:f64[2,64,64] = sub gzc gze
    gzi:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] gzh
    gzj:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gzi gzg
    gzk:i64[2,64,64] = add gzj 1:i64[]
    gzl:i64[2,64,64] = jit[name=remainder jaxpr=remainder] gzk gzg
    gzm:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gzf
    gzn:f64[64,64] = squeeze[dimensions=(0,)] gzm
    gzo:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gzf
    gzp:f64[64,64] = squeeze[dimensions=(0,)] gzo
    gzq:f64[64,64] = mul gzn gzp
    gzr:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gzj
    gzs:i64[64,64] = squeeze[dimensions=(0,)] gzr
    gzt:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gzj
    gzu:i64[64,64] = squeeze[dimensions=(0,)] gzt
    gzv:bool[64,64] = lt gzs 0:i64[]
    gzw:i64[64,64] = add gzs 64:i64[]
    gzx:i64[64,64] = select_n gzv gzs gzw
    gzy:bool[64,64] = lt gzu 0:i64[]
    gzz:i64[64,64] = add gzu 64:i64[]
    haa:i64[64,64] = select_n gzy gzu gzz
    hab:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] gzx
    hac:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] haa
    had:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hab
    hae:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hac
    haf:i32[64,64,2] = concatenate[dimension=2] had hae
    hag:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] gmb haf
    hah:f64[64,64] = mul gzq hag
    hai:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gzf
    haj:f64[64,64] = squeeze[dimensions=(0,)] hai
    hak:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gze
    hal:f64[64,64] = squeeze[dimensions=(0,)] hak
    ham:f64[64,64] = mul haj hal
    han:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gzj
    hao:i64[64,64] = squeeze[dimensions=(0,)] han
    hap:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gzl
    haq:i64[64,64] = squeeze[dimensions=(0,)] hap
    har:bool[64,64] = lt hao 0:i64[]
    has:i64[64,64] = add hao 64:i64[]
    hat:i64[64,64] = select_n har hao has
    hau:bool[64,64] = lt haq 0:i64[]
    hav:i64[64,64] = add haq 64:i64[]
    haw:i64[64,64] = select_n hau haq hav
    hax:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hat
    hay:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] haw
    haz:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hax
    hba:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hay
    hbb:i32[64,64,2] = concatenate[dimension=2] haz hba
    hbc:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] gmb hbb
    hbd:f64[64,64] = mul ham hbc
    hbe:f64[64,64] = add hah hbd
    hbf:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gze
    hbg:f64[64,64] = squeeze[dimensions=(0,)] hbf
    hbh:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gzf
    hbi:f64[64,64] = squeeze[dimensions=(0,)] hbh
    hbj:f64[64,64] = mul hbg hbi
    hbk:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gzl
    hbl:i64[64,64] = squeeze[dimensions=(0,)] hbk
    hbm:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gzj
    hbn:i64[64,64] = squeeze[dimensions=(0,)] hbm
    hbo:bool[64,64] = lt hbl 0:i64[]
    hbp:i64[64,64] = add hbl 64:i64[]
    hbq:i64[64,64] = select_n hbo hbl hbp
    hbr:bool[64,64] = lt hbn 0:i64[]
    hbs:i64[64,64] = add hbn 64:i64[]
    hbt:i64[64,64] = select_n hbr hbn hbs
    hbu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hbq
    hbv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hbt
    hbw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hbu
    hbx:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hbv
    hby:i32[64,64,2] = concatenate[dimension=2] hbw hbx
    hbz:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] gmb hby
    hca:f64[64,64] = mul hbj hbz
    hcb:f64[64,64] = add hbe hca
    hcc:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gze
    hcd:f64[64,64] = squeeze[dimensions=(0,)] hcc
    hce:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gze
    hcf:f64[64,64] = squeeze[dimensions=(0,)] hce
    hcg:f64[64,64] = mul hcd hcf
    hch:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] gzl
    hci:i64[64,64] = squeeze[dimensions=(0,)] hch
    hcj:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] gzl
    hck:i64[64,64] = squeeze[dimensions=(0,)] hcj
    hcl:bool[64,64] = lt hci 0:i64[]
    hcm:i64[64,64] = add hci 64:i64[]
    hcn:i64[64,64] = select_n hcl hci hcm
    hco:bool[64,64] = lt hck 0:i64[]
    hcp:i64[64,64] = add hck 64:i64[]
    hcq:i64[64,64] = select_n hco hck hcp
    hcr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hcn
    hcs:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hcq
    hct:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hcr
    hcu:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hcs
    hcv:i32[64,64,2] = concatenate[dimension=2] hct hcu
    hcw:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] gmb hcv
    hcx:f64[64,64] = mul hcg hcw
    hcy:f64[64,64] = add hcb hcx
    hcz:f64[2,64,64] = neg gyw
    hda:f64[] = neg eb
    hdb:f64[] = convert_element_type[new_dtype=float64 weak_type=False] hda
    hdc:f64[2,64,64] = mul hdb hcz
    hdd:f64[2,64,64] = add ea hdc
    hde:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    hdf:f64[2,64,64] = div hdd hde
    hdg:f64[2,64,64] = floor hdf
    hdh:f64[2,64,64] = sub hdf hdg
    hdi:f64[2,64,64] = sub 1.0:f64[] hdh
    hdj:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] cy
    hdk:f64[2,64,64] = sub hdf hdh
    hdl:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] hdk
    hdm:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hdl hdj
    hdn:i64[2,64,64] = add hdm 1:i64[]
    hdo:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hdn hdj
    hdp:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hdi
    hdq:f64[64,64] = squeeze[dimensions=(0,)] hdp
    hdr:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hdi
    hds:f64[64,64] = squeeze[dimensions=(0,)] hdr
    hdt:f64[64,64] = mul hdq hds
    hdu:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hdm
    hdv:i64[64,64] = squeeze[dimensions=(0,)] hdu
    hdw:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hdm
    hdx:i64[64,64] = squeeze[dimensions=(0,)] hdw
    hdy:bool[64,64] = lt hdv 0:i64[]
    hdz:i64[64,64] = add hdv 64:i64[]
    hea:i64[64,64] = select_n hdy hdv hdz
    heb:bool[64,64] = lt hdx 0:i64[]
    hec:i64[64,64] = add hdx 64:i64[]
    hed:i64[64,64] = select_n heb hdx hec
    hee:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hea
    hef:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hed
    heg:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hee
    heh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hef
    hei:i32[64,64,2] = concatenate[dimension=2] heg heh
    hej:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hcy hei
    hek:f64[64,64] = mul hdt hej
    hel:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hdi
    hem:f64[64,64] = squeeze[dimensions=(0,)] hel
    hen:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hdh
    heo:f64[64,64] = squeeze[dimensions=(0,)] hen
    hep:f64[64,64] = mul hem heo
    heq:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hdm
    her:i64[64,64] = squeeze[dimensions=(0,)] heq
    hes:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hdo
    het:i64[64,64] = squeeze[dimensions=(0,)] hes
    heu:bool[64,64] = lt her 0:i64[]
    hev:i64[64,64] = add her 64:i64[]
    hew:i64[64,64] = select_n heu her hev
    hex:bool[64,64] = lt het 0:i64[]
    hey:i64[64,64] = add het 64:i64[]
    hez:i64[64,64] = select_n hex het hey
    hfa:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hew
    hfb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hez
    hfc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hfa
    hfd:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hfb
    hfe:i32[64,64,2] = concatenate[dimension=2] hfc hfd
    hff:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hcy hfe
    hfg:f64[64,64] = mul hep hff
    hfh:f64[64,64] = add hek hfg
    hfi:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hdh
    hfj:f64[64,64] = squeeze[dimensions=(0,)] hfi
    hfk:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hdi
    hfl:f64[64,64] = squeeze[dimensions=(0,)] hfk
    hfm:f64[64,64] = mul hfj hfl
    hfn:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hdo
    hfo:i64[64,64] = squeeze[dimensions=(0,)] hfn
    hfp:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hdm
    hfq:i64[64,64] = squeeze[dimensions=(0,)] hfp
    hfr:bool[64,64] = lt hfo 0:i64[]
    hfs:i64[64,64] = add hfo 64:i64[]
    hft:i64[64,64] = select_n hfr hfo hfs
    hfu:bool[64,64] = lt hfq 0:i64[]
    hfv:i64[64,64] = add hfq 64:i64[]
    hfw:i64[64,64] = select_n hfu hfq hfv
    hfx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hft
    hfy:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hfw
    hfz:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hfx
    hga:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hfy
    hgb:i32[64,64,2] = concatenate[dimension=2] hfz hga
    hgc:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hcy hgb
    hgd:f64[64,64] = mul hfm hgc
    hge:f64[64,64] = add hfh hgd
    hgf:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hdh
    hgg:f64[64,64] = squeeze[dimensions=(0,)] hgf
    hgh:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hdh
    hgi:f64[64,64] = squeeze[dimensions=(0,)] hgh
    hgj:f64[64,64] = mul hgg hgi
    hgk:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hdo
    hgl:i64[64,64] = squeeze[dimensions=(0,)] hgk
    hgm:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hdo
    hgn:i64[64,64] = squeeze[dimensions=(0,)] hgm
    hgo:bool[64,64] = lt hgl 0:i64[]
    hgp:i64[64,64] = add hgl 64:i64[]
    hgq:i64[64,64] = select_n hgo hgl hgp
    hgr:bool[64,64] = lt hgn 0:i64[]
    hgs:i64[64,64] = add hgn 64:i64[]
    hgt:i64[64,64] = select_n hgr hgn hgs
    hgu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hgq
    hgv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hgt
    hgw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hgu
    hgx:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hgv
    hgy:i32[64,64,2] = concatenate[dimension=2] hgw hgx
    hgz:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hcy hgy
    hha:f64[64,64] = mul hgj hgz
    hhb:f64[64,64] = add hge hha
    hhc:f64[64,64] = sub gmb hhb
    hhd:f64[64,64] = div hhc 2.0:f64[]
    hhe:f64[64,64] = add gmb hhd
    hhf:f64[] = neg eb
    hhg:f64[] = convert_element_type[new_dtype=float64 weak_type=False] hhf
    hhh:f64[2,64,64] = mul hhg gyw
    hhi:f64[2,64,64] = add ea hhh
    hhj:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    hhk:f64[2,64,64] = div hhi hhj
    hhl:f64[2,64,64] = floor hhk
    hhm:f64[2,64,64] = sub hhk hhl
    hhn:f64[2,64,64] = sub 1.0:f64[] hhm
    hho:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] cz
    hhp:f64[2,64,64] = sub hhk hhm
    hhq:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] hhp
    hhr:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hhq hho
    hhs:i64[2,64,64] = add hhr 1:i64[]
    hht:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hhs hho
    hhu:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hhn
    hhv:f64[64,64] = squeeze[dimensions=(0,)] hhu
    hhw:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hhn
    hhx:f64[64,64] = squeeze[dimensions=(0,)] hhw
    hhy:f64[64,64] = mul hhv hhx
    hhz:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hhr
    hia:i64[64,64] = squeeze[dimensions=(0,)] hhz
    hib:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hhr
    hic:i64[64,64] = squeeze[dimensions=(0,)] hib
    hid:bool[64,64] = lt hia 0:i64[]
    hie:i64[64,64] = add hia 64:i64[]
    hif:i64[64,64] = select_n hid hia hie
    hig:bool[64,64] = lt hic 0:i64[]
    hih:i64[64,64] = add hic 64:i64[]
    hii:i64[64,64] = select_n hig hic hih
    hij:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hif
    hik:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hii
    hil:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hij
    him:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hik
    hin:i32[64,64,2] = concatenate[dimension=2] hil him
    hio:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hhe hin
    hip:f64[64,64] = mul hhy hio
    hiq:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hhn
    hir:f64[64,64] = squeeze[dimensions=(0,)] hiq
    his:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hhm
    hit:f64[64,64] = squeeze[dimensions=(0,)] his
    hiu:f64[64,64] = mul hir hit
    hiv:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hhr
    hiw:i64[64,64] = squeeze[dimensions=(0,)] hiv
    hix:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hht
    hiy:i64[64,64] = squeeze[dimensions=(0,)] hix
    hiz:bool[64,64] = lt hiw 0:i64[]
    hja:i64[64,64] = add hiw 64:i64[]
    hjb:i64[64,64] = select_n hiz hiw hja
    hjc:bool[64,64] = lt hiy 0:i64[]
    hjd:i64[64,64] = add hiy 64:i64[]
    hje:i64[64,64] = select_n hjc hiy hjd
    hjf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hjb
    hjg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hje
    hjh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hjf
    hji:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hjg
    hjj:i32[64,64,2] = concatenate[dimension=2] hjh hji
    hjk:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hhe hjj
    hjl:f64[64,64] = mul hiu hjk
    hjm:f64[64,64] = add hip hjl
    hjn:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hhm
    hjo:f64[64,64] = squeeze[dimensions=(0,)] hjn
    hjp:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hhn
    hjq:f64[64,64] = squeeze[dimensions=(0,)] hjp
    hjr:f64[64,64] = mul hjo hjq
    hjs:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hht
    hjt:i64[64,64] = squeeze[dimensions=(0,)] hjs
    hju:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hhr
    hjv:i64[64,64] = squeeze[dimensions=(0,)] hju
    hjw:bool[64,64] = lt hjt 0:i64[]
    hjx:i64[64,64] = add hjt 64:i64[]
    hjy:i64[64,64] = select_n hjw hjt hjx
    hjz:bool[64,64] = lt hjv 0:i64[]
    hka:i64[64,64] = add hjv 64:i64[]
    hkb:i64[64,64] = select_n hjz hjv hka
    hkc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hjy
    hkd:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hkb
    hke:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hkc
    hkf:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hkd
    hkg:i32[64,64,2] = concatenate[dimension=2] hke hkf
    hkh:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hhe hkg
    hki:f64[64,64] = mul hjr hkh
    hkj:f64[64,64] = add hjm hki
    hkk:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hhm
    hkl:f64[64,64] = squeeze[dimensions=(0,)] hkk
    hkm:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hhm
    hkn:f64[64,64] = squeeze[dimensions=(0,)] hkm
    hko:f64[64,64] = mul hkl hkn
    hkp:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hht
    hkq:i64[64,64] = squeeze[dimensions=(0,)] hkp
    hkr:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hht
    hks:i64[64,64] = squeeze[dimensions=(0,)] hkr
    hkt:bool[64,64] = lt hkq 0:i64[]
    hku:i64[64,64] = add hkq 64:i64[]
    hkv:i64[64,64] = select_n hkt hkq hku
    hkw:bool[64,64] = lt hks 0:i64[]
    hkx:i64[64,64] = add hks 64:i64[]
    hky:i64[64,64] = select_n hkw hks hkx
    hkz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hkv
    hla:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hky
    hlb:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hkz
    hlc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hla
    hld:i32[64,64,2] = concatenate[dimension=2] hlb hlc
    hle:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hhe hld
    hlf:f64[64,64] = mul hko hle
    hlg:f64[64,64] = add hkj hlf
    hlh:f64[] = neg eb
    hli:f64[] = convert_element_type[new_dtype=float64 weak_type=False] hlh
    hlj:f64[2,64,64] = mul hli gyw
    hlk:f64[2,64,64] = add ea hlj
    hll:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    hlm:f64[2,64,64] = div hlk hll
    hln:f64[2,64,64] = floor hlm
    hlo:f64[2,64,64] = sub hlm hln
    hlp:f64[2,64,64] = sub 1.0:f64[] hlo
    hlq:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] da
    hlr:f64[2,64,64] = sub hlm hlo
    hls:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] hlr
    hlt:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hls hlq
    hlu:i64[2,64,64] = add hlt 1:i64[]
    hlv:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hlu hlq
    hlw:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hlp
    hlx:f64[64,64] = squeeze[dimensions=(0,)] hlw
    hly:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hlp
    hlz:f64[64,64] = squeeze[dimensions=(0,)] hly
    hma:f64[64,64] = mul hlx hlz
    hmb:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hlt
    hmc:i64[64,64] = squeeze[dimensions=(0,)] hmb
    hmd:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hlt
    hme:i64[64,64] = squeeze[dimensions=(0,)] hmd
    hmf:bool[64,64] = lt hmc 0:i64[]
    hmg:i64[64,64] = add hmc 64:i64[]
    hmh:i64[64,64] = select_n hmf hmc hmg
    hmi:bool[64,64] = lt hme 0:i64[]
    hmj:i64[64,64] = add hme 64:i64[]
    hmk:i64[64,64] = select_n hmi hme hmj
    hml:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hmh
    hmm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hmk
    hmn:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hml
    hmo:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hmm
    hmp:i32[64,64,2] = concatenate[dimension=2] hmn hmo
    hmq:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] gyl hmp
    hmr:f64[64,64] = mul hma hmq
    hms:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hlp
    hmt:f64[64,64] = squeeze[dimensions=(0,)] hms
    hmu:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hlo
    hmv:f64[64,64] = squeeze[dimensions=(0,)] hmu
    hmw:f64[64,64] = mul hmt hmv
    hmx:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hlt
    hmy:i64[64,64] = squeeze[dimensions=(0,)] hmx
    hmz:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hlv
    hna:i64[64,64] = squeeze[dimensions=(0,)] hmz
    hnb:bool[64,64] = lt hmy 0:i64[]
    hnc:i64[64,64] = add hmy 64:i64[]
    hnd:i64[64,64] = select_n hnb hmy hnc
    hne:bool[64,64] = lt hna 0:i64[]
    hnf:i64[64,64] = add hna 64:i64[]
    hng:i64[64,64] = select_n hne hna hnf
    hnh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hnd
    hni:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hng
    hnj:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hnh
    hnk:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hni
    hnl:i32[64,64,2] = concatenate[dimension=2] hnj hnk
    hnm:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] gyl hnl
    hnn:f64[64,64] = mul hmw hnm
    hno:f64[64,64] = add hmr hnn
    hnp:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hlo
    hnq:f64[64,64] = squeeze[dimensions=(0,)] hnp
    hnr:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hlp
    hns:f64[64,64] = squeeze[dimensions=(0,)] hnr
    hnt:f64[64,64] = mul hnq hns
    hnu:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hlv
    hnv:i64[64,64] = squeeze[dimensions=(0,)] hnu
    hnw:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hlt
    hnx:i64[64,64] = squeeze[dimensions=(0,)] hnw
    hny:bool[64,64] = lt hnv 0:i64[]
    hnz:i64[64,64] = add hnv 64:i64[]
    hoa:i64[64,64] = select_n hny hnv hnz
    hob:bool[64,64] = lt hnx 0:i64[]
    hoc:i64[64,64] = add hnx 64:i64[]
    hod:i64[64,64] = select_n hob hnx hoc
    hoe:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hoa
    hof:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hod
    hog:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hoe
    hoh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hof
    hoi:i32[64,64,2] = concatenate[dimension=2] hog hoh
    hoj:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] gyl hoi
    hok:f64[64,64] = mul hnt hoj
    hol:f64[64,64] = add hno hok
    hom:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hlo
    hon:f64[64,64] = squeeze[dimensions=(0,)] hom
    hoo:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hlo
    hop:f64[64,64] = squeeze[dimensions=(0,)] hoo
    hoq:f64[64,64] = mul hon hop
    hor:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hlv
    hos:i64[64,64] = squeeze[dimensions=(0,)] hor
    hot:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hlv
    hou:i64[64,64] = squeeze[dimensions=(0,)] hot
    hov:bool[64,64] = lt hos 0:i64[]
    how:i64[64,64] = add hos 64:i64[]
    hox:i64[64,64] = select_n hov hos how
    hoy:bool[64,64] = lt hou 0:i64[]
    hoz:i64[64,64] = add hou 64:i64[]
    hpa:i64[64,64] = select_n hoy hou hoz
    hpb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hox
    hpc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hpa
    hpd:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hpb
    hpe:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hpc
    hpf:i32[64,64,2] = concatenate[dimension=2] hpd hpe
    hpg:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] gyl hpf
    hph:f64[64,64] = mul hoq hpg
    hpi:f64[64,64] = add hol hph
    hpj:f64[2,64,64] = neg gyw
    hpk:f64[] = neg eb
    hpl:f64[] = convert_element_type[new_dtype=float64 weak_type=False] hpk
    hpm:f64[2,64,64] = mul hpl hpj
    hpn:f64[2,64,64] = add ea hpm
    hpo:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    hpp:f64[2,64,64] = div hpn hpo
    hpq:f64[2,64,64] = floor hpp
    hpr:f64[2,64,64] = sub hpp hpq
    hps:f64[2,64,64] = sub 1.0:f64[] hpr
    hpt:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] db
    hpu:f64[2,64,64] = sub hpp hpr
    hpv:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] hpu
    hpw:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hpv hpt
    hpx:i64[2,64,64] = add hpw 1:i64[]
    hpy:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hpx hpt
    hpz:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hps
    hqa:f64[64,64] = squeeze[dimensions=(0,)] hpz
    hqb:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hps
    hqc:f64[64,64] = squeeze[dimensions=(0,)] hqb
    hqd:f64[64,64] = mul hqa hqc
    hqe:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hpw
    hqf:i64[64,64] = squeeze[dimensions=(0,)] hqe
    hqg:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hpw
    hqh:i64[64,64] = squeeze[dimensions=(0,)] hqg
    hqi:bool[64,64] = lt hqf 0:i64[]
    hqj:i64[64,64] = add hqf 64:i64[]
    hqk:i64[64,64] = select_n hqi hqf hqj
    hql:bool[64,64] = lt hqh 0:i64[]
    hqm:i64[64,64] = add hqh 64:i64[]
    hqn:i64[64,64] = select_n hql hqh hqm
    hqo:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hqk
    hqp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hqn
    hqq:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hqo
    hqr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hqp
    hqs:i32[64,64,2] = concatenate[dimension=2] hqq hqr
    hqt:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hpi hqs
    hqu:f64[64,64] = mul hqd hqt
    hqv:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hps
    hqw:f64[64,64] = squeeze[dimensions=(0,)] hqv
    hqx:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hpr
    hqy:f64[64,64] = squeeze[dimensions=(0,)] hqx
    hqz:f64[64,64] = mul hqw hqy
    hra:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hpw
    hrb:i64[64,64] = squeeze[dimensions=(0,)] hra
    hrc:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hpy
    hrd:i64[64,64] = squeeze[dimensions=(0,)] hrc
    hre:bool[64,64] = lt hrb 0:i64[]
    hrf:i64[64,64] = add hrb 64:i64[]
    hrg:i64[64,64] = select_n hre hrb hrf
    hrh:bool[64,64] = lt hrd 0:i64[]
    hri:i64[64,64] = add hrd 64:i64[]
    hrj:i64[64,64] = select_n hrh hrd hri
    hrk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hrg
    hrl:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hrj
    hrm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hrk
    hrn:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hrl
    hro:i32[64,64,2] = concatenate[dimension=2] hrm hrn
    hrp:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hpi hro
    hrq:f64[64,64] = mul hqz hrp
    hrr:f64[64,64] = add hqu hrq
    hrs:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hpr
    hrt:f64[64,64] = squeeze[dimensions=(0,)] hrs
    hru:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hps
    hrv:f64[64,64] = squeeze[dimensions=(0,)] hru
    hrw:f64[64,64] = mul hrt hrv
    hrx:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hpy
    hry:i64[64,64] = squeeze[dimensions=(0,)] hrx
    hrz:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hpw
    hsa:i64[64,64] = squeeze[dimensions=(0,)] hrz
    hsb:bool[64,64] = lt hry 0:i64[]
    hsc:i64[64,64] = add hry 64:i64[]
    hsd:i64[64,64] = select_n hsb hry hsc
    hse:bool[64,64] = lt hsa 0:i64[]
    hsf:i64[64,64] = add hsa 64:i64[]
    hsg:i64[64,64] = select_n hse hsa hsf
    hsh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hsd
    hsi:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hsg
    hsj:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hsh
    hsk:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hsi
    hsl:i32[64,64,2] = concatenate[dimension=2] hsj hsk
    hsm:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hpi hsl
    hsn:f64[64,64] = mul hrw hsm
    hso:f64[64,64] = add hrr hsn
    hsp:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hpr
    hsq:f64[64,64] = squeeze[dimensions=(0,)] hsp
    hsr:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hpr
    hss:f64[64,64] = squeeze[dimensions=(0,)] hsr
    hst:f64[64,64] = mul hsq hss
    hsu:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hpy
    hsv:i64[64,64] = squeeze[dimensions=(0,)] hsu
    hsw:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hpy
    hsx:i64[64,64] = squeeze[dimensions=(0,)] hsw
    hsy:bool[64,64] = lt hsv 0:i64[]
    hsz:i64[64,64] = add hsv 64:i64[]
    hta:i64[64,64] = select_n hsy hsv hsz
    htb:bool[64,64] = lt hsx 0:i64[]
    htc:i64[64,64] = add hsx 64:i64[]
    htd:i64[64,64] = select_n htb hsx htc
    hte:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hta
    htf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] htd
    htg:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hte
    hth:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] htf
    hti:i32[64,64,2] = concatenate[dimension=2] htg hth
    htj:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hpi hti
    htk:f64[64,64] = mul hst htj
    htl:f64[64,64] = add hso htk
    htm:f64[64,64] = sub gyl htl
    htn:f64[64,64] = div htm 2.0:f64[]
    hto:f64[64,64] = add gyl htn
    htp:f64[] = neg eb
    htq:f64[] = convert_element_type[new_dtype=float64 weak_type=False] htp
    htr:f64[2,64,64] = mul htq gyw
    hts:f64[2,64,64] = add ea htr
    htt:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    htu:f64[2,64,64] = div hts htt
    htv:f64[2,64,64] = floor htu
    htw:f64[2,64,64] = sub htu htv
    htx:f64[2,64,64] = sub 1.0:f64[] htw
    hty:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dc
    htz:f64[2,64,64] = sub htu htw
    hua:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] htz
    hub:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hua hty
    huc:i64[2,64,64] = add hub 1:i64[]
    hud:i64[2,64,64] = jit[name=remainder jaxpr=remainder] huc hty
    hue:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] htx
    huf:f64[64,64] = squeeze[dimensions=(0,)] hue
    hug:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] htx
    huh:f64[64,64] = squeeze[dimensions=(0,)] hug
    hui:f64[64,64] = mul huf huh
    huj:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hub
    huk:i64[64,64] = squeeze[dimensions=(0,)] huj
    hul:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hub
    hum:i64[64,64] = squeeze[dimensions=(0,)] hul
    hun:bool[64,64] = lt huk 0:i64[]
    huo:i64[64,64] = add huk 64:i64[]
    hup:i64[64,64] = select_n hun huk huo
    huq:bool[64,64] = lt hum 0:i64[]
    hur:i64[64,64] = add hum 64:i64[]
    hus:i64[64,64] = select_n huq hum hur
    hut:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hup
    huu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hus
    huv:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hut
    huw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] huu
    hux:i32[64,64,2] = concatenate[dimension=2] huv huw
    huy:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hto hux
    huz:f64[64,64] = mul hui huy
    hva:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] htx
    hvb:f64[64,64] = squeeze[dimensions=(0,)] hva
    hvc:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] htw
    hvd:f64[64,64] = squeeze[dimensions=(0,)] hvc
    hve:f64[64,64] = mul hvb hvd
    hvf:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hub
    hvg:i64[64,64] = squeeze[dimensions=(0,)] hvf
    hvh:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hud
    hvi:i64[64,64] = squeeze[dimensions=(0,)] hvh
    hvj:bool[64,64] = lt hvg 0:i64[]
    hvk:i64[64,64] = add hvg 64:i64[]
    hvl:i64[64,64] = select_n hvj hvg hvk
    hvm:bool[64,64] = lt hvi 0:i64[]
    hvn:i64[64,64] = add hvi 64:i64[]
    hvo:i64[64,64] = select_n hvm hvi hvn
    hvp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hvl
    hvq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hvo
    hvr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hvp
    hvs:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hvq
    hvt:i32[64,64,2] = concatenate[dimension=2] hvr hvs
    hvu:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hto hvt
    hvv:f64[64,64] = mul hve hvu
    hvw:f64[64,64] = add huz hvv
    hvx:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] htw
    hvy:f64[64,64] = squeeze[dimensions=(0,)] hvx
    hvz:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] htx
    hwa:f64[64,64] = squeeze[dimensions=(0,)] hvz
    hwb:f64[64,64] = mul hvy hwa
    hwc:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hud
    hwd:i64[64,64] = squeeze[dimensions=(0,)] hwc
    hwe:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hub
    hwf:i64[64,64] = squeeze[dimensions=(0,)] hwe
    hwg:bool[64,64] = lt hwd 0:i64[]
    hwh:i64[64,64] = add hwd 64:i64[]
    hwi:i64[64,64] = select_n hwg hwd hwh
    hwj:bool[64,64] = lt hwf 0:i64[]
    hwk:i64[64,64] = add hwf 64:i64[]
    hwl:i64[64,64] = select_n hwj hwf hwk
    hwm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hwi
    hwn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hwl
    hwo:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hwm
    hwp:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hwn
    hwq:i32[64,64,2] = concatenate[dimension=2] hwo hwp
    hwr:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hto hwq
    hws:f64[64,64] = mul hwb hwr
    hwt:f64[64,64] = add hvw hws
    hwu:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] htw
    hwv:f64[64,64] = squeeze[dimensions=(0,)] hwu
    hww:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] htw
    hwx:f64[64,64] = squeeze[dimensions=(0,)] hww
    hwy:f64[64,64] = mul hwv hwx
    hwz:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hud
    hxa:i64[64,64] = squeeze[dimensions=(0,)] hwz
    hxb:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hud
    hxc:i64[64,64] = squeeze[dimensions=(0,)] hxb
    hxd:bool[64,64] = lt hxa 0:i64[]
    hxe:i64[64,64] = add hxa 64:i64[]
    hxf:i64[64,64] = select_n hxd hxa hxe
    hxg:bool[64,64] = lt hxc 0:i64[]
    hxh:i64[64,64] = add hxc 64:i64[]
    hxi:i64[64,64] = select_n hxg hxc hxh
    hxj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hxf
    hxk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hxi
    hxl:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hxj
    hxm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hxk
    hxn:i32[64,64,2] = concatenate[dimension=2] hxl hxm
    hxo:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hto hxn
    hxp:f64[64,64] = mul hwy hxo
    hxq:f64[64,64] = add hwt hxp
    hxr:c128[64,33] = jit[name=fft jaxpr=fft] hlg
    hxs:c128[] = reduce_prod[axes=(0,)] dd
    hxt:c128[] = sqrt hxs
    hxu:c128[] = div (1+0j):c128[] hxt
    hxv:c128[64,33] = mul hxr hxu
    hxw:c128[1,64,33] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 64, 33)
      sharding=None
    ] hxv
    hxx:c128[2,64,33] = mul dw hxw
    hxy:f64[2,64,64] = jit[name=fft jaxpr=fft1] hxx
    hxz:f64[] = reduce_prod[axes=(0,)] de
    hya:f64[] = sqrt hxz
    hyb:f64[2,64,64] = mul hxy hya
    hyc:f64[] = neg eb
    hyd:f64[] = convert_element_type[new_dtype=float64 weak_type=False] hyc
    hye:f64[2,64,64] = mul hyd hyb
    hyf:f64[2,64,64] = add ea hye
    hyg:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    hyh:f64[2,64,64] = div hyf hyg
    hyi:f64[2,64,64] = floor hyh
    hyj:f64[2,64,64] = sub hyh hyi
    hyk:f64[2,64,64] = sub 1.0:f64[] hyj
    hyl:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] df
    hym:f64[2,64,64] = sub hyh hyj
    hyn:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] hym
    hyo:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hyn hyl
    hyp:i64[2,64,64] = add hyo 1:i64[]
    hyq:i64[2,64,64] = jit[name=remainder jaxpr=remainder] hyp hyl
    hyr:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hyk
    hys:f64[64,64] = squeeze[dimensions=(0,)] hyr
    hyt:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hyk
    hyu:f64[64,64] = squeeze[dimensions=(0,)] hyt
    hyv:f64[64,64] = mul hys hyu
    hyw:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hyo
    hyx:i64[64,64] = squeeze[dimensions=(0,)] hyw
    hyy:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hyo
    hyz:i64[64,64] = squeeze[dimensions=(0,)] hyy
    hza:bool[64,64] = lt hyx 0:i64[]
    hzb:i64[64,64] = add hyx 64:i64[]
    hzc:i64[64,64] = select_n hza hyx hzb
    hzd:bool[64,64] = lt hyz 0:i64[]
    hze:i64[64,64] = add hyz 64:i64[]
    hzf:i64[64,64] = select_n hzd hyz hze
    hzg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hzc
    hzh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hzf
    hzi:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hzg
    hzj:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] hzh
    hzk:i32[64,64,2] = concatenate[dimension=2] hzi hzj
    hzl:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hlg hzk
    hzm:f64[64,64] = mul hyv hzl
    hzn:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hyk
    hzo:f64[64,64] = squeeze[dimensions=(0,)] hzn
    hzp:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hyj
    hzq:f64[64,64] = squeeze[dimensions=(0,)] hzp
    hzr:f64[64,64] = mul hzo hzq
    hzs:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hyo
    hzt:i64[64,64] = squeeze[dimensions=(0,)] hzs
    hzu:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hyq
    hzv:i64[64,64] = squeeze[dimensions=(0,)] hzu
    hzw:bool[64,64] = lt hzt 0:i64[]
    hzx:i64[64,64] = add hzt 64:i64[]
    hzy:i64[64,64] = select_n hzw hzt hzx
    hzz:bool[64,64] = lt hzv 0:i64[]
    iaa:i64[64,64] = add hzv 64:i64[]
    iab:i64[64,64] = select_n hzz hzv iaa
    iac:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] hzy
    iad:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iab
    iae:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] iac
    iaf:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] iad
    iag:i32[64,64,2] = concatenate[dimension=2] iae iaf
    iah:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hlg iag
    iai:f64[64,64] = mul hzr iah
    iaj:f64[64,64] = add hzm iai
    iak:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hyj
    ial:f64[64,64] = squeeze[dimensions=(0,)] iak
    iam:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hyk
    ian:f64[64,64] = squeeze[dimensions=(0,)] iam
    iao:f64[64,64] = mul ial ian
    iap:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hyq
    iaq:i64[64,64] = squeeze[dimensions=(0,)] iap
    iar:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hyo
    ias:i64[64,64] = squeeze[dimensions=(0,)] iar
    iat:bool[64,64] = lt iaq 0:i64[]
    iau:i64[64,64] = add iaq 64:i64[]
    iav:i64[64,64] = select_n iat iaq iau
    iaw:bool[64,64] = lt ias 0:i64[]
    iax:i64[64,64] = add ias 64:i64[]
    iay:i64[64,64] = select_n iaw ias iax
    iaz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iav
    iba:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iay
    ibb:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] iaz
    ibc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] iba
    ibd:i32[64,64,2] = concatenate[dimension=2] ibb ibc
    ibe:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hlg ibd
    ibf:f64[64,64] = mul iao ibe
    ibg:f64[64,64] = add iaj ibf
    ibh:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hyj
    ibi:f64[64,64] = squeeze[dimensions=(0,)] ibh
    ibj:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hyj
    ibk:f64[64,64] = squeeze[dimensions=(0,)] ibj
    ibl:f64[64,64] = mul ibi ibk
    ibm:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] hyq
    ibn:i64[64,64] = squeeze[dimensions=(0,)] ibm
    ibo:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] hyq
    ibp:i64[64,64] = squeeze[dimensions=(0,)] ibo
    ibq:bool[64,64] = lt ibn 0:i64[]
    ibr:i64[64,64] = add ibn 64:i64[]
    ibs:i64[64,64] = select_n ibq ibn ibr
    ibt:bool[64,64] = lt ibp 0:i64[]
    ibu:i64[64,64] = add ibp 64:i64[]
    ibv:i64[64,64] = select_n ibt ibp ibu
    ibw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ibs
    ibx:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ibv
    iby:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ibw
    ibz:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ibx
    ica:i32[64,64,2] = concatenate[dimension=2] iby ibz
    icb:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hlg ica
    icc:f64[64,64] = mul ibl icb
    icd:f64[64,64] = add ibg icc
    ice:f64[2,64,64] = neg hyb
    icf:f64[] = neg eb
    icg:f64[] = convert_element_type[new_dtype=float64 weak_type=False] icf
    ich:f64[2,64,64] = mul icg ice
    ici:f64[2,64,64] = add ea ich
    icj:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    ick:f64[2,64,64] = div ici icj
    icl:f64[2,64,64] = floor ick
    icm:f64[2,64,64] = sub ick icl
    icn:f64[2,64,64] = sub 1.0:f64[] icm
    ico:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dg
    icp:f64[2,64,64] = sub ick icm
    icq:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] icp
    icr:i64[2,64,64] = jit[name=remainder jaxpr=remainder] icq ico
    ics:i64[2,64,64] = add icr 1:i64[]
    ict:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ics ico
    icu:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] icn
    icv:f64[64,64] = squeeze[dimensions=(0,)] icu
    icw:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] icn
    icx:f64[64,64] = squeeze[dimensions=(0,)] icw
    icy:f64[64,64] = mul icv icx
    icz:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] icr
    ida:i64[64,64] = squeeze[dimensions=(0,)] icz
    idb:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] icr
    idc:i64[64,64] = squeeze[dimensions=(0,)] idb
    idd:bool[64,64] = lt ida 0:i64[]
    ide:i64[64,64] = add ida 64:i64[]
    idf:i64[64,64] = select_n idd ida ide
    idg:bool[64,64] = lt idc 0:i64[]
    idh:i64[64,64] = add idc 64:i64[]
    idi:i64[64,64] = select_n idg idc idh
    idj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] idf
    idk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] idi
    idl:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] idj
    idm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] idk
    idn:i32[64,64,2] = concatenate[dimension=2] idl idm
    ido:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] icd idn
    idp:f64[64,64] = mul icy ido
    idq:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] icn
    idr:f64[64,64] = squeeze[dimensions=(0,)] idq
    ids:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] icm
    idt:f64[64,64] = squeeze[dimensions=(0,)] ids
    idu:f64[64,64] = mul idr idt
    idv:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] icr
    idw:i64[64,64] = squeeze[dimensions=(0,)] idv
    idx:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ict
    idy:i64[64,64] = squeeze[dimensions=(0,)] idx
    idz:bool[64,64] = lt idw 0:i64[]
    iea:i64[64,64] = add idw 64:i64[]
    ieb:i64[64,64] = select_n idz idw iea
    iec:bool[64,64] = lt idy 0:i64[]
    ied:i64[64,64] = add idy 64:i64[]
    iee:i64[64,64] = select_n iec idy ied
    ief:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ieb
    ieg:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iee
    ieh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ief
    iei:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ieg
    iej:i32[64,64,2] = concatenate[dimension=2] ieh iei
    iek:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] icd iej
    iel:f64[64,64] = mul idu iek
    iem:f64[64,64] = add idp iel
    ien:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] icm
    ieo:f64[64,64] = squeeze[dimensions=(0,)] ien
    iep:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] icn
    ieq:f64[64,64] = squeeze[dimensions=(0,)] iep
    ier:f64[64,64] = mul ieo ieq
    ies:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ict
    iet:i64[64,64] = squeeze[dimensions=(0,)] ies
    ieu:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] icr
    iev:i64[64,64] = squeeze[dimensions=(0,)] ieu
    iew:bool[64,64] = lt iet 0:i64[]
    iex:i64[64,64] = add iet 64:i64[]
    iey:i64[64,64] = select_n iew iet iex
    iez:bool[64,64] = lt iev 0:i64[]
    ifa:i64[64,64] = add iev 64:i64[]
    ifb:i64[64,64] = select_n iez iev ifa
    ifc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iey
    ifd:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ifb
    ife:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ifc
    iff:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ifd
    ifg:i32[64,64,2] = concatenate[dimension=2] ife iff
    ifh:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] icd ifg
    ifi:f64[64,64] = mul ier ifh
    ifj:f64[64,64] = add iem ifi
    ifk:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] icm
    ifl:f64[64,64] = squeeze[dimensions=(0,)] ifk
    ifm:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] icm
    ifn:f64[64,64] = squeeze[dimensions=(0,)] ifm
    ifo:f64[64,64] = mul ifl ifn
    ifp:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ict
    ifq:i64[64,64] = squeeze[dimensions=(0,)] ifp
    ifr:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ict
    ifs:i64[64,64] = squeeze[dimensions=(0,)] ifr
    ift:bool[64,64] = lt ifq 0:i64[]
    ifu:i64[64,64] = add ifq 64:i64[]
    ifv:i64[64,64] = select_n ift ifq ifu
    ifw:bool[64,64] = lt ifs 0:i64[]
    ifx:i64[64,64] = add ifs 64:i64[]
    ify:i64[64,64] = select_n ifw ifs ifx
    ifz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ifv
    iga:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ify
    igb:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ifz
    igc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] iga
    igd:i32[64,64,2] = concatenate[dimension=2] igb igc
    ige:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] icd igd
    igf:f64[64,64] = mul ifo ige
    igg:f64[64,64] = add ifj igf
    igh:f64[64,64] = sub hlg igg
    igi:f64[64,64] = div igh 2.0:f64[]
    igj:f64[64,64] = add hlg igi
    igk:f64[] = neg eb
    igl:f64[] = convert_element_type[new_dtype=float64 weak_type=False] igk
    igm:f64[2,64,64] = mul igl hyb
    ign:f64[2,64,64] = add ea igm
    igo:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    igp:f64[2,64,64] = div ign igo
    igq:f64[2,64,64] = floor igp
    igr:f64[2,64,64] = sub igp igq
    igs:f64[2,64,64] = sub 1.0:f64[] igr
    igt:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dh
    igu:f64[2,64,64] = sub igp igr
    igv:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] igu
    igw:i64[2,64,64] = jit[name=remainder jaxpr=remainder] igv igt
    igx:i64[2,64,64] = add igw 1:i64[]
    igy:i64[2,64,64] = jit[name=remainder jaxpr=remainder] igx igt
    igz:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] igs
    iha:f64[64,64] = squeeze[dimensions=(0,)] igz
    ihb:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] igs
    ihc:f64[64,64] = squeeze[dimensions=(0,)] ihb
    ihd:f64[64,64] = mul iha ihc
    ihe:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] igw
    ihf:i64[64,64] = squeeze[dimensions=(0,)] ihe
    ihg:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] igw
    ihh:i64[64,64] = squeeze[dimensions=(0,)] ihg
    ihi:bool[64,64] = lt ihf 0:i64[]
    ihj:i64[64,64] = add ihf 64:i64[]
    ihk:i64[64,64] = select_n ihi ihf ihj
    ihl:bool[64,64] = lt ihh 0:i64[]
    ihm:i64[64,64] = add ihh 64:i64[]
    ihn:i64[64,64] = select_n ihl ihh ihm
    iho:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ihk
    ihp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ihn
    ihq:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] iho
    ihr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ihp
    ihs:i32[64,64,2] = concatenate[dimension=2] ihq ihr
    iht:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] igj ihs
    ihu:f64[64,64] = mul ihd iht
    ihv:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] igs
    ihw:f64[64,64] = squeeze[dimensions=(0,)] ihv
    ihx:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] igr
    ihy:f64[64,64] = squeeze[dimensions=(0,)] ihx
    ihz:f64[64,64] = mul ihw ihy
    iia:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] igw
    iib:i64[64,64] = squeeze[dimensions=(0,)] iia
    iic:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] igy
    iid:i64[64,64] = squeeze[dimensions=(0,)] iic
    iie:bool[64,64] = lt iib 0:i64[]
    iif:i64[64,64] = add iib 64:i64[]
    iig:i64[64,64] = select_n iie iib iif
    iih:bool[64,64] = lt iid 0:i64[]
    iii:i64[64,64] = add iid 64:i64[]
    iij:i64[64,64] = select_n iih iid iii
    iik:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iig
    iil:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iij
    iim:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] iik
    iin:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] iil
    iio:i32[64,64,2] = concatenate[dimension=2] iim iin
    iip:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] igj iio
    iiq:f64[64,64] = mul ihz iip
    iir:f64[64,64] = add ihu iiq
    iis:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] igr
    iit:f64[64,64] = squeeze[dimensions=(0,)] iis
    iiu:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] igs
    iiv:f64[64,64] = squeeze[dimensions=(0,)] iiu
    iiw:f64[64,64] = mul iit iiv
    iix:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] igy
    iiy:i64[64,64] = squeeze[dimensions=(0,)] iix
    iiz:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] igw
    ija:i64[64,64] = squeeze[dimensions=(0,)] iiz
    ijb:bool[64,64] = lt iiy 0:i64[]
    ijc:i64[64,64] = add iiy 64:i64[]
    ijd:i64[64,64] = select_n ijb iiy ijc
    ije:bool[64,64] = lt ija 0:i64[]
    ijf:i64[64,64] = add ija 64:i64[]
    ijg:i64[64,64] = select_n ije ija ijf
    ijh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ijd
    iji:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ijg
    ijj:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ijh
    ijk:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] iji
    ijl:i32[64,64,2] = concatenate[dimension=2] ijj ijk
    ijm:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] igj ijl
    ijn:f64[64,64] = mul iiw ijm
    ijo:f64[64,64] = add iir ijn
    ijp:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] igr
    ijq:f64[64,64] = squeeze[dimensions=(0,)] ijp
    ijr:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] igr
    ijs:f64[64,64] = squeeze[dimensions=(0,)] ijr
    ijt:f64[64,64] = mul ijq ijs
    iju:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] igy
    ijv:i64[64,64] = squeeze[dimensions=(0,)] iju
    ijw:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] igy
    ijx:i64[64,64] = squeeze[dimensions=(0,)] ijw
    ijy:bool[64,64] = lt ijv 0:i64[]
    ijz:i64[64,64] = add ijv 64:i64[]
    ika:i64[64,64] = select_n ijy ijv ijz
    ikb:bool[64,64] = lt ijx 0:i64[]
    ikc:i64[64,64] = add ijx 64:i64[]
    ikd:i64[64,64] = select_n ikb ijx ikc
    ike:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ika
    ikf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ikd
    ikg:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ike
    ikh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ikf
    iki:i32[64,64,2] = concatenate[dimension=2] ikg ikh
    ikj:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] igj iki
    ikk:f64[64,64] = mul ijt ikj
    ikl:f64[64,64] = add ijo ikk
    ikm:f64[] = neg eb
    ikn:f64[] = convert_element_type[new_dtype=float64 weak_type=False] ikm
    iko:f64[2,64,64] = mul ikn hyb
    ikp:f64[2,64,64] = add ea iko
    ikq:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    ikr:f64[2,64,64] = div ikp ikq
    iks:f64[2,64,64] = floor ikr
    ikt:f64[2,64,64] = sub ikr iks
    iku:f64[2,64,64] = sub 1.0:f64[] ikt
    ikv:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] di
    ikw:f64[2,64,64] = sub ikr ikt
    ikx:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ikw
    iky:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ikx ikv
    ikz:i64[2,64,64] = add iky 1:i64[]
    ila:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ikz ikv
    ilb:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] iku
    ilc:f64[64,64] = squeeze[dimensions=(0,)] ilb
    ild:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] iku
    ile:f64[64,64] = squeeze[dimensions=(0,)] ild
    ilf:f64[64,64] = mul ilc ile
    ilg:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] iky
    ilh:i64[64,64] = squeeze[dimensions=(0,)] ilg
    ili:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] iky
    ilj:i64[64,64] = squeeze[dimensions=(0,)] ili
    ilk:bool[64,64] = lt ilh 0:i64[]
    ill:i64[64,64] = add ilh 64:i64[]
    ilm:i64[64,64] = select_n ilk ilh ill
    iln:bool[64,64] = lt ilj 0:i64[]
    ilo:i64[64,64] = add ilj 64:i64[]
    ilp:i64[64,64] = select_n iln ilj ilo
    ilq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ilm
    ilr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ilp
    ils:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ilq
    ilt:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ilr
    ilu:i32[64,64,2] = concatenate[dimension=2] ils ilt
    ilv:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hxq ilu
    ilw:f64[64,64] = mul ilf ilv
    ilx:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] iku
    ily:f64[64,64] = squeeze[dimensions=(0,)] ilx
    ilz:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ikt
    ima:f64[64,64] = squeeze[dimensions=(0,)] ilz
    imb:f64[64,64] = mul ily ima
    imc:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] iky
    imd:i64[64,64] = squeeze[dimensions=(0,)] imc
    ime:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ila
    imf:i64[64,64] = squeeze[dimensions=(0,)] ime
    img:bool[64,64] = lt imd 0:i64[]
    imh:i64[64,64] = add imd 64:i64[]
    imi:i64[64,64] = select_n img imd imh
    imj:bool[64,64] = lt imf 0:i64[]
    imk:i64[64,64] = add imf 64:i64[]
    iml:i64[64,64] = select_n imj imf imk
    imm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] imi
    imn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iml
    imo:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] imm
    imp:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] imn
    imq:i32[64,64,2] = concatenate[dimension=2] imo imp
    imr:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hxq imq
    ims:f64[64,64] = mul imb imr
    imt:f64[64,64] = add ilw ims
    imu:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ikt
    imv:f64[64,64] = squeeze[dimensions=(0,)] imu
    imw:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] iku
    imx:f64[64,64] = squeeze[dimensions=(0,)] imw
    imy:f64[64,64] = mul imv imx
    imz:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ila
    ina:i64[64,64] = squeeze[dimensions=(0,)] imz
    inb:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] iky
    inc:i64[64,64] = squeeze[dimensions=(0,)] inb
    ind:bool[64,64] = lt ina 0:i64[]
    ine:i64[64,64] = add ina 64:i64[]
    inf:i64[64,64] = select_n ind ina ine
    ing:bool[64,64] = lt inc 0:i64[]
    inh:i64[64,64] = add inc 64:i64[]
    ini:i64[64,64] = select_n ing inc inh
    inj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] inf
    ink:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ini
    inl:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] inj
    inm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ink
    inn:i32[64,64,2] = concatenate[dimension=2] inl inm
    ino:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hxq inn
    inp:f64[64,64] = mul imy ino
    inq:f64[64,64] = add imt inp
    inr:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ikt
    ins:f64[64,64] = squeeze[dimensions=(0,)] inr
    int:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ikt
    inu:f64[64,64] = squeeze[dimensions=(0,)] int
    inv:f64[64,64] = mul ins inu
    inw:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ila
    inx:i64[64,64] = squeeze[dimensions=(0,)] inw
    iny:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ila
    inz:i64[64,64] = squeeze[dimensions=(0,)] iny
    ioa:bool[64,64] = lt inx 0:i64[]
    iob:i64[64,64] = add inx 64:i64[]
    ioc:i64[64,64] = select_n ioa inx iob
    iod:bool[64,64] = lt inz 0:i64[]
    ioe:i64[64,64] = add inz 64:i64[]
    iof:i64[64,64] = select_n iod inz ioe
    iog:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ioc
    ioh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iof
    ioi:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] iog
    ioj:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ioh
    iok:i32[64,64,2] = concatenate[dimension=2] ioi ioj
    iol:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] hxq iok
    iom:f64[64,64] = mul inv iol
    ion:f64[64,64] = add inq iom
    ioo:f64[2,64,64] = neg hyb
    iop:f64[] = neg eb
    ioq:f64[] = convert_element_type[new_dtype=float64 weak_type=False] iop
    ior:f64[2,64,64] = mul ioq ioo
    ios:f64[2,64,64] = add ea ior
    iot:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    iou:f64[2,64,64] = div ios iot
    iov:f64[2,64,64] = floor iou
    iow:f64[2,64,64] = sub iou iov
    iox:f64[2,64,64] = sub 1.0:f64[] iow
    ioy:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dj
    ioz:f64[2,64,64] = sub iou iow
    ipa:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ioz
    ipb:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ipa ioy
    ipc:i64[2,64,64] = add ipb 1:i64[]
    ipd:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ipc ioy
    ipe:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] iox
    ipf:f64[64,64] = squeeze[dimensions=(0,)] ipe
    ipg:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] iox
    iph:f64[64,64] = squeeze[dimensions=(0,)] ipg
    ipi:f64[64,64] = mul ipf iph
    ipj:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ipb
    ipk:i64[64,64] = squeeze[dimensions=(0,)] ipj
    ipl:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ipb
    ipm:i64[64,64] = squeeze[dimensions=(0,)] ipl
    ipn:bool[64,64] = lt ipk 0:i64[]
    ipo:i64[64,64] = add ipk 64:i64[]
    ipp:i64[64,64] = select_n ipn ipk ipo
    ipq:bool[64,64] = lt ipm 0:i64[]
    ipr:i64[64,64] = add ipm 64:i64[]
    ips:i64[64,64] = select_n ipq ipm ipr
    ipt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ipp
    ipu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ips
    ipv:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ipt
    ipw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ipu
    ipx:i32[64,64,2] = concatenate[dimension=2] ipv ipw
    ipy:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ion ipx
    ipz:f64[64,64] = mul ipi ipy
    iqa:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] iox
    iqb:f64[64,64] = squeeze[dimensions=(0,)] iqa
    iqc:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] iow
    iqd:f64[64,64] = squeeze[dimensions=(0,)] iqc
    iqe:f64[64,64] = mul iqb iqd
    iqf:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ipb
    iqg:i64[64,64] = squeeze[dimensions=(0,)] iqf
    iqh:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ipd
    iqi:i64[64,64] = squeeze[dimensions=(0,)] iqh
    iqj:bool[64,64] = lt iqg 0:i64[]
    iqk:i64[64,64] = add iqg 64:i64[]
    iql:i64[64,64] = select_n iqj iqg iqk
    iqm:bool[64,64] = lt iqi 0:i64[]
    iqn:i64[64,64] = add iqi 64:i64[]
    iqo:i64[64,64] = select_n iqm iqi iqn
    iqp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iql
    iqq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iqo
    iqr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] iqp
    iqs:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] iqq
    iqt:i32[64,64,2] = concatenate[dimension=2] iqr iqs
    iqu:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ion iqt
    iqv:f64[64,64] = mul iqe iqu
    iqw:f64[64,64] = add ipz iqv
    iqx:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] iow
    iqy:f64[64,64] = squeeze[dimensions=(0,)] iqx
    iqz:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] iox
    ira:f64[64,64] = squeeze[dimensions=(0,)] iqz
    irb:f64[64,64] = mul iqy ira
    irc:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ipd
    ird:i64[64,64] = squeeze[dimensions=(0,)] irc
    ire:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ipb
    irf:i64[64,64] = squeeze[dimensions=(0,)] ire
    irg:bool[64,64] = lt ird 0:i64[]
    irh:i64[64,64] = add ird 64:i64[]
    iri:i64[64,64] = select_n irg ird irh
    irj:bool[64,64] = lt irf 0:i64[]
    irk:i64[64,64] = add irf 64:i64[]
    irl:i64[64,64] = select_n irj irf irk
    irm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iri
    irn:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] irl
    iro:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] irm
    irp:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] irn
    irq:i32[64,64,2] = concatenate[dimension=2] iro irp
    irr:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ion irq
    irs:f64[64,64] = mul irb irr
    irt:f64[64,64] = add iqw irs
    iru:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] iow
    irv:f64[64,64] = squeeze[dimensions=(0,)] iru
    irw:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] iow
    irx:f64[64,64] = squeeze[dimensions=(0,)] irw
    iry:f64[64,64] = mul irv irx
    irz:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ipd
    isa:i64[64,64] = squeeze[dimensions=(0,)] irz
    isb:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ipd
    isc:i64[64,64] = squeeze[dimensions=(0,)] isb
    isd:bool[64,64] = lt isa 0:i64[]
    ise:i64[64,64] = add isa 64:i64[]
    isf:i64[64,64] = select_n isd isa ise
    isg:bool[64,64] = lt isc 0:i64[]
    ish:i64[64,64] = add isc 64:i64[]
    isi:i64[64,64] = select_n isg isc ish
    isj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] isf
    isk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] isi
    isl:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] isj
    ism:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] isk
    isn:i32[64,64,2] = concatenate[dimension=2] isl ism
    iso:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ion isn
    isp:f64[64,64] = mul iry iso
    isq:f64[64,64] = add irt isp
    isr:f64[64,64] = sub hxq isq
    iss:f64[64,64] = div isr 2.0:f64[]
    ist:f64[64,64] = add hxq iss
    isu:f64[] = neg eb
    isv:f64[] = convert_element_type[new_dtype=float64 weak_type=False] isu
    isw:f64[2,64,64] = mul isv hyb
    isx:f64[2,64,64] = add ea isw
    isy:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    isz:f64[2,64,64] = div isx isy
    ita:f64[2,64,64] = floor isz
    itb:f64[2,64,64] = sub isz ita
    itc:f64[2,64,64] = sub 1.0:f64[] itb
    itd:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dk
    ite:f64[2,64,64] = sub isz itb
    itf:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ite
    itg:i64[2,64,64] = jit[name=remainder jaxpr=remainder] itf itd
    ith:i64[2,64,64] = add itg 1:i64[]
    iti:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ith itd
    itj:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] itc
    itk:f64[64,64] = squeeze[dimensions=(0,)] itj
    itl:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] itc
    itm:f64[64,64] = squeeze[dimensions=(0,)] itl
    itn:f64[64,64] = mul itk itm
    ito:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] itg
    itp:i64[64,64] = squeeze[dimensions=(0,)] ito
    itq:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] itg
    itr:i64[64,64] = squeeze[dimensions=(0,)] itq
    its:bool[64,64] = lt itp 0:i64[]
    itt:i64[64,64] = add itp 64:i64[]
    itu:i64[64,64] = select_n its itp itt
    itv:bool[64,64] = lt itr 0:i64[]
    itw:i64[64,64] = add itr 64:i64[]
    itx:i64[64,64] = select_n itv itr itw
    ity:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] itu
    itz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] itx
    iua:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ity
    iub:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] itz
    iuc:i32[64,64,2] = concatenate[dimension=2] iua iub
    iud:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ist iuc
    iue:f64[64,64] = mul itn iud
    iuf:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] itc
    iug:f64[64,64] = squeeze[dimensions=(0,)] iuf
    iuh:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] itb
    iui:f64[64,64] = squeeze[dimensions=(0,)] iuh
    iuj:f64[64,64] = mul iug iui
    iuk:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] itg
    iul:i64[64,64] = squeeze[dimensions=(0,)] iuk
    ium:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] iti
    iun:i64[64,64] = squeeze[dimensions=(0,)] ium
    iuo:bool[64,64] = lt iul 0:i64[]
    iup:i64[64,64] = add iul 64:i64[]
    iuq:i64[64,64] = select_n iuo iul iup
    iur:bool[64,64] = lt iun 0:i64[]
    ius:i64[64,64] = add iun 64:i64[]
    iut:i64[64,64] = select_n iur iun ius
    iuu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iuq
    iuv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iut
    iuw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] iuu
    iux:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] iuv
    iuy:i32[64,64,2] = concatenate[dimension=2] iuw iux
    iuz:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ist iuy
    iva:f64[64,64] = mul iuj iuz
    ivb:f64[64,64] = add iue iva
    ivc:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] itb
    ivd:f64[64,64] = squeeze[dimensions=(0,)] ivc
    ive:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] itc
    ivf:f64[64,64] = squeeze[dimensions=(0,)] ive
    ivg:f64[64,64] = mul ivd ivf
    ivh:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] iti
    ivi:i64[64,64] = squeeze[dimensions=(0,)] ivh
    ivj:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] itg
    ivk:i64[64,64] = squeeze[dimensions=(0,)] ivj
    ivl:bool[64,64] = lt ivi 0:i64[]
    ivm:i64[64,64] = add ivi 64:i64[]
    ivn:i64[64,64] = select_n ivl ivi ivm
    ivo:bool[64,64] = lt ivk 0:i64[]
    ivp:i64[64,64] = add ivk 64:i64[]
    ivq:i64[64,64] = select_n ivo ivk ivp
    ivr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ivn
    ivs:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] ivq
    ivt:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ivr
    ivu:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] ivs
    ivv:i32[64,64,2] = concatenate[dimension=2] ivt ivu
    ivw:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ist ivv
    ivx:f64[64,64] = mul ivg ivw
    ivy:f64[64,64] = add ivb ivx
    ivz:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] itb
    iwa:f64[64,64] = squeeze[dimensions=(0,)] ivz
    iwb:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] itb
    iwc:f64[64,64] = squeeze[dimensions=(0,)] iwb
    iwd:f64[64,64] = mul iwa iwc
    iwe:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] iti
    iwf:i64[64,64] = squeeze[dimensions=(0,)] iwe
    iwg:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] iti
    iwh:i64[64,64] = squeeze[dimensions=(0,)] iwg
    iwi:bool[64,64] = lt iwf 0:i64[]
    iwj:i64[64,64] = add iwf 64:i64[]
    iwk:i64[64,64] = select_n iwi iwf iwj
    iwl:bool[64,64] = lt iwh 0:i64[]
    iwm:i64[64,64] = add iwh 64:i64[]
    iwn:i64[64,64] = select_n iwl iwh iwm
    iwo:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iwk
    iwp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iwn
    iwq:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] iwo
    iwr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] iwp
    iws:i32[64,64,2] = concatenate[dimension=2] iwq iwr
    iwt:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ist iws
    iwu:f64[64,64] = mul iwd iwt
    iwv:f64[64,64] = add ivy iwu
    iww:c128[64,33] = jit[name=fft jaxpr=fft] ikl
    iwx:c128[] = reduce_prod[axes=(0,)] dl
    iwy:c128[] = sqrt iwx
    iwz:c128[] = div (1+0j):c128[] iwy
    ixa:c128[64,33] = mul iww iwz
    ixb:c128[1,64,33] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 64, 33)
      sharding=None
    ] ixa
    ixc:c128[2,64,33] = mul dw ixb
    ixd:f64[2,64,64] = jit[name=fft jaxpr=fft1] ixc
    ixe:f64[] = reduce_prod[axes=(0,)] dm
    ixf:f64[] = sqrt ixe
    ixg:f64[2,64,64] = mul ixd ixf
    ixh:f64[] = neg eb
    ixi:f64[] = convert_element_type[new_dtype=float64 weak_type=False] ixh
    ixj:f64[2,64,64] = mul ixi ixg
    ixk:f64[2,64,64] = add ea ixj
    ixl:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    ixm:f64[2,64,64] = div ixk ixl
    ixn:f64[2,64,64] = floor ixm
    ixo:f64[2,64,64] = sub ixm ixn
    ixp:f64[2,64,64] = sub 1.0:f64[] ixo
    ixq:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dn
    ixr:f64[2,64,64] = sub ixm ixo
    ixs:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] ixr
    ixt:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ixs ixq
    ixu:i64[2,64,64] = add ixt 1:i64[]
    ixv:i64[2,64,64] = jit[name=remainder jaxpr=remainder] ixu ixq
    ixw:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ixp
    ixx:f64[64,64] = squeeze[dimensions=(0,)] ixw
    ixy:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ixp
    ixz:f64[64,64] = squeeze[dimensions=(0,)] ixy
    iya:f64[64,64] = mul ixx ixz
    iyb:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ixt
    iyc:i64[64,64] = squeeze[dimensions=(0,)] iyb
    iyd:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ixt
    iye:i64[64,64] = squeeze[dimensions=(0,)] iyd
    iyf:bool[64,64] = lt iyc 0:i64[]
    iyg:i64[64,64] = add iyc 64:i64[]
    iyh:i64[64,64] = select_n iyf iyc iyg
    iyi:bool[64,64] = lt iye 0:i64[]
    iyj:i64[64,64] = add iye 64:i64[]
    iyk:i64[64,64] = select_n iyi iye iyj
    iyl:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iyh
    iym:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] iyk
    iyn:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] iyl
    iyo:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] iym
    iyp:i32[64,64,2] = concatenate[dimension=2] iyn iyo
    iyq:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ikl iyp
    iyr:f64[64,64] = mul iya iyq
    iys:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ixp
    iyt:f64[64,64] = squeeze[dimensions=(0,)] iys
    iyu:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ixo
    iyv:f64[64,64] = squeeze[dimensions=(0,)] iyu
    iyw:f64[64,64] = mul iyt iyv
    iyx:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ixt
    iyy:i64[64,64] = squeeze[dimensions=(0,)] iyx
    iyz:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ixv
    iza:i64[64,64] = squeeze[dimensions=(0,)] iyz
    izb:bool[64,64] = lt iyy 0:i64[]
    izc:i64[64,64] = add iyy 64:i64[]
    izd:i64[64,64] = select_n izb iyy izc
    ize:bool[64,64] = lt iza 0:i64[]
    izf:i64[64,64] = add iza 64:i64[]
    izg:i64[64,64] = select_n ize iza izf
    izh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] izd
    izi:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] izg
    izj:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] izh
    izk:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] izi
    izl:i32[64,64,2] = concatenate[dimension=2] izj izk
    izm:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ikl izl
    izn:f64[64,64] = mul iyw izm
    izo:f64[64,64] = add iyr izn
    izp:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ixo
    izq:f64[64,64] = squeeze[dimensions=(0,)] izp
    izr:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ixp
    izs:f64[64,64] = squeeze[dimensions=(0,)] izr
    izt:f64[64,64] = mul izq izs
    izu:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ixv
    izv:i64[64,64] = squeeze[dimensions=(0,)] izu
    izw:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ixt
    izx:i64[64,64] = squeeze[dimensions=(0,)] izw
    izy:bool[64,64] = lt izv 0:i64[]
    izz:i64[64,64] = add izv 64:i64[]
    jaa:i64[64,64] = select_n izy izv izz
    jab:bool[64,64] = lt izx 0:i64[]
    jac:i64[64,64] = add izx 64:i64[]
    jad:i64[64,64] = select_n jab izx jac
    jae:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jaa
    jaf:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jad
    jag:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jae
    jah:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jaf
    jai:i32[64,64,2] = concatenate[dimension=2] jag jah
    jaj:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ikl jai
    jak:f64[64,64] = mul izt jaj
    jal:f64[64,64] = add izo jak
    jam:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ixo
    jan:f64[64,64] = squeeze[dimensions=(0,)] jam
    jao:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ixo
    jap:f64[64,64] = squeeze[dimensions=(0,)] jao
    jaq:f64[64,64] = mul jan jap
    jar:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] ixv
    jas:i64[64,64] = squeeze[dimensions=(0,)] jar
    jat:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] ixv
    jau:i64[64,64] = squeeze[dimensions=(0,)] jat
    jav:bool[64,64] = lt jas 0:i64[]
    jaw:i64[64,64] = add jas 64:i64[]
    jax:i64[64,64] = select_n jav jas jaw
    jay:bool[64,64] = lt jau 0:i64[]
    jaz:i64[64,64] = add jau 64:i64[]
    jba:i64[64,64] = select_n jay jau jaz
    jbb:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jax
    jbc:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jba
    jbd:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jbb
    jbe:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jbc
    jbf:i32[64,64,2] = concatenate[dimension=2] jbd jbe
    jbg:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] ikl jbf
    jbh:f64[64,64] = mul jaq jbg
    jbi:f64[64,64] = add jal jbh
    jbj:f64[2,64,64] = neg ixg
    jbk:f64[] = neg eb
    jbl:f64[] = convert_element_type[new_dtype=float64 weak_type=False] jbk
    jbm:f64[2,64,64] = mul jbl jbj
    jbn:f64[2,64,64] = add ea jbm
    jbo:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    jbp:f64[2,64,64] = div jbn jbo
    jbq:f64[2,64,64] = floor jbp
    jbr:f64[2,64,64] = sub jbp jbq
    jbs:f64[2,64,64] = sub 1.0:f64[] jbr
    jbt:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] do
    jbu:f64[2,64,64] = sub jbp jbr
    jbv:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] jbu
    jbw:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jbv jbt
    jbx:i64[2,64,64] = add jbw 1:i64[]
    jby:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jbx jbt
    jbz:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jbs
    jca:f64[64,64] = squeeze[dimensions=(0,)] jbz
    jcb:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jbs
    jcc:f64[64,64] = squeeze[dimensions=(0,)] jcb
    jcd:f64[64,64] = mul jca jcc
    jce:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jbw
    jcf:i64[64,64] = squeeze[dimensions=(0,)] jce
    jcg:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jbw
    jch:i64[64,64] = squeeze[dimensions=(0,)] jcg
    jci:bool[64,64] = lt jcf 0:i64[]
    jcj:i64[64,64] = add jcf 64:i64[]
    jck:i64[64,64] = select_n jci jcf jcj
    jcl:bool[64,64] = lt jch 0:i64[]
    jcm:i64[64,64] = add jch 64:i64[]
    jcn:i64[64,64] = select_n jcl jch jcm
    jco:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jck
    jcp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jcn
    jcq:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jco
    jcr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jcp
    jcs:i32[64,64,2] = concatenate[dimension=2] jcq jcr
    jct:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] jbi jcs
    jcu:f64[64,64] = mul jcd jct
    jcv:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jbs
    jcw:f64[64,64] = squeeze[dimensions=(0,)] jcv
    jcx:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jbr
    jcy:f64[64,64] = squeeze[dimensions=(0,)] jcx
    jcz:f64[64,64] = mul jcw jcy
    jda:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jbw
    jdb:i64[64,64] = squeeze[dimensions=(0,)] jda
    jdc:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jby
    jdd:i64[64,64] = squeeze[dimensions=(0,)] jdc
    jde:bool[64,64] = lt jdb 0:i64[]
    jdf:i64[64,64] = add jdb 64:i64[]
    jdg:i64[64,64] = select_n jde jdb jdf
    jdh:bool[64,64] = lt jdd 0:i64[]
    jdi:i64[64,64] = add jdd 64:i64[]
    jdj:i64[64,64] = select_n jdh jdd jdi
    jdk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jdg
    jdl:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jdj
    jdm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jdk
    jdn:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jdl
    jdo:i32[64,64,2] = concatenate[dimension=2] jdm jdn
    jdp:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] jbi jdo
    jdq:f64[64,64] = mul jcz jdp
    jdr:f64[64,64] = add jcu jdq
    jds:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jbr
    jdt:f64[64,64] = squeeze[dimensions=(0,)] jds
    jdu:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jbs
    jdv:f64[64,64] = squeeze[dimensions=(0,)] jdu
    jdw:f64[64,64] = mul jdt jdv
    jdx:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jby
    jdy:i64[64,64] = squeeze[dimensions=(0,)] jdx
    jdz:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jbw
    jea:i64[64,64] = squeeze[dimensions=(0,)] jdz
    jeb:bool[64,64] = lt jdy 0:i64[]
    jec:i64[64,64] = add jdy 64:i64[]
    jed:i64[64,64] = select_n jeb jdy jec
    jee:bool[64,64] = lt jea 0:i64[]
    jef:i64[64,64] = add jea 64:i64[]
    jeg:i64[64,64] = select_n jee jea jef
    jeh:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jed
    jei:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jeg
    jej:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jeh
    jek:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jei
    jel:i32[64,64,2] = concatenate[dimension=2] jej jek
    jem:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] jbi jel
    jen:f64[64,64] = mul jdw jem
    jeo:f64[64,64] = add jdr jen
    jep:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jbr
    jeq:f64[64,64] = squeeze[dimensions=(0,)] jep
    jer:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jbr
    jes:f64[64,64] = squeeze[dimensions=(0,)] jer
    jet:f64[64,64] = mul jeq jes
    jeu:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jby
    jev:i64[64,64] = squeeze[dimensions=(0,)] jeu
    jew:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jby
    jex:i64[64,64] = squeeze[dimensions=(0,)] jew
    jey:bool[64,64] = lt jev 0:i64[]
    jez:i64[64,64] = add jev 64:i64[]
    jfa:i64[64,64] = select_n jey jev jez
    jfb:bool[64,64] = lt jex 0:i64[]
    jfc:i64[64,64] = add jex 64:i64[]
    jfd:i64[64,64] = select_n jfb jex jfc
    jfe:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jfa
    jff:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jfd
    jfg:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jfe
    jfh:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jff
    jfi:i32[64,64,2] = concatenate[dimension=2] jfg jfh
    jfj:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] jbi jfi
    jfk:f64[64,64] = mul jet jfj
    jfl:f64[64,64] = add jeo jfk
    jfm:f64[64,64] = sub ikl jfl
    jfn:f64[64,64] = div jfm 2.0:f64[]
    jfo:f64[64,64] = add ikl jfn
    jfp:f64[] = neg eb
    jfq:f64[] = convert_element_type[new_dtype=float64 weak_type=False] jfp
    jfr:f64[2,64,64] = mul jfq ixg
    jfs:f64[2,64,64] = add ea jfr
    jft:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    jfu:f64[2,64,64] = div jfs jft
    jfv:f64[2,64,64] = floor jfu
    jfw:f64[2,64,64] = sub jfu jfv
    jfx:f64[2,64,64] = sub 1.0:f64[] jfw
    jfy:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dp
    jfz:f64[2,64,64] = sub jfu jfw
    jga:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] jfz
    jgb:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jga jfy
    jgc:i64[2,64,64] = add jgb 1:i64[]
    jgd:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jgc jfy
    jge:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jfx
    jgf:f64[64,64] = squeeze[dimensions=(0,)] jge
    jgg:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jfx
    jgh:f64[64,64] = squeeze[dimensions=(0,)] jgg
    jgi:f64[64,64] = mul jgf jgh
    jgj:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jgb
    jgk:i64[64,64] = squeeze[dimensions=(0,)] jgj
    jgl:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jgb
    jgm:i64[64,64] = squeeze[dimensions=(0,)] jgl
    jgn:bool[64,64] = lt jgk 0:i64[]
    jgo:i64[64,64] = add jgk 64:i64[]
    jgp:i64[64,64] = select_n jgn jgk jgo
    jgq:bool[64,64] = lt jgm 0:i64[]
    jgr:i64[64,64] = add jgm 64:i64[]
    jgs:i64[64,64] = select_n jgq jgm jgr
    jgt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jgp
    jgu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jgs
    jgv:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jgt
    jgw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jgu
    jgx:i32[64,64,2] = concatenate[dimension=2] jgv jgw
    jgy:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] jfo jgx
    jgz:f64[64,64] = mul jgi jgy
    jha:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jfx
    jhb:f64[64,64] = squeeze[dimensions=(0,)] jha
    jhc:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jfw
    jhd:f64[64,64] = squeeze[dimensions=(0,)] jhc
    jhe:f64[64,64] = mul jhb jhd
    jhf:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jgb
    jhg:i64[64,64] = squeeze[dimensions=(0,)] jhf
    jhh:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jgd
    jhi:i64[64,64] = squeeze[dimensions=(0,)] jhh
    jhj:bool[64,64] = lt jhg 0:i64[]
    jhk:i64[64,64] = add jhg 64:i64[]
    jhl:i64[64,64] = select_n jhj jhg jhk
    jhm:bool[64,64] = lt jhi 0:i64[]
    jhn:i64[64,64] = add jhi 64:i64[]
    jho:i64[64,64] = select_n jhm jhi jhn
    jhp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jhl
    jhq:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jho
    jhr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jhp
    jhs:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jhq
    jht:i32[64,64,2] = concatenate[dimension=2] jhr jhs
    jhu:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] jfo jht
    jhv:f64[64,64] = mul jhe jhu
    jhw:f64[64,64] = add jgz jhv
    jhx:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jfw
    jhy:f64[64,64] = squeeze[dimensions=(0,)] jhx
    jhz:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jfx
    jia:f64[64,64] = squeeze[dimensions=(0,)] jhz
    jib:f64[64,64] = mul jhy jia
    jic:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jgd
    jid:i64[64,64] = squeeze[dimensions=(0,)] jic
    jie:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jgb
    jif:i64[64,64] = squeeze[dimensions=(0,)] jie
    jig:bool[64,64] = lt jid 0:i64[]
    jih:i64[64,64] = add jid 64:i64[]
    jii:i64[64,64] = select_n jig jid jih
    jij:bool[64,64] = lt jif 0:i64[]
    jik:i64[64,64] = add jif 64:i64[]
    jil:i64[64,64] = select_n jij jif jik
    jim:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jii
    jin:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jil
    jio:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jim
    jip:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jin
    jiq:i32[64,64,2] = concatenate[dimension=2] jio jip
    jir:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] jfo jiq
    jis:f64[64,64] = mul jib jir
    jit:f64[64,64] = add jhw jis
    jiu:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jfw
    jiv:f64[64,64] = squeeze[dimensions=(0,)] jiu
    jiw:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jfw
    jix:f64[64,64] = squeeze[dimensions=(0,)] jiw
    jiy:f64[64,64] = mul jiv jix
    jiz:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jgd
    jja:i64[64,64] = squeeze[dimensions=(0,)] jiz
    jjb:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jgd
    jjc:i64[64,64] = squeeze[dimensions=(0,)] jjb
    jjd:bool[64,64] = lt jja 0:i64[]
    jje:i64[64,64] = add jja 64:i64[]
    jjf:i64[64,64] = select_n jjd jja jje
    jjg:bool[64,64] = lt jjc 0:i64[]
    jjh:i64[64,64] = add jjc 64:i64[]
    jji:i64[64,64] = select_n jjg jjc jjh
    jjj:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jjf
    jjk:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jji
    jjl:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jjj
    jjm:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jjk
    jjn:i32[64,64,2] = concatenate[dimension=2] jjl jjm
    jjo:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] jfo jjn
    jjp:f64[64,64] = mul jiy jjo
    jjq:f64[64,64] = add jit jjp
    jjr:f64[] = neg eb
    jjs:f64[] = convert_element_type[new_dtype=float64 weak_type=False] jjr
    jjt:f64[2,64,64] = mul jjs ixg
    jju:f64[2,64,64] = add ea jjt
    jjv:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    jjw:f64[2,64,64] = div jju jjv
    jjx:f64[2,64,64] = floor jjw
    jjy:f64[2,64,64] = sub jjw jjx
    jjz:f64[2,64,64] = sub 1.0:f64[] jjy
    jka:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dq
    jkb:f64[2,64,64] = sub jjw jjy
    jkc:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] jkb
    jkd:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jkc jka
    jke:i64[2,64,64] = add jkd 1:i64[]
    jkf:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jke jka
    jkg:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jjz
    jkh:f64[64,64] = squeeze[dimensions=(0,)] jkg
    jki:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jjz
    jkj:f64[64,64] = squeeze[dimensions=(0,)] jki
    jkk:f64[64,64] = mul jkh jkj
    jkl:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jkd
    jkm:i64[64,64] = squeeze[dimensions=(0,)] jkl
    jkn:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jkd
    jko:i64[64,64] = squeeze[dimensions=(0,)] jkn
    jkp:bool[64,64] = lt jkm 0:i64[]
    jkq:i64[64,64] = add jkm 64:i64[]
    jkr:i64[64,64] = select_n jkp jkm jkq
    jks:bool[64,64] = lt jko 0:i64[]
    jkt:i64[64,64] = add jko 64:i64[]
    jku:i64[64,64] = select_n jks jko jkt
    jkv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jkr
    jkw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jku
    jkx:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jkv
    jky:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jkw
    jkz:i32[64,64,2] = concatenate[dimension=2] jkx jky
    jla:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] iwv jkz
    jlb:f64[64,64] = mul jkk jla
    jlc:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jjz
    jld:f64[64,64] = squeeze[dimensions=(0,)] jlc
    jle:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jjy
    jlf:f64[64,64] = squeeze[dimensions=(0,)] jle
    jlg:f64[64,64] = mul jld jlf
    jlh:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jkd
    jli:i64[64,64] = squeeze[dimensions=(0,)] jlh
    jlj:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jkf
    jlk:i64[64,64] = squeeze[dimensions=(0,)] jlj
    jll:bool[64,64] = lt jli 0:i64[]
    jlm:i64[64,64] = add jli 64:i64[]
    jln:i64[64,64] = select_n jll jli jlm
    jlo:bool[64,64] = lt jlk 0:i64[]
    jlp:i64[64,64] = add jlk 64:i64[]
    jlq:i64[64,64] = select_n jlo jlk jlp
    jlr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jln
    jls:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jlq
    jlt:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jlr
    jlu:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jls
    jlv:i32[64,64,2] = concatenate[dimension=2] jlt jlu
    jlw:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] iwv jlv
    jlx:f64[64,64] = mul jlg jlw
    jly:f64[64,64] = add jlb jlx
    jlz:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jjy
    jma:f64[64,64] = squeeze[dimensions=(0,)] jlz
    jmb:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jjz
    jmc:f64[64,64] = squeeze[dimensions=(0,)] jmb
    jmd:f64[64,64] = mul jma jmc
    jme:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jkf
    jmf:i64[64,64] = squeeze[dimensions=(0,)] jme
    jmg:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jkd
    jmh:i64[64,64] = squeeze[dimensions=(0,)] jmg
    jmi:bool[64,64] = lt jmf 0:i64[]
    jmj:i64[64,64] = add jmf 64:i64[]
    jmk:i64[64,64] = select_n jmi jmf jmj
    jml:bool[64,64] = lt jmh 0:i64[]
    jmm:i64[64,64] = add jmh 64:i64[]
    jmn:i64[64,64] = select_n jml jmh jmm
    jmo:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jmk
    jmp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jmn
    jmq:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jmo
    jmr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jmp
    jms:i32[64,64,2] = concatenate[dimension=2] jmq jmr
    jmt:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] iwv jms
    jmu:f64[64,64] = mul jmd jmt
    jmv:f64[64,64] = add jly jmu
    jmw:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jjy
    jmx:f64[64,64] = squeeze[dimensions=(0,)] jmw
    jmy:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jjy
    jmz:f64[64,64] = squeeze[dimensions=(0,)] jmy
    jna:f64[64,64] = mul jmx jmz
    jnb:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jkf
    jnc:i64[64,64] = squeeze[dimensions=(0,)] jnb
    jnd:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jkf
    jne:i64[64,64] = squeeze[dimensions=(0,)] jnd
    jnf:bool[64,64] = lt jnc 0:i64[]
    jng:i64[64,64] = add jnc 64:i64[]
    jnh:i64[64,64] = select_n jnf jnc jng
    jni:bool[64,64] = lt jne 0:i64[]
    jnj:i64[64,64] = add jne 64:i64[]
    jnk:i64[64,64] = select_n jni jne jnj
    jnl:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jnh
    jnm:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jnk
    jnn:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jnl
    jno:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jnm
    jnp:i32[64,64,2] = concatenate[dimension=2] jnn jno
    jnq:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] iwv jnp
    jnr:f64[64,64] = mul jna jnq
    jns:f64[64,64] = add jmv jnr
    jnt:f64[2,64,64] = neg ixg
    jnu:f64[] = neg eb
    jnv:f64[] = convert_element_type[new_dtype=float64 weak_type=False] jnu
    jnw:f64[2,64,64] = mul jnv jnt
    jnx:f64[2,64,64] = add ea jnw
    jny:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    jnz:f64[2,64,64] = div jnx jny
    joa:f64[2,64,64] = floor jnz
    job:f64[2,64,64] = sub jnz joa
    joc:f64[2,64,64] = sub 1.0:f64[] job
    jod:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dr
    joe:f64[2,64,64] = sub jnz job
    jof:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] joe
    jog:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jof jod
    joh:i64[2,64,64] = add jog 1:i64[]
    joi:i64[2,64,64] = jit[name=remainder jaxpr=remainder] joh jod
    joj:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] joc
    jok:f64[64,64] = squeeze[dimensions=(0,)] joj
    jol:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] joc
    jom:f64[64,64] = squeeze[dimensions=(0,)] jol
    jon:f64[64,64] = mul jok jom
    joo:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jog
    jop:i64[64,64] = squeeze[dimensions=(0,)] joo
    joq:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jog
    jor:i64[64,64] = squeeze[dimensions=(0,)] joq
    jos:bool[64,64] = lt jop 0:i64[]
    jot:i64[64,64] = add jop 64:i64[]
    jou:i64[64,64] = select_n jos jop jot
    jov:bool[64,64] = lt jor 0:i64[]
    jow:i64[64,64] = add jor 64:i64[]
    jox:i64[64,64] = select_n jov jor jow
    joy:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jou
    joz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jox
    jpa:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] joy
    jpb:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] joz
    jpc:i32[64,64,2] = concatenate[dimension=2] jpa jpb
    jpd:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] jns jpc
    jpe:f64[64,64] = mul jon jpd
    jpf:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] joc
    jpg:f64[64,64] = squeeze[dimensions=(0,)] jpf
    jph:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] job
    jpi:f64[64,64] = squeeze[dimensions=(0,)] jph
    jpj:f64[64,64] = mul jpg jpi
    jpk:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jog
    jpl:i64[64,64] = squeeze[dimensions=(0,)] jpk
    jpm:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] joi
    jpn:i64[64,64] = squeeze[dimensions=(0,)] jpm
    jpo:bool[64,64] = lt jpl 0:i64[]
    jpp:i64[64,64] = add jpl 64:i64[]
    jpq:i64[64,64] = select_n jpo jpl jpp
    jpr:bool[64,64] = lt jpn 0:i64[]
    jps:i64[64,64] = add jpn 64:i64[]
    jpt:i64[64,64] = select_n jpr jpn jps
    jpu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jpq
    jpv:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jpt
    jpw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jpu
    jpx:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jpv
    jpy:i32[64,64,2] = concatenate[dimension=2] jpw jpx
    jpz:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] jns jpy
    jqa:f64[64,64] = mul jpj jpz
    jqb:f64[64,64] = add jpe jqa
    jqc:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] job
    jqd:f64[64,64] = squeeze[dimensions=(0,)] jqc
    jqe:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] joc
    jqf:f64[64,64] = squeeze[dimensions=(0,)] jqe
    jqg:f64[64,64] = mul jqd jqf
    jqh:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] joi
    jqi:i64[64,64] = squeeze[dimensions=(0,)] jqh
    jqj:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jog
    jqk:i64[64,64] = squeeze[dimensions=(0,)] jqj
    jql:bool[64,64] = lt jqi 0:i64[]
    jqm:i64[64,64] = add jqi 64:i64[]
    jqn:i64[64,64] = select_n jql jqi jqm
    jqo:bool[64,64] = lt jqk 0:i64[]
    jqp:i64[64,64] = add jqk 64:i64[]
    jqq:i64[64,64] = select_n jqo jqk jqp
    jqr:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jqn
    jqs:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jqq
    jqt:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jqr
    jqu:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jqs
    jqv:i32[64,64,2] = concatenate[dimension=2] jqt jqu
    jqw:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] jns jqv
    jqx:f64[64,64] = mul jqg jqw
    jqy:f64[64,64] = add jqb jqx
    jqz:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] job
    jra:f64[64,64] = squeeze[dimensions=(0,)] jqz
    jrb:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] job
    jrc:f64[64,64] = squeeze[dimensions=(0,)] jrb
    jrd:f64[64,64] = mul jra jrc
    jre:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] joi
    jrf:i64[64,64] = squeeze[dimensions=(0,)] jre
    jrg:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] joi
    jrh:i64[64,64] = squeeze[dimensions=(0,)] jrg
    jri:bool[64,64] = lt jrf 0:i64[]
    jrj:i64[64,64] = add jrf 64:i64[]
    jrk:i64[64,64] = select_n jri jrf jrj
    jrl:bool[64,64] = lt jrh 0:i64[]
    jrm:i64[64,64] = add jrh 64:i64[]
    jrn:i64[64,64] = select_n jrl jrh jrm
    jro:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jrk
    jrp:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jrn
    jrq:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jro
    jrr:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jrp
    jrs:i32[64,64,2] = concatenate[dimension=2] jrq jrr
    jrt:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] jns jrs
    jru:f64[64,64] = mul jrd jrt
    jrv:f64[64,64] = add jqy jru
    jrw:f64[64,64] = sub iwv jrv
    jrx:f64[64,64] = div jrw 2.0:f64[]
    jry:f64[64,64] = add iwv jrx
    jrz:f64[] = neg eb
    jsa:f64[] = convert_element_type[new_dtype=float64 weak_type=False] jrz
    jsb:f64[2,64,64] = mul jsa ixg
    jsc:f64[2,64,64] = add ea jsb
    jsd:f64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] dz
    jse:f64[2,64,64] = div jsc jsd
    jsf:f64[2,64,64] = floor jse
    jsg:f64[2,64,64] = sub jse jsf
    jsh:f64[2,64,64] = sub 1.0:f64[] jsg
    jsi:i64[2,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1, 1)
      sharding=None
    ] ds
    jsj:f64[2,64,64] = sub jse jsg
    jsk:i64[2,64,64] = convert_element_type[new_dtype=int64 weak_type=False] jsj
    jsl:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jsk jsi
    jsm:i64[2,64,64] = add jsl 1:i64[]
    jsn:i64[2,64,64] = jit[name=remainder jaxpr=remainder] jsm jsi
    jso:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jsh
    jsp:f64[64,64] = squeeze[dimensions=(0,)] jso
    jsq:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jsh
    jsr:f64[64,64] = squeeze[dimensions=(0,)] jsq
    jss:f64[64,64] = mul jsp jsr
    jst:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jsl
    jsu:i64[64,64] = squeeze[dimensions=(0,)] jst
    jsv:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jsl
    jsw:i64[64,64] = squeeze[dimensions=(0,)] jsv
    jsx:bool[64,64] = lt jsu 0:i64[]
    jsy:i64[64,64] = add jsu 64:i64[]
    jsz:i64[64,64] = select_n jsx jsu jsy
    jta:bool[64,64] = lt jsw 0:i64[]
    jtb:i64[64,64] = add jsw 64:i64[]
    jtc:i64[64,64] = select_n jta jsw jtb
    jtd:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jsz
    jte:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jtc
    jtf:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jtd
    jtg:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jte
    jth:i32[64,64,2] = concatenate[dimension=2] jtf jtg
    jti:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] jry jth
    jtj:f64[64,64] = mul jss jti
    jtk:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jsh
    jtl:f64[64,64] = squeeze[dimensions=(0,)] jtk
    jtm:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jsg
    jtn:f64[64,64] = squeeze[dimensions=(0,)] jtm
    jto:f64[64,64] = mul jtl jtn
    jtp:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jsl
    jtq:i64[64,64] = squeeze[dimensions=(0,)] jtp
    jtr:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jsn
    jts:i64[64,64] = squeeze[dimensions=(0,)] jtr
    jtt:bool[64,64] = lt jtq 0:i64[]
    jtu:i64[64,64] = add jtq 64:i64[]
    jtv:i64[64,64] = select_n jtt jtq jtu
    jtw:bool[64,64] = lt jts 0:i64[]
    jtx:i64[64,64] = add jts 64:i64[]
    jty:i64[64,64] = select_n jtw jts jtx
    jtz:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jtv
    jua:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jty
    jub:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jtz
    juc:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jua
    jud:i32[64,64,2] = concatenate[dimension=2] jub juc
    jue:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] jry jud
    juf:f64[64,64] = mul jto jue
    jug:f64[64,64] = add jtj juf
    juh:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jsg
    jui:f64[64,64] = squeeze[dimensions=(0,)] juh
    juj:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jsh
    juk:f64[64,64] = squeeze[dimensions=(0,)] juj
    jul:f64[64,64] = mul jui juk
    jum:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jsn
    jun:i64[64,64] = squeeze[dimensions=(0,)] jum
    juo:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jsl
    jup:i64[64,64] = squeeze[dimensions=(0,)] juo
    juq:bool[64,64] = lt jun 0:i64[]
    jur:i64[64,64] = add jun 64:i64[]
    jus:i64[64,64] = select_n juq jun jur
    jut:bool[64,64] = lt jup 0:i64[]
    juu:i64[64,64] = add jup 64:i64[]
    juv:i64[64,64] = select_n jut jup juu
    juw:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jus
    jux:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] juv
    juy:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] juw
    juz:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jux
    jva:i32[64,64,2] = concatenate[dimension=2] juy juz
    jvb:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] jry jva
    jvc:f64[64,64] = mul jul jvb
    jvd:f64[64,64] = add jug jvc
    jve:f64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jsg
    jvf:f64[64,64] = squeeze[dimensions=(0,)] jve
    jvg:f64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jsg
    jvh:f64[64,64] = squeeze[dimensions=(0,)] jvg
    jvi:f64[64,64] = mul jvf jvh
    jvj:i64[1,64,64] = slice[
      limit_indices=(1, 64, 64)
      start_indices=(0, 0, 0)
      strides=None
    ] jsn
    jvk:i64[64,64] = squeeze[dimensions=(0,)] jvj
    jvl:i64[1,64,64] = slice[
      limit_indices=(2, 64, 64)
      start_indices=(1, 0, 0)
      strides=None
    ] jsn
    jvm:i64[64,64] = squeeze[dimensions=(0,)] jvl
    jvn:bool[64,64] = lt jvk 0:i64[]
    jvo:i64[64,64] = add jvk 64:i64[]
    jvp:i64[64,64] = select_n jvn jvk jvo
    jvq:bool[64,64] = lt jvm 0:i64[]
    jvr:i64[64,64] = add jvm 64:i64[]
    jvs:i64[64,64] = select_n jvq jvm jvr
    jvt:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jvp
    jvu:i32[64,64] = convert_element_type[new_dtype=int32 weak_type=False] jvs
    jvv:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jvt
    jvw:i32[64,64,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(64, 64, 1)
      sharding=None
    ] jvu
    jvx:i32[64,64,2] = concatenate[dimension=2] jvv jvw
    jvy:f64[64,64] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] jry jvx
    jvz:f64[64,64] = mul jvi jvy
    jwa:f64[64,64] = add jvd jvz
  in (jjq, jwa) }

🔀 Structured control flow primitives

The loop in the integrate function is unrolled on tracing.
This results in a jaxpr which grows in size as n_step increases with a (steep) accompanying increase in compile time.
As an alternative JAX provides functional structured control primitives in the jax.lax module.
For example lax.fori_loop can be used in place of a native Python for loop.
def integrate_fori_loop(vorticity, tracer, kernels, mesh, time_step, n_step):
    
    def loop_step(i, vorticity_tracer):
        return step(*vorticity_tracer, kernels, mesh, time_step)

    return jax.lax.fori_loop(0, n_step, loop_step, (vorticity, tracer))

Importantly the structured loop control flow primitives like lax.fori_loop are not unrolled on tracing and so compile time does not increase with number of iterations.

➡️ Applying jit to integration

We can now JIT compile this integrate_fori_loop function

jitted_integrate = jax.jit(integrate_fori_loop, static_argnames="n_step")
assert np.allclose(
    jitted_integrate(
        vorticity_jax, tracer, kernels, mesh, time_step, n_step
    ),
    integrate(
        vorticity_numpy, tracer, kernels, mesh, time_step, n_step
    ),
)

📊 Performance of jitted_integrate

Comparing performance we now see (JIT compiled) JAX is now exceeding NumPy’s performance by a reasonable margin:

initial_tracer = generate_initial_tracer(mesh)
print("NumPy")
%timeit integrate(vorticity_numpy, tracer, kernels, mesh, time_step, n_step)
print("Jitted JAX")
%timeit jax.block_until_ready(jitted_integrate(vorticity_jax, tracer, kernels, mesh, time_step, n_step))
NumPy
448 ms ± 63.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Jitted JAX
9.14 ms ± 1.41 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

⛓️ Sequences with lax.scan

As well as lax.fori_loop JAX provides other loop primitives such as lax.scan for sequence to sequence mapping.

In our case we can use lax.scan to define a variant of integrate which outputs sequences of simulated fields.

@jax.jit(static_argnames=("n_step",))
def integrate_scan(vorticity, tracer, kernels, mesh, time_step, n_step):

    def scan_step(vorticity_tracer, _):
        vorticity_tracer = step(*vorticity_tracer, kernels, mesh, time_step)
        return vorticity_tracer, vorticity_tracer

    _, (vorticity_sequence, tracer_sequence) = jax.lax.scan(
        scan_step, (vorticity, tracer), length=n_step
    )
    return vorticity_sequence, tracer_sequence

🙌 Putting it all together

time_step, total_time = 1., 100.
u = rng.standard_normal(mesh.shape[0] * mesh.shape[1])
vorticity = generate_vorticity(u, mesh, kernels)
tracer = generate_initial_tracer(mesh)
vorticity_seq, tracer_seq = integrate_scan(
    vorticity, tracer, kernels, mesh, time_step, int(total_time / time_step)
)
animate_fields(mesh, kernels, vorticity, tracer, vorticity_seq, tracer_seq)

🇩 Differentiable simulation

A common task in scientific computing settings is to fit a model to observed data.
While gradient-free approaches for calibrating models against observations exist, exploiting gradients allows much more efficient fitting in high-dimensional settings.
As a proxy for fitting to data here we will consider the task of optimizing the initial vorticity to bring the final tracer field close to a target value.
We use differentiable objective function which combines a squared error loss on the final tracer field with a L2 regularization term on the initial vorticity:
@jax.jit(static_argnames=("time_step", "total_time"))
def objective(
    u, initial_tracer, target_tracer, kernels, mesh, time_step, total_time
):
    vorticity = generate_vorticity(u, mesh, kernels)
    n_step = int(total_time / time_step)
    _, final_tracer = integrate_fori_loop(
        vorticity, initial_tracer, kernels, mesh, time_step, n_step
    )
    return ((final_tracer - target_tracer)**2).sum() + (u**2).sum()

🎯 Target tracer field

We use a target final tracer field loaded from an image.

🇩 Automatic differentiation with JAX

We can then use the jax.value_and_grad transformation to compute a function which evaluates both the value of the function and its gradient, using reverse mode automatic differentiation

value_and_grad_objective = jax.jit(
    jax.value_and_grad(objective), static_argnames=("time_step", "total_time")
)

This can be composed with the jax.jit transformation to compile the gradient function.

🖥️ Evaluating value and gradient

We can evaluate the transformed function at a test input to check that it correctly gives the same value as the original (primal) function

u = rng.standard_normal(mesh.shape[0] * mesh.shape[1])
objective_kwargs = {
    "initial_tracer": initial_tracer,
    "target_tracer": target_tracer,
    "kernels": kernels,
    "mesh": mesh,
    "time_step": time_step,
    "total_time": total_time
}
value, gradient = value_and_grad_objective(u, **objective_kwargs)
assert np.allclose(value, objective(u, **objective_kwargs))

💵 Cost of automatic differentiation

Importantly reverse mode differentiation allows us to evaluate both the value and gradient of a scalar function with respect to all inputs at operation cost that is a small constant multiple of the operation cost of evaluating the function itself:

%time jax.block_until_ready(objective(u, **objective_kwargs))
CPU times: user 210 ms, sys: 34.6 ms, total: 245 ms
Wall time: 107 ms
Array(5187.07212124, dtype=float64)
%time jax.block_until_ready(value_and_grad_objective(u, **objective_kwargs))
CPU times: user 619 ms, sys: 465 ms, total: 1.08 s
Wall time: 355 ms
(Array(5187.07212124, dtype=float64),
 Array([  -4.5890741 ,  407.31275407, -171.15432591, ...,   -0.40939672,
          -0.6479958 ,   -3.69480833], dtype=float64))

⚡ Optimization with JAX

JAX also provides some example implementations of optimization algorithms that can be used with its automatic differentiation support.
The below helper function uses the adaptive moments (Adam) optimizer to minimize the objective function from an initial state over a specified number of steps.
from jax.example_libraries import optimizers

def optimize(initial_u, objective_kwargs, n_steps=100, step_size=5e-2):
    initialize, update, get_params = optimizers.adam(step_size)
    
    def optimizer_step(i, state):
        u = get_params(state)
        value, grad = value_and_grad_objective(u, **objective_kwargs)
        jax.debug.print("Iter {i:03}: obj = {value:.4g}", i=i, value=value)
        return update(i, grad, state)

    initial_state = initialize(initial_u)
    final_state = jax.lax.fori_loop(0, n_steps + 1, optimizer_step, initial_state)
    return get_params(final_state)

⚡ Optimizing initial vorticity

Using the optimize helper function, we can minimize the objective function to find an initial vorticity which flows the tracer field to be close to the target.

u_init = jnp.zeros(mesh.shape[0] * mesh.shape[1])
u_opt = optimize(u_init, objective_kwargs, 200)
Iter 000: obj = 1850
Iter 001: obj = 1462
Iter 002: obj = 1273
Iter 003: obj = 1131
Iter 004: obj = 993.7
Iter 005: obj = 925.6
Iter 006: obj = 848.9
Iter 007: obj = 787.8
Iter 008: obj = 758.2
Iter 009: obj = 718.3
Iter 010: obj = 687.6
Iter 011: obj = 656.3
Iter 012: obj = 627.6
Iter 013: obj = 600.3
Iter 014: obj = 570.7
Iter 015: obj = 555
Iter 016: obj = 530.6
Iter 017: obj = 512.2
Iter 018: obj = 500.9
Iter 019: obj = 480.7
Iter 020: obj = 460.9
Iter 021: obj = 449.9
Iter 022: obj = 438
Iter 023: obj = 424.8
Iter 024: obj = 410.9
Iter 025: obj = 402.5
Iter 026: obj = 393.1
Iter 027: obj = 382.8
Iter 028: obj = 374.8
Iter 029: obj = 364.6
Iter 030: obj = 356.4
Iter 031: obj = 350.6
Iter 032: obj = 343.2
Iter 033: obj = 335.5
Iter 034: obj = 328.8
Iter 035: obj = 322.6
Iter 036: obj = 316.8
Iter 037: obj = 311.8
Iter 038: obj = 306.3
Iter 039: obj = 300.5
Iter 040: obj = 297.2
Iter 041: obj = 293.2
Iter 042: obj = 289.6
Iter 043: obj = 286.6
Iter 044: obj = 283.5
Iter 045: obj = 280
Iter 046: obj = 277.2
Iter 047: obj = 275
Iter 048: obj = 273
Iter 049: obj = 271
Iter 050: obj = 269.3
Iter 051: obj = 267.3
Iter 052: obj = 265.5
Iter 053: obj = 264
Iter 054: obj = 262.4
Iter 055: obj = 260.6
Iter 056: obj = 259.1
Iter 057: obj = 257.7
Iter 058: obj = 256.4
Iter 059: obj = 255.1
Iter 060: obj = 254
Iter 061: obj = 252.7
Iter 062: obj = 251.3
Iter 063: obj = 249.9
Iter 064: obj = 248.7
Iter 065: obj = 247.6
Iter 066: obj = 246.5
Iter 067: obj = 245.5
Iter 068: obj = 244.5
Iter 069: obj = 243.5
Iter 070: obj = 242.6
Iter 071: obj = 241.7
Iter 072: obj = 240.8
Iter 073: obj = 240
Iter 074: obj = 239.1
Iter 075: obj = 238.3
Iter 076: obj = 237.5
Iter 077: obj = 236.8
Iter 078: obj = 236.1
Iter 079: obj = 235.4
Iter 080: obj = 234.8
Iter 081: obj = 234.2
Iter 082: obj = 233.6
Iter 083: obj = 233
Iter 084: obj = 232.4
Iter 085: obj = 232
Iter 086: obj = 231.6
Iter 087: obj = 231.3
Iter 088: obj = 231.4
Iter 089: obj = 231.9
Iter 090: obj = 232.6
Iter 091: obj = 232.2
Iter 092: obj = 230.4
Iter 093: obj = 228.3
Iter 094: obj = 228
Iter 095: obj = 228.8
Iter 096: obj = 228.9
Iter 097: obj = 227.8
Iter 098: obj = 226.4
Iter 099: obj = 226
Iter 100: obj = 226.5
Iter 101: obj = 226.5
Iter 102: obj = 225.7
Iter 103: obj = 224.7
Iter 104: obj = 224.3
Iter 105: obj = 224.4
Iter 106: obj = 224.5
Iter 107: obj = 224.1
Iter 108: obj = 223.4
Iter 109: obj = 222.7
Iter 110: obj = 222.4
Iter 111: obj = 222.2
Iter 112: obj = 222.2
Iter 113: obj = 222
Iter 114: obj = 221.6
Iter 115: obj = 221.1
Iter 116: obj = 220.6
Iter 117: obj = 220.2
Iter 118: obj = 220
Iter 119: obj = 219.8
Iter 120: obj = 219.7
Iter 121: obj = 219.6
Iter 122: obj = 219.5
Iter 123: obj = 219.4
Iter 124: obj = 219.3
Iter 125: obj = 219.1
Iter 126: obj = 218.9
Iter 127: obj = 218.7
Iter 128: obj = 218.5
Iter 129: obj = 218.2
Iter 130: obj = 217.9
Iter 131: obj = 217.6
Iter 132: obj = 217.3
Iter 133: obj = 217.1
Iter 134: obj = 216.8
Iter 135: obj = 216.6
Iter 136: obj = 216.4
Iter 137: obj = 216.1
Iter 138: obj = 215.9
Iter 139: obj = 215.8
Iter 140: obj = 215.6
Iter 141: obj = 215.5
Iter 142: obj = 215.4
Iter 143: obj = 215.5
Iter 144: obj = 215.7
Iter 145: obj = 216.1
Iter 146: obj = 216.8
Iter 147: obj = 217.4
Iter 148: obj = 217.9
Iter 149: obj = 217.3
Iter 150: obj = 215.9
Iter 151: obj = 214.2
Iter 152: obj = 213
Iter 153: obj = 212.7
Iter 154: obj = 213.1
Iter 155: obj = 213.8
Iter 156: obj = 214
Iter 157: obj = 213.6
Iter 158: obj = 212.7
Iter 159: obj = 211.8
Iter 160: obj = 211.3
Iter 161: obj = 211.2
Iter 162: obj = 211.4
Iter 163: obj = 211.7
Iter 164: obj = 211.8
Iter 165: obj = 211.6
Iter 166: obj = 211.2
Iter 167: obj = 210.7
Iter 168: obj = 210.3
Iter 169: obj = 210
Iter 170: obj = 209.8
Iter 171: obj = 209.8
Iter 172: obj = 209.8
Iter 173: obj = 209.8
Iter 174: obj = 209.8
Iter 175: obj = 209.8
Iter 176: obj = 209.7
Iter 177: obj = 209.6
Iter 178: obj = 209.5
Iter 179: obj = 209.4
Iter 180: obj = 209.2
Iter 181: obj = 209.1
Iter 182: obj = 208.9
Iter 183: obj = 208.8
Iter 184: obj = 208.6
Iter 185: obj = 208.5
Iter 186: obj = 208.4
Iter 187: obj = 208.3
Iter 188: obj = 208.2
Iter 189: obj = 208.1
Iter 190: obj = 208.1
Iter 191: obj = 208
Iter 192: obj = 208
Iter 193: obj = 208
Iter 194: obj = 208
Iter 195: obj = 208
Iter 196: obj = 208
Iter 197: obj = 208
Iter 198: obj = 208
Iter 199: obj = 208
Iter 200: obj = 208

🎞️ Visualizing optimized simulation

We can then simulate trajectories of the fields from the optimized initial vorticity and animate:

📈 Scaling resolution on a GPU

We can run exactly the same code on a GPU equipped system and JAX will default running operations on a GPU.
With GPU acceleration we get a much bigger speed-up with JAX over NumPy, allowing scaling to higher resolutions.
Animation shows simulation at 256×256 resolution after optimization on a GPU (NVIDIA Tesla T4).

📌 Summary

  • Writing numerical code using Array API gives support across NumPy, JAX and other array libraries with minimal effort.
  • If you are already familiar with NumPy, JAX offers an accessible and intuitive route to exploiting (multi-vendor) GPU acceleration and automatic differentiation.
  • Requirement to be able to trace functions to apply (some) JAX transforms puts some constraints on use of control flow.
  • JAX’s functional programming style and structured control primitives can take a bit of getting used to!
  • Important topics not covered: pytrees, vmap, shard_map.