%%capture
%matplotlib inline
from functools import partial
import matplotlib.pyplot as plt
import math
try:
import optax
except ModuleNotFoundError:
%pip install -qq optax
import optax
import jax
import jax.numpy as jnp
from jax import jit
from jax.nn.initializers import glorot_normal
try:
import flax
except ModuleNotFoundError:
%pip install -qq flax
import flax
import flax.linen as nn
from flax.training import train_state
import tensorflow_probability.substrates.jax as tfp
from typing import Any, Callable, Sequence
import optax
!pip install jaxopt
import jaxopt
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
from tqdm import tqdm
import seaborn as sns
from flax.core import unfreeze
import random
Motivation
Neural network traditionally give point estimates on the test data. These predictions could be either a guess or be completely wrong. A better approach would be to give a range or a band of predictions instead of one value. This band of predictions can help in making a decision. For example: If I were to predict the AQI measure for city A for tomorrow, an interval \(I\in[200,230]\) is better than a single number like \(195\).
Model architecture and loss function
A neural network model with activation function as RELU (Rectified Linear Unit) to introduce a degree of non-linearity. The loss function is the squared error loss. The architecture is defined by the user.
class load(nn.Module):
list
features:
@nn.compact
def __call__(self, X, deterministic, rate=0.03):
for i, feature in enumerate(self.features):
= nn.Dense(feature, name=f"layer{i}")(X)
X if i != 0 and i != len(self.features) - 1:
= nn.relu(X)
X = nn.Dropout(rate=rate, deterministic=deterministic)(X)
X return X
def loss_fn(self, params, X, y, deterministic=False, key=jax.random.PRNGKey(0)):
= self.apply(params, X, False, rngs={"dropout": key})
y_pred = jnp.sum((y - y_pred)**2)/(2*X.shape[0])
loss return loss
We use a Adam optimizer and use JAX to do the automatic gradient computation using “value_and_grad.” It takes a function and returns a function where the argument specified remains the same for subsequent calls. When the function is called again, further parameters can now be defined which helps increase the flexibility of the code whereas also helps fix particular set of parameters which the user wants should be same across all function calls.
def fit(model, params, X, y, learning_rate=0.01, epochs=100, key=0, verbose=False):
= optax.adam(learning_rate=learning_rate)
opt = opt.init(params)
opt_state
= jax.jit(jax.value_and_grad(model.loss_fn))
loss_grad_fn = jax.random.PRNGKey(key)
key = []
losses for i in range(epochs):
= jax.random.split(key)
key, _ = loss_grad_fn(params, X, y, True)
loss_val, grads = opt.update(grads, opt_state)
updates, opt_state = optax.apply_updates(params, updates)
params
losses.append(loss_val)if verbose and i % (epochs / 10) == 0:
print('Loss step {}: '.format(i), loss_val)
return params, jnp.array(losses)
In sin data, the noise added is multiplied with \(x\) whereas in hetero data, the noise is multiplied with \(x^2\). Since the noise is dependant on x, it becomes heteroscedastic data.
def data_sin(points=20, xrange=(-3, 3), std=0.3):
= jnp.linspace(-1, 1, 1000).reshape(-1, 1)
xx = jax.random.PRNGKey(0)
key = jax.random.normal(key, shape=(3,)) * 0.02
epsilons = jnp.array([[x + 0.3 * jnp.sin(2 * jnp.pi * (x + epsilons[0])) +
y_true 0.3 * jnp.sin(4 * jnp.pi * (x + epsilons[1])) + epsilons[2]] for x in xx])
= jnp.array([[x + 0.3 * jnp.sin(2 * jnp.pi * (x + epsilons[0])) + 0.3 * jnp.sin(4 *
yy * (x + epsilons[1])) + epsilons[2] + x*np.random.normal(0, std)] for x in xx])
jnp.pi return xx.reshape(1000, 1), yy.reshape(1000, 1), y_true.reshape(1000, 1)
def data_hetero(points=20, xrange=(-3, 3), std=0.3):
= jnp.linspace(-1, 1, 1000).reshape(-1, 1)
xx = jax.random.PRNGKey(0)
key = jnp.array([[x*10*x] for x in xx])
y_true = jnp.array([[x*10*x + x*x*np.random.normal(0, std)] for x in xx])
yy return xx.reshape(1000, 1), yy.reshape(1000, 1), y_true.reshape(1000, 1)
= jax.random.split(jax.random.PRNGKey(0))
key, subkey = data_sin(points=1000, xrange=(-1, 1), std=2)
X, Y, y_true
plt.figure()=0.8)
plt.scatter(X, Y, alpha=2, color='black', linestyle='-')
plt.plot(X, y_true, linewidth sns.despine()
= load([32, 64, 1]) model
We keep the number of neural networks in the bootstrapping process to be 5. Generally, it can be kept as large as 50 too.
= 5
n_estimators = jnp.linspace(-2, 2, 1000).reshape(-1, 1)
x_grid = jax.random.split(jax.random.PRNGKey(0), n_estimators)
keys = []
Y_final = []
parameters for i in range(n_estimators):
= jax.random.choice(keys[i], jnp.array(range(1000)), (1000, 1))
ids = X[ids], Y[ids]
x, y = []
loss = model.init(keys[i], x, True)
params = fit(model, params, x, y, epochs=100)
params, losses = model.apply(params, x_grid, True)
Y_pred
parameters.append(params) Y_final.append(Y_pred)
= jnp.array(Y_final)
Y_final = Y_final.mean(axis=0)
mean = Y_final.std(axis=0)
std = mean.squeeze()
mean = std.squeeze() std
We can see here since the training data lied between \(-1\) and \(1\), the standard deviation did not change much but outside this interval, it increased as x increased. This is the particular characteristic of a bayesian model. If the model gave point predictions, using the error, we would get a constant band of uncertainty i.e. the error itself. This gives no information and does not take into consideration the nature of the model since it is trained on a limited interval and outside these intervals, its uncertainty should naturally increase. This is seen in the bayesian model.
'kx', alpha=0.1)
plt.plot(X, Y, +std, mean-std, color='red', alpha=1)
plt.fill_between(x_grid.squeeze(), mean+2*std,
plt.fill_between(x_grid.squeeze(), mean-2*std, color='red', alpha=0.6)
mean+3*std,
plt.fill_between(x_grid.squeeze(), mean-3*std, color='red', alpha=0.2) mean
<matplotlib.collections.PolyCollection at 0x2aad1a58e50>
Quadratic dataset
= jax.random.split(jax.random.PRNGKey(0))
key, subkey = data_hetero(points=1000, xrange=(-1, 1), std=2)
X, Y, y_true
plt.figure()=0.8)
plt.scatter(X, Y, alpha=2, color='black', linestyle='-')
plt.plot(X, y_true, linewidth sns.despine()
= load([32, 64, 1]) model
= 5
n_estimators = jnp.linspace(-2, 2, 1000).reshape(-1, 1)
x_grid = jax.random.split(jax.random.PRNGKey(0), n_estimators)
keys = []
Y_final = []
parameters for i in range(n_estimators):
= jax.random.choice(keys[i], jnp.array(range(1000)), (1000, 1))
ids = X[ids], Y[ids]
x, y = []
loss = model.init(keys[i], x, True)
params = fit(model, params, x, y, epochs=100)
params, losses = model.apply(params, x_grid, True)
Y_pred
parameters.append(params) Y_final.append(Y_pred)
= jnp.array(Y_final)
Y_final = Y_final.mean(axis=0)
mean = Y_final.std(axis=0)
std = mean.squeeze()
mean = std.squeeze() std
Here too the effect is similar as above. Outside \([-1, 1]\), the uncertainty bands increase. However within the training region, the uncertainty band is narrow since the model is more sure about the training predictions.
'kx', alpha=0.1)
plt.plot(X, Y, +std, mean-std, color='red', alpha=1)
plt.fill_between(x_grid.squeeze(), mean+2*std,
plt.fill_between(x_grid.squeeze(), mean-2*std, color='red', alpha=0.6)
mean+3*std,
plt.fill_between(x_grid.squeeze(), mean-3*std, color='red', alpha=0.2) mean
<matplotlib.collections.PolyCollection at 0x2aad88c2bc0>