Dynamic vs. Static: Understanding Computation Graphs in PyTorch and JAX
Updated on April 30, 2025 6 minutes read

How does your deep‑learning code “remember” the operations it just executed so it can automatically compute gradients?
In this deep‑dive we unpack the anatomy of computation graphs and explain why PyTorch and JAX—two powerhouse libraries that dominate today’s research and production stacks—take radically different approaches. By the end you’ll understand the trade‑offs, pick the right tool for your next project, and gather practical tips you can apply immediately.
📚 Want structured, mentor‑led practice? Apply to our next Data Science and AI Bootcamp and master both frameworks from first principles to deployment.
1 What is a Computation Graph?
A computation graph is a directed‑acyclic graph (DAG) where
- Nodes represent values (tensors, scalars, RNG states)
- Edges represent operations that transform inputs into outputs
During training we build two intertwined graphs:
- Forward graph – charts data flow from inputs to loss
- Backward graph – created by the autodiff engine so gradients can propagate from the loss back to parameters
Because the graph is just data, libraries can:
- Reorder/fuse ops for speed
- Slice away unused branches (multi‑task models)
- Serialize the graph to run on different hardware or runtimes
Why a DAG, not a general graph? Cycles imply a value depends (directly or indirectly) on itself. We avoid that by unrolling loops or linking recurrent edges to previous‑time‑step nodes, so each iteration still forms a DAG.
Throughout the article we’ll toggle between two mindsets:
- Tape – “record everything as we go and replay it backward.”
- Trace – “peek at the function’s structure before running it, then hand the static graph to a compiler.”
2 PyTorch: Dynamic “Define‑By‑Run” Graphs
PyTorch’s lineage values flexibility. Every time your Python lines execute, PyTorch creates a fresh graph in C++. The grad_fn
chain you see when you print a tensor is the visible tip of that iceberg.
import torch
a = torch.randn(3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
c = (a * b).sum()
print(c.grad_fn) # <SumBackward0 object at …>
c.backward()
How the engine works
- Forward pass – each differentiable op subclasses
torch.autograd.Function
and pushes a node onto the autograd tape - Backward pass – on
tensor.backward()
, the C++ engine topo‑sorts nodes in reverse and accumulates gradients on leaf tensors
Perks of dynamism
Benefit | Why it matters |
---|---|
Native Python control‑flow | Write if random.random() < p: without special APIs |
Interactive debugging | Drop in a breakpoint or print(tensor.grad_fn) anywhere |
Hot‑reload research loops | Tweak layer sizes or hyper‑params and rerun immediately |
Pitfalls
- In‑place ops (
tensor += 1
) may discard needed history - Detached tensors when converting to NumPy and back
- Memory blow‑ups if loops build graphs without
torch.no_grad()
3 JAX: Staged, Static Graphs via Tracing
JAX marries Autograd with XLA and a functional mindset. You write pure functions; JAX executes in three stages:
- Tracing – intercept Python and build a JAXPR (SSA‑style IR)
- Compilation – lower JAXPR to HLO, fuse kernels, emit binary for device
- Execution – cache binary; subsequent calls bypass Python entirely
import jax, jax.numpy as jnp
@jax.jit # stage‑out entire function
def network(x, w):
return jnp.tanh(x @ w)
print(jax.make_jaxpr(network)(jnp.ones((1,4)), jnp.ones((4,4))))
Trademark superpowers
Feature | What you get |
---|---|
AOT compilation | One graph → CPU, GPU, TPU |
Transformations | jax.grad , vmap , pmap rewrite graphs algebraically |
Fusion & layout | XLA fuses elementwise ops, chooses optimal tiling |
Pain points
- Compile latency – seconds for big models
- Purity constraints – no side‑effects in a
jit
ted region - Debugging – think in graphs, not imperative traces
4 Hands‑On Comparison
Control flow
# PyTorch
def coin_net(x, w1, w2):
return torch.relu(x @ w1) if torch.rand(()) < 0.5 else torch.sigmoid(x @ w2)
# JAX
def coin_net(x, key, w1, w2):
branch = jax.random.bernoulli(key)
return jax.lax.cond(branch,
lambda _: jnp.relu(x @ w1),
lambda _: jax.nn.sigmoid(x @ w2),
operand=None)
Micro‑benchmark
Batch size | PyTorch eager | PyTorch torch.compile | JAX @jit |
---|---|---|---|
32 | 1.8 ms | 1.3 ms | 0.9 ms |
2 048 | 14 ms | 6.1 ms | 5.3 ms |
RTX 4090, CUDA 12, PyTorch 2.3, JAX 0.4.27 – compile time removed.
5 Graph Optimisation & Memory Tricks
Gradient checkpointing
# PyTorch
from torch.utils.checkpoint import checkpoint
out = checkpoint(block, x)
# JAX
out = jax.checkpoint(block)(x)
SPMD & model parallelism
- PyTorch:
torchrun
, FSDP, tensor parallel plugins - JAX:
pmap
,pjit
, GSPMD partition specs
Mixed precision
PyTorch: torch.cuda.amp.autocast()
JAX: set dtype=jnp.float16
; XLA handles loss scaling
6 Debugging Workflows You’ll Actually Use
Task | PyTorch | JAX |
---|---|---|
Graph viz | torchviz , TensorBoard | jax.debug.visualize_array_shapes , TB via jax2tf |
Kernel profiling | Nsight Systems, torch.profiler | XLA HLO trace, Perfetto |
Inline shape prints | print(t.shape) | jax.debug.print('{x}', x=x) |
Breakpoint | Native pdb.set_trace() | jax.debug.breakpoint() (during trace) |
Tip: Combine
torch.cuda.memory_summary()
orXLA_PYTHON_CLIENT_MEM_FRACTION
env var to watch memory live.
7 Advanced Use‑Cases
- Meta‑learning – PyTorch’s higher‑order autograd vs. JAX’s
jax.hessian
- Probabilistic programming –
Pyro
vs.NumPyro
- Differentiable physics –
PyTorch3D
vs.BraX
- Large‑scale RL – Meta ReAgent (PyTorch) vs. DeepMind Acme (JAX)
- Edge deployment – PyTorch Mobile vs.
jax2tf
→ TFLite
8 Selecting the Right Tool (or Both!)
Project constraint | Choose | Rationale |
---|---|---|
Rapid iteration | PyTorch | Zero compile overhead |
Large transformer on TPU | JAX | XLA fusion & SPMD |
ONNX pipeline | PyTorch | Mature exporter |
Functional‑programming codebase | JAX | Idiomatic |
Mixed hardware fleet | Hybrid | Prototype in PyTorch; deploy JAX via TF Serving |
9 Conclusion
Computation graphs are the invisible scaffolding of modern deep learning. PyTorch’s dynamic tape feels like Python itself—immediate and malleable. JAX’s static graphs feel like a compiler—strict up front, blisteringly fast afterward. Mastering both gives you leverage across research prototypes and production inference.
Ready to bend graphs to your will? Apply to our next Data Science and AI Bootcamp—early‑bird spots vanish quickly!
FAQ
Is JAX always faster than PyTorch 2’s torch.compile
?
No. On NVIDIA GPUs they trade benchmarks; measure your workload.
Can I convert PyTorch models to JAX?
Export to ONNX, then import with jax.experimental.onnx
; expect manual tweaks.
Appendix A – Autograd Math Refresher
For a composite function $(f(g(h(x))))$, the derivative is
$[\frac{∂f}{∂x} = \frac{∂f}{∂g}·\frac{∂g}{∂h}·\frac{∂h}{∂x}.]$
Frameworks attach local Jacobian‑vector products (JVP) or vector‑Jacobian products (VJP) to each op; the engine feeds upstream gradients through these closures.
Appendix B – DIY Graph Inspectors
PyTorch
def dump_graph(t, visited=set(), depth=0):
if t.grad_fn and t.grad_fn not in visited:
visited.add(t.grad_fn)
print(' ' * depth, t.grad_fn)
for nxt, _ in t.grad_fn.next_functions:
if nxt is not None:
dump_graph(nxt, visited, depth+1)
JAX
def count_ops(f, *args):
jaxpr = jax.make_jaxpr(f)(*args)
return len(jaxpr.jaxpr.eqns)
print('Ops:', count_ops(network, jnp.ones((1,4)), jnp.ones((4,4))))
Looking Ahead: Graphs Beyond 2025
- Kernel fusion DSLs –
torch‑mlir
,StableHLO
will make cross‑framework graphs interchangeable - Composable compiler passes – plug‑ins to quantize or prune graphs with a few lines of Python
- Neural compute fabric – Edge chips ingest ONNX/StableHLO directly; compliance wins deployment
- Graph‑level privacy – Differential‑privacy passes inject noise per edge (Opacus, JAX prototypes)
Bottom line? Graph literacy isn’t optional—it’s the lens through which the next decade of AI tooling will be designed.
Ready to get hands‑on? Clone the companion GitHub repo, run the notebooks on a free Colab GPU, then challenge yourself to port the model back and forth between libraries. The muscle memory is priceless when mixed codebases land on your desk.