MMGP_demo / gp_torch.py
fabiencasenave's picture
initial commit
f4d7da3
import numpy as np
import torch
import torch.nn as nn
class GaussianProcessRegressor(nn.Module):
def __init__(
self,
length_scale=1.0,
noise_scale=1.0,
amplitude_scale=1.0,
):
super().__init__()
if isinstance(length_scale, float):
length_scale = np.array([length_scale])
elif isinstance(length_scale, np.ndarray):
assert length_scale.ndim == 1
else:
raise TypeError()
self.register_parameter(
"length_scale_",
param=nn.Parameter(torch.Tensor(np.log(length_scale)), requires_grad=True),
)
self.register_parameter(
"noise_scale_",
param=nn.Parameter(torch.tensor(np.log(noise_scale)), requires_grad=True),
)
self.register_parameter(
"amplitude_scale_",
param=nn.Parameter(
torch.tensor(np.log(amplitude_scale)), requires_grad=True
),
)
self.nll = None
def forward(self, x):
alpha = self.alpha
k = self.Kxy(self.X, x)
mu = k.T.mm(alpha)
return mu
def log_marginal_likelihood(self, X, y):
D = X.shape[1]
K = self.Kxx(X)
L = torch.linalg.cholesky(K)
alpha = torch.linalg.solve(L.T, torch.linalg.solve(L, y))
marginal_likelihood = (
-0.5 * y.T.mm(alpha)
- torch.log(torch.diag(L)).sum()
- D * 0.5 * np.log(2 * np.pi)
)
self.L = L
self.alpha = alpha
self.K = K
return marginal_likelihood
def Kxx(self, X):
param = self.length_scale_.exp().sqrt()
sqdist = torch.cdist(X / param[None], X / param[None]) ** 2
res = self.amplitude_scale_.exp() * torch.exp(-0.5 * sqdist) + self.noise_scale_.exp() * torch.eye(len(X)).type_as(X)
return res
def Kxy(self, X, Z):
param = self.length_scale_.exp().sqrt()
sqdist = torch.cdist(X / param[None], Z / param[None]) ** 2
res = self.amplitude_scale_.exp() * torch.exp(-0.5 * sqdist)
return res
def fit(self, X, y, opt, num_steps):
assert X.shape[1] == len(self.length_scale_)
self.y = y
self.X = X
scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.9)
self.train()
nll_hist = []
for it in range(num_steps):
opt.zero_grad()
try:
nll = -self.log_marginal_likelihood(self.X, self.y).sum()
except torch.linalg.LinAlgError:
break
nll.backward()
opt.step()
if it%10==0 and it<1000:
scheduler.step()
nll_hist.append(nll.item())
return nll_hist