**Improving Mujoco-MJX with First Principles**
Han Wang
October 20th, 2024
[Han's Blog](https://hansolowang.com)
TLDR: Current MJX implementation of quaternion multiplication is not efficient on the GPU. Re-implementing quaternion multiplication with a GPU friendly method leads to noticeable improvements in compile and execution time.
# Background
Last year Mujoco released [Mujoco XLA](https://mujoco.readthedocs.io/en/stable/mjx.html#). A JAX implementation of physics simulation. MJX is best for simulating in parallel on the GPU. However, while profiling an MJX simulation of kuka_iiwa14 on GPU. I noticed that the forward kinematics is unusually slow. The jit time of forward kinematics is 1.83s and the execution time is 0.00213s, over 20x slower than the python binding of the C implementation....
model_path = "kuka_iiwa_14/iiwa14.xml"
mjc_m = mujoco.MjModel.from_xml_path(model_path)
mjc_d = mujoco.MjData(mjc_m)
mjx_m = mjx.put_model(mjc_m)
mjx_d = mjx.make_data(mjx_m)
mjc_d = mujoco.mj_kinematics(mjc_m, mjc_d) # 0.0001s
jit_kinematics = jax.jit(mjx.kinematics)
f_jitted = jax.jit(mjx.kinematics).lower(mjx_m, mjx_d).compile() # 1.83s compile time
mjx_d = f_jitted(mjx_m, mjx_d) # 0.00213s execution time
# The timings are the average of 1000 runs each.
Let's see what the intermediate stages look like under the hood .
kinematics jaxpr: 7100 lines
kinematics hlo: 5300 lines
Not very helpful... My brain is too smooth to retrieve useful information from the thousand of lines of intermediate jaxpr and hlo.
If there is somehow a way to associate each jaxpr or hlo with the originating python code, then it might be more useful. So if you know any useful static/dynamics instrumentation tools for JAX please let me know.
# Applying First Principles Thinking
OK, lets use first principles (probably should have started with this first :p)
Forward kinematics at the basic level is coordinate transformations down a chain. You apply a translation and a rotation at each link. Translation is addition. Rotation is more complex. In MJX, its done with quaternion multiplication. The implementation is the standard 28 floating point operations.
When we dig deeper and peek the HLO of the MJX implementation. Woah... for those 28 floating point operations, we have 96 hlo operations. The majority of them are memory operations, slicing and reshaping.
Ok, lets avoid the memory operations by reimplementing quaternion multiplications as [matrix-matrix multiplications](https://en.wikipedia.org/wiki/Quaternion#Matrix_representations). It will be over 200 floating point operations instead of 28, but should save us memory expenses.
CODE for quat mul with matrices
QUAT_MULTIPLY = np.zeros((4, 4, 4), dtype=np.float64)
QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0],
[ 0,-1, 0, 0],
[ 0, 0,-1, 0],
[ 0, 0, 0,-1]]
QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0],
[ 1, 0, 0, 0],
[ 0, 0, 0, 1],
[ 0, 0,-1, 0]]
QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0],
[ 0, 0, 0,-1],
[ 1, 0, 0, 0],
[ 0, 1, 0, 0]]
QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
[ 0, 0, 1, 0],
[ 0,-1, 0, 0],
[ 1, 0, 0, 0]]
def quat_mul(u: jax.Array, v: jax.Array) -> jax.Array:
return jnp.sum(
u[..., :, None, None] *
v[..., None, :, None],
axis=(-3, -2))
Let's compare the memory performance with Nsight Compute