Step-by-step demonstration of implementing a linear regression model with JAX

Last Modified: July 17th, 2023 | Reading Time: 8 minutes

To start off, I’ll be offering you a comprehensive implementation of a linear regression model using JAX, unaccompanied by any regularization techniques. This model comes equipped with three potential loss functions: ** Mean Absolute Error (MAE)**,

import jax import jax.numpy as jnp import numpy as np from functools import partial from typing import NamedTuple from tqdm.notebook import tqdm LOSS_FN_MAPPING = { "mae": lambda y_true, y_pred: jnp.mean(jnp.abs(y_true - y_pred)), "mse": lambda y_true, y_pred: jnp.mean((y_true - y_pred) ** 2), "rmse": lambda y_true, y_pred: jnp.sqrt(jnp.mean((y_true - y_pred) ** 2)) } class LinearParameters(NamedTuple): w: jnp.ndarray b: jnp.ndarray | None def linear_model(params: LinearParameters, x: jnp.ndarray) -> jnp.ndarray: return jax.lax.cond( jnp.isnan(params.b), lambda: jnp.dot(x, params.w), lambda: jnp.dot(x, params.w) + params.b ) batched_linear_model = jax.vmap(linear_model, in_axes=(None, 0)) def loss_fn(loss_fn_arg, params: LinearParameters, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: preds = batched_linear_model(params, x) return loss_fn_arg(y, preds) @partial(jax.jit, static_argnames=('loss_fn_arg', 'learning_rate', )) def update(learning_rate, loss_fn_arg, params, x, y): grad_loss_fn = jax.value_and_grad(partial(loss_fn, loss_fn_arg)) loss, grad = grad_loss_fn(params, x, y) return jax.tree_map( lambda p, g: p - g * learning_rate, params, grad ), loss class BatchGenerator: def __init__(self, X, Y, batch_size): self.X = X self.Y = Y self.batch_size = batch_size self.num_batches = (X.shape[0] - 1) // batch_size + 1 def __iter__(self): for i in range(self.num_batches): start = i * self.batch_size end = start + self.batch_size yield self.X[start:end], self.Y[start:end] class LinearRegression: def __init__( self, use_bias: bool = True ): self.use_bias = use_bias self.params = None def fit(self, x: np.ndarray, y: np.ndarray, learning_rate: float = 0.01, **kwargs): loss = kwargs.get("loss", "mae") batch_size = kwargs.get("batch_size", 32) epochs = kwargs.get("epochs", 100) assert loss in LOSS_FN_MAPPING.keys(), f"loss must be one of {list(LOSS_FN_MAPPING.keys())}" number_of_features = x.shape[1] resolved_loss_fn = LOSS_FN_MAPPING[loss] batch_generator = BatchGenerator(x, y, batch_size) if self.use_bias: b = jnp.float32(1.0) else: b = jnp.nan w = jax.random.normal(jax.random.PRNGKey(42), (number_of_features,)) self.params = LinearParameters(w, b) with tqdm(range(epochs), position=0) as pbar: for epoch in range(epochs): pbar.set_description(f"Epoch {epoch + 1}/{epochs}") for x_batch, y_batch in batch_generator: self.params, loss_value = update( learning_rate, resolved_loss_fn, self.params, x_batch, y_batch ) pbar.set_postfix({"loss": loss_value}) pbar.update(1) def predict(self, x: np.ndarray) -> np.ndarray: assert self.params is not None, "Model not fitted yet" return batched_linear_model(self.params, jnp.asarray(x))

import jax import jax.numpy as jnp import numpy as np from functools import partial from typing import NamedTuple from tqdm.notebook import tqdm

In the initial phase, the import section of our ** LinearRegression** class comes into the limelight, showcasing an array of libraries we’ll be deploying. Primary among these are the

Transitioning to a brief overview of JAX’s jit function, we’ll avoid diving too deep for the purpose of this article. Nonetheless, it’s worth noting that it plays a critical role in facilitating just-in-time compilation with XLA. This powerful tool allows us to harness the benefits of performance optimization, eliminating the need to micromanage low-level details. As a result, we achieve a seamless blend of efficiency and cleanliness in our code, making it an appealing choice for our linear regression model implementation.

LOSS_FN_MAPPING = { "mae": lambda y_true, y_pred: jnp.mean(jnp.abs(y_true - y_pred)), "mse": lambda y_true, y_pred: jnp.mean((y_true - y_pred) ** 2), "rmse": lambda y_true, y_pred: jnp.sqrt(jnp.mean((y_true - y_pred) ** 2)) }

The code provided in this code features a selection of loss functions for your convenience. These include *Mean Absolute Error (MAE),*** Mean Squared Error (MSE)**, and

class BatchGenerator: def __init__(self, X, Y, batch_size): self.X = X self.Y = Y self.batch_size = batch_size self.num_batches = (X.shape[0] - 1) // batch_size + 1 def __iter__(self): for i in range(self.num_batches): start = i * self.batch_size end = start + self.batch_size yield self.X[start:end], self.Y[start:end]

The ** BatchGenerator** class plays a supportive role in our code implementation. Its primary function is to divide the given dataset into manageable batches. While this class doesn’t perform overly complex tasks, it is essential for handling larger datasets. It accomplishes this by calculating the number of batches and, based on this, generates subsets of the data. The output from this class is essentially a sequence of data batches, making it a useful tool in our linear regression model execution. Importantly, the utilization of this class allows our model to support mini-batch learning.

In this section, I will explain the functions and structures employed to implement the full model. Let us delve deeper into this topic.

class LinearParameters(NamedTuple): w: jnp.ndarray b: jnp.ndarray | None

The * LinearParameters* class serves as a container for two crucial elements used in our regression model: the

def linear_model(params: LinearParameters, x: jnp.ndarray) -> jnp.ndarray: return jax.lax.cond( jnp.isnan(params.b), lambda: jnp.dot(x, params.w), lambda: jnp.dot(x, params.w) + params.b ) batched_linear_model = jax.vmap(linear_model, in_axes=(None, 0)) def loss_fn(loss_fn_arg, params: LinearParameters, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: preds = batched_linear_model(params, x) return loss_fn_arg(y, preds) @partial(jax.jit, static_argnames=('loss_fn_arg', 'learning_rate', )) def update(learning_rate, loss_fn_arg, params, x, y): grad_loss_fn = jax.value_and_grad(partial(loss_fn, loss_fn_arg)) loss, grad = grad_loss_fn(params, x, y) return jax.tree_map( lambda p, g: p - g * learning_rate, params, grad ), loss

The upcoming suite of functions formulates the essence of the linear regression model. The functions labeled as ** linear_model** and ‘

The function labeled ** loss_fn** takes on the task of computing the loss for the supplied features and target data. In contrast, the

Let’s take a moment to delve deeper into the ‘update’ function, specifically focusing on two significant calls nested within it:

*jax.value_and_grad(…)*- This method serves a dual purpose by acting as a wrapper around the loss function. Not only does it compute the gradient, but it also returns the value of the loss function for the given input. It essentially provides a two-for-one functionality—giving us both the gradient and the loss value
- loss value is used in this example just to provide more information visible during the training process

- If you wish to calculate the gradient without using the value of the loss function, you can do so by simply replacing
with*jax.value_and_grad(…)*. Subsequently, instead of using*jax.grad(…)*, you would just need to use*loss, grad = grad_loss_fn(…)**grad = grad_loss_fn(…)*

- This method serves a dual purpose by acting as a wrapper around the loss function. Not only does it compute the gradient, but it also returns the value of the loss function for the given input. It essentially provides a two-for-one functionality—giving us both the gradient and the loss value
*jax.tree_map(…)*is a powerful function within the JAX library that enables applying a function to each element in a structure, while preserving that structure. This structure can be a nested combination of tuples, lists, dictionaries or custom structure. The utility of*jax.tree_map*comes from its ability to handle complex, nested data structures with ease, reducing the amount of manual coding you need to do*jax.tree_map*- In simpler terms, when we use
, it will navigate through both the parameters and the gradients that have been calculated using the*tree_map*function. During this process, it will update each parameter based on the equation p – g * learning_rate. The end result is an updated set of parameters that incorporate the learnings from the gradient computation, thereby incrementally refining our regression model.*value_and_grad* - In essence, it is like an advanced version of the
function, equipped with the intelligence to navigate and operate on intricate data structures right out of the box*map*

- In simpler terms, when we use

This figure showcases an animation demonstrating the learning process as it seeks the minimum of the loss function. It’s an ideal representation of how we’d like our model to navigate through the complex landscape of the loss function. However, it’s essential to understand that this represents an optimal scenario. In real-world applications, it’s often challenging to locate the global minimum due to the presence of numerous local minima. These local minima could potentially trap the learning process, obstructing it from reaching the global minimum that we aim to achieve. The visual illustration in this figure, therefore, while representing the ideal learning journey, should be taken with the understanding that actual model training can be a more challenging and nuanced process.

To demonstrate our linear regression model in action, we’ll need some data. In this next section, we’ll create a synthetic dataset specifically tailored for regression tasks. To accomplish this, we’ll utilize the ** make_regression(…)** function from the popular

import numpy as np from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split from sklearn.metrics import mean_absolute_error xs, ys, coef = make_regression( n_features=1, n_informative=1, n_targets=1, n_samples=10_000, noise=2, coef=True, bias=5 ) scaler = StandardScaler() xs = scaler.fit_transform(xs) x_train, x_test, y_train, y_test = train_test_split(xs, ys, test_size=0.2, random_state=42)

Let’s put our implemented linear regression class to use by creating a model instance with pre-determined hyperparameters. Here’s what we’re going to set:

- The
, which guides the model’s learning speed, will be fixed at*learning rate*for illustrative purposes*0.5* - We’ll run the training for a total of
*300*to allow the model ample opportunity to learn from our data*epochs* - The
, which determines how many data points are processed together, will be set to*batch size**512* - For our
, which quantifies how far off our model’s predictions are from the actual values, we’ll employ*loss function**Mean Absolute Error (MAE)*

linear_regression = LinearRegression(use_bias=True) linear_regression.fit(x_train, y_train, loss="mae", learning_rate=0.5, epochs=300, batch_size=512) y_predictions = linear_regression.predict(x_test) print(f"MAE: {mean_absolute_error(y_test, y_predictions)}")

Once training is complete, it’s time to assess the model’s performance using the test dataset we set aside earlier. We can also examine the learned weights of our model and compare them with the ** coef** variable provided by the

In my case, the weights obtained by our model align well with the coefficients used to generate the synthetic data. This strong correspondence indicates that our model has been quite successful in uncovering the underlying patterns within the dataset, a promising affirmation of the effectiveness of our linear regression model with this specific data.

However, there’s a slight divergence concerning the bias term. While the ** make_regression** function set the

linear_regression.params # coef = array(39.69681269) >>> LinearParameters(w=Array([39.758873], dtype=float32), b=Array(4.5382257, dtype=float32))

To wrap things up, let’s visually represent our findings. We’ll plot our linear regression function, overlaid with the training and testing datasets. This graphical representation will help us better understand how well our model fits the data at hand.

In conclusion, this blog post provided a practical, step-by-step guide to implementing linear regression using JAX. It dove into the intricacies of coding the model, detailing key functionalities, and showcasing the application of the model on a sample dataset. The post illustrated the model’s learning process, highlighting how it effectively identifies underlying patterns.

Importantly, the intention behind this post was not to delve deeply into comprehensive data preprocessing, model fine-tuning, or in-depth model analysis. Rather, the aim was to swiftly illustrate the implementation of linear regression using JAX. This streamlined approach offers a starting point for those interested in the capabilities of JAX, showcasing its efficient and straightforward application in the field of machine learning. The world of JAX is wide and varied, and there’s much more to explore and learn. Happy coding