Step-by-step demonstration of implementing a linear regression model with JAX
Last Modified: July 17th, 2023 | Reading Time: 8 minutes
To start off, I’ll be offering you a comprehensive implementation of a linear regression model using JAX, unaccompanied by any regularization techniques. This model comes equipped with three potential loss functions: Mean Absolute Error (MAE), Mean Squared Error (MSE), and Root Mean Squared Error (RMSE). In this article, our primary focus will be the Mean Absolute Error as our chosen loss function, though I encourage you to experiment with the others. The aim here is to provide a pure, unadulterated glimpse into the workings of a JAX-powered linear regression model, while also inviting the opportunity for further exploration and learning. For the learning process, we will employ the mini-batch gradient descent method, which allows us to balance computational efficiency with the convergence speed of the model.
import jax import jax.numpy as jnp import numpy as np from functools import partial from typing import NamedTuple from tqdm.notebook import tqdm LOSS_FN_MAPPING = { "mae": lambda y_true, y_pred: jnp.mean(jnp.abs(y_true - y_pred)), "mse": lambda y_true, y_pred: jnp.mean((y_true - y_pred) ** 2), "rmse": lambda y_true, y_pred: jnp.sqrt(jnp.mean((y_true - y_pred) ** 2)) } class LinearParameters(NamedTuple): w: jnp.ndarray b: jnp.ndarray | None def linear_model(params: LinearParameters, x: jnp.ndarray) -> jnp.ndarray: return jax.lax.cond( jnp.isnan(params.b), lambda: jnp.dot(x, params.w), lambda: jnp.dot(x, params.w) + params.b ) batched_linear_model = jax.vmap(linear_model, in_axes=(None, 0)) def loss_fn(loss_fn_arg, params: LinearParameters, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: preds = batched_linear_model(params, x) return loss_fn_arg(y, preds) @partial(jax.jit, static_argnames=('loss_fn_arg', 'learning_rate', )) def update(learning_rate, loss_fn_arg, params, x, y): grad_loss_fn = jax.value_and_grad(partial(loss_fn, loss_fn_arg)) loss, grad = grad_loss_fn(params, x, y) return jax.tree_map( lambda p, g: p - g * learning_rate, params, grad ), loss class BatchGenerator: def __init__(self, X, Y, batch_size): self.X = X self.Y = Y self.batch_size = batch_size self.num_batches = (X.shape[0] - 1) // batch_size + 1 def __iter__(self): for i in range(self.num_batches): start = i * self.batch_size end = start + self.batch_size yield self.X[start:end], self.Y[start:end] class LinearRegression: def __init__( self, use_bias: bool = True ): self.use_bias = use_bias self.params = None def fit(self, x: np.ndarray, y: np.ndarray, learning_rate: float = 0.01, **kwargs): loss = kwargs.get("loss", "mae") batch_size = kwargs.get("batch_size", 32) epochs = kwargs.get("epochs", 100) assert loss in LOSS_FN_MAPPING.keys(), f"loss must be one of {list(LOSS_FN_MAPPING.keys())}" number_of_features = x.shape[1] resolved_loss_fn = LOSS_FN_MAPPING[loss] batch_generator = BatchGenerator(x, y, batch_size) if self.use_bias: b = jnp.float32(1.0) else: b = jnp.nan w = jax.random.normal(jax.random.PRNGKey(42), (number_of_features,)) self.params = LinearParameters(w, b) with tqdm(range(epochs), position=0) as pbar: for epoch in range(epochs): pbar.set_description(f"Epoch {epoch + 1}/{epochs}") for x_batch, y_batch in batch_generator: self.params, loss_value = update( learning_rate, resolved_loss_fn, self.params, x_batch, y_batch ) pbar.set_postfix({"loss": loss_value}) pbar.update(1) def predict(self, x: np.ndarray) -> np.ndarray: assert self.params is not None, "Model not fitted yet" return batched_linear_model(self.params, jnp.asarray(x))
import jax import jax.numpy as jnp import numpy as np from functools import partial from typing import NamedTuple from tqdm.notebook import tqdm
In the initial phase, the import section of our LinearRegression class comes into the limelight, showcasing an array of libraries we’ll be deploying. Primary among these are the JAX and NumPy libraries, while we leverage auxiliary libraries like tqdm, functools, and typing for a variety of supporting tasks. These include rendering progress visually, neatly bundling regression parameters within a class, and enabling the provision of extra arguments to the JAX’s jit function.
Transitioning to a brief overview of JAX’s jit function, we’ll avoid diving too deep for the purpose of this article. Nonetheless, it’s worth noting that it plays a critical role in facilitating just-in-time compilation with XLA. This powerful tool allows us to harness the benefits of performance optimization, eliminating the need to micromanage low-level details. As a result, we achieve a seamless blend of efficiency and cleanliness in our code, making it an appealing choice for our linear regression model implementation.
LOSS_FN_MAPPING = { "mae": lambda y_true, y_pred: jnp.mean(jnp.abs(y_true - y_pred)), "mse": lambda y_true, y_pred: jnp.mean((y_true - y_pred) ** 2), "rmse": lambda y_true, y_pred: jnp.sqrt(jnp.mean((y_true - y_pred) ** 2)) }
The code provided in this code features a selection of loss functions for your convenience. These include Mean Absolute Error (MAE), Mean Squared Error (MSE), and Root Mean Squared Error (RMSE). The choice of which function to apply can be easily made later via the fit method within the LinearRegression class.
class BatchGenerator: def __init__(self, X, Y, batch_size): self.X = X self.Y = Y self.batch_size = batch_size self.num_batches = (X.shape[0] - 1) // batch_size + 1 def __iter__(self): for i in range(self.num_batches): start = i * self.batch_size end = start + self.batch_size yield self.X[start:end], self.Y[start:end]
The BatchGenerator class plays a supportive role in our code implementation. Its primary function is to divide the given dataset into manageable batches. While this class doesn’t perform overly complex tasks, it is essential for handling larger datasets. It accomplishes this by calculating the number of batches and, based on this, generates subsets of the data. The output from this class is essentially a sequence of data batches, making it a useful tool in our linear regression model execution. Importantly, the utilization of this class allows our model to support mini-batch learning.
In this section, I will explain the functions and structures employed to implement the full model. Let us delve deeper into this topic.
class LinearParameters(NamedTuple): w: jnp.ndarray b: jnp.ndarray | None
The LinearParameters class serves as a container for two crucial elements used in our regression model: the weights and the bias. These components play significant roles in shaping the model’s predictive capability. The inclusion of bias is optional and can be controlled during the creation of the LinearRegression class. Thus, LinearParameters essentially offers a neat and efficient way to manage and manipulate the primary parameters of our linear regression model.
def linear_model(params: LinearParameters, x: jnp.ndarray) -> jnp.ndarray: return jax.lax.cond( jnp.isnan(params.b), lambda: jnp.dot(x, params.w), lambda: jnp.dot(x, params.w) + params.b ) batched_linear_model = jax.vmap(linear_model, in_axes=(None, 0)) def loss_fn(loss_fn_arg, params: LinearParameters, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: preds = batched_linear_model(params, x) return loss_fn_arg(y, preds) @partial(jax.jit, static_argnames=('loss_fn_arg', 'learning_rate', )) def update(learning_rate, loss_fn_arg, params, x, y): grad_loss_fn = jax.value_and_grad(partial(loss_fn, loss_fn_arg)) loss, grad = grad_loss_fn(params, x, y) return jax.tree_map( lambda p, g: p - g * learning_rate, params, grad ), loss
The upcoming suite of functions formulates the essence of the linear regression model. The functions labeled as linear_model and ‘batched_linear_model‘ lay down the framework of the linear model, which involves the dot product of features and weights (plus bias if it’s been incorporated). The ‘batched_linear_model‘ function specifically facilitates the batch processing of data. If you’re intrigued by the usage of vmap, you’ll find a thorough explanation in a separate article devoted to this function.
The function labeled loss_fn takes on the task of computing the loss for the supplied features and target data. In contrast, the update function encapsulates the full learning process. During this process, it calculates the gradient for a given batch of data, then updates the current parameters based on the learning rate. Together, these functions constitute a comprehensive learning mechanism for our linear regression model.
Let’s take a moment to delve deeper into the ‘update’ function, specifically focusing on two significant calls nested within it:
This figure showcases an animation demonstrating the learning process as it seeks the minimum of the loss function. It’s an ideal representation of how we’d like our model to navigate through the complex landscape of the loss function. However, it’s essential to understand that this represents an optimal scenario. In real-world applications, it’s often challenging to locate the global minimum due to the presence of numerous local minima. These local minima could potentially trap the learning process, obstructing it from reaching the global minimum that we aim to achieve. The visual illustration in this figure, therefore, while representing the ideal learning journey, should be taken with the understanding that actual model training can be a more challenging and nuanced process.
To demonstrate our linear regression model in action, we’ll need some data. In this next section, we’ll create a synthetic dataset specifically tailored for regression tasks. To accomplish this, we’ll utilize the make_regression(…) function from the popular sklearn library. This function allows us to easily generate a dataset with a single feature and a corresponding continuous target variable. Let’s proceed to generate our data for model training.
import numpy as np from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split from sklearn.metrics import mean_absolute_error xs, ys, coef = make_regression( n_features=1, n_informative=1, n_targets=1, n_samples=10_000, noise=2, coef=True, bias=5 ) scaler = StandardScaler() xs = scaler.fit_transform(xs) x_train, x_test, y_train, y_test = train_test_split(xs, ys, test_size=0.2, random_state=42)
Let’s put our implemented linear regression class to use by creating a model instance with pre-determined hyperparameters. Here’s what we’re going to set:
linear_regression = LinearRegression(use_bias=True) linear_regression.fit(x_train, y_train, loss="mae", learning_rate=0.5, epochs=300, batch_size=512) y_predictions = linear_regression.predict(x_test) print(f"MAE: {mean_absolute_error(y_test, y_predictions)}")
Once training is complete, it’s time to assess the model’s performance using the test dataset we set aside earlier. We can also examine the learned weights of our model and compare them with the coef variable provided by the make_regression function. This variable signifies the coefficients of the underlying linear model used to generate the dataset.
In my case, the weights obtained by our model align well with the coefficients used to generate the synthetic data. This strong correspondence indicates that our model has been quite successful in uncovering the underlying patterns within the dataset, a promising affirmation of the effectiveness of our linear regression model with this specific data.
However, there’s a slight divergence concerning the bias term. While the make_regression function set the bias to 5, our model learned a bias of 4.5. Despite this minor discrepancy, it’s important to note that further fine-tuning of the model parameters may help to reduce this gap. This minor variation serves as a reminder that while our model is performing well, there’s always room for further optimization.
linear_regression.params # coef = array(39.69681269) >>> LinearParameters(w=Array([39.758873], dtype=float32), b=Array(4.5382257, dtype=float32))
To wrap things up, let’s visually represent our findings. We’ll plot our linear regression function, overlaid with the training and testing datasets. This graphical representation will help us better understand how well our model fits the data at hand.
In conclusion, this blog post provided a practical, step-by-step guide to implementing linear regression using JAX. It dove into the intricacies of coding the model, detailing key functionalities, and showcasing the application of the model on a sample dataset. The post illustrated the model’s learning process, highlighting how it effectively identifies underlying patterns.
Importantly, the intention behind this post was not to delve deeply into comprehensive data preprocessing, model fine-tuning, or in-depth model analysis. Rather, the aim was to swiftly illustrate the implementation of linear regression using JAX. This streamlined approach offers a starting point for those interested in the capabilities of JAX, showcasing its efficient and straightforward application in the field of machine learning. The world of JAX is wide and varied, and there’s much more to explore and learn. Happy coding