AI Model Training with JAX (original) (raw)

Last Updated : 23 Aug, 2025

JAX is a cutting edge machine learning and numerical computing library developed by Google that combines the familiarity of NumPy with powerful features like automatic differentiation, just-in-time (JIT) compilation and vectorization for highly efficient model training. It seamlessly runs code on CPUs, GPUs and TPUs using XLA compilation to maximize speed and hardware utilization all without requiring manual device placement calls like .cuda() in PyTorch.

Building on JAX, **Flax is a neural network library that provides higher-level abstractions such as nn.Module to enable rapid experimentation with deep neural architectures in a modular and scalable way. Flax supports advanced features including checkpointing, regularization and multi-device training, making it ideal for scalable research and production workflows that fully leverage JAX’s performance and accelerator capabilities.

It provides:

Implementation

Lets see a example of making a model using jax:

Step 1 : Importing Required Libraries

JAX provides a NumPy-like API (jnp) for high-performance arrays and mathematical operations and supports automatic differentiation.

Python `

import jax import jax.numpy as jnp

`

Step 2: Defining Model Initialization

Set up the weights and bias for your linear regression model:

Python `

def init_params(rng_key, input_dim, output_dim): w_key, b_key = jax.random.split(rng_key) W = jax.random.normal(w_key, (input_dim, output_dim)) b = jax.random.normal(b_key, (output_dim,)) return {'W': W, 'b': b}

`

Step 3 : Defining the Model (Linear Layer)

Python `

def model(params, x): return jnp.dot(x, params['W']) + params['b']

`

Step 4 : Defining the Loss Function

Here we will use Mean Squared Error as loss function.

Python `

def loss_fn(params, x, y): preds = model(params, x) return jnp.mean((preds - y) ** 2)

`

Step 5 : Defining One Gradient Update Step

Here JIT (@jax.jit) compiles this to run as fast as possible on CPU, GPU or TPU.

Python `

@jax.jit def update(params, x, y, lr=0.01): grads = jax.grad(loss_fn)(params, x, y) return {k: v - lr * grads[k] for k, v in params.items()}

`

Step 6 : Generating Training and Testing Data

Here 80% data will be used for training and 20% for testing.

Python `

key = jax.random.PRNGKey(0) n_train, n_test = 256, 64 x_train = jax.random.normal(key, (n_train, 2)) true_w = jnp.array([[1.5], [-2.0]]) true_b = jnp.array([0.5]) y_train = x_train @ true_w + true_b + 0.1 * jax.random.normal(key, (n_train, 1))

x_test = jax.random.normal(key, (n_test, 2)) y_test = x_test @ true_w + true_b + 0.1 * jax.random.normal(key, (n_test, 1))

`

Step 7 : Initialize Model Parameters

Python `

params = init_params(key, input_dim=2, output_dim=1)

`

Step 8 : Training Loop

Perform multiple updates over the training data. Here we set epochs to 100.

Python `

epochs = 100 for epoch in range(epochs): params = update(params, x_train, y_train) if (epoch+1) % 20 == 0: train_loss = loss_fn(params, x_train, y_train) print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}")

`

Step 9 : Evaluating the Model

Here we evaluate model and a low test loss means the model learned well.

Python `

test_loss = loss_fn(params, x_test, y_test) print(f"Test Loss: {test_loss:.4f}")

`

Step 10. Make a Sample Prediction

Python `

sample_x = jnp.array([[0.0, 0.0]]) pred_y = model(params, sample_x) print("Prediction for input [0.0, 0.0]:", pred_y)

`

**Output:

Screenshot-2025-07-18-at-11558PM

Output

**Google Colab Link : AI Model Training with JAX

**Best Practices and Common Pitfalls

Practical Use Case: Training on Real Datasets