Unlocking Linear Regression with JAX

Step-by-step demonstration of implementing a linear regression model with JAX

Last Modified: July 17th, 2023 | Reading Time: 8 minutes

Full code

To start off, I’ll be offering you a comprehensive implementation of a linear regression model using JAX, unaccompanied by any regularization techniques. This model comes equipped with three potential loss functions: Mean Absolute Error (MAE)Mean Squared Error (MSE), and Root Mean Squared Error (RMSE). In this article, our primary focus will be the Mean Absolute Error as our chosen loss function, though I encourage you to experiment with the others. The aim here is to provide a pure, unadulterated glimpse into the workings of a JAX-powered linear regression model, while also inviting the opportunity for further exploration and learning. For the learning process, we will employ the mini-batch gradient descent method, which allows us to balance computational efficiency with the convergence speed of the model.

import jax
import jax.numpy as jnp
import numpy as np

from functools import partial
from typing import NamedTuple
from tqdm.notebook import tqdm


LOSS_FN_MAPPING = {
    "mae": lambda y_true, y_pred: jnp.mean(jnp.abs(y_true - y_pred)),
    "mse": lambda y_true, y_pred: jnp.mean((y_true - y_pred) ** 2),
    "rmse": lambda y_true, y_pred: jnp.sqrt(jnp.mean((y_true - y_pred) ** 2))
}


class LinearParameters(NamedTuple):
    w: jnp.ndarray 
    b: jnp.ndarray | None


def linear_model(params: LinearParameters, x: jnp.ndarray) -> jnp.ndarray:
    return jax.lax.cond(
        jnp.isnan(params.b),
        lambda: jnp.dot(x, params.w),
        lambda: jnp.dot(x, params.w) + params.b
    )


batched_linear_model = jax.vmap(linear_model, in_axes=(None, 0))


def loss_fn(loss_fn_arg, params: LinearParameters, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
    preds = batched_linear_model(params, x)
    return loss_fn_arg(y, preds)


@partial(jax.jit, static_argnames=('loss_fn_arg', 'learning_rate', ))
def update(learning_rate, loss_fn_arg, params, x, y):
    grad_loss_fn = jax.value_and_grad(partial(loss_fn, loss_fn_arg))
    loss, grad = grad_loss_fn(params, x, y)

    return jax.tree_map(
        lambda p, g: p - g * learning_rate,
        params,
        grad
    ), loss


class BatchGenerator:

    def __init__(self, X, Y, batch_size):
        self.X = X
        self.Y = Y
        self.batch_size = batch_size
        self.num_batches = (X.shape[0] - 1) // batch_size + 1

    def __iter__(self):
        for i in range(self.num_batches):
            start = i * self.batch_size
            end = start + self.batch_size

            yield self.X[start:end], self.Y[start:end]


class LinearRegression:
    def __init__(
            self,
            use_bias: bool = True
    ):
        self.use_bias = use_bias
        self.params = None

    def fit(self, x: np.ndarray, y: np.ndarray, learning_rate: float = 0.01, **kwargs):
        loss = kwargs.get("loss", "mae")
        batch_size = kwargs.get("batch_size", 32)
        epochs = kwargs.get("epochs", 100)

        assert loss in LOSS_FN_MAPPING.keys(), f"loss must be one of {list(LOSS_FN_MAPPING.keys())}"

        number_of_features = x.shape[1]
        resolved_loss_fn = LOSS_FN_MAPPING[loss]

        batch_generator = BatchGenerator(x, y, batch_size)

        if self.use_bias:
            b = jnp.float32(1.0)
        else:
            b = jnp.nan

        w = jax.random.normal(jax.random.PRNGKey(42), (number_of_features,))

        self.params = LinearParameters(w, b)

        with tqdm(range(epochs), position=0) as pbar:

            for epoch in range(epochs):
                pbar.set_description(f"Epoch {epoch + 1}/{epochs}")

                for x_batch, y_batch in batch_generator:
                    self.params, loss_value = update(
                        learning_rate,
                        resolved_loss_fn,
                        self.params,
                        x_batch,
                        y_batch
                    )

                pbar.set_postfix({"loss": loss_value})
                pbar.update(1)

    def predict(self, x: np.ndarray) -> np.ndarray:
        assert self.params is not None, "Model not fitted yet"

        return batched_linear_model(self.params, jnp.asarray(x))

Import

import jax
import jax.numpy as jnp
import numpy as np

from functools import partial
from typing import NamedTuple
from tqdm.notebook import tqdm

In the initial phase, the import section of our LinearRegression class comes into the limelight, showcasing an array of libraries we’ll be deploying. Primary among these are the JAX and NumPy libraries, while we leverage auxiliary libraries like tqdmfunctools, and typing for a variety of supporting tasks. These include rendering progress visually, neatly bundling regression parameters within a class, and enabling the provision of extra arguments to the JAX’s jit function.

Transitioning to a brief overview of JAX’s jit function, we’ll avoid diving too deep for the purpose of this article. Nonetheless, it’s worth noting that it plays a critical role in facilitating just-in-time compilation with XLA. This powerful tool allows us to harness the benefits of performance optimization, eliminating the need to micromanage low-level details. As a result, we achieve a seamless blend of efficiency and cleanliness in our code, making it an appealing choice for our linear regression model implementation.

Loss functions

LOSS_FN_MAPPING = {
    "mae": lambda y_true, y_pred: jnp.mean(jnp.abs(y_true - y_pred)),
    "mse": lambda y_true, y_pred: jnp.mean((y_true - y_pred) ** 2),
    "rmse": lambda y_true, y_pred: jnp.sqrt(jnp.mean((y_true - y_pred) ** 2))
}

The code provided in this code features a selection of loss functions for your convenience. These include Mean Absolute Error (MAE), Mean Squared Error (MSE), and Root Mean Squared Error (RMSE). The choice of which function to apply can be easily made later via the fit method within the LinearRegression class.

Data processing helpers

class BatchGenerator:

    def __init__(self, X, Y, batch_size):
        self.X = X
        self.Y = Y
        self.batch_size = batch_size
        self.num_batches = (X.shape[0] - 1) // batch_size + 1

    def __iter__(self):
        for i in range(self.num_batches):
            start = i * self.batch_size
            end = start + self.batch_size

            yield self.X[start:end], self.Y[start:end]

The BatchGenerator class plays a supportive role in our code implementation. Its primary function is to divide the given dataset into manageable batches. While this class doesn’t perform overly complex tasks, it is essential for handling larger datasets. It accomplishes this by calculating the number of batches and, based on this, generates subsets of the data. The output from this class is essentially a sequence of data batches, making it a useful tool in our linear regression model execution. Importantly, the utilization of this class allows our model to support mini-batch learning.

Regression implementation

In this section, I will explain the functions and structures employed to implement the full model. Let us delve deeper into this topic.

class LinearParameters(NamedTuple):
    w: jnp.ndarray 
    b: jnp.ndarray | None

The LinearParameters class serves as a container for two crucial elements used in our regression model: the weights and the bias. These components play significant roles in shaping the model’s predictive capability. The inclusion of bias is optional and can be controlled during the creation of the LinearRegression class. Thus, LinearParameters essentially offers a neat and efficient way to manage and manipulate the primary parameters of our linear regression model.

def linear_model(params: LinearParameters, x: jnp.ndarray) -> jnp.ndarray:
    return jax.lax.cond(
        jnp.isnan(params.b),
        lambda: jnp.dot(x, params.w),
        lambda: jnp.dot(x, params.w) + params.b
    )
    
    
batched_linear_model = jax.vmap(linear_model, in_axes=(None, 0))


def loss_fn(loss_fn_arg, params: LinearParameters, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
    preds = batched_linear_model(params, x)
    return loss_fn_arg(y, preds)


@partial(jax.jit, static_argnames=('loss_fn_arg', 'learning_rate', ))
def update(learning_rate, loss_fn_arg, params, x, y):
    grad_loss_fn = jax.value_and_grad(partial(loss_fn, loss_fn_arg))
    loss, grad = grad_loss_fn(params, x, y)

    return jax.tree_map(
        lambda p, g: p - g * learning_rate,
        params,
        grad
    ), loss

The upcoming suite of functions formulates the essence of the linear regression model. The functions labeled as linear_model and ‘batched_linear_model‘ lay down the framework of the linear model, which involves the dot product of features and weights (plus bias if it’s been incorporated). The ‘batched_linear_model‘ function specifically facilitates the batch processing of data. If you’re intrigued by the usage of vmap, you’ll find a thorough explanation in a separate article devoted to this function.

The function labeled loss_fn takes on the task of computing the loss for the supplied features and target data. In contrast, the update function encapsulates the full learning process. During this process, it calculates the gradient for a given batch of data, then updates the current parameters based on the learning rate. Together, these functions constitute a comprehensive learning mechanism for our linear regression model.

Let’s take a moment to delve deeper into the ‘update’ function, specifically focusing on two significant calls nested within it:

  • jax.value_and_grad(…)
    • This method serves a dual purpose by acting as a wrapper around the loss function. Not only does it compute the gradient, but it also returns the value of the loss function for the given input. It essentially provides a two-for-one functionality—giving us both the gradient and the loss value
      • loss value is used in this example just to provide more information visible during the training process
    • If you wish to calculate the gradient without using the value of the loss function, you can do so by simply replacing jax.value_and_grad(…) with jax.grad(…). Subsequently, instead of using loss, grad = grad_loss_fn(…), you would just need to use grad = grad_loss_fn(…)
  • jax.tree_map(…)
    • jax.tree_map is a powerful function within the JAX library that enables applying a function to each element in a structure, while preserving that structure. This structure can be a nested combination of tuples, lists, dictionaries or custom structure. The utility of jax.tree_map comes from its ability to handle complex, nested data structures with ease, reducing the amount of manual coding you need to do
      • In simpler terms, when we use tree_map, it will navigate through both the parameters and the gradients that have been calculated using the value_and_grad function. During this process, it will update each parameter based on the equation p – g * learning_rate. The end result is an updated set of parameters that incorporate the learnings from the gradient computation, thereby incrementally refining our regression model.
      • In essence, it is like an advanced version of the map function, equipped with the intelligence to navigate and operate on intricate data structures right out of the box
Figure 1. Depiction of the Learning Process in Pursuit of the Loss Function’s Minimum

This figure showcases an animation demonstrating the learning process as it seeks the minimum of the loss function. It’s an ideal representation of how we’d like our model to navigate through the complex landscape of the loss function. However, it’s essential to understand that this represents an optimal scenario. In real-world applications, it’s often challenging to locate the global minimum due to the presence of numerous local minima. These local minima could potentially trap the learning process, obstructing it from reaching the global minimum that we aim to achieve. The visual illustration in this figure, therefore, while representing the ideal learning journey, should be taken with the understanding that actual model training can be a more challenging and nuanced process.

Learning

To demonstrate our linear regression model in action, we’ll need some data. In this next section, we’ll create a synthetic dataset specifically tailored for regression tasks. To accomplish this, we’ll utilize the make_regression(…) function from the popular sklearn library. This function allows us to easily generate a dataset with a single feature and a corresponding continuous target variable. Let’s proceed to generate our data for model training.

import numpy as np

from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error



xs, ys, coef = make_regression(
    n_features=1,
    n_informative=1,
    n_targets=1,
    n_samples=10_000,
    noise=2,
    coef=True,
    bias=5
)

scaler = StandardScaler()
xs = scaler.fit_transform(xs)

x_train, x_test, y_train, y_test = train_test_split(xs, ys, test_size=0.2, random_state=42)

Let’s put our implemented linear regression class to use by creating a model instance with pre-determined hyperparameters. Here’s what we’re going to set:

  • The learning rate, which guides the model’s learning speed, will be fixed at 0.5 for illustrative purposes
  • We’ll run the training for a total of 300 epochs to allow the model ample opportunity to learn from our data
  • The batch size, which determines how many data points are processed together, will be set to 512
  • For our loss function, which quantifies how far off our model’s predictions are from the actual values, we’ll employ Mean Absolute Error (MAE)
linear_regression = LinearRegression(use_bias=True)
linear_regression.fit(x_train, y_train, loss="mae", learning_rate=0.5, epochs=300, batch_size=512)


y_predictions = linear_regression.predict(x_test)
print(f"MAE: {mean_absolute_error(y_test, y_predictions)}")

Once training is complete, it’s time to assess the model’s performance using the test dataset we set aside earlier. We can also examine the learned weights of our model and compare them with the coef variable provided by the make_regression function. This variable signifies the coefficients of the underlying linear model used to generate the dataset.

In my case, the weights obtained by our model align well with the coefficients used to generate the synthetic data. This strong correspondence indicates that our model has been quite successful in uncovering the underlying patterns within the dataset, a promising affirmation of the effectiveness of our linear regression model with this specific data.

However, there’s a slight divergence concerning the bias term. While the make_regression function set the bias to 5, our model learned a bias of 4.5. Despite this minor discrepancy, it’s important to note that further fine-tuning of the model parameters may help to reduce this gap. This minor variation serves as a reminder that while our model is performing well, there’s always room for further optimization.

linear_regression.params

# coef = array(39.69681269)
>>>
LinearParameters(w=Array([39.758873], dtype=float32), b=Array(4.5382257, dtype=float32))

To wrap things up, let’s visually represent our findings. We’ll plot our linear regression function, overlaid with the training and testing datasets. This graphical representation will help us better understand how well our model fits the data at hand.

Figure 2. Visualization of Linear Regression Model: Fitted Line Versus Observed Data Points

Summary

In conclusion, this blog post provided a practical, step-by-step guide to implementing linear regression using JAX. It dove into the intricacies of coding the model, detailing key functionalities, and showcasing the application of the model on a sample dataset. The post illustrated the model’s learning process, highlighting how it effectively identifies underlying patterns.

Importantly, the intention behind this post was not to delve deeply into comprehensive data preprocessing, model fine-tuning, or in-depth model analysis. Rather, the aim was to swiftly illustrate the implementation of linear regression using JAX. This streamlined approach offers a starting point for those interested in the capabilities of JAX, showcasing its efficient and straightforward application in the field of machine learning. The world of JAX is wide and varied, and there’s much more to explore and learn. Happy coding

100x100
Article by:

Dino Causevic

Dino is highly skilled engineer with 8 years hands on complex tasks and solutions, navigating his way through machine learning, computer vision, and artificial intelligence in general. With all that, Dino managed to lead teams and produce solutions that disrupted the industry, shaping the way we live and use technology today.

Like what we do? Come work with us