Dynamic vs. Static: Understanding Computation Graphs in PyTorch and JAX

Updated on April 30, 2025 6 minutes read

Dynamic vs. Static: Understanding Computation Graphs in PyTorch and JAX cover image

How does your deep‑learning code “remember” the operations it just executed so it can automatically compute gradients?
In this deep‑dive we unpack the anatomy of computation graphs and explain why PyTorch and JAX—two powerhouse libraries that dominate today’s research and production stacks—take radically different approaches. By the end you’ll understand the trade‑offs, pick the right tool for your next project, and gather practical tips you can apply immediately.

📚 Want structured, mentor‑led practice? Apply to our next Data Science and AI Bootcamp and master both frameworks from first principles to deployment.


1 What is a Computation Graph?

A computation graph is a directed‑acyclic graph (DAG) where

  • Nodes represent values (tensors, scalars, RNG states)
  • Edges represent operations that transform inputs into outputs

During training we build two intertwined graphs:

  1. Forward graph – charts data flow from inputs to loss
  2. Backward graph – created by the autodiff engine so gradients can propagate from the loss back to parameters

Because the graph is just data, libraries can:

  • Reorder/fuse ops for speed
  • Slice away unused branches (multi‑task models)
  • Serialize the graph to run on different hardware or runtimes

Why a DAG, not a general graph? Cycles imply a value depends (directly or indirectly) on itself. We avoid that by unrolling loops or linking recurrent edges to previous‑time‑step nodes, so each iteration still forms a DAG.

Throughout the article we’ll toggle between two mindsets:

  • Tape – “record everything as we go and replay it backward.”
  • Trace – “peek at the function’s structure before running it, then hand the static graph to a compiler.”

2 PyTorch: Dynamic “Define‑By‑Run” Graphs

PyTorch’s lineage values flexibility. Every time your Python lines execute, PyTorch creates a fresh graph in C++. The grad_fn chain you see when you print a tensor is the visible tip of that iceberg.

import torch

a = torch.randn(3, requires_grad=True)
b = torch.randn(3, requires_grad=True)

c = (a * b).sum()
print(c.grad_fn)  # <SumBackward0 object at …>

c.backward()

How the engine works

  1. Forward pass – each differentiable op subclasses torch.autograd.Function and pushes a node onto the autograd tape
  2. Backward pass – on tensor.backward(), the C++ engine topo‑sorts nodes in reverse and accumulates gradients on leaf tensors

Perks of dynamism

BenefitWhy it matters
Native Python control‑flowWrite if random.random() < p: without special APIs
Interactive debuggingDrop in a breakpoint or print(tensor.grad_fn) anywhere
Hot‑reload research loopsTweak layer sizes or hyper‑params and rerun immediately

Pitfalls

  • In‑place ops (tensor += 1) may discard needed history
  • Detached tensors when converting to NumPy and back
  • Memory blow‑ups if loops build graphs without torch.no_grad()

3 JAX: Staged, Static Graphs via Tracing

JAX marries Autograd with XLA and a functional mindset. You write pure functions; JAX executes in three stages:

  1. Tracing – intercept Python and build a JAXPR (SSA‑style IR)
  2. Compilation – lower JAXPR to HLO, fuse kernels, emit binary for device
  3. Execution – cache binary; subsequent calls bypass Python entirely
import jax, jax.numpy as jnp

@jax.jit  # stage‑out entire function
def network(x, w):
    return jnp.tanh(x @ w)

print(jax.make_jaxpr(network)(jnp.ones((1,4)), jnp.ones((4,4))))

Trademark superpowers

FeatureWhat you get
AOT compilationOne graph → CPU, GPU, TPU
Transformationsjax.grad, vmap, pmap rewrite graphs algebraically
Fusion & layoutXLA fuses elementwise ops, chooses optimal tiling

Pain points

  • Compile latency – seconds for big models
  • Purity constraints – no side‑effects in a jitted region
  • Debugging – think in graphs, not imperative traces

4 Hands‑On Comparison

Control flow

# PyTorch
def coin_net(x, w1, w2):
    return torch.relu(x @ w1) if torch.rand(()) < 0.5 else torch.sigmoid(x @ w2)

# JAX
def coin_net(x, key, w1, w2):
    branch = jax.random.bernoulli(key)
    return jax.lax.cond(branch,
                        lambda _: jnp.relu(x @ w1),
                        lambda _: jax.nn.sigmoid(x @ w2),
                        operand=None)

Micro‑benchmark

Batch sizePyTorch eagerPyTorch torch.compileJAX @jit
321.8 ms1.3 ms0.9 ms
2 04814 ms6.1 ms5.3 ms

