Skip to main content

Checkpointing Flax NNX Models with Orbax: Saving and Restoring State

Google for DevelopersDecember 3, 202513 min181 views
36 connections·40 entities in this video→

Understanding NNX State Management

  • πŸ’‘ Flax NNX provides a Pythonic, stateful approach to defining models, where module instances directly hold their state variables (like parameters or batch statistics) as attributes.
  • πŸ”‘ State variables, such as nx.param for learnable parameters or nx.batch_batch_stat for batch normalization statistics, are instances derived from nnx.Variable.
  • 🧩 Since NNX version 0.11, module instances are native JAX PyTrees, allowing them to be passed to JAX functions, but nx.state is used to extract only the dynamic nnx.Variable objects for checkpointing.

Core NNX Functions for State Handling

  • βš™οΈ nx.split separates a module instance into its static structure (a graphdef) and its dynamic state (an nnx.State PyTree), which is suitable for saving or passing to JAX functions.
  • πŸ”„ nx.merge reconstructs a module instance given its static structure and a state PyTree, typically loaded from a checkpoint.
  • ⬆️ nx.update modifies an existing module instance by updating its variables with data from a state PyTree, rather than creating a new instance.

Orbax Checkpointing Workflow: Saving Models

  • πŸ’Ύ Orbax is the standard checkpointing library in the JAX ecosystem, designed to reliably save and load state, especially in complex distributed settings.
  • πŸ—„οΈ The CheckpointManager is a wrapper that handles checkpoint logistics, including saving at specific steps, version tracking, automatic deletion of old checkpoints, and restoring the latest one.
  • πŸ“ To save an NNX model, first create a CheckpointManager, then use nx.split to extract the nnx.State PyTree, and finally call manager.save with the training step and the state wrapped in an Orbax argument structure.
  • βœ… It's important to use wait_until_finished to ensure the save operation completes, especially if it occurs in the background.

Orbax Checkpointing Workflow: Restoring Models

  • πŸ” Restoring requires a template to guide Orbax. An abstract model is created using nx.eval_shape to get a PyTree of shape and dtype information without allocating actual data.
  • πŸ—οΈ This abstract model is then split into its graphdef (static structure) and an abstract_state PyTree, which serves as the template for Orbax.
  • πŸ“₯ manager.restore is called with the abstract state as a template, returning a restored state PyTree containing the loaded JAX arrays.
  • πŸ› οΈ The restored state and the graphdef are then used with nx.merge to reconstruct the model instance, ready for inference or continued training.
Knowledge graph40 entities Β· 36 connections

How they connect

An interactive map of every person, idea, and reference from this conversation. Hover to trace connections, click to explore.

Hover Β· drag to explore
40 entities
Chapters5 moments

Key Moments

Transcript48 segments

Full Transcript

Topics15 themes

What’s Discussed

OrbaxFlax NNXCheckpointingJAX EcosystemModel StateParametersOptimizer StateDistributed TrainingPyTreesSerializationState ManagementNNX ModuleNNX Variablenx.splitnx.merge
Smart Objects40 Β· 36 links
ProductsΒ· 10
ConceptsΒ· 25
MediasΒ· 5