Flax provides a flexible end-to-end user experience for JAX users; its NNX is a simplified API that creates, inspects, debugs, and analyzes neural networks in JAX. It has first class support for Python reference semantics, enabling users to express their models using regular Python objects. Flax NNX is an evolution of the previous Flax Linen API.