RTX 4090, CUDA 12, PyTorch 2.3, JAX 0.4.27 – compile time removed.


5 Graph Optimisation & Memory Tricks

Gradient checkpointing

# PyTorch
from torch.utils.checkpoint import checkpoint
out = checkpoint(block, x)

# JAX
out = jax.checkpoint(block)(x)

SPMD & model parallelism

  • PyTorch: torchrun, FSDP, tensor parallel plugins
  • JAX: pmap, pjit, GSPMD partition specs

Mixed precision

PyTorch: torch.cuda.amp.autocast()
JAX: set dtype=jnp.float16; XLA handles loss scaling


6 Debugging Workflows You’ll Actually Use

TaskPyTorchJAX
Graph viztorchviz, TensorBoardjax.debug.visualize_array_shapes, TB via jax2tf
Kernel profilingNsight Systems, torch.profilerXLA HLO trace, Perfetto
Inline shape printsprint(t.shape)jax.debug.print('{x}', x=x)
BreakpointNative pdb.set_trace()jax.debug.breakpoint() (during trace)

Tip: Combine torch.cuda.memory_summary() or XLA_PYTHON_CLIENT_MEM_FRACTION env var to watch memory live.


7 Advanced Use‑Cases

  1. Meta‑learning – PyTorch’s higher‑order autograd vs. JAX’s jax.hessian
  2. Probabilistic programmingPyro vs. NumPyro
  3. Differentiable physicsPyTorch3D vs. BraX
  4. Large‑scale RL – Meta ReAgent (PyTorch) vs. DeepMind Acme (JAX)
  5. Edge deployment – PyTorch Mobile vs. jax2tf → TFLite

8 Selecting the Right Tool (or Both!)

Project constraintChooseRationale
Rapid iterationPyTorchZero compile overhead
Large transformer on TPUJAXXLA fusion & SPMD
ONNX pipelinePyTorchMature exporter
Functional‑programming codebaseJAXIdiomatic
Mixed hardware fleetHybridPrototype in PyTorch; deploy JAX via TF Serving

9 Conclusion

Computation graphs are the invisible scaffolding of modern deep learning. PyTorch’s dynamic tape feels like Python itself—immediate and malleable. JAX’s static graphs feel like a compiler—strict up front, blisteringly fast afterward. Mastering both gives you leverage across research prototypes and production inference.

Ready to bend graphs to your will? Apply to our next Data Science and AI Bootcamp—early‑bird spots vanish quickly!


FAQ

**How do I visualise a JAX graph?** Use `jax.make_jaxpr` for text or convert with `jax2tf` then open in TensorBoard.

Is JAX always faster than PyTorch 2’s torch.compile?
No. On NVIDIA GPUs they trade benchmarks; measure your workload.

Can I convert PyTorch models to JAX?
Export to ONNX, then import with jax.experimental.onnx; expect manual tweaks.


Appendix A – Autograd Math Refresher

For a composite function $(f(g(h(x))))$, the derivative is $[\frac{∂f}{∂x} = \frac{∂f}{∂g}·\frac{∂g}{∂h}·\frac{∂h}{∂x}.]$
Frameworks attach local Jacobian‑vector products (JVP) or vector‑Jacobian products (VJP) to each op; the engine feeds upstream gradients through these closures.


Appendix B – DIY Graph Inspectors

PyTorch

def dump_graph(t, visited=set(), depth=0):
    if t.grad_fn and t.grad_fn not in visited:
        visited.add(t.grad_fn)
        print('  ' * depth, t.grad_fn)
        for nxt, _ in t.grad_fn.next_functions:
            if nxt is not None:
                dump_graph(nxt, visited, depth+1)

JAX

def count_ops(f, *args):
    jaxpr = jax.make_jaxpr(f)(*args)
    return len(jaxpr.jaxpr.eqns)
print('Ops:', count_ops(network, jnp.ones((1,4)), jnp.ones((4,4))))

Looking Ahead: Graphs Beyond 2025

  • Kernel fusion DSLstorch‑mlir, StableHLO will make cross‑framework graphs interchangeable
  • Composable compiler passes – plug‑ins to quantize or prune graphs with a few lines of Python
  • Neural compute fabric – Edge chips ingest ONNX/StableHLO directly; compliance wins deployment
  • Graph‑level privacy – Differential‑privacy passes inject noise per edge (Opacus, JAX prototypes)

Bottom line? Graph literacy isn’t optional—it’s the lens through which the next decade of AI tooling will be designed.


Ready to get hands‑on? Clone the companion GitHub repo, run the notebooks on a free Colab GPU, then challenge yourself to port the model back and forth between libraries. The muscle memory is priceless when mixed codebases land on your desk.