Skip to content

Variance inconsistency in HeteroskedasticSingleTaskGP #933

Closed
@mklpr

Description

@mklpr

hi,
in HeteroskedasticSingleTaskGP, where using different ways to compute posterior with noise, i got different results and can't explain or understand it myself, so seek for helps here.

i use four ways to compute posterior with noise,

  1. model_heter.posterior(scan_x, observation_noise=True)
  2. mll_heter.likelihood(model_heter.posterior(scan_x, observation_noise=False), scan_x)
  3. model_heter.likelihood.noise_covar.noise_model.posterior(scan_x).mean to calculate noise variance and than add variance from model_heter.posterior(scan_x, observation_noise=False) to compute total posterior variance
  4. model_heter.likelihood.noise_covar.noise_model(scan_x).mean.exp() to calculate noise variance and than add variance from model_heter.posterior(scan_x, observation_noise=False) to compute total posterior variance

method 1 and method 2 has the same results, but method 3 and method 4 different from all others, in my knowledge total posterior variance equals noise variance from noise_model plus variance from GP kernel, and verify it in SingleTaskGP, so what's wrong in HeteroskedasticSingleTaskGP? is it comes from the log transfrom and how mll_heter.likelihood(model_heter.posterior(scan_x, observation_noise=False), scan_x) process it internally? thanks.

test code

Refer to https://colab.research.google.com/drive/1dOUHQzl3aQ8hz6QUtwRrXlQBGqZadQgG#scrollTo=D0A4Cf0W_QkZ

import os
import torch
import matplotlib.pyplot as plt
import warnings
import numpy as np

plt.rcParams['figure.figsize'] = (14, 8)
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.size'] = 14

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.double

# warnings.filterwarnings('ignore')

seed = 7433
torch.manual_seed(seed)
np.random.seed(seed)

W_x = np.random.uniform(0, np.pi, size=200)
W_x = np.sort(W_x)
W_y = np.random.normal(loc=(np.sin(2.5*W_x)*np.sin(1.5*W_x)),
                       scale=(0.01 + 0.25*(1-np.sin(2.5*W_x))**2),
                       size=200)

X_train = torch.tensor(W_x.reshape(-1,1), dtype=torch.double)
y_train = torch.tensor(W_y.reshape(-1, 1), dtype=torch.double)

from botorch.models import SingleTaskGP
from gpytorch.constraints import GreaterThan
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch import fit_gpytorch_model

model = SingleTaskGP(train_X=X_train, train_Y=y_train)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
_ = fit_gpytorch_model(mll)

scan_x = torch.linspace(0, np.pi, 500, dtype=dtype).reshape(-1,1,1)

with torch.no_grad():
    scan_y = model.posterior(scan_x, observation_noise=False)
    plt.plot(scan_x.numpy().reshape(-1), scan_y.mean.reshape(-1))
    
    lower, upper = scan_y.mvn.confidence_region()
    plt.fill_between(scan_x.numpy().reshape(-1), lower.numpy().reshape(-1), upper.numpy().reshape(-1), alpha=0.2)
    
    scan_y_with_noise = model.posterior(scan_x, observation_noise=True)
    lower_with_noise, upper_with_noise = scan_y_with_noise.mvn.confidence_region()
    plt.fill_between(scan_x.numpy().reshape(-1), lower_with_noise.numpy().reshape(-1), upper_with_noise.numpy().reshape(-1), alpha=0.2)
    
    plt.scatter(X_train, y_train)
    
    plt.legend(['posterior mean', 'posterior confidence', 'posterior confidence with noise', 'observed data'])

with torch.no_grad():
    observed_var = torch.pow(model.posterior(X_train).mean - y_train, 2)

from botorch.models import HeteroskedasticSingleTaskGP

model_heter = HeteroskedasticSingleTaskGP(train_X=X_train, train_Y=y_train,
                                    train_Yvar=observed_var)
mll_heter = ExactMarginalLogLikelihood(model_heter.likelihood, model_heter)
_ = fit_gpytorch_model(mll_heter)

mll_heter.eval()
model_heter.eval()
with torch.no_grad():
    plt.figure()
    scan_y = model_heter.posterior(scan_x, observation_noise=False)
    plt.plot(scan_x.numpy().reshape(-1), scan_y.mean.reshape(-1))
    
    lower, upper = scan_y.mvn.confidence_region()
    plt.fill_between(scan_x.numpy().reshape(-1), lower.numpy().reshape(-1), upper.numpy().reshape(-1), alpha=0.2)
    
    scan_y_with_noise = model_heter.posterior(scan_x, observation_noise=True)
    lower_with_noise, upper_with_noise = scan_y_with_noise.mvn.confidence_region()
    plt.fill_between(scan_x.numpy().reshape(-1), lower_with_noise.numpy().reshape(-1), upper_with_noise.numpy().reshape(-1), alpha=0.2)

    scan_y_with_noise2 = mll_heter.likelihood(scan_y.mvn, scan_x)
    lower_with_noise2, upper_with_noise2 = scan_y_with_noise2.confidence_region()
    plt.fill_between(scan_x.numpy().reshape(-1), lower_with_noise2.numpy().reshape(-1), upper_with_noise2.numpy().reshape(-1), alpha=0.2)

    noise_var = model_heter.likelihood.noise_covar.noise_model.posterior(scan_x).mean
    std_with_noise = (scan_y.variance.reshape(-1) + noise_var.reshape(-1)).sqrt()
    plt.fill_between(scan_x.numpy().reshape(-1), (scan_y.mean.reshape(-1) - 2 * std_with_noise.reshape(-1)).numpy(),
                     (scan_y.mean.reshape(-1) + 2 * std_with_noise.reshape(-1)).numpy(), alpha=0.2)

    noise_var2 = model_heter.likelihood.noise_covar.noise_model(scan_x).mean.exp()
    std_with_noise2 = (scan_y.variance.reshape(-1) + noise_var2.reshape(-1)).sqrt()
    plt.fill_between(scan_x.numpy().reshape(-1), (scan_y.mean.reshape(-1) - 2 * std_with_noise2.reshape(-1)).numpy(),
                     (scan_y.mean.reshape(-1) + 2 * std_with_noise2.reshape(-1)).numpy(), alpha=0.2)
    
    plt.scatter(X_train, y_train)
    plt.legend(['posterior mean', 'posterior confidence', 'posterior confidence with noise', 'posterior confidence with noise2',
                'posterior confidence with noise3', 'posterior confidence with noise4' , 'observed data'])

image

image

system info

  • botorch==0.5.0
  • gpytorch==1.5.0
  • torch==1.9.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions