From 93048811da0f2e1c4afe0c57664f025a47f92f17 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Sat, 18 Jan 2025 02:55:14 -0800 Subject: [PATCH 01/35] Latent Information Gain --- .../acquisition/latent_information_gain.py | 95 +++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 botorch_community/acquisition/latent_information_gain.py diff --git a/botorch_community/acquisition/latent_information_gain.py b/botorch_community/acquisition/latent_information_gain.py new file mode 100644 index 0000000000..5cae0fc0f3 --- /dev/null +++ b/botorch_community/acquisition/latent_information_gain.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +Latent Information Gain Acquisition Function for Neural Process Models. + +References: + +.. [Wu2023arxiv] + Wu, D., Niu, R., Chinazzi, M., Vespignani, A., Ma, Y.-A., & Yu, R. (2023). + Deep Bayesian Active Learning for Accelerating Stochastic Simulation. + arXiv preprint arXiv:2106.02770. Retrieved from https://arxiv.org/abs/2106.02770 + +Contributor: eibarolle +""" + +from __future__ import annotations + +import warnings +from typing import Optional + +import torch +from botorch import settings +from botorch_community.models.np_regression import NeuralProcessModel +from torch import Tensor + +import torch +#reference: https://arxiv.org/abs/2106.02770 + +class LatentInformationGain: + def __init__( + self, + model: NeuralProcessModel, + num_samples: int = 10, + min_std: float = 0.1, + scaler: float = 0.9 + ) -> None: + """ + Latent Information Gain (LIG) Acquisition Function, designed for the + NeuralProcessModel. + + Args: + model: Trained NeuralProcessModel. + num_samples (int): Number of samples for calculation, defaults to 10. + min_std: Float representing the minimum possible standardized std, defaults to 0.1. + scaler: Float scaling the std, defaults to 0.9. + """ + self.model = model + self.num_samples = num_samples + self.min_std = min_std + self.scaler = scaler + + def acquisition(self, candidate_x, context_x, context_y): + """ + Conduct the Latent Information Gain acquisition function for the inputs. + + Args: + candidate_x: Candidate input points, as a Tensor. + context_x: Context input points, as a Tensor. + context_y: Context target points, as a Tensor. + + Returns: + torch.Tensor: The LIG score of computed KLDs. + """ + + # Encoding and Scaling the context data + z_mu_context, z_logvar_context = self.model.data_to_z_params(context_x, context_y) + kl = 0.0 + for _ in range(self.num_samples): + # Taking reparameterized samples + samples = self.model.sample_z(z_mu_context, z_logvar_context) + + # Using the Decoder to take predicted values + y_pred = self.model.decoder(candidate_x, samples) + + # Combining context and candidate data + combined_x = torch.cat([context_x, candidate_x], dim=0) + combined_y = torch.cat([context_y, y_pred], dim=0) + + # Computing posterior variables + z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params(combined_x, combined_y) + std_prior = self.min_std + self.scaler * torch.sigmoid(z_logvar_context) + std_posterior = self.min_std + self.scaler * torch.sigmoid(z_logvar_posterior) + + p = torch.distributions.Normal(z_mu_posterior, std_posterior) + q = torch.distributions.Normal(z_mu_context, std_prior) + + kl_divergence = torch.distributions.kl_divergence(p, q).sum() + kl += kl_divergence + + # Average KLD + return kl / self.num_samples From 7a8e4aba0b326744cb48f9dac99afeae0a3e4467 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Sat, 18 Jan 2025 02:55:55 -0800 Subject: [PATCH 02/35] NP Regression --- botorch_community/models/np_regression.py | 458 ++++++++++++++++++++++ 1 file changed, 458 insertions(+) create mode 100644 botorch_community/models/np_regression.py diff --git a/botorch_community/models/np_regression.py b/botorch_community/models/np_regression.py new file mode 100644 index 0000000000..955da2315b --- /dev/null +++ b/botorch_community/models/np_regression.py @@ -0,0 +1,458 @@ +r""" +Neural Process Regression models based on PyTorch models. + +References: + +.. [Wu2023arxiv] + Wu, D., Niu, R., Chinazzi, M., Vespignani, A., Ma, Y.-A., & Yu, R. (2023). + Deep Bayesian Active Learning for Accelerating Stochastic Simulation. + arXiv preprint arXiv:2106.02770. Retrieved from https://arxiv.org/abs/2106.02770 + +Contributor: eibarolle +""" + +import copy +import numpy as np +from numpy.random import binomial +import torch +import torch.nn as nn +import matplotlib.pyplot as plts +# %matplotlib inline +from botorch.models.model import Model +from botorch.posteriors import GPyTorchPosterior +from botorch.acquisition.objective import PosteriorTransform +from sklearn.gaussian_process import GaussianProcessRegressor +from sklearn.gaussian_process.kernels import (RBF, Matern, RationalQuadratic, + ExpSineSquared, DotProduct, + ConstantKernel) +from typing import Callable, List, Optional, Tuple +from torch.nn import Module, ModuleDict, ModuleList +from sklearn import preprocessing +from scipy.stats import multivariate_normal +from gpytorch.distributions import MultivariateNormal + +device = torch.device("cpu") +# Account for different acquisitions + +#reference: https://chrisorm.github.io/NGP.html +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + output_dim: int, + hidden_dims: List[int], + activation: Callable = nn.Sigmoid, + init_func: Optional[Callable] = nn.init.normal_ + ) -> None: + r""" + A modular implementation of a Multilayer Perceptron (MLP). + + Args: + input_dim: An int representing the total input dimensionality. + output_dim: An int representing the total encoded dimensionality. + hidden_dims: A list of integers representing the # of units in each hidden dimension. + activation: Activation function applied between layers, defaults to nn.Sigmoid. + init_func: A function initializing the weights, defaults to nn.init.normal_. + """ + super().__init__() + layers = [] + prev_dim = input_dim + + for hidden_dim in hidden_dims: + layer = nn.Linear(prev_dim, hidden_dim) + if init_func is not None: + init_func(layer.weight) + layers.append(layer) + layers.append(activation()) + prev_dim = hidden_dim + + final_layer = nn.Linear(prev_dim, output_dim) + if init_func is not None: + init_func(final_layer.weight) + layers.append(final_layer) + self.model = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.model(x) + + +class REncoder(nn.Module): + def __init__( + self, + input_dim: int, + output_dim: int, + hidden_dims: List[int], + activation: Callable = nn.Sigmoid, + init_func: Optional[Callable] = nn.init.normal_ + ) -> None: + r"""Encodes inputs of the form (x_i,y_i) into representations, r_i. + + Args: + input_dim: An int representing the total input dimensionality. + output_dim: An int representing the total encoded dimensionality. + hidden_dims: A list of integers representing the # of units in each hidden dimension. + activation: Activation function applied between layers, defaults to nn.Sigmoid. + init_func: A function initializing the weights, defaults to nn.init.normal_. + """ + super().__init__() + self.mlp = MLP(input_dim, output_dim, hidden_dims, activation=activation, init_func=init_func) + + def forward( + self, + inputs: torch.Tensor, + ) -> torch.Tensor: + r"""Forward pass for representation encoder. + + Args: + inputs: Input tensor + + Returns: + torch.Tensor: Encoded representations + """ + return self.mlp(inputs) + +class ZEncoder(nn.Module): + def __init__(self, + input_dim: int, + output_dim: int, + hidden_dims: List[int], + activation: Callable = nn.Sigmoid, + init_func: Optional[Callable] = nn.init.normal_ + ) -> None: + r"""Takes an r representation and produces the mean & standard + deviation of the normally distributed function encoding, z. + + Args: + input_dim: An int representing r's aggregated dimensionality. + output_dim: An int representing z's latent dimensionality. + hidden_dims: A list of integers representing the # of units in each hidden dimension. + activation: Activation function applied between layers, defaults to nn.Sigmoid. + init_func: A function initializing the weights, defaults to nn.init.normal_. + """ + super().__init__() + self.mean_net = MLP(input_dim, output_dim, hidden_dims, activation=activation, init_func=init_func) + self.logvar_net = MLP(input_dim, output_dim, hidden_dims, activation=activation, init_func=init_func) + + def forward( + self, + inputs: torch.Tensor, + ) -> torch.Tensor: + r"""Forward pass for latent encoder. + + Args: + inputs: Input tensor + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - Mean of the latent Gaussian distribution. + - Log variance of the latent Gaussian distribution. + """ + return self.mean_net(inputs), self.logvar_net(inputs) + +class Decoder(torch.nn.Module): + def __init__( + self, + input_dim: int, + output_dim: int, + hidden_dims: List[int], + activation: Callable = nn.Sigmoid, + init_func: Optional[Callable] = nn.init.normal_ + ) -> None: + r"""Takes the x star points, along with a 'function encoding', z, and makes predictions. + + Args: + input_dim: An int representing the total input dimensionality. + output_dim: An int representing the total encoded dimensionality. + hidden_dims: A list of integers representing the # of units in each hidden dimension. + activation: Activation function applied between layers, defaults to nn.Sigmoid. + init_func: A function initializing the weights, defaults to nn.init.normal_. + """ + super().__init__() + self.mlp = MLP(input_dim, output_dim, hidden_dims, activation=activation, init_func=init_func) + + def forward( + self, + x_pred: torch.Tensor, + z: torch.Tensor, + ) -> torch.Tensor: + r"""Forward pass for decoder. + + Args: + x_pred: No. of data points, by x_dim + z: No. of samples, by z_dim + + Returns: + torch.Tensor: Predicted target values. + """ + z_expanded = z.unsqueeze(0).expand(x_pred.size(0), -1) + xz = torch.cat([x_pred, z_expanded], dim=-1) + return self.mlp(xz) + +def MAE( + pred: torch.Tensor, + target: torch.Tensor, +) -> torch.Tensor: + r"""Mean Absolute Error loss function. + + Args: + pred: The predicted values tensor. + target: The target values tensor. + + Returns: + torch.Tensor: A tensor representing the MAE. + """ + loss = torch.abs(pred-target) + return loss.mean() + +class NeuralProcessModel(Model): + def __init__( + self, + r_hidden_dims: List[int], + z_hidden_dims: List[int], + decoder_hidden_dims: List[int], + x_dim: int, + y_dim: int, + r_dim: int, + z_dim: int, + activation: Callable = nn.Sigmoid, + init_func: Optional[Callable] = torch.nn.init.normal_, + ) -> None: + r"""Diffusion Convolutional Recurrent Neural Network Model Implementation. + + Args: + r_hidden_dims: Hidden Dimensions/Layer list for REncoder + z_hidden_dims: Hidden Dimensions/Layer list for ZEncoder + decoder_hidden_dims: Hidden Dimensions/Layer for Decoder + x_dim: Int dimensionality of input data x. + y_dim: Int dimensionality of target data y. + r_dim: Int dimensionality of representation r. + z_dim: Int dimensionality of latent variable z. + activation: Activation function applied between layers, defaults to nn.Sigmoid. + init_func: A function initializing the weights, defaults to nn.init.normal_. + """ + super().__init__() + self.r_encoder = REncoder(x_dim+y_dim, r_dim, r_hidden_dims, activation=activation, init_func=init_func) + self.z_encoder = ZEncoder(r_dim, z_dim, z_hidden_dims, activation=activation, init_func=init_func) + self.decoder = Decoder(x_dim + z_dim, y_dim, decoder_hidden_dims, activation=activation, init_func=init_func) + self.z_dim = z_dim + self.z_mu_all = None + self.z_logvar_all = None + self.z_mu_context = None + self.z_logvar_context = None + # self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) # Look at BoTorch native versions + #self.train(n_epochs, x_train, y_train) + + def data_to_z_params( + self, + x: torch.Tensor, + y: torch.Tensor, + xy_dim: int = 1, + r_dim: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Compute latent parameters from inputs as a latent distribution. + + Args: + x: Input tensor + y: Target tensor + xy_dim: Combined Input Dimension as int, defaults as 1 + r_dim: Combined Target Dimension as int, defaults as 0. + + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + - x_c: Context input data. + - y_c: Context target data. + - x_t: Target input data. + - y_t: Target target data. + """ + xy = torch.cat([x,y], dim=xy_dim) + rs = self.r_encoder(xy) + r_agg = rs.mean(dim=r_dim) + return self.z_encoder(r_agg) + + def sample_z( + self, + mu: torch.Tensor, + logvar: torch.Tensor, + n: int = 1, + min_std: float = 0.1, + scaler: float = 0.9 + ) -> torch.Tensor: + r"""Reparameterization trick for z's latent distribution. + + Args: + mu: Tensor representing the Gaussian distribution mean. + logvar: Tensor representing the log variance of the Gaussian distribution. + n: Int representing the # of samples, defaults to 1. + min_std: Float representing the minimum possible standardized std, defaults to 0.1. + scaler: Float scaling the std, defaults to 0.9. + + Returns: + torch.Tensor: Samples from the Gaussian distribution. + """ + if min_std <= 0 or scaler <= 0: + raise ValueError() + if n == 1: + eps = torch.autograd.Variable(logvar.data.new(self.z_dim).normal_()).to(device) + else: + eps = torch.autograd.Variable(logvar.data.new(n,self.z_dim).normal_()).to(device) + + std = min_std + scaler * torch.sigmoid(logvar) + return mu + std * eps + + def KLD_gaussian( + self, + min_std: float = 0.1, + scaler: float = 0.9 + ) -> torch.Tensor: + r"""Analytical KLD between 2 Gaussian Distributions. + + Args: + min_std: Float representing the minimum possible standardized std, defaults to 0.1. + scaler: Float scaling the std, defaults to 0.9. + + Returns: + torch.Tensor: A tensor representing the KLD. + """ + + if min_std <= 0 or scaler <= 0: + raise ValueError() + std_q = min_std + scaler * torch.sigmoid(self.z_logvar_all) + std_p = min_std + scaler * torch.sigmoid(self.z_logvar_context) + p = torch.distributions.Normal(self.z_mu_context, std_p) + q = torch.distributions.Normal(self.z_mu_all, std_q) + return torch.distributions.kl_divergence(p, q).sum() + + def posterior( + self, + X: torch.Tensor, + covariance_multiplier: float, + observation_constant: float, + observation_noise: bool = False, + posterior_transform: Optional[PosteriorTransform] = None, + ) -> GPyTorchPosterior: + r"""Computes the model's posterior distribution for given input tensors. + + Args: + X: Input Tensor + covariance_multiplier: Float scaling the covariance. + observation_constant: Float representing the noise constant. + observation_noise: Adds observation noise to the covariance if True, defaults to False. + posterior_transform: An optional posterior transformation, defaults to None. + + Returns: + GPyTorchPosterior: The posterior distribution object + utilizing MultivariateNormal. + """ + mean = self.decoder(X, self.sample_z(self.z_mu_all, self.z_logvar_all)) + covariance = torch.eye(X.size(0)) * covariance_multiplier + if (observation_noise): + covariance = covariance + observation_constant + mvn = MultivariateNormal(mean, covariance) + posterior = GPyTorchPosterior(mvn) + if posterior_transform is not None: + posterior = posterior_transform(posterior) + return posterior + + def load_state_dict( + self, + state_dict: dict, + strict: bool = True + ) -> None: + """ + Initialize the fully Bayesian model before loading the state dict. + + Args: + state_dict (dict): A dictionary containing the parameters. + strict (bool): Case matching strictness. + """ + super().load_state_dict(state_dict, strict=strict) + + def transform_inputs( + self, + X: torch.Tensor, + input_transform: Optional[Module] = None, + ) -> torch.Tensor: + r"""Transform inputs. + + Args: + X: A tensor of inputs + input_transform: A Module that performs the input transformation. + + Returns: + torch.Tensor: A tensor of transformed inputs + """ + if input_transform is not None: + input_transform.to(X) + return input_transform(X) + try: + return self.input_transform(X) + except AttributeError: + return X + + def forward( + self, + x_t: torch.Tensor, + x_c: torch.Tensor, + y_c: torch.Tensor, + y_t: torch.Tensor, + input_dim: int = 0, + target_dim: int = 0 + ) -> torch.Tensor: + r"""Forward pass for the model. + + Args: + x_t: Target input data. + x_c: Context input data. + y_c: Context target data. + y_t: Target output data. + input_dim: Input dimension concatenated + target_dim: Target dimension concatendated + + Returns: + torch.Tensor: Predicted target values. + """ + if any(tensor.numel() == 0 for tensor in [x_t, x_c, y_c]): + raise ValueError() + if input_dim not in [0, 1]: + raise ValueError() + if x_c.size(1 - input_dim) != x_t.size(1 - input_dim): + raise ValueError() + if y_c.size(1 - target_dim) != y_t.size(1 - target_dim): + raise ValueError() + + self.z_mu_all, self.z_logvar_all = self.data_to_z_params(torch.cat([x_c, x_t], dim = input_dim), torch.cat([y_c, y_t], dim = target_dim)) + self.z_mu_context, self.z_logvar_context = self.data_to_z_params(x_c, y_c) + z = self.sample_z(self.z_mu_all, self.z_logvar_all) + return self.decoder(x_t, z) + + def random_split_context_target( + self, + x: torch.Tensor, + y: torch.Tensor, + n_context: int, + axis: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + r"""Helper function to split randomly into context and target. + + Args: + x: Input data tensor. + y: Target data tensor. + n_context (int): Number of context points. + axis: Dimension axis as int + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + - x_c: Context input data. + - y_c: Context target data. + - x_t: Target input data. + - y_t: Target target data. + """ + ind = np.arange(x.shape[0]) + mask = np.random.choice(ind, size=n_context, replace=False) + x_c = torch.from_numpy(x[mask]) + y_c = torch.from_numpy(y[mask]) + x_t = torch.from_numpy(np.delete(x, mask, axis=0)) + y_t = torch.from_numpy(np.delete(y, mask, axis=0)) + + return x_c, y_c, x_t, y_t + \ No newline at end of file From 5f0dba0c0f58c13f8f2eb3e3bf2d5ed4e7c3fdb8 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Sat, 18 Jan 2025 02:57:30 -0800 Subject: [PATCH 03/35] Test NP Regression --- test_community/models/test_np_regression.py | 126 ++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 test_community/models/test_np_regression.py diff --git a/test_community/models/test_np_regression.py b/test_community/models/test_np_regression.py new file mode 100644 index 0000000000..67ee410f96 --- /dev/null +++ b/test_community/models/test_np_regression.py @@ -0,0 +1,126 @@ +import unittest +import numpy as np +import torch +from torch import nn +from torch.optim import Adam +from botorch_community.models.np_regression import NeuralProcessModel +from botorch.posteriors import GPyTorchPosterior +from torch import Tensor + +class TestNeuralProcessModel(unittest.TestCase): + def initialize(self): + self.r_hidden_dims = [16, 16] + self.z_hidden_dims = [32, 32] + self.decoder_hidden_dims = [16, 16] + self.x_dim = 2 + self.y_dim = 1 + self.r_dim = 8 + self.z_dim = 8 + self.model = NeuralProcessModel( + self.r_hidden_dims, + self.z_hidden_dims, + self.decoder_hidden_dims, + self.x_dim, + self.y_dim, + self.r_dim, + self.z_dim, + ) + self.x_data = np.random.rand(100, self.x_dim) + self.y_data = np.random.rand(100, self.y_dim) + + def test_r_encoder(self): + self.initialize() + input = torch.rand(10, self.x_dim + self.y_dim) + output = self.model.r_encoder(input) + self.assertEqual(output.shape, (10, self.r_dim)) + self.assertTrue(torch.is_tensor(output)) + + def test_z_encoder(self): + self.initialize() + input = torch.rand(10, self.r_dim) + mean, logvar = self.model.z_encoder(input) + self.assertEqual(mean.shape, (10, self.z_dim)) + self.assertEqual(logvar.shape, (10, self.z_dim)) + self.assertTrue(torch.is_tensor(mean)) + self.assertTrue(torch.is_tensor(logvar)) + + def test_decoder(self): + self.initialize() + x_pred = torch.rand(10, self.x_dim) + z = torch.rand(self.z_dim) + output = self.model.decoder(x_pred, z) + self.assertEqual(output.shape, (10, self.y_dim)) + self.assertTrue(torch.is_tensor(output)) + + def test_sample_z(self): + self.initialize() + mu = torch.rand(self.z_dim) + logvar = torch.rand(self.z_dim) + samples = self.model.sample_z(mu, logvar, n=5) + self.assertEqual(samples.shape, (5, self.z_dim)) + self.assertTrue(torch.is_tensor(samples)) + + def test_KLD_gaussian(self): + self.initialize() + self.model.z_mu_all = torch.rand(self.z_dim) + self.model.z_logvar_all = torch.rand(self.z_dim) + self.model.z_mu_context = torch.rand(self.z_dim) + self.model.z_logvar_context = torch.rand(self.z_dim) + kld = self.model.KLD_gaussian() + self.assertGreaterEqual(kld.item(), 0) + self.assertTrue(torch.is_tensor(kld)) + + def test_data_to_z_params(self): + self.initialize() + x = torch.rand(10, self.x_dim) + y = torch.rand(10, self.y_dim) + mu, logvar = self.model.data_to_z_params(x, y) + self.assertEqual(mu.shape, (self.z_dim,)) + self.assertEqual(logvar.shape, (self.z_dim,)) + self.assertTrue(torch.is_tensor(mu)) + self.assertTrue(torch.is_tensor(logvar)) + + def test_forward(self): + self.initialize() + x_t = torch.rand(5, self.x_dim) + x_c = torch.rand(10, self.x_dim) + y_c = torch.rand(10, self.y_dim) + y_t = torch.rand(5, self.y_dim) + output = self.model(x_t, x_c, y_c, y_t) + self.assertEqual(output.shape, (5, self.y_dim)) + + def test_random_split_context_target(self): + self.initialize() + x_c, y_c, x_t, y_t = self.model.random_split_context_target( + self.x_data[:, 0], self.y_data, 20, 0 + ) + self.assertEqual(x_c.shape[0], 20) + self.assertEqual(y_c.shape[0], 20) + self.assertEqual(x_t.shape[0], 80) + self.assertEqual(y_t.shape[0], 80) + + def test_posterior(self): + self.initialize() + x_t = torch.rand(5, self.x_dim) + x_c = torch.rand(10, self.x_dim) + y_c = torch.rand(10, self.y_dim) + y_t = torch.rand(5, self.y_dim) + output = self.model(x_t, x_c, y_c, y_t) + posterior = self.model.posterior(x_t, 0.1, 0.01, observation_noise=True) + self.assertIsInstance(posterior, GPyTorchPosterior) + mvn = posterior.mvn + self.assertEqual(mvn.covariance_matrix.size(), (5, 5, 5)) + + def test_load_state_dict(self): + self.initialize() + state_dict = {"r_encoder.mlp.model.0.bias": torch.rand(16)} + self.model.load_state_dict(state_dict, strict = False) + + def test_transform_inputs(self): + self.initialize() + X = torch.rand(5, 3) + self.assertTrue(torch.equal(self.model.transform_inputs(X), X)) + + +if __name__ == "__main__": + unittest.main() From c55d7a9877df30a840f702f38fca4f552d0d9ed4 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Sat, 18 Jan 2025 02:59:06 -0800 Subject: [PATCH 04/35] Test Latent Information Gain --- .../test_latent_information_gain.py | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 test_community/acquisition/test_latent_information_gain.py diff --git a/test_community/acquisition/test_latent_information_gain.py b/test_community/acquisition/test_latent_information_gain.py new file mode 100644 index 0000000000..ff135bd2fa --- /dev/null +++ b/test_community/acquisition/test_latent_information_gain.py @@ -0,0 +1,91 @@ +import unittest +import torch +from torch import nn +from torch.distributions import Normal +from botorch_community.acquisition.latent_information_gain import LatentInformationGain +from botorch_community.models.np_regression import NeuralProcessModel + +class TestLatentInformationGain(unittest.TestCase): + def setUp(self): + self.x_dim = 2 + self.y_dim = 1 + self.r_dim = 8 + self.z_dim = 3 + self.r_hidden_dims = [16, 16] + self.z_hidden_dims = [32, 32] + self.decoder_hidden_dims = [16, 16] + self.num_samples = 10 + self.model = NeuralProcessModel( + r_hidden_dims = self.r_hidden_dims, + z_hidden_dims = self.z_hidden_dims, + decoder_hidden_dims = self.decoder_hidden_dims, + x_dim=self.x_dim, + y_dim=self.y_dim, + r_dim=self.r_dim, + z_dim=self.z_dim, + ) + self.acquisition_function = LatentInformationGain( + model=self.model, + num_samples=self.num_samples, + ) + self.context_x = torch.rand(10, self.x_dim) + self.context_y = torch.rand(10, self.y_dim) + self.candidate_x = torch.rand(5, self.x_dim) + + def test_initialization(self): + self.assertEqual(self.acquisition_function.num_samples, self.num_samples) + self.assertEqual(self.acquisition_function.model, self.model) + + def test_acquisition_shape(self): + lig_score = self.acquisition_function.acquisition( + candidate_x=self.candidate_x, + context_x=self.context_x, + context_y=self.context_y, + ) + self.assertTrue(torch.is_tensor(lig_score)) + self.assertEqual(lig_score.shape, ()) + + def test_acquisition_kl(self): + lig_score = self.acquisition_function.acquisition( + candidate_x=self.candidate_x, + context_x=self.context_x, + context_y=self.context_y, + ) + self.assertGreaterEqual(lig_score.item(), 0) + + def test_acquisition_samples(self): + lig_1 = self.acquisition_function.acquisition( + candidate_x=self.candidate_x, + context_x=self.context_x, + context_y=self.context_y, + ) + + self.acquisition_function.num_samples = 20 + lig_2 = self.acquisition_function.acquisition( + candidate_x=self.candidate_x, + context_x=self.context_x, + context_y=self.context_y, + ) + self.assertTrue(lig_2.item() < lig_1.item()) + self.assertTrue(abs(lig_2.item() - lig_1.item()) < 0.2) + + def test_acquisition_invalid_inputs(self): + invalid_context_x = torch.rand(10, self.x_dim + 5) + with self.assertRaises(Exception): + self.acquisition_function.acquisition( + candidate_x=self.candidate_x, + context_x=invalid_context_x, + context_y=self.context_y, + ) + + invalid_candidate_x = torch.rand(5, self.x_dim + 5) + with self.assertRaises(Exception): + self.acquisition_function.acquisition( + candidate_x=invalid_candidate_x, + context_x=self.context_x, + context_y=self.context_y, + ) + + +if __name__ == "__main__": + unittest.main() From 657151ff3d265a947ed6bf95b859021726e92463 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Sat, 18 Jan 2025 03:03:01 -0800 Subject: [PATCH 05/35] NP Regression Documentation --- docs/acquisition.md | 379 ++++++++++++++++++++++++-------------------- docs/models.md | 377 +++++++++++++++++++++---------------------- 2 files changed, 397 insertions(+), 359 deletions(-) diff --git a/docs/acquisition.md b/docs/acquisition.md index 761a79e096..dc126a2388 100644 --- a/docs/acquisition.md +++ b/docs/acquisition.md @@ -1,172 +1,207 @@ ---- -id: acquisition -title: Acquisition Functions ---- - -Acquisition functions are heuristics employed to evaluate the usefulness of one -of more design points for achieving the objective of maximizing the underlying -black box function. - -BoTorch supports both analytic as well as (quasi-) Monte-Carlo based acquisition -functions. It provides a generic -[`AcquisitionFunction`](../api/acquisition.html#acquisitionfunction) API that -abstracts away from the particular type, so that optimization can be performed -on the same objects. - - -## Monte Carlo Acquisition Functions - -Many common acquisition functions can be expressed as the expectation of some -real-valued function of the model output(s) at the design point(s): - -$$ -\alpha(X) = \mathbb{E}\bigl[ a(\xi) \mid - \xi \sim \mathbb{P}(f(X) \mid \mathcal{D}) \bigr] -$$ - -where $X = (x_1, \dotsc, x_q)$, and $\mathbb{P}(f(X) \mid \mathcal{D})$ is the -posterior distribution of the function $f$ at $X$ given the data $\mathcal{D}$ -observed so far. - -Evaluating the acquisition function thus requires evaluating an integral over -the posterior distribution. In most cases, this is analytically intractable. In -particular, analytic expressions generally do not exist for batch acquisition -functions that consider multiple design points jointly (i.e. $q > 1$). - -An alternative is to use Monte-Carlo (MC) sampling to approximate the integrals. -An MC approximation of $\alpha$ at $X$ using $N$ MC samples is - -$$ \alpha(X) \approx \frac{1}{N} \sum_{i=1}^N a(\xi_{i}) $$ - -where $\xi_i \sim \mathbb{P}(f(X) \mid \mathcal{D})$. - -For instance, for q-Expected Improvement (qEI), we have: - -$$ -\text{qEI}(X) \approx \frac{1}{N} \sum_{i=1}^N \max_{j=1,..., q} -\bigl\\{ \max(\xi_{ij} - f^\*, 0) \bigr\\}, -\qquad \xi_{i} \sim \mathbb{P}(f(X) \mid \mathcal{D}) -$$ - -where $f^\*$ is the best function value observed so far (assuming noiseless -observations). Using the reparameterization trick ([^KingmaWelling2014], -[^Rezende2014]), - -$$ -\text{qEI}(X) \approx \frac{1}{N} \sum_{i=1}^N \max_{j=1,..., q} -\bigl\\{ \max\bigl( \mu(X)\_j + (L(X) \epsilon_i)\_j - f^\*, 0 \bigr) \bigr\\}, -\qquad \epsilon_{i} \sim \mathcal{N}(0, I) -$$ - -where $\mu(X)$ is the posterior mean of $f$ at $X$, and $L(X)L(X)^T = \Sigma(X)$ -is a root decomposition of the posterior covariance matrix. - -All MC-based acquisition functions in BoTorch are derived from -[`MCAcquisitionFunction`](../api/acquisition.html#mcacquisitionfunction). - -Acquisition functions expect input tensors $X$ of shape -$\textit{batch_shape} \times q \times d$, where $d$ is the dimension of the -feature space, $q$ is the number of points considered jointly, and -$\textit{batch_shape}$ is the batch-shape of the input tensor. The output -$\alpha(X)$ will have shape $\textit{batch_shape}$, with each element -corresponding to the respective $q \times d$ batch tensor in the input $X$. -Note that for analytic acquisition functions, it must be that $q=1$. - -### MC, q-MC, and Fixed Base Samples - -BoTorch relies on the re-parameterization trick and (quasi)-Monte-Carlo sampling -for optimization and estimation of the batch acquisition functions [^Wilson2017]. -The results below show the reduced variance when estimating an expected -improvement (EI) acquisition function using base samples obtained via quasi-MC -sampling versus standard MC sampling. - -![MC_qMC](assets/EI_MC_qMC.png) - -In the plots above, the base samples used to estimate each point are resampled. -As discussed in the [Overview](./overview), a single set of base samples can be -used for optimization when the re-parameterization trick is employed. What are the -trade-offs between using a fixed set of base samples versus re-sampling on every -MC evaluation of the acquisition function? Below, we show that fixing base samples -produces functions that are potentially much easier to optimize, without resorting to -stochastic optimization methods. - -![resampling_fixed](assets/EI_resampling_fixed.png) - -If the base samples are fixed, the problem of optimizing the acquisition function -is deterministic, allowing for conventional quasi-second order methods such as -L-BFGS or sequential least-squares programming (SLSQP) to be used. These have -faster convergence rates than first-order methods and can speed up acquisition -function optimization significantly. - -One concern is that the approximated acquisition function is *biased* for any -fixed set of base samples, which may adversely affect the solution. However, we -find that in practice, both the optimal value and the optimal solution of these -biased problems for standard acquisition functions converge quite rapidly to -their true counterparts as more samples are used. Note that for evaluation of -the acquisition function we integrate over a $qo$-dimensional space (where -$q$ is the number of points in the q-batch and $o$ is the number of outputs -included in the objective). Therefore, the MC integration problem can be quite -low-dimensional even for models on high-dimensional feature spaces (large $d$). -Because using additional samples is relatively cheap computationally, -we default to 500 base samples in the MC acquisition functions. - -On the other hand, when re-sampling is used in conjunction with a stochastic -optimization algorithm, the kind of bias noted above is no longer a concern. -The trade-off here is that the optimization may be less effective, as discussed -above. - - -## Analytic Acquisition Functions - -BoTorch also provides implementations of analytic acquisition functions that -do not depend on MC sampling. These acquisition functions are subclasses of -[`AnalyticAcquisitionFunction`](../api/acquisition.html#analyticacquisitionfunction) -and only exist for the case of a single candidate point ($q = 1$). These -include classical acquisition functions such as Expected Improvement (EI), -Upper Confidence Bound (UCB), and Probability of Improvement (PI). An example -comparing [`ExpectedImprovement`](../api/acquisition.html#expectedimprovement), -the analytic version of EI, to it's MC counterpart -[`qExpectedImprovement`](../api/acquisition.html#qexpectedimprovement) -can be found in -[this tutorial](../tutorials/compare_mc_analytic_acquisition). - -Analytic acquisition functions allow for an explicit expression in terms of the -summary statistics of the posterior distribution at the evaluated point(s). -A popular acquisition function is Expected Improvement of a single point -for a Gaussian posterior, given by - -$$ \text{EI}(x) = \mathbb{E}\bigl[ -\max(y - f^\*, 0) \mid y\sim \mathcal{N}(\mu(x), \sigma^2(x)) -\bigr] $$ - -where $\mu(x)$ and $\sigma(x)$ are the posterior mean and variance of $f$ at the -point $x$, and $f^\*$ is again the best function value observed so far (assuming -noiseless observations). It can be shown that - -$$ \text{EI}(x) = \sigma(x) \bigl( z \Phi(z) + \varphi(z) \bigr)$$ - -where $z = \frac{\mu(x) - f_{\max}}{\sigma(x)}$ and $\Phi$ and $\varphi$ are -the cdf and pdf of the standard normal distribution, respectively. - -With some additional work, it is also possible to express the gradient of -the Expected Improvement with respect to the design $x$. Classic Bayesian -Optimization software will implement this gradient function explicitly, so that -it can be used for numerically optimizing the acquisition function. - -BoTorch, in contrast, harnesses PyTorch's automatic differentiation feature -("autograd") in order to obtain gradients of acquisition functions. This makes -implementing new acquisition functions much less cumbersome, as it does not -require to analytically derive gradients. All that is required is that the -operations performed in the acquisition function computation allow for the -back-propagation of gradient information through the posterior and the model. - - -[^KingmaWelling2014]: D. P. Kingma, M. Welling. Auto-Encoding Variational Bayes. -ICLR, 2013. - -[^Rezende2014]: D. J. Rezende, S. Mohamed, D. Wierstra. Stochastic -Backpropagation and Approximate Inference in Deep Generative Models. ICML, 2014. - -[^Wilson2017]: J. T. Wilson, R. Moriconi, F. Hutter, M. P. Deisenroth. -The Reparameterization Trick for Acquisition Functions. NeurIPS Workshop on -Bayesian Optimization, 2017. +--- +id: acquisition +title: Acquisition Functions +--- + +Acquisition functions are heuristics employed to evaluate the usefulness of one +of more design points for achieving the objective of maximizing the underlying +black box function. + +BoTorch supports both analytic as well as (quasi-) Monte-Carlo based acquisition +functions. It provides a generic +[`AcquisitionFunction`](../api/acquisition.html#acquisitionfunction) API that +abstracts away from the particular type, so that optimization can be performed +on the same objects. + + +## Monte Carlo Acquisition Functions + +Many common acquisition functions can be expressed as the expectation of some +real-valued function of the model output(s) at the design point(s): + +$$ +\alpha(X) = \mathbb{E}\bigl[ a(\xi) \mid + \xi \sim \mathbb{P}(f(X) \mid \mathcal{D}) \bigr] +$$ + +where $X = (x_1, \dotsc, x_q)$, and $\mathbb{P}(f(X) \mid \mathcal{D})$ is the +posterior distribution of the function $f$ at $X$ given the data $\mathcal{D}$ +observed so far. + +Evaluating the acquisition function thus requires evaluating an integral over +the posterior distribution. In most cases, this is analytically intractable. In +particular, analytic expressions generally do not exist for batch acquisition +functions that consider multiple design points jointly (i.e. $q > 1$). + +An alternative is to use Monte-Carlo (MC) sampling to approximate the integrals. +An MC approximation of $\alpha$ at $X$ using $N$ MC samples is + +$$ \alpha(X) \approx \frac{1}{N} \sum_{i=1}^N a(\xi_{i}) $$ + +where $\xi_i \sim \mathbb{P}(f(X) \mid \mathcal{D})$. + +For instance, for q-Expected Improvement (qEI), we have: + +$$ +\text{qEI}(X) \approx \frac{1}{N} \sum_{i=1}^N \max_{j=1,..., q} +\bigl\\{ \max(\xi_{ij} - f^\*, 0) \bigr\\}, +\qquad \xi_{i} \sim \mathbb{P}(f(X) \mid \mathcal{D}) +$$ + +where $f^\*$ is the best function value observed so far (assuming noiseless +observations). Using the reparameterization trick ([^KingmaWelling2014], +[^Rezende2014]), + +$$ +\text{qEI}(X) \approx \frac{1}{N} \sum_{i=1}^N \max_{j=1,..., q} +\bigl\\{ \max\bigl( \mu(X)\_j + (L(X) \epsilon_i)\_j - f^\*, 0 \bigr) \bigr\\}, +\qquad \epsilon_{i} \sim \mathcal{N}(0, I) +$$ + +where $\mu(X)$ is the posterior mean of $f$ at $X$, and $L(X)L(X)^T = \Sigma(X)$ +is a root decomposition of the posterior covariance matrix. + +All MC-based acquisition functions in BoTorch are derived from +[`MCAcquisitionFunction`](../api/acquisition.html#mcacquisitionfunction). + +Acquisition functions expect input tensors $X$ of shape +$\textit{batch_shape} \times q \times d$, where $d$ is the dimension of the +feature space, $q$ is the number of points considered jointly, and +$\textit{batch_shape}$ is the batch-shape of the input tensor. The output +$\alpha(X)$ will have shape $\textit{batch_shape}$, with each element +corresponding to the respective $q \times d$ batch tensor in the input $X$. +Note that for analytic acquisition functions, it must be that $q=1$. + +### MC, q-MC, and Fixed Base Samples + +BoTorch relies on the re-parameterization trick and (quasi)-Monte-Carlo sampling +for optimization and estimation of the batch acquisition functions [^Wilson2017]. +The results below show the reduced variance when estimating an expected +improvement (EI) acquisition function using base samples obtained via quasi-MC +sampling versus standard MC sampling. + +![MC_qMC](assets/EI_MC_qMC.png) + +In the plots above, the base samples used to estimate each point are resampled. +As discussed in the [Overview](./overview), a single set of base samples can be +used for optimization when the re-parameterization trick is employed. What are the +trade-offs between using a fixed set of base samples versus re-sampling on every +MC evaluation of the acquisition function? Below, we show that fixing base samples +produces functions that are potentially much easier to optimize, without resorting to +stochastic optimization methods. + +![resampling_fixed](assets/EI_resampling_fixed.png) + +If the base samples are fixed, the problem of optimizing the acquisition function +is deterministic, allowing for conventional quasi-second order methods such as +L-BFGS or sequential least-squares programming (SLSQP) to be used. These have +faster convergence rates than first-order methods and can speed up acquisition +function optimization significantly. + +One concern is that the approximated acquisition function is *biased* for any +fixed set of base samples, which may adversely affect the solution. However, we +find that in practice, both the optimal value and the optimal solution of these +biased problems for standard acquisition functions converge quite rapidly to +their true counterparts as more samples are used. Note that for evaluation of +the acquisition function we integrate over a $qo$-dimensional space (where +$q$ is the number of points in the q-batch and $o$ is the number of outputs +included in the objective). Therefore, the MC integration problem can be quite +low-dimensional even for models on high-dimensional feature spaces (large $d$). +Because using additional samples is relatively cheap computationally, +we default to 500 base samples in the MC acquisition functions. + +On the other hand, when re-sampling is used in conjunction with a stochastic +optimization algorithm, the kind of bias noted above is no longer a concern. +The trade-off here is that the optimization may be less effective, as discussed +above. + + +## Analytic Acquisition Functions + +BoTorch also provides implementations of analytic acquisition functions that +do not depend on MC sampling. These acquisition functions are subclasses of +[`AnalyticAcquisitionFunction`](../api/acquisition.html#analyticacquisitionfunction) +and only exist for the case of a single candidate point ($q = 1$). These +include classical acquisition functions such as Expected Improvement (EI), +Upper Confidence Bound (UCB), and Probability of Improvement (PI). An example +comparing [`ExpectedImprovement`](../api/acquisition.html#expectedimprovement), +the analytic version of EI, to it's MC counterpart +[`qExpectedImprovement`](../api/acquisition.html#qexpectedimprovement) +can be found in +[this tutorial](../tutorials/compare_mc_analytic_acquisition). + +Analytic acquisition functions allow for an explicit expression in terms of the +summary statistics of the posterior distribution at the evaluated point(s). +A popular acquisition function is Expected Improvement of a single point +for a Gaussian posterior, given by + +$$ \text{EI}(x) = \mathbb{E}\bigl[ +\max(y - f^\*, 0) \mid y\sim \mathcal{N}(\mu(x), \sigma^2(x)) +\bigr] $$ + +where $\mu(x)$ and $\sigma(x)$ are the posterior mean and variance of $f$ at the +point $x$, and $f^\*$ is again the best function value observed so far (assuming +noiseless observations). It can be shown that + +$$ \text{EI}(x) = \sigma(x) \bigl( z \Phi(z) + \varphi(z) \bigr)$$ + +where $z = \frac{\mu(x) - f_{\max}}{\sigma(x)}$ and $\Phi$ and $\varphi$ are +the cdf and pdf of the standard normal distribution, respectively. + +With some additional work, it is also possible to express the gradient of +the Expected Improvement with respect to the design $x$. Classic Bayesian +Optimization software will implement this gradient function explicitly, so that +it can be used for numerically optimizing the acquisition function. + +BoTorch, in contrast, harnesses PyTorch's automatic differentiation feature +("autograd") in order to obtain gradients of acquisition functions. This makes +implementing new acquisition functions much less cumbersome, as it does not +require to analytically derive gradients. All that is required is that the +operations performed in the acquisition function computation allow for the +back-propagation of gradient information through the posterior and the model. + + +[^KingmaWelling2014]: D. P. Kingma, M. Welling. Auto-Encoding Variational Bayes. +ICLR, 2013. + +[^Rezende2014]: D. J. Rezende, S. Mohamed, D. Wierstra. Stochastic +Backpropagation and Approximate Inference in Deep Generative Models. ICML, 2014. + +[^Wilson2017]: J. T. Wilson, R. Moriconi, F. Hutter, M. P. Deisenroth. +The Reparameterization Trick for Acquisition Functions. NeurIPS Workshop on +Bayesian Optimization, 2017. + +## Latent Information Gain + +In the high-dimensional spatiotemporal domain, Expected Information Gain becomes +less informative for useful observations, and it can be difficult to calculate +its parameters. To overcome these limitations, we propose a novel acquisition +function by computing the expected information gain in the latent space rather +than the observational space. To design this acquisition function, +we prove the equivalence between the expected information gain +in the observational space and the expected KL divergence in the +latent processes w.r.t. a candidate parameter 𝜃, as illustrated by the +following proposition. + +Proposition 1. The expected information gain (EIG) for Neural +Process is equivalent to the KL divergence between the prior and +posterior in the latent process, that is + +$$ \text{EIG}(\hat{x}_{1:T}, \theta) := \mathbb{E} \left[ H(\hat{x}_{1:T}) - +H(\hat{x}_{1:T} \mid z_{1:T}, \theta) \right] += \mathbb{E}_{p(\hat{x}_{1:T} \mid \theta)} +\text{KL} \left( p(z_{1:T} \mid \hat{x}_{1:T}, \theta) \,\|\, p(z_{1:T}) \right) +$$ + + +Inspired by this fact, we propose a novel acquisition function computing the +expected KL divergence in the latent processes and name it LIG. Specifically, +the trained NP model produces a variational posterior given the current dataset. +For every parameter $$\theta$$ remained in the search space, we can predict +$$\hat{x}_{1:T}$$ with the decoder. We use $$\hat{x}_{1:T}$$ and $$\theta$$ +as input to the encoder to re-evaluate the posterior. LIG computes the +distributional difference with respect to the latent process. +[Wu2023arxiv]: + Wu, D., Niu, R., Chinazzi, M., Vespignani, A., Ma, Y.-A., & Yu, R. (2023). + Deep Bayesian Active Learning for Accelerating Stochastic Simulation. + arXiv preprint arXiv:2106.02770. Retrieved from https://arxiv.org/abs/2106.02770 diff --git a/docs/models.md b/docs/models.md index cec82f5c46..ea63ea0ea8 100644 --- a/docs/models.md +++ b/docs/models.md @@ -1,187 +1,190 @@ ---- -id: models -title: Models ---- - -Models play an essential role in Bayesian Optimization (BO). A model is used as -a surrogate function for the actual underlying black box function to be -optimized. In BoTorch, a `Model` maps a set of design points to a posterior -probability distribution of its output(s) over the design points. - -In BO, the model used is traditionally a Gaussian Process (GP), in which case -the posterior distribution is a multivariate normal. While BoTorch supports many -GP models, **BoTorch makes no assumption on the model being a GP** or the -posterior being multivariate normal. With the exception of some of the analytic -acquisition functions in the -[`botorch.acquisition.analytic`](../api/acquisition.html#analytic-acquisition-function-api) -module, BoTorch’s Monte Carlo-based acquisition functions are compatible with -any model that conforms to the `Model` interface, whether user-implemented or -provided. - -Under the hood, BoTorch models are PyTorch `Modules` that implement the -light-weight [`Model`](../api/models.html#model-apis) interface. When working -with GPs, -[`GPyTorchModel`](../api/models.html#module-botorch.models.gp_regression) -provides a base class for conveniently wrapping GPyTorch models. - -Users can extend `Model` and `GPyTorchModel` to generate their own models. For -more on implementing your own models, see -[Implementing Custom Models](#implementing-custom-models) below. - -## Terminology - -### Multi-Output and Multi-Task - -A `Model` (as in the BoTorch object) may have multiple outputs, multiple inputs, -and may exploit correlation between different inputs. BoTorch uses the following -terminology to distinguish these model types: - -- _Multi-Output Model_: a `Model` with multiple outputs. Most BoTorch `Model`s - are multi-output. -- _Multi-Task Model_: a `Model` making use of a logical grouping of - inputs/observations (as in the underlying process). For example, there could - be multiple tasks where each task has a different fidelity. In a multi-task - model, the relationship between different outputs is modeled, with a joint - model across tasks. - -Note the following: - -- A multi-task (MT) model may or may not be a multi-output model. For example, - if a multi-task model uses different tasks for modeling but only outputs - predictions for one of those tasks, it is single-output. -- Conversely, a multi-output (MO) model may or may not be a multi-task model. - For example, multi-output `Model`s that model different outputs independently - rather than building a joint model are not multi-task. -- If a model is both, we refer to it as a multi-task-multi-output (MTMO) model. - -### Noise: Homoskedastic, fixed, and heteroskedastic - -Noise can be treated in several different ways: - -- _Homoskedastic_: Noise is not provided as an input and is inferred, with a - constant variance that does not depend on `X`. Many models, such as - `SingleTaskGP`, take this approach. Use these models if you know that your - observations are noisy, but not how noisy. - -- _Fixed_: Noise is provided as an input, `train_Yvar`, and is not fit. In - “fixed noise” models like `SingleTaskGP` with noise observations, noise cannot - be predicted out-of-sample because it has not been modeled. Use these models - if you have estimates of the noise in your observations (e.g. observations may - be averages over individual samples in which case you would provide the mean - as observation and the standard error of the mean as the noise estimate), or - if you know your observations are noiseless (by passing a zero noise level). - -- _Heteroskedastic_: Noise is provided as an input and is modeled to allow for - predicting noise out-of-sample. BoTorch does not implement a model that - supports this out of the box. - -## Standard BoTorch Models - -BoTorch provides several GPyTorch models to cover most standard BO use cases: - -### Single-Task GPs - -These models use the same training data for all outputs and assume conditional -independence of the outputs given the input. If different training data is -required for each output, use a -[`ModelListGP`](../api/models.html#module-botorch.models.model_list_gp_regression) -instead. - -- [`SingleTaskGP`](../api/models.html#botorch.models.gp_regression.SingleTaskGP): - a single-task exact GP that supports both inferred and observed noise. When - noise observations are not provided, it infers a homoskedastic noise level. -- [`MixedSingleTaskGP`](../api/models.html#botorch.models.gp_regression_mixed.MixedSingleTaskGP): - a single-task exact GP that supports mixed search spaces, which combine - discrete and continuous features. -- [`SaasFullyBayesianSingleTaskGP`](../api/models.html#botorch.models.fully_bayesian.SaasFullyBayesianSingleTaskGP): - a fully Bayesian single-task GP with the SAAS prior. This model is suitable - for sample-efficient high-dimensional Bayesian optimization. - -### Model List of Single-Task GPs - -- [`ModelListGP`](../api/models.html#module-botorch.models.model_list_gp_regression): - A multi-output model in which outcomes are modeled independently, given a list - of any type of single-task GP. This model should be used when the same - training data is not used for all outputs. - -### Multi-Task GPs - -- [`MultiTaskGP`](../api/models.html#module-botorch.models.multitask): a - Hadamard multi-task, multi-output GP using an ICM kernel. Supports both known - observation noise levels and inferring a homoskedastic noise level (when noise - observations are not provided). -- [`KroneckerMultiTaskGP`](../api/models.html#botorch.models.multitask.KroneckerMultiTaskGP): - A multi-task, multi-output GP using an ICM kernel, with Kronecker structure. - Useful for multi-fidelity optimization. -- [`SaasFullyBayesianMultiTaskGP`](../api/models.html#saasfullybayesianmultitaskgp): - a fully Bayesian multi-task GP using an ICM kernel. The data kernel uses the - SAAS prior to model high-dimensional parameter spaces. - -All of the above models use RBF kernels with Automatic Relevance Discovery -(ARD), and have reasonable priors on hyperparameters that make them work well in -settings where the **input features are normalized to the unit cube** and the -**observations are standardized** (zero mean, unit variance). The lengthscale -priors scale with the input dimension, which makes them adaptable to both low -and high dimensional problems. See -[this discussion](https://github.com/pytorch/botorch/discussions/2451) for -additional context on the default hyperparameters. - -## Other useful models - -- [`ModelList`](../api/models.html#botorch.models.model.ModelList): a - multi-output model container in which outcomes are modeled independently by - individual `Model`s (as in `ModelListGP`, but the component models do not all - need to be GPyTorch models). -- [`SingleTaskMultiFidelityGP`](../api/models.html#botorch.models.gp_regression_fidelity.SingleTaskMultiFidelityGP): - A GP model for multi-fidelity optimization. For more on Multi-Fidelity BO, see - the [tutorial](../tutorials/discrete_multi_fidelity_bo). -- [`HigherOrderGP`](../api/models.html#botorch.models.higher_order_gp.HigherOrderGP): - A GP model with matrix-valued predictions, such as images or grids of images. -- [`PairwiseGP`](../api/models.html#module-botorch.models.pairwise_gp): A - probit-likelihood GP that learns via pairwise comparison data, useful for - preference learning. -- [`ApproximateGPyTorchModel`](../api/models.html#botorch.models.approximate_gp.ApproximateGPyTorchModel): - for efficient computation when data is large or responses are non-Gaussian. -- [Deterministic models](../api/models.html#module-botorch.models.deterministic), - such as - [`AffineDeterministicModel`](../api/models.html#botorch.models.deterministic.AffineDeterministicModel), - [`AffineFidelityCostModel`](../api/models.html#botorch.models.cost.AffineFidelityCostModel), - [`GenericDeterministicModel`](../api/models.html#botorch.models.deterministic.GenericDeterministicModel), - and - [`PosteriorMeanModel`](../api/models.html#botorch.models.deterministic.PosteriorMeanModel) - express known input-output relationships; they conform to the BoTorch `Model` - API, so they can easily be used in conjunction with other BoTorch models. - Deterministic models are useful for multi-objective optimization with known - objective functions and for encoding cost functions for cost-aware - acquisition. -- [`SingleTaskVariationalGP`](../api/models.html#botorch.models.approximate_gp.SingleTaskVariationalGP): - an approximate model for faster computation when you have a lot of data or - your responses are non-Gaussian. - -## Implementing Custom Models - -The configurability of the above models is limited (for instance, it is not -straightforward to use a different kernel). Doing so is an intentional design -decision -- we believe that having a few simple and easy-to-understand models -for basic use cases is more valuable than having a highly complex and -configurable model class whose implementation is difficult to understand. - -Instead, we advocate that users implement their own models to cover more -specialized use cases. The light-weight nature of BoTorch's Model API makes this -easy to do. See the -[Using a custom BoTorch model in Ax](../tutorials/custom_botorch_model_in_ax) -tutorial for an example. - -The BoTorch `Model` interface is light-weight and easy to extend. The only -requirement for using BoTorch's Monte-Carlo based acquisition functions is that -the model has a `posterior` method. It takes in a Tensor `X` of design points, -and returns a Posterior object describing the (joint) probability distribution -of the model output(s) over the design points in `X`. The `Posterior` object -must implement an `rsample()` method for sampling from the posterior of the -model. If you wish to use gradient-based optimization algorithms, the model -should allow back-propagating gradients through the samples to the model input. - -If you happen to implement a model that would be useful for other researchers as -well (and involves more than just swapping out the RBF kernel for a Matérn -kernel), please consider [contributing](getting_started#contributing) this model -to BoTorch. +--- +id: models +title: Models +--- + +Models play an essential role in Bayesian Optimization (BO). A model is used as +a surrogate function for the actual underlying black box function to be +optimized. In BoTorch, a `Model` maps a set of design points to a posterior +probability distribution of its output(s) over the design points. + +In BO, the model used is traditionally a Gaussian Process (GP), in which case +the posterior distribution is a multivariate normal. While BoTorch supports many +GP models, **BoTorch makes no assumption on the model being a GP** or the +posterior being multivariate normal. With the exception of some of the analytic +acquisition functions in the +[`botorch.acquisition.analytic`](../api/acquisition.html#analytic-acquisition-function-api) +module, BoTorch’s Monte Carlo-based acquisition functions are compatible with +any model that conforms to the `Model` interface, whether user-implemented or +provided. + +Under the hood, BoTorch models are PyTorch `Modules` that implement the +light-weight [`Model`](../api/models.html#model-apis) interface. When working +with GPs, +[`GPyTorchModel`](../api/models.html#module-botorch.models.gp_regression) +provides a base class for conveniently wrapping GPyTorch models. + +Users can extend `Model` and `GPyTorchModel` to generate their own models. For +more on implementing your own models, see +[Implementing Custom Models](#implementing-custom-models) below. + +## Terminology + +### Multi-Output and Multi-Task + +A `Model` (as in the BoTorch object) may have multiple outputs, multiple inputs, +and may exploit correlation between different inputs. BoTorch uses the following +terminology to distinguish these model types: + +- _Multi-Output Model_: a `Model` with multiple outputs. Most BoTorch `Model`s + are multi-output. +- _Multi-Task Model_: a `Model` making use of a logical grouping of + inputs/observations (as in the underlying process). For example, there could + be multiple tasks where each task has a different fidelity. In a multi-task + model, the relationship between different outputs is modeled, with a joint + model across tasks. + +Note the following: + +- A multi-task (MT) model may or may not be a multi-output model. For example, + if a multi-task model uses different tasks for modeling but only outputs + predictions for one of those tasks, it is single-output. +- Conversely, a multi-output (MO) model may or may not be a multi-task model. + For example, multi-output `Model`s that model different outputs independently + rather than building a joint model are not multi-task. +- If a model is both, we refer to it as a multi-task-multi-output (MTMO) model. + +### Noise: Homoskedastic, fixed, and heteroskedastic + +Noise can be treated in several different ways: + +- _Homoskedastic_: Noise is not provided as an input and is inferred, with a + constant variance that does not depend on `X`. Many models, such as + `SingleTaskGP`, take this approach. Use these models if you know that your + observations are noisy, but not how noisy. + +- _Fixed_: Noise is provided as an input, `train_Yvar`, and is not fit. In + “fixed noise” models like `SingleTaskGP` with noise observations, noise cannot + be predicted out-of-sample because it has not been modeled. Use these models + if you have estimates of the noise in your observations (e.g. observations may + be averages over individual samples in which case you would provide the mean + as observation and the standard error of the mean as the noise estimate), or + if you know your observations are noiseless (by passing a zero noise level). + +- _Heteroskedastic_: Noise is provided as an input and is modeled to allow for + predicting noise out-of-sample. BoTorch does not implement a model that + supports this out of the box. + +## Standard BoTorch Models + +BoTorch provides several GPyTorch models to cover most standard BO use cases: + +### Single-Task GPs + +These models use the same training data for all outputs and assume conditional +independence of the outputs given the input. If different training data is +required for each output, use a +[`ModelListGP`](../api/models.html#module-botorch.models.model_list_gp_regression) +instead. + +- [`SingleTaskGP`](../api/models.html#botorch.models.gp_regression.SingleTaskGP): + a single-task exact GP that supports both inferred and observed noise. When + noise observations are not provided, it infers a homoskedastic noise level. +- [`MixedSingleTaskGP`](../api/models.html#botorch.models.gp_regression_mixed.MixedSingleTaskGP): + a single-task exact GP that supports mixed search spaces, which combine + discrete and continuous features. +- [`SaasFullyBayesianSingleTaskGP`](../api/models.html#botorch.models.fully_bayesian.SaasFullyBayesianSingleTaskGP): + a fully Bayesian single-task GP with the SAAS prior. This model is suitable + for sample-efficient high-dimensional Bayesian optimization. + +### Model List of Single-Task GPs + +- [`ModelListGP`](../api/models.html#module-botorch.models.model_list_gp_regression): + A multi-output model in which outcomes are modeled independently, given a list + of any type of single-task GP. This model should be used when the same + training data is not used for all outputs. + +### Multi-Task GPs + +- [`MultiTaskGP`](../api/models.html#module-botorch.models.multitask): a + Hadamard multi-task, multi-output GP using an ICM kernel. Supports both known + observation noise levels and inferring a homoskedastic noise level (when noise + observations are not provided). +- [`KroneckerMultiTaskGP`](../api/models.html#botorch.models.multitask.KroneckerMultiTaskGP): + A multi-task, multi-output GP using an ICM kernel, with Kronecker structure. + Useful for multi-fidelity optimization. +- [`SaasFullyBayesianMultiTaskGP`](../api/models.html#saasfullybayesianmultitaskgp): + a fully Bayesian multi-task GP using an ICM kernel. The data kernel uses the + SAAS prior to model high-dimensional parameter spaces. + +All of the above models use RBF kernels with Automatic Relevance Discovery +(ARD), and have reasonable priors on hyperparameters that make them work well in +settings where the **input features are normalized to the unit cube** and the +**observations are standardized** (zero mean, unit variance). The lengthscale +priors scale with the input dimension, which makes them adaptable to both low +and high dimensional problems. See +[this discussion](https://github.com/pytorch/botorch/discussions/2451) for +additional context on the default hyperparameters. + +## Other useful models + +- [`ModelList`](../api/models.html#botorch.models.model.ModelList): a + multi-output model container in which outcomes are modeled independently by + individual `Model`s (as in `ModelListGP`, but the component models do not all + need to be GPyTorch models). +- [`SingleTaskMultiFidelityGP`](../api/models.html#botorch.models.gp_regression_fidelity.SingleTaskMultiFidelityGP): + A GP model for multi-fidelity optimization. For more on Multi-Fidelity BO, see + the [tutorial](../tutorials/discrete_multi_fidelity_bo). +- [`HigherOrderGP`](../api/models.html#botorch.models.higher_order_gp.HigherOrderGP): + A GP model with matrix-valued predictions, such as images or grids of images. +- [`PairwiseGP`](../api/models.html#module-botorch.models.pairwise_gp): A + probit-likelihood GP that learns via pairwise comparison data, useful for + preference learning. +- [`ApproximateGPyTorchModel`](../api/models.html#botorch.models.approximate_gp.ApproximateGPyTorchModel): + for efficient computation when data is large or responses are non-Gaussian. +- [Deterministic models](../api/models.html#module-botorch.models.deterministic), + such as + [`AffineDeterministicModel`](../api/models.html#botorch.models.deterministic.AffineDeterministicModel), + [`AffineFidelityCostModel`](../api/models.html#botorch.models.cost.AffineFidelityCostModel), + [`GenericDeterministicModel`](../api/models.html#botorch.models.deterministic.GenericDeterministicModel), + and + [`PosteriorMeanModel`](../api/models.html#botorch.models.deterministic.PosteriorMeanModel) + express known input-output relationships; they conform to the BoTorch `Model` + API, so they can easily be used in conjunction with other BoTorch models. + Deterministic models are useful for multi-objective optimization with known + objective functions and for encoding cost functions for cost-aware + acquisition. +- [`SingleTaskVariationalGP`](../api/models.html#botorch.models.approximate_gp.SingleTaskVariationalGP): + an approximate model for faster computation when you have a lot of data or + your responses are non-Gaussian. +- [`NeuralProcessModel`](../api/models.html#botorch_community.models.np_regression.NeuralProcessModel): + A NP Model utilizing a novel acquisition function computing the expected KL + Divergence in the latent processes. + +## Implementing Custom Models + +The configurability of the above models is limited (for instance, it is not +straightforward to use a different kernel). Doing so is an intentional design +decision -- we believe that having a few simple and easy-to-understand models +for basic use cases is more valuable than having a highly complex and +configurable model class whose implementation is difficult to understand. + +Instead, we advocate that users implement their own models to cover more +specialized use cases. The light-weight nature of BoTorch's Model API makes this +easy to do. See the +[Using a custom BoTorch model in Ax](../tutorials/custom_botorch_model_in_ax) +tutorial for an example. + +The BoTorch `Model` interface is light-weight and easy to extend. The only +requirement for using BoTorch's Monte-Carlo based acquisition functions is that +the model has a `posterior` method. It takes in a Tensor `X` of design points, +and returns a Posterior object describing the (joint) probability distribution +of the model output(s) over the design points in `X`. The `Posterior` object +must implement an `rsample()` method for sampling from the posterior of the +model. If you wish to use gradient-based optimization algorithms, the model +should allow back-propagating gradients through the samples to the model input. + +If you happen to implement a model that would be useful for other researchers as +well (and involves more than just swapping out the RBF kernel for a Matérn +kernel), please consider [contributing](getting_started#contributing) this model +to BoTorch. From 4f35e0f74403627b130baf78fd54aa09a7d0bb0f Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Sat, 25 Jan 2025 06:46:41 -0800 Subject: [PATCH 06/35] 1/25 Updates --- .../acquisition/latent_information_gain.py | 33 ++++++++++++------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/botorch_community/acquisition/latent_information_gain.py b/botorch_community/acquisition/latent_information_gain.py index 5cae0fc0f3..0b000053c6 100644 --- a/botorch_community/acquisition/latent_information_gain.py +++ b/botorch_community/acquisition/latent_information_gain.py @@ -23,51 +23,60 @@ from typing import Optional import torch -from botorch import settings +from botorch.acquisition import AcquisitionFunction from botorch_community.models.np_regression import NeuralProcessModel from torch import Tensor import torch #reference: https://arxiv.org/abs/2106.02770 -class LatentInformationGain: +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class LatentInformationGain(AcquisitionFunction): def __init__( self, + context_x: torch.Tensor, + context_y: torch.Tensor, model: NeuralProcessModel, num_samples: int = 10, - min_std: float = 0.1, - scaler: float = 0.9 + min_std: float = 0.01, + scaler: float = 0.5 ) -> None: """ Latent Information Gain (LIG) Acquisition Function, designed for the - NeuralProcessModel. + NeuralProcessModel. This is a subclass of AcquisitionFunction. Args: model: Trained NeuralProcessModel. + context_x: Context input points, as a Tensor. + context_y: Context target points, as a Tensor. num_samples (int): Number of samples for calculation, defaults to 10. min_std: Float representing the minimum possible standardized std, defaults to 0.1. scaler: Float scaling the std, defaults to 0.9. """ - self.model = model + super().__init__(model=model) + self.model = model.to(device) self.num_samples = num_samples self.min_std = min_std self.scaler = scaler + self.context_x = context_x.to(device) + self.context_y = context_y.to(device) - def acquisition(self, candidate_x, context_x, context_y): + def forward(self, candidate_x): """ Conduct the Latent Information Gain acquisition function for the inputs. Args: candidate_x: Candidate input points, as a Tensor. - context_x: Context input points, as a Tensor. - context_y: Context target points, as a Tensor. Returns: torch.Tensor: The LIG score of computed KLDs. """ + candidate_x = candidate_x.to(device) + # Encoding and Scaling the context data - z_mu_context, z_logvar_context = self.model.data_to_z_params(context_x, context_y) + z_mu_context, z_logvar_context = self.model.data_to_z_params(self.context_x, self.context_y) kl = 0.0 for _ in range(self.num_samples): # Taking reparameterized samples @@ -77,8 +86,8 @@ def acquisition(self, candidate_x, context_x, context_y): y_pred = self.model.decoder(candidate_x, samples) # Combining context and candidate data - combined_x = torch.cat([context_x, candidate_x], dim=0) - combined_y = torch.cat([context_y, y_pred], dim=0) + combined_x = torch.cat([self.context_x, candidate_x], dim=0).to(device) + combined_y = torch.cat([self.context_y, y_pred], dim=0).to(device) # Computing posterior variables z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params(combined_x, combined_y) From 280776dda4d06aee20e83a58028796b67d38a66f Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Sat, 25 Jan 2025 06:47:07 -0800 Subject: [PATCH 07/35] 1/25 Updates --- botorch_community/models/np_regression.py | 132 +++++++++------------- 1 file changed, 55 insertions(+), 77 deletions(-) diff --git a/botorch_community/models/np_regression.py b/botorch_community/models/np_regression.py index 955da2315b..bce3495948 100644 --- a/botorch_community/models/np_regression.py +++ b/botorch_community/models/np_regression.py @@ -11,27 +11,16 @@ Contributor: eibarolle """ -import copy -import numpy as np -from numpy.random import binomial import torch import torch.nn as nn -import matplotlib.pyplot as plts -# %matplotlib inline from botorch.models.model import Model from botorch.posteriors import GPyTorchPosterior from botorch.acquisition.objective import PosteriorTransform -from sklearn.gaussian_process import GaussianProcessRegressor -from sklearn.gaussian_process.kernels import (RBF, Matern, RationalQuadratic, - ExpSineSquared, DotProduct, - ConstantKernel) from typing import Callable, List, Optional, Tuple -from torch.nn import Module, ModuleDict, ModuleList -from sklearn import preprocessing -from scipy.stats import multivariate_normal +from torch.nn import Module from gpytorch.distributions import MultivariateNormal -device = torch.device("cpu") +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Account for different acquisitions #reference: https://chrisorm.github.io/NGP.html @@ -59,21 +48,21 @@ def __init__( prev_dim = input_dim for hidden_dim in hidden_dims: - layer = nn.Linear(prev_dim, hidden_dim) + layer = nn.Linear(prev_dim, hidden_dim).to(device) if init_func is not None: init_func(layer.weight) layers.append(layer) layers.append(activation()) prev_dim = hidden_dim - final_layer = nn.Linear(prev_dim, output_dim) + final_layer = nn.Linear(prev_dim, output_dim).to(device) if init_func is not None: init_func(final_layer.weight) layers.append(final_layer) - self.model = nn.Sequential(*layers) + self.model = nn.Sequential(*layers).to(device) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.model(x) + return self.model(x.to(device)) class REncoder(nn.Module): @@ -95,12 +84,9 @@ def __init__( init_func: A function initializing the weights, defaults to nn.init.normal_. """ super().__init__() - self.mlp = MLP(input_dim, output_dim, hidden_dims, activation=activation, init_func=init_func) + self.mlp = MLP(input_dim=input_dim, output_dim=output_dim, hidden_dims=hidden_dims, activation=activation, init_func=init_func).to(device) - def forward( - self, - inputs: torch.Tensor, - ) -> torch.Tensor: + def forward(self, inputs: torch.Tensor) -> torch.Tensor: r"""Forward pass for representation encoder. Args: @@ -109,7 +95,7 @@ def forward( Returns: torch.Tensor: Encoded representations """ - return self.mlp(inputs) + return self.mlp(inputs.to(device)) class ZEncoder(nn.Module): def __init__(self, @@ -130,13 +116,10 @@ def __init__(self, init_func: A function initializing the weights, defaults to nn.init.normal_. """ super().__init__() - self.mean_net = MLP(input_dim, output_dim, hidden_dims, activation=activation, init_func=init_func) - self.logvar_net = MLP(input_dim, output_dim, hidden_dims, activation=activation, init_func=init_func) + self.mean_net = MLP(input_dim=input_dim, output_dim=output_dim, hidden_dims=hidden_dims, activation=activation, init_func=init_func).to(device) + self.logvar_net = MLP(input_dim=input_dim, output_dim=output_dim, hidden_dims=hidden_dims, activation=activation, init_func=init_func).to(device) - def forward( - self, - inputs: torch.Tensor, - ) -> torch.Tensor: + def forward(self, inputs: torch.Tensor) -> torch.Tensor: r"""Forward pass for latent encoder. Args: @@ -147,6 +130,7 @@ def forward( - Mean of the latent Gaussian distribution. - Log variance of the latent Gaussian distribution. """ + inputs = inputs.to(device) return self.mean_net(inputs), self.logvar_net(inputs) class Decoder(torch.nn.Module): @@ -168,23 +152,21 @@ def __init__( init_func: A function initializing the weights, defaults to nn.init.normal_. """ super().__init__() - self.mlp = MLP(input_dim, output_dim, hidden_dims, activation=activation, init_func=init_func) + self.mlp = MLP(input_dim=input_dim, output_dim=output_dim, hidden_dims=hidden_dims, activation=activation, init_func=init_func).to(device) - def forward( - self, - x_pred: torch.Tensor, - z: torch.Tensor, - ) -> torch.Tensor: + def forward(self, x_pred: torch.Tensor, z: torch.Tensor) -> torch.Tensor: r"""Forward pass for decoder. Args: - x_pred: No. of data points, by x_dim - z: No. of samples, by z_dim + x_pred: Input points of shape (n x d_x), representing # of data points by x_dim. + z: Latent encoding of shape (num_samples x d_z), representing # of samples by z_dim. Returns: - torch.Tensor: Predicted target values. + torch.Tensor: Predicted target values of shape (n, z_dim), representing # of data points by z_dim. """ - z_expanded = z.unsqueeze(0).expand(x_pred.size(0), -1) + z = z.to(device) + z_expanded = z.unsqueeze(0).expand(x_pred.size(0), -1).to(device) + x_pred = x_pred.to(device) xz = torch.cat([x_pred, z_expanded], dim=-1) return self.mlp(xz) @@ -231,16 +213,14 @@ def __init__( init_func: A function initializing the weights, defaults to nn.init.normal_. """ super().__init__() - self.r_encoder = REncoder(x_dim+y_dim, r_dim, r_hidden_dims, activation=activation, init_func=init_func) - self.z_encoder = ZEncoder(r_dim, z_dim, z_hidden_dims, activation=activation, init_func=init_func) - self.decoder = Decoder(x_dim + z_dim, y_dim, decoder_hidden_dims, activation=activation, init_func=init_func) + self.r_encoder = REncoder(x_dim+y_dim, r_dim, r_hidden_dims, activation=activation, init_func=init_func).to(device) + self.z_encoder = ZEncoder(r_dim, z_dim, z_hidden_dims, activation=activation, init_func=init_func).to(device) + self.decoder = Decoder(x_dim + z_dim, y_dim, decoder_hidden_dims, activation=activation, init_func=init_func).to(device) self.z_dim = z_dim self.z_mu_all = None self.z_logvar_all = None self.z_mu_context = None self.z_logvar_context = None - # self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) # Look at BoTorch native versions - #self.train(n_epochs, x_train, y_train) def data_to_z_params( self, @@ -264,9 +244,11 @@ def data_to_z_params( - x_t: Target input data. - y_t: Target target data. """ - xy = torch.cat([x,y], dim=xy_dim) + x = x.to(device) + y = y.to(device) + xy = torch.cat([x,y], dim=xy_dim).to(device).to(device) rs = self.r_encoder(xy) - r_agg = rs.mean(dim=r_dim) + r_agg = rs.mean(dim=r_dim).to(device) return self.z_encoder(r_agg) def sample_z( @@ -274,8 +256,8 @@ def sample_z( mu: torch.Tensor, logvar: torch.Tensor, n: int = 1, - min_std: float = 0.1, - scaler: float = 0.9 + min_std: float = 0.01, + scaler: float = 0.5 ) -> torch.Tensor: r"""Reparameterization trick for z's latent distribution. @@ -291,12 +273,15 @@ def sample_z( """ if min_std <= 0 or scaler <= 0: raise ValueError() + + shape = [n, self.z_dim] if n == 1: - eps = torch.autograd.Variable(logvar.data.new(self.z_dim).normal_()).to(device) - else: - eps = torch.autograd.Variable(logvar.data.new(n,self.z_dim).normal_()).to(device) + shape = shape[1:] + eps = torch.autograd.Variable(logvar.data.new(*shape).normal_()).to(device) std = min_std + scaler * torch.sigmoid(logvar) + std = std.to(device) + mu = mu.to(device) return mu + std * eps def KLD_gaussian( @@ -316,10 +301,10 @@ def KLD_gaussian( if min_std <= 0 or scaler <= 0: raise ValueError() - std_q = min_std + scaler * torch.sigmoid(self.z_logvar_all) - std_p = min_std + scaler * torch.sigmoid(self.z_logvar_context) - p = torch.distributions.Normal(self.z_mu_context, std_p) - q = torch.distributions.Normal(self.z_mu_all, std_q) + std_q = min_std + scaler * torch.sigmoid(self.z_logvar_all).to(device) + std_p = min_std + scaler * torch.sigmoid(self.z_logvar_context).to(device) + p = torch.distributions.Normal(self.z_mu_context.to(device), std_p) + q = torch.distributions.Normal(self.z_mu_all.to(device), std_q) return torch.distributions.kl_divergence(p, q).sum() def posterior( @@ -343,7 +328,8 @@ def posterior( GPyTorchPosterior: The posterior distribution object utilizing MultivariateNormal. """ - mean = self.decoder(X, self.sample_z(self.z_mu_all, self.z_logvar_all)) + X = X.to(device) + mean = self.decoder(X.to(device), self.sample_z(self.z_mu_all, self.z_logvar_all)) covariance = torch.eye(X.size(0)) * covariance_multiplier if (observation_noise): covariance = covariance + observation_constant @@ -352,20 +338,6 @@ def posterior( if posterior_transform is not None: posterior = posterior_transform(posterior) return posterior - - def load_state_dict( - self, - state_dict: dict, - strict: bool = True - ) -> None: - """ - Initialize the fully Bayesian model before loading the state dict. - - Args: - state_dict (dict): A dictionary containing the parameters. - strict (bool): Case matching strictness. - """ - super().load_state_dict(state_dict, strict=strict) def transform_inputs( self, @@ -381,6 +353,7 @@ def transform_inputs( Returns: torch.Tensor: A tensor of transformed inputs """ + X = X.to(device) if input_transform is not None: input_transform.to(X) return input_transform(X) @@ -420,6 +393,11 @@ def forward( if y_c.size(1 - target_dim) != y_t.size(1 - target_dim): raise ValueError() + x_t = x_t.to(device) + x_c = x_c.to(device) + y_c = y_c.to(device) + y_t = y_t.to(device) + self.z_mu_all, self.z_logvar_all = self.data_to_z_params(torch.cat([x_c, x_t], dim = input_dim), torch.cat([y_c, y_t], dim = target_dim)) self.z_mu_context, self.z_logvar_context = self.data_to_z_params(x_c, y_c) z = self.sample_z(self.z_mu_all, self.z_logvar_all) @@ -447,12 +425,12 @@ def random_split_context_target( - x_t: Target input data. - y_t: Target target data. """ - ind = np.arange(x.shape[0]) - mask = np.random.choice(ind, size=n_context, replace=False) - x_c = torch.from_numpy(x[mask]) - y_c = torch.from_numpy(y[mask]) - x_t = torch.from_numpy(np.delete(x, mask, axis=0)) - y_t = torch.from_numpy(np.delete(y, mask, axis=0)) - + mask = torch.randperm(x.shape[0])[:n_context] + x_c = torch.from_numpy(x[mask]).to(device) + y_c = torch.from_numpy(y[mask]).to(device) + splitter = torch.zeros(x.shape[0], dtype=torch.bool) + splitter[mask] = True + x_t = torch.from_numpy(x[~splitter]).to(device) + y_t = torch.from_numpy(y[~splitter]).to(device) return x_c, y_c, x_t, y_t \ No newline at end of file From a8114292880cf96fa2a52b11b22636911e4aec58 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Sat, 25 Jan 2025 06:48:43 -0800 Subject: [PATCH 08/35] 1/25 Updates --- .../test_latent_information_gain.py | 48 ++++--------------- 1 file changed, 10 insertions(+), 38 deletions(-) diff --git a/test_community/acquisition/test_latent_information_gain.py b/test_community/acquisition/test_latent_information_gain.py index ff135bd2fa..856ef65efd 100644 --- a/test_community/acquisition/test_latent_information_gain.py +++ b/test_community/acquisition/test_latent_information_gain.py @@ -1,7 +1,5 @@ import unittest import torch -from torch import nn -from torch.distributions import Normal from botorch_community.acquisition.latent_information_gain import LatentInformationGain from botorch_community.models.np_regression import NeuralProcessModel @@ -11,6 +9,8 @@ def setUp(self): self.y_dim = 1 self.r_dim = 8 self.z_dim = 3 + self.context_x = torch.rand(10, self.x_dim) + self.context_y = torch.rand(10, self.y_dim) self.r_hidden_dims = [16, 16] self.z_hidden_dims = [32, 32] self.decoder_hidden_dims = [16, 16] @@ -25,11 +25,11 @@ def setUp(self): z_dim=self.z_dim, ) self.acquisition_function = LatentInformationGain( + context_x=self.context_x, + context_y=self.context_y, model=self.model, num_samples=self.num_samples, ) - self.context_x = torch.rand(10, self.x_dim) - self.context_y = torch.rand(10, self.y_dim) self.candidate_x = torch.rand(5, self.x_dim) def test_initialization(self): @@ -37,54 +37,26 @@ def test_initialization(self): self.assertEqual(self.acquisition_function.model, self.model) def test_acquisition_shape(self): - lig_score = self.acquisition_function.acquisition( - candidate_x=self.candidate_x, - context_x=self.context_x, - context_y=self.context_y, + lig_score = self.acquisition_function.forward( + candidate_x=self.candidate_x ) self.assertTrue(torch.is_tensor(lig_score)) self.assertEqual(lig_score.shape, ()) def test_acquisition_kl(self): - lig_score = self.acquisition_function.acquisition( - candidate_x=self.candidate_x, - context_x=self.context_x, - context_y=self.context_y, + lig_score = self.acquisition_function.forward( + candidate_x=self.candidate_x ) self.assertGreaterEqual(lig_score.item(), 0) def test_acquisition_samples(self): - lig_1 = self.acquisition_function.acquisition( - candidate_x=self.candidate_x, - context_x=self.context_x, - context_y=self.context_y, - ) + lig_1 = self.acquisition_function.forward(candidate_x=self.candidate_x) self.acquisition_function.num_samples = 20 - lig_2 = self.acquisition_function.acquisition( - candidate_x=self.candidate_x, - context_x=self.context_x, - context_y=self.context_y, - ) + lig_2 = self.acquisition_function.forward(candidate_x=self.candidate_x) self.assertTrue(lig_2.item() < lig_1.item()) self.assertTrue(abs(lig_2.item() - lig_1.item()) < 0.2) - def test_acquisition_invalid_inputs(self): - invalid_context_x = torch.rand(10, self.x_dim + 5) - with self.assertRaises(Exception): - self.acquisition_function.acquisition( - candidate_x=self.candidate_x, - context_x=invalid_context_x, - context_y=self.context_y, - ) - - invalid_candidate_x = torch.rand(5, self.x_dim + 5) - with self.assertRaises(Exception): - self.acquisition_function.acquisition( - candidate_x=invalid_candidate_x, - context_x=self.context_x, - context_y=self.context_y, - ) if __name__ == "__main__": From 50fe7a137ecbac95d4223f97d081a2f95d7225ca Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Sat, 25 Jan 2025 06:49:45 -0800 Subject: [PATCH 09/35] 1/25 Updates --- test_community/models/test_np_regression.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/test_community/models/test_np_regression.py b/test_community/models/test_np_regression.py index 67ee410f96..ce3e1ff046 100644 --- a/test_community/models/test_np_regression.py +++ b/test_community/models/test_np_regression.py @@ -1,11 +1,10 @@ import unittest import numpy as np import torch -from torch import nn -from torch.optim import Adam from botorch_community.models.np_regression import NeuralProcessModel from botorch.posteriors import GPyTorchPosterior -from torch import Tensor + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class TestNeuralProcessModel(unittest.TestCase): def initialize(self): @@ -111,15 +110,10 @@ def test_posterior(self): mvn = posterior.mvn self.assertEqual(mvn.covariance_matrix.size(), (5, 5, 5)) - def test_load_state_dict(self): - self.initialize() - state_dict = {"r_encoder.mlp.model.0.bias": torch.rand(16)} - self.model.load_state_dict(state_dict, strict = False) - def test_transform_inputs(self): self.initialize() X = torch.rand(5, 3) - self.assertTrue(torch.equal(self.model.transform_inputs(X), X)) + self.assertTrue(torch.equal(self.model.transform_inputs(X), X.to(device))) if __name__ == "__main__": From 4aeeeeb29111e7fc50c66a839972f1051b86eab4 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Sun, 2 Feb 2025 05:46:34 -0800 Subject: [PATCH 10/35] Update Acquisition Dimensions --- .../acquisition/latent_information_gain.py | 29 +++++++------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/botorch_community/acquisition/latent_information_gain.py b/botorch_community/acquisition/latent_information_gain.py index 0b000053c6..bf07d8f36d 100644 --- a/botorch_community/acquisition/latent_information_gain.py +++ b/botorch_community/acquisition/latent_information_gain.py @@ -62,43 +62,36 @@ def __init__( self.context_x = context_x.to(device) self.context_y = context_y.to(device) - def forward(self, candidate_x): + def forward(self, candidate_x: Tensor) -> Tensor: """ Conduct the Latent Information Gain acquisition function for the inputs. Args: - candidate_x: Candidate input points, as a Tensor. + candidate_x: Candidate input points, as a Tensor. Ideally in the shape (N, q, D), and assumes N = 1 if the given dimensions are 2D. Returns: - torch.Tensor: The LIG score of computed KLDs. + torch.Tensor: The LIG scores of computed KLDs, in the shape (N, q). """ - candidate_x = candidate_x.to(device) - + if candidate_x.dim() == 2: + candidate_x = candidate_x.unsqueeze(0) + N, q, D = candidate_x.shape # Encoding and Scaling the context data z_mu_context, z_logvar_context = self.model.data_to_z_params(self.context_x, self.context_y) - kl = 0.0 + kl = torch.zeros(N, q, device=device) for _ in range(self.num_samples): - # Taking reparameterized samples + # Taking Samples/Predictions samples = self.model.sample_z(z_mu_context, z_logvar_context) - - # Using the Decoder to take predicted values - y_pred = self.model.decoder(candidate_x, samples) - - # Combining context and candidate data - combined_x = torch.cat([self.context_x, candidate_x], dim=0).to(device) + y_pred = self.model.decoder(candidate_x.view(-1, D), samples) + # Combining the data + combined_x = torch.cat([self.context_x, candidate_x.view(-1, D)], dim=0).to(device) combined_y = torch.cat([self.context_y, y_pred], dim=0).to(device) - # Computing posterior variables z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params(combined_x, combined_y) std_prior = self.min_std + self.scaler * torch.sigmoid(z_logvar_context) std_posterior = self.min_std + self.scaler * torch.sigmoid(z_logvar_posterior) - p = torch.distributions.Normal(z_mu_posterior, std_posterior) q = torch.distributions.Normal(z_mu_context, std_prior) - kl_divergence = torch.distributions.kl_divergence(p, q).sum() kl += kl_divergence - - # Average KLD return kl / self.num_samples From be34f608255d2b201d1c25d2b2cb1166baa90c15 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Sun, 2 Feb 2025 05:47:06 -0800 Subject: [PATCH 11/35] Updated Test Files --- .../acquisition/test_latent_information_gain.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/test_community/acquisition/test_latent_information_gain.py b/test_community/acquisition/test_latent_information_gain.py index 856ef65efd..f944292fd9 100644 --- a/test_community/acquisition/test_latent_information_gain.py +++ b/test_community/acquisition/test_latent_information_gain.py @@ -41,23 +41,13 @@ def test_acquisition_shape(self): candidate_x=self.candidate_x ) self.assertTrue(torch.is_tensor(lig_score)) - self.assertEqual(lig_score.shape, ()) + self.assertEqual(lig_score.shape, (1, 5)) def test_acquisition_kl(self): lig_score = self.acquisition_function.forward( candidate_x=self.candidate_x ) - self.assertGreaterEqual(lig_score.item(), 0) - - def test_acquisition_samples(self): - lig_1 = self.acquisition_function.forward(candidate_x=self.candidate_x) - - self.acquisition_function.num_samples = 20 - lig_2 = self.acquisition_function.forward(candidate_x=self.candidate_x) - self.assertTrue(lig_2.item() < lig_1.item()) - self.assertTrue(abs(lig_2.item() - lig_1.item()) < 0.2) - - + self.assertGreaterEqual(lig_score.mean().item(), 0) if __name__ == "__main__": unittest.main() From 204ba31d74f6584956c65f607b7d38700636367d Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Thu, 6 Mar 2025 02:17:05 -0800 Subject: [PATCH 12/35] Updated LIG Parameters/Generalizability --- .../acquisition/latent_information_gain.py | 99 +++++++++++-------- 1 file changed, 56 insertions(+), 43 deletions(-) diff --git a/botorch_community/acquisition/latent_information_gain.py b/botorch_community/acquisition/latent_information_gain.py index bf07d8f36d..8c5a524f37 100644 --- a/botorch_community/acquisition/latent_information_gain.py +++ b/botorch_community/acquisition/latent_information_gain.py @@ -18,80 +18,93 @@ """ from __future__ import annotations - -import warnings -from typing import Optional - +from typing import Type, Any import torch from botorch.acquisition import AcquisitionFunction from botorch_community.models.np_regression import NeuralProcessModel from torch import Tensor - -import torch -#reference: https://arxiv.org/abs/2106.02770 +# reference: https://arxiv.org/abs/2106.02770 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + class LatentInformationGain(AcquisitionFunction): def __init__( - self, - context_x: torch.Tensor, - context_y: torch.Tensor, - model: NeuralProcessModel, + self, + model: Type[Any] = NeuralProcessModel, num_samples: int = 10, min_std: float = 0.01, - scaler: float = 0.5 + scaler: float = 0.5, ) -> None: """ - Latent Information Gain (LIG) Acquisition Function, designed for the - NeuralProcessModel. This is a subclass of AcquisitionFunction. + Latent Information Gain (LIG) Acquisition Function. + Uses the model's built-in posterior function to generalize KL computation. Args: - model: Trained NeuralProcessModel. - context_x: Context input points, as a Tensor. - context_y: Context target points, as a Tensor. + model: The model class to be used, defaults to NeuralProcessModel. num_samples (int): Number of samples for calculation, defaults to 10. - min_std: Float representing the minimum possible standardized std, defaults to 0.1. - scaler: Float scaling the std, defaults to 0.9. + min_std: Float representing the minimum possible standardized std, + defaults to 0.01. + scaler: Float scaling the std, defaults to 0.5. """ - super().__init__(model=model) - self.model = model.to(device) + super().__init__() + self.model = model self.num_samples = num_samples self.min_std = min_std self.scaler = scaler - self.context_x = context_x.to(device) - self.context_y = context_y.to(device) def forward(self, candidate_x: Tensor) -> Tensor: """ - Conduct the Latent Information Gain acquisition function for the inputs. + Conduct the Latent Information Gain acquisition function using the model's + posterior. Args: - candidate_x: Candidate input points, as a Tensor. Ideally in the shape (N, q, D), and assumes N = 1 if the given dimensions are 2D. + candidate_x: Candidate input points, as a Tensor. Ideally in the shape + (N, q, D). Returns: torch.Tensor: The LIG scores of computed KLDs, in the shape (N, q). """ candidate_x = candidate_x.to(device) if candidate_x.dim() == 2: - candidate_x = candidate_x.unsqueeze(0) + candidate_x = candidate_x.unsqueeze(0) # Ensure (N, q, D) format N, q, D = candidate_x.shape - # Encoding and Scaling the context data - z_mu_context, z_logvar_context = self.model.data_to_z_params(self.context_x, self.context_y) + kl = torch.zeros(N, q, device=device) - for _ in range(self.num_samples): - # Taking Samples/Predictions - samples = self.model.sample_z(z_mu_context, z_logvar_context) - y_pred = self.model.decoder(candidate_x.view(-1, D), samples) - # Combining the data - combined_x = torch.cat([self.context_x, candidate_x.view(-1, D)], dim=0).to(device) - combined_y = torch.cat([self.context_y, y_pred], dim=0).to(device) - # Computing posterior variables - z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params(combined_x, combined_y) - std_prior = self.min_std + self.scaler * torch.sigmoid(z_logvar_context) - std_posterior = self.min_std + self.scaler * torch.sigmoid(z_logvar_posterior) - p = torch.distributions.Normal(z_mu_posterior, std_posterior) - q = torch.distributions.Normal(z_mu_context, std_prior) - kl_divergence = torch.distributions.kl_divergence(p, q).sum() - kl += kl_divergence + + if self.model is NeuralProcessModel: + z_mu_context, z_logvar_context = self.model.data_to_z_params( + self.context_x, self.context_y + ) + for _ in range(self.num_samples): + # Taking Samples/Predictions + samples = self.model.sample_z(z_mu_context, z_logvar_context) + y_pred = self.model.decoder(candidate_x.view(-1, D), samples) + # Combining the data + combined_x = torch.cat( + [self.context_x, candidate_x.view(-1, D)], dim=0 + ).to(device) + combined_y = torch.cat([self.context_y, y_pred], dim=0).to(device) + # Computing posterior variables + z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params( + combined_x, combined_y + ) + std_prior = self.min_std + self.scaler * torch.sigmoid(z_logvar_context) + std_posterior = self.min_std + self.scaler * torch.sigmoid( + z_logvar_posterior + ) + p = torch.distributions.Normal(z_mu_posterior, std_posterior) + q = torch.distributions.Normal(z_mu_context, std_prior) + kl_divergence = torch.distributions.kl_divergence(p, q).sum(dim=-1) + kl += kl_divergence + else: + for _ in range(self.num_samples): + posterior_prior = self.model.posterior(self.model.train_X) + posterior_candidate = self.model.posterior(candidate_x.view(-1, D)) + + kl_divergence = torch.distributions.kl_divergence( + posterior_candidate.mvn, posterior_prior.mvn + ).sum(dim=-1) + kl += kl_divergence + return kl / self.num_samples From 8e33fc4cf43770bc75a9fa03d4ec14f1ccf06d0b Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Thu, 6 Mar 2025 02:17:50 -0800 Subject: [PATCH 13/35] Updated NPR Compatability --- botorch_community/models/np_regression.py | 360 +++++++++++++--------- 1 file changed, 210 insertions(+), 150 deletions(-) diff --git a/botorch_community/models/np_regression.py b/botorch_community/models/np_regression.py index bce3495948..860538f94d 100644 --- a/botorch_community/models/np_regression.py +++ b/botorch_community/models/np_regression.py @@ -11,37 +11,45 @@ Contributor: eibarolle """ +from typing import Callable, List, Optional, Tuple + import torch import torch.nn as nn +from botorch.acquisition.objective import PosteriorTransform from botorch.models.model import Model +from botorch.models.transforms.input import InputTransform from botorch.posteriors import GPyTorchPosterior -from botorch.acquisition.objective import PosteriorTransform -from typing import Callable, List, Optional, Tuple -from torch.nn import Module from gpytorch.distributions import MultivariateNormal +from gpytorch.likelihoods import GaussianLikelihood +from gpytorch.likelihoods.likelihood import Likelihood +from torch.nn import Module device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Account for different acquisitions -#reference: https://chrisorm.github.io/NGP.html + +# reference: https://chrisorm.github.io/NGP.html class MLP(nn.Module): def __init__( - self, - input_dim: int, - output_dim: int, - hidden_dims: List[int], + self, + input_dim: int, + output_dim: int, + hidden_dims: List[int], activation: Callable = nn.Sigmoid, - init_func: Optional[Callable] = nn.init.normal_ + init_func: Optional[Callable] = nn.init.normal_, ) -> None: r""" A modular implementation of a Multilayer Perceptron (MLP). - + Args: input_dim: An int representing the total input dimensionality. output_dim: An int representing the total encoded dimensionality. - hidden_dims: A list of integers representing the # of units in each hidden dimension. - activation: Activation function applied between layers, defaults to nn.Sigmoid. - init_func: A function initializing the weights, defaults to nn.init.normal_. + hidden_dims: A list of integers representing the # of units in each hidden + dimension. + activation: Activation function applied between layers, defaults to + nn.Sigmoid. + init_func: A function initializing the weights, + defaults to nn.init.normal_. """ super().__init__() layers = [] @@ -63,29 +71,38 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x.to(device)) - + class REncoder(nn.Module): def __init__( self, - input_dim: int, - output_dim: int, + input_dim: int, + output_dim: int, hidden_dims: List[int], activation: Callable = nn.Sigmoid, - init_func: Optional[Callable] = nn.init.normal_ + init_func: Optional[Callable] = nn.init.normal_, ) -> None: r"""Encodes inputs of the form (x_i,y_i) into representations, r_i. Args: input_dim: An int representing the total input dimensionality. output_dim: An int representing the total encoded dimensionality. - hidden_dims: A list of integers representing the # of units in each hidden dimension. - activation: Activation function applied between layers, defaults to nn.Sigmoid. - init_func: A function initializing the weights, defaults to nn.init.normal_. + hidden_dims: A list of integers representing the # of units in each hidden + dimension. + activation: Activation function applied between layers, defaults to nn. + Sigmoid. + init_func: A function initializing the weights, + defaults to nn.init.normal_. """ super().__init__() - self.mlp = MLP(input_dim=input_dim, output_dim=output_dim, hidden_dims=hidden_dims, activation=activation, init_func=init_func).to(device) - + self.mlp = MLP( + input_dim=input_dim, + output_dim=output_dim, + hidden_dims=hidden_dims, + activation=activation, + init_func=init_func, + ).to(device) + def forward(self, inputs: torch.Tensor) -> torch.Tensor: r"""Forward pass for representation encoder. @@ -97,72 +114,103 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ return self.mlp(inputs.to(device)) + class ZEncoder(nn.Module): - def __init__(self, + def __init__( + self, input_dim: int, output_dim: int, hidden_dims: List[int], activation: Callable = nn.Sigmoid, - init_func: Optional[Callable] = nn.init.normal_ + init_func: Optional[Callable] = nn.init.normal_, ) -> None: - r"""Takes an r representation and produces the mean & standard + r"""Takes an r representation and produces the mean & standard deviation of the normally distributed function encoding, z. - + Args: input_dim: An int representing r's aggregated dimensionality. output_dim: An int representing z's latent dimensionality. - hidden_dims: A list of integers representing the # of units in each hidden dimension. - activation: Activation function applied between layers, defaults to nn.Sigmoid. - init_func: A function initializing the weights, defaults to nn.init.normal_. + hidden_dims: A list of integers representing the # of units in each hidden + dimension. + activation: Activation function applied between layers, defaults to nn. + Sigmoid. + init_func: A function initializing the weights, + defaults to nn.init.normal_. """ super().__init__() - self.mean_net = MLP(input_dim=input_dim, output_dim=output_dim, hidden_dims=hidden_dims, activation=activation, init_func=init_func).to(device) - self.logvar_net = MLP(input_dim=input_dim, output_dim=output_dim, hidden_dims=hidden_dims, activation=activation, init_func=init_func).to(device) - + self.mean_net = MLP( + input_dim=input_dim, + output_dim=output_dim, + hidden_dims=hidden_dims, + activation=activation, + init_func=init_func, + ).to(device) + self.logvar_net = MLP( + input_dim=input_dim, + output_dim=output_dim, + hidden_dims=hidden_dims, + activation=activation, + init_func=init_func, + ).to(device) + def forward(self, inputs: torch.Tensor) -> torch.Tensor: r"""Forward pass for latent encoder. Args: inputs: Input tensor - + Returns: - Tuple[torch.Tensor, torch.Tensor]: + Tuple[torch.Tensor, torch.Tensor]: - Mean of the latent Gaussian distribution. - Log variance of the latent Gaussian distribution. """ inputs = inputs.to(device) return self.mean_net(inputs), self.logvar_net(inputs) - + + class Decoder(torch.nn.Module): def __init__( self, - input_dim: int, - output_dim: int, + input_dim: int, + output_dim: int, hidden_dims: List[int], activation: Callable = nn.Sigmoid, - init_func: Optional[Callable] = nn.init.normal_ + init_func: Optional[Callable] = nn.init.normal_, ) -> None: - r"""Takes the x star points, along with a 'function encoding', z, and makes predictions. - + r"""Takes the x star points, along with a 'function encoding', z, and makes + predictions. + Args: input_dim: An int representing the total input dimensionality. output_dim: An int representing the total encoded dimensionality. - hidden_dims: A list of integers representing the # of units in each hidden dimension. - activation: Activation function applied between layers, defaults to nn.Sigmoid. - init_func: A function initializing the weights, defaults to nn.init.normal_. + hidden_dims: A list of integers representing the # of units in each hidden + dimension. + activation: Activation function applied between layers, defaults to + nn.Sigmoid. + init_func: A function initializing the weights, + defaults to nn.init.normal_. """ super().__init__() - self.mlp = MLP(input_dim=input_dim, output_dim=output_dim, hidden_dims=hidden_dims, activation=activation, init_func=init_func).to(device) - + self.mlp = MLP( + input_dim=input_dim, + output_dim=output_dim, + hidden_dims=hidden_dims, + activation=activation, + init_func=init_func, + ).to(device) + def forward(self, x_pred: torch.Tensor, z: torch.Tensor) -> torch.Tensor: r"""Forward pass for decoder. Args: - x_pred: Input points of shape (n x d_x), representing # of data points by x_dim. - z: Latent encoding of shape (num_samples x d_z), representing # of samples by z_dim. + x_pred: Input points of shape (n x d_x), representing # of data points by + x_dim. + z: Latent encoding of shape (num_samples x d_z), representing # of samples + by z_dim. Returns: - torch.Tensor: Predicted target values of shape (n, z_dim), representing # of data points by z_dim. + torch.Tensor: Predicted target values of shape (n x z_dim), representing # + of data points by z_dim. """ z = z.to(device) z_expanded = z.unsqueeze(0).expand(x_pred.size(0), -1).to(device) @@ -170,38 +218,29 @@ def forward(self, x_pred: torch.Tensor, z: torch.Tensor) -> torch.Tensor: xz = torch.cat([x_pred, z_expanded], dim=-1) return self.mlp(xz) -def MAE( - pred: torch.Tensor, - target: torch.Tensor, -) -> torch.Tensor: - r"""Mean Absolute Error loss function. - - Args: - pred: The predicted values tensor. - target: The target values tensor. - - Returns: - torch.Tensor: A tensor representing the MAE. - """ - loss = torch.abs(pred-target) - return loss.mean() class NeuralProcessModel(Model): def __init__( self, - r_hidden_dims: List[int], - z_hidden_dims: List[int], - decoder_hidden_dims: List[int], - x_dim: int, - y_dim: int, - r_dim: int, - z_dim: int, + train_X: torch.Tensor, + train_Y: torch.Tensor, + r_hidden_dims: List[int] = [16, 16], + z_hidden_dims: List[int] = [32, 32], + decoder_hidden_dims: List[int] = [16, 16], + x_dim: int = 2, + y_dim: int = 1, + r_dim: int = 64, + z_dim: int = 8, activation: Callable = nn.Sigmoid, init_func: Optional[Callable] = torch.nn.init.normal_, + likelihood: Likelihood | None = None, + input_transform: InputTransform | None = None, ) -> None: r"""Diffusion Convolutional Recurrent Neural Network Model Implementation. Args: + train_X: A `batch_shape x n x d` tensor of training features. + train_Y: A `batch_shape x n x m` tensor of training observations. r_hidden_dims: Hidden Dimensions/Layer list for REncoder z_hidden_dims: Hidden Dimensions/Layer list for ZEncoder decoder_hidden_dims: Hidden Dimensions/Layer for Decoder @@ -209,19 +248,45 @@ def __init__( y_dim: Int dimensionality of target data y. r_dim: Int dimensionality of representation r. z_dim: Int dimensionality of latent variable z. - activation: Activation function applied between layers, defaults to nn.Sigmoid. - init_func: A function initializing the weights, defaults to nn.init.normal_. + activation: Activation function applied between layers, defaults to nn. + Sigmoid. + init_func: A function initializing the weights, + defaults to nn.init.normal_. + likelihood: A likelihood. If omitted, use a standard GaussianLikelihood. + input_transform: An input transform that is applied in the model's + forward pass. """ super().__init__() - self.r_encoder = REncoder(x_dim+y_dim, r_dim, r_hidden_dims, activation=activation, init_func=init_func).to(device) - self.z_encoder = ZEncoder(r_dim, z_dim, z_hidden_dims, activation=activation, init_func=init_func).to(device) - self.decoder = Decoder(x_dim + z_dim, y_dim, decoder_hidden_dims, activation=activation, init_func=init_func).to(device) + self.r_encoder = REncoder( + x_dim + y_dim, + r_dim, + r_hidden_dims, + activation=activation, + init_func=init_func, + ).to(device) + self.z_encoder = ZEncoder( + r_dim, z_dim, z_hidden_dims, activation=activation, init_func=init_func + ).to(device) + self.decoder = Decoder( + x_dim + z_dim, + y_dim, + decoder_hidden_dims, + activation=activation, + init_func=init_func, + ).to(device) + self.train_X = train_X.to(device) + self.train_Y = train_Y.to(device) self.z_dim = z_dim self.z_mu_all = None self.z_logvar_all = None self.z_mu_context = None self.z_logvar_context = None - + if likelihood is None: + self.likelihood = GaussianLikelihood().to(device) + else: + self.likelihood = likelihood.to(device) + self.input_transform = input_transform + def data_to_z_params( self, x: torch.Tensor, @@ -238,7 +303,7 @@ def data_to_z_params( r_dim: Combined Target Dimension as int, defaults as 0. Returns: - Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - x_c: Context input data. - y_c: Context target data. - x_t: Target input data. @@ -246,18 +311,18 @@ def data_to_z_params( """ x = x.to(device) y = y.to(device) - xy = torch.cat([x,y], dim=xy_dim).to(device).to(device) + xy = torch.cat([x, y], dim=xy_dim).to(device).to(device) rs = self.r_encoder(xy) r_agg = rs.mean(dim=r_dim).to(device) - return self.z_encoder(r_agg) - + return self.z_encoder(r_agg) + def sample_z( self, mu: torch.Tensor, logvar: torch.Tensor, n: int = 1, min_std: float = 0.01, - scaler: float = 0.5 + scaler: float = 0.5, ) -> torch.Tensor: r"""Reparameterization trick for z's latent distribution. @@ -265,40 +330,38 @@ def sample_z( mu: Tensor representing the Gaussian distribution mean. logvar: Tensor representing the log variance of the Gaussian distribution. n: Int representing the # of samples, defaults to 1. - min_std: Float representing the minimum possible standardized std, defaults to 0.1. - scaler: Float scaling the std, defaults to 0.9. + min_std: Float representing the minimum possible standardized std, defaults + to 0.01. + scaler: Float scaling the std, defaults to 0.5. Returns: torch.Tensor: Samples from the Gaussian distribution. - """ + """ if min_std <= 0 or scaler <= 0: raise ValueError() - + shape = [n, self.z_dim] if n == 1: shape = shape[1:] eps = torch.autograd.Variable(logvar.data.new(*shape).normal_()).to(device) - - std = min_std + scaler * torch.sigmoid(logvar) + + std = min_std + scaler * torch.sigmoid(logvar) std = std.to(device) mu = mu.to(device) return mu + std * eps - def KLD_gaussian( - self, - min_std: float = 0.1, - scaler: float = 0.9 - ) -> torch.Tensor: + def KLD_gaussian(self, min_std: float = 0.01, scaler: float = 0.5) -> torch.Tensor: r"""Analytical KLD between 2 Gaussian Distributions. Args: - min_std: Float representing the minimum possible standardized std, defaults to 0.1. - scaler: Float scaling the std, defaults to 0.9. - + min_std: Float representing the minimum possible standardized std, defaults + to 0.01. + scaler: Float scaling the std, defaults to 0.5. + Returns: torch.Tensor: A tensor representing the KLD. """ - + if min_std <= 0 or scaler <= 0: raise ValueError() std_q = min_std + scaler * torch.sigmoid(self.z_logvar_all).to(device) @@ -306,14 +369,13 @@ def KLD_gaussian( p = torch.distributions.Normal(self.z_mu_context.to(device), std_p) q = torch.distributions.Normal(self.z_mu_all.to(device), std_q) return torch.distributions.kl_divergence(p, q).sum() - + def posterior( - self, - X: torch.Tensor, - covariance_multiplier: float, - observation_constant: float, - observation_noise: bool = False, - posterior_transform: Optional[PosteriorTransform] = None, + self, + X: torch.Tensor, + output_indices: list[int] | None = None, + observation_noise: bool = False, + posterior_transform: PosteriorTransform | None = None, ) -> GPyTorchPosterior: r"""Computes the model's posterior distribution for given input tensors. @@ -321,18 +383,27 @@ def posterior( X: Input Tensor covariance_multiplier: Float scaling the covariance. observation_constant: Float representing the noise constant. - observation_noise: Adds observation noise to the covariance if True, defaults to False. - posterior_transform: An optional posterior transformation, defaults to None. + output_indices: Ignored (defined in parent Model, but not used here). + observation_noise: Adds observation noise to the covariance if True, + defaults to False. + posterior_transform: An optional posterior transformation, + defaults to None. Returns: - GPyTorchPosterior: The posterior distribution object + GPyTorchPosterior: The posterior distribution object utilizing MultivariateNormal. """ + X = self.transform_inputs(X) X = X.to(device) - mean = self.decoder(X.to(device), self.sample_z(self.z_mu_all, self.z_logvar_all)) - covariance = torch.eye(X.size(0)) * covariance_multiplier - if (observation_noise): - covariance = covariance + observation_constant + mean = self.decoder( + X.to(device), self.sample_z(self.z_mu_all, self.z_logvar_all) + ) + z_var = torch.exp(self.z_logvar_all) + covariance = torch.eye(X.size(0)).to(device) * z_var.mean() + if observation_noise: + covariance = covariance + self.likelihood.noise * torch.eye( + covariance.size(0) + ).to(device) mvn = MultivariateNormal(mean, covariance) posterior = GPyTorchPosterior(mvn) if posterior_transform is not None: @@ -364,73 +435,62 @@ def transform_inputs( def forward( self, - x_t: torch.Tensor, - x_c: torch.Tensor, - y_c: torch.Tensor, - y_t: torch.Tensor, - input_dim: int = 0, - target_dim: int = 0 - ) -> torch.Tensor: + train_X: torch.Tensor, + train_Y: torch.Tensor, + n_context: int, + axis: int = 0, + ) -> MultivariateNormal: r"""Forward pass for the model. Args: - x_t: Target input data. - x_c: Context input data. - y_c: Context target data. - y_t: Target output data. - input_dim: Input dimension concatenated - target_dim: Target dimension concatendated + train_X: A `batch_shape x n x d` tensor of training features. + train_Y: A `batch_shape x n x m` tensor of training observations. + n_context (int): Number of context points. + axis: Dimension axis as int, defaulted as 0. Returns: - torch.Tensor: Predicted target values. + MultivariateNormal: Predicted target distribution. """ - if any(tensor.numel() == 0 for tensor in [x_t, x_c, y_c]): - raise ValueError() - if input_dim not in [0, 1]: - raise ValueError() - if x_c.size(1 - input_dim) != x_t.size(1 - input_dim): - raise ValueError() - if y_c.size(1 - target_dim) != y_t.size(1 - target_dim): - raise ValueError() - + train_X = self.transform_inputs(train_X) + x_c, y_c, x_t, y_t = self.random_split_context_target( + train_X, train_Y, n_context, axis=axis + ) x_t = x_t.to(device) x_c = x_c.to(device) y_c = y_c.to(device) y_t = y_t.to(device) - - self.z_mu_all, self.z_logvar_all = self.data_to_z_params(torch.cat([x_c, x_t], dim = input_dim), torch.cat([y_c, y_t], dim = target_dim)) - self.z_mu_context, self.z_logvar_context = self.data_to_z_params(x_c, y_c) - z = self.sample_z(self.z_mu_all, self.z_logvar_all) - return self.decoder(x_t, z) - + self.z_mu_all, self.z_logvar_all = self.data_to_z_params( + self.train_X, self.train_Y, dim=axis + ) + self.z_mu_context, self.z_logvar_context = self.data_to_z_params( + x_c, y_c, dim=axis + ) + x_t = self.transform_inputs(x_t) + return self.posterior(x_t).distribution + def random_split_context_target( - self, - x: torch.Tensor, - y: torch.Tensor, - n_context: int, - axis: int + self, x: torch.Tensor, y: torch.Tensor, n_context: int, axis: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: r"""Helper function to split randomly into context and target. Args: - x: Input data tensor. - y: Target data tensor. + x: A `batch_shape x n x d` tensor of training features. + y: A `batch_shape x n x m` tensor of training observations. n_context (int): Number of context points. axis: Dimension axis as int Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - x_c: Context input data. - y_c: Context target data. - x_t: Target input data. - y_t: Target target data. """ - mask = torch.randperm(x.shape[0])[:n_context] + mask = torch.randperm(x.shape[axis])[:n_context] x_c = torch.from_numpy(x[mask]).to(device) y_c = torch.from_numpy(y[mask]).to(device) - splitter = torch.zeros(x.shape[0], dtype=torch.bool) + splitter = torch.zeros(x.shape[axis], dtype=torch.bool) splitter[mask] = True x_t = torch.from_numpy(x[~splitter]).to(device) y_t = torch.from_numpy(y[~splitter]).to(device) return x_c, y_c, x_t, y_t - \ No newline at end of file From 72327257c2d9f712b4678d215c8f586794352845 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Tue, 11 Mar 2025 15:45:35 -0700 Subject: [PATCH 14/35] Updated NPR Parameters --- botorch_community/models/np_regression.py | 46 ++++++++++++++++++++--- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/botorch_community/models/np_regression.py b/botorch_community/models/np_regression.py index 860538f94d..0941397203 100644 --- a/botorch_community/models/np_regression.py +++ b/botorch_community/models/np_regression.py @@ -231,6 +231,7 @@ def __init__( y_dim: int = 1, r_dim: int = 64, z_dim: int = 8, + n_context: int = 20, activation: Callable = nn.Sigmoid, init_func: Optional[Callable] = torch.nn.init.normal_, likelihood: Likelihood | None = None, @@ -248,6 +249,7 @@ def __init__( y_dim: Int dimensionality of target data y. r_dim: Int dimensionality of representation r. z_dim: Int dimensionality of latent variable z. + n_context (int): Number of context points, defaults to 20. activation: Activation function applied between layers, defaults to nn. Sigmoid. init_func: A function initializing the weights, @@ -276,6 +278,7 @@ def __init__( ).to(device) self.train_X = train_X.to(device) self.train_Y = train_Y.to(device) + self.n_context = n_context self.z_dim = z_dim self.z_mu_all = None self.z_logvar_all = None @@ -430,14 +433,13 @@ def transform_inputs( return input_transform(X) try: return self.input_transform(X) - except AttributeError: + except (AttributeError, TypeError): return X def forward( self, train_X: torch.Tensor, train_Y: torch.Tensor, - n_context: int, axis: int = 0, ) -> MultivariateNormal: r"""Forward pass for the model. @@ -445,7 +447,6 @@ def forward( Args: train_X: A `batch_shape x n x d` tensor of training features. train_Y: A `batch_shape x n x m` tensor of training observations. - n_context (int): Number of context points. axis: Dimension axis as int, defaulted as 0. Returns: @@ -453,17 +454,17 @@ def forward( """ train_X = self.transform_inputs(train_X) x_c, y_c, x_t, y_t = self.random_split_context_target( - train_X, train_Y, n_context, axis=axis + train_X, train_Y, self.n_context, axis=axis ) x_t = x_t.to(device) x_c = x_c.to(device) y_c = y_c.to(device) y_t = y_t.to(device) self.z_mu_all, self.z_logvar_all = self.data_to_z_params( - self.train_X, self.train_Y, dim=axis + self.train_X, self.train_Y ) self.z_mu_context, self.z_logvar_context = self.data_to_z_params( - x_c, y_c, dim=axis + x_c, y_c ) x_t = self.transform_inputs(x_t) return self.posterior(x_t).distribution @@ -486,6 +487,7 @@ def random_split_context_target( - x_t: Target input data. - y_t: Target target data. """ + self.n_context = n_context mask = torch.randperm(x.shape[axis])[:n_context] x_c = torch.from_numpy(x[mask]).to(device) y_c = torch.from_numpy(y[mask]).to(device) @@ -494,3 +496,35 @@ def random_split_context_target( x_t = torch.from_numpy(x[~splitter]).to(device) y_t = torch.from_numpy(y[~splitter]).to(device) return x_c, y_c, x_t, y_t + + def random_split_context_target( + self, + x: torch.Tensor, + y: torch.Tensor, + n_context: int = 20, + axis: int = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + r"""Helper function to split randomly into context and target. + + Args: + x: A `batch_shape x n x d` tensor of training features. + y: A `batch_shape x n x m` tensor of training observations. + n_context (int): Number of context points. + axis: Dimension axis as int + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + - x_c: Context input data. + - y_c: Context target data. + - x_t: Target input data. + - y_t: Target target data. + """ + self.n_context = n_context + mask = torch.randperm(x.shape[axis])[:n_context] + splitter = torch.zeros(x.shape[axis], dtype=torch.bool) + x_c = x[mask].to(device) + y_c = y[mask].to(device) + splitter[mask] = True + x_t = x[~splitter].to(device) + y_t = y[~splitter].to(device) + return x_c, y_c, x_t, y_t From d78d26269826a56e7bd603783fa2df53f61d6de0 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Tue, 11 Mar 2025 15:46:09 -0700 Subject: [PATCH 15/35] LIG WIP --- .../acquisition/latent_information_gain.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/botorch_community/acquisition/latent_information_gain.py b/botorch_community/acquisition/latent_information_gain.py index 8c5a524f37..6a6ad91c5f 100644 --- a/botorch_community/acquisition/latent_information_gain.py +++ b/botorch_community/acquisition/latent_information_gain.py @@ -47,7 +47,7 @@ def __init__( defaults to 0.01. scaler: Float scaling the std, defaults to 0.5. """ - super().__init__() + super().__init__(model) self.model = model self.num_samples = num_samples self.min_std = min_std @@ -72,19 +72,28 @@ def forward(self, candidate_x: Tensor) -> Tensor: kl = torch.zeros(N, q, device=device) - if self.model is NeuralProcessModel: - z_mu_context, z_logvar_context = self.model.data_to_z_params( - self.context_x, self.context_y + if isinstance(self.model, NeuralProcessModel): + x_c, y_c, x_t, y_t = self.model.random_split_context_target( + self.model.train_X[:, 0], self.model.train_Y ) + print(x_c.shape) + print(y_c.shape) + print(self.model.train_X) + print(self.model.train_X[:, 0]) + print(self.model.train_Y) + print(self.model.train_Y[:, 0]) + z_mu_context, z_logvar_context = self.model.data_to_z_params(x_c, y_c, xy_dim = -1) + print(z_mu_context) + print(z_logvar_context) for _ in range(self.num_samples): # Taking Samples/Predictions samples = self.model.sample_z(z_mu_context, z_logvar_context) y_pred = self.model.decoder(candidate_x.view(-1, D), samples) # Combining the data combined_x = torch.cat( - [self.context_x, candidate_x.view(-1, D)], dim=0 + [x_c, candidate_x.view(-1, D)], dim=0 ).to(device) - combined_y = torch.cat([self.context_y, y_pred], dim=0).to(device) + combined_y = torch.cat([self.y_c, y_pred], dim=0).to(device) # Computing posterior variables z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params( combined_x, combined_y From 32916b973704e95fe411645e76cf0b1eee63322a Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Tue, 11 Mar 2025 15:47:28 -0700 Subject: [PATCH 16/35] Updated Tests --- test_community/models/test_np_regression.py | 57 +++++++++------------ 1 file changed, 25 insertions(+), 32 deletions(-) diff --git a/test_community/models/test_np_regression.py b/test_community/models/test_np_regression.py index ce3e1ff046..4f60c07848 100644 --- a/test_community/models/test_np_regression.py +++ b/test_community/models/test_np_regression.py @@ -1,11 +1,11 @@ import unittest -import numpy as np import torch -from botorch_community.models.np_regression import NeuralProcessModel from botorch.posteriors import GPyTorchPosterior +from botorch_community.models.np_regression import NeuralProcessModel device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + class TestNeuralProcessModel(unittest.TestCase): def initialize(self): self.r_hidden_dims = [16, 16] @@ -16,7 +16,9 @@ def initialize(self): self.r_dim = 8 self.z_dim = 8 self.model = NeuralProcessModel( - self.r_hidden_dims, + torch.rand(100, self.x_dim), + torch.rand(100, self.y_dim), + self.r_hidden_dims, self.z_hidden_dims, self.decoder_hidden_dims, self.x_dim, @@ -24,31 +26,29 @@ def initialize(self): self.r_dim, self.z_dim, ) - self.x_data = np.random.rand(100, self.x_dim) - self.y_data = np.random.rand(100, self.y_dim) def test_r_encoder(self): self.initialize() - input = torch.rand(10, self.x_dim + self.y_dim) + input = torch.rand(100, self.x_dim + self.y_dim) output = self.model.r_encoder(input) - self.assertEqual(output.shape, (10, self.r_dim)) + self.assertEqual(output.shape, (100, self.r_dim)) self.assertTrue(torch.is_tensor(output)) def test_z_encoder(self): self.initialize() - input = torch.rand(10, self.r_dim) + input = torch.rand(100, self.r_dim) mean, logvar = self.model.z_encoder(input) - self.assertEqual(mean.shape, (10, self.z_dim)) - self.assertEqual(logvar.shape, (10, self.z_dim)) + self.assertEqual(mean.shape, (100, self.z_dim)) + self.assertEqual(logvar.shape, (100, self.z_dim)) self.assertTrue(torch.is_tensor(mean)) self.assertTrue(torch.is_tensor(logvar)) def test_decoder(self): self.initialize() - x_pred = torch.rand(10, self.x_dim) + x_pred = torch.rand(100, self.x_dim) z = torch.rand(self.z_dim) output = self.model.decoder(x_pred, z) - self.assertEqual(output.shape, (10, self.y_dim)) + self.assertEqual(output.shape, (100, self.y_dim)) self.assertTrue(torch.is_tensor(output)) def test_sample_z(self): @@ -71,9 +71,10 @@ def test_KLD_gaussian(self): def test_data_to_z_params(self): self.initialize() - x = torch.rand(10, self.x_dim) - y = torch.rand(10, self.y_dim) - mu, logvar = self.model.data_to_z_params(x, y) + mu, logvar = self.model.data_to_z_params( + self.model.train_X, + self.model.train_Y + ) self.assertEqual(mu.shape, (self.z_dim,)) self.assertEqual(logvar.shape, (self.z_dim,)) self.assertTrue(torch.is_tensor(mu)) @@ -81,40 +82,32 @@ def test_data_to_z_params(self): def test_forward(self): self.initialize() - x_t = torch.rand(5, self.x_dim) - x_c = torch.rand(10, self.x_dim) - y_c = torch.rand(10, self.y_dim) - y_t = torch.rand(5, self.y_dim) - output = self.model(x_t, x_c, y_c, y_t) - self.assertEqual(output.shape, (5, self.y_dim)) + output = self.model(self.model.train_X, self.model.train_Y) + self.assertEqual(output.loc.shape, (80, self.y_dim)) def test_random_split_context_target(self): self.initialize() x_c, y_c, x_t, y_t = self.model.random_split_context_target( - self.x_data[:, 0], self.y_data, 20, 0 + self.model.train_X[:, 0], self.model.train_Y ) self.assertEqual(x_c.shape[0], 20) self.assertEqual(y_c.shape[0], 20) self.assertEqual(x_t.shape[0], 80) self.assertEqual(y_t.shape[0], 80) - + def test_posterior(self): self.initialize() - x_t = torch.rand(5, self.x_dim) - x_c = torch.rand(10, self.x_dim) - y_c = torch.rand(10, self.y_dim) - y_t = torch.rand(5, self.y_dim) - output = self.model(x_t, x_c, y_c, y_t) - posterior = self.model.posterior(x_t, 0.1, 0.01, observation_noise=True) + self.model(self.model.train_X, self.model.train_Y) + posterior = self.model.posterior(self.model.train_X, observation_noise=True) self.assertIsInstance(posterior, GPyTorchPosterior) mvn = posterior.mvn - self.assertEqual(mvn.covariance_matrix.size(), (5, 5, 5)) - + self.assertEqual(mvn.covariance_matrix.size(), (100, 100, 100)) + def test_transform_inputs(self): self.initialize() X = torch.rand(5, 3) self.assertTrue(torch.equal(self.model.transform_inputs(X), X.to(device))) - + if __name__ == "__main__": unittest.main() From e13f38ce0fd4902a1cda9ad78de1c94b9bce8a86 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Tue, 11 Mar 2025 15:47:52 -0700 Subject: [PATCH 17/35] Test LIG WIP --- .../test_latent_information_gain.py | 31 +++++++++---------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/test_community/acquisition/test_latent_information_gain.py b/test_community/acquisition/test_latent_information_gain.py index f944292fd9..4961bd266e 100644 --- a/test_community/acquisition/test_latent_information_gain.py +++ b/test_community/acquisition/test_latent_information_gain.py @@ -1,53 +1,50 @@ import unittest + import torch from botorch_community.acquisition.latent_information_gain import LatentInformationGain from botorch_community.models.np_regression import NeuralProcessModel + class TestLatentInformationGain(unittest.TestCase): def setUp(self): self.x_dim = 2 self.y_dim = 1 self.r_dim = 8 - self.z_dim = 3 - self.context_x = torch.rand(10, self.x_dim) - self.context_y = torch.rand(10, self.y_dim) + self.z_dim = 8 self.r_hidden_dims = [16, 16] self.z_hidden_dims = [32, 32] self.decoder_hidden_dims = [16, 16] - self.num_samples = 10 self.model = NeuralProcessModel( - r_hidden_dims = self.r_hidden_dims, - z_hidden_dims = self.z_hidden_dims, - decoder_hidden_dims = self.decoder_hidden_dims, + torch.rand(10, self.x_dim), + torch.rand(10, self.y_dim), + r_hidden_dims=self.r_hidden_dims, + z_hidden_dims=self.z_hidden_dims, + decoder_hidden_dims=self.decoder_hidden_dims, x_dim=self.x_dim, y_dim=self.y_dim, r_dim=self.r_dim, z_dim=self.z_dim, ) self.acquisition_function = LatentInformationGain( - context_x=self.context_x, - context_y=self.context_y, model=self.model, - num_samples=self.num_samples, ) self.candidate_x = torch.rand(5, self.x_dim) def test_initialization(self): - self.assertEqual(self.acquisition_function.num_samples, self.num_samples) + self.assertEqual(self.acquisition_function.num_samples, 10) self.assertEqual(self.acquisition_function.model, self.model) def test_acquisition_shape(self): - lig_score = self.acquisition_function.forward( - candidate_x=self.candidate_x - ) + self.model(self.model.train_X, self.model.train_Y) + lig_score = self.acquisition_function.forward(candidate_x=self.candidate_x) self.assertTrue(torch.is_tensor(lig_score)) self.assertEqual(lig_score.shape, (1, 5)) def test_acquisition_kl(self): - lig_score = self.acquisition_function.forward( - candidate_x=self.candidate_x - ) + self.model(self.model.train_X, self.model.train_Y) + lig_score = self.acquisition_function.forward(candidate_x=self.candidate_x) self.assertGreaterEqual(lig_score.mean().item(), 0) + if __name__ == "__main__": unittest.main() From bf95d41d9b175995ec5230ce56e8282382604cbb Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Fri, 21 Mar 2025 05:47:36 -0700 Subject: [PATCH 18/35] LIG Updated Parameters --- .../acquisition/latent_information_gain.py | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/botorch_community/acquisition/latent_information_gain.py b/botorch_community/acquisition/latent_information_gain.py index 6a6ad91c5f..71fab3a9e0 100644 --- a/botorch_community/acquisition/latent_information_gain.py +++ b/botorch_community/acquisition/latent_information_gain.py @@ -18,7 +18,9 @@ """ from __future__ import annotations -from typing import Type, Any + +from typing import Any, Type + import torch from botorch.acquisition import AcquisitionFunction from botorch_community.models.np_regression import NeuralProcessModel @@ -42,7 +44,7 @@ def __init__( Args: model: The model class to be used, defaults to NeuralProcessModel. - num_samples (int): Number of samples for calculation, defaults to 10. + num_samples: Int showing the # of samples for calculation, defaults to 10. min_std: Float representing the minimum possible standardized std, defaults to 0.01. scaler: Float scaling the std, defaults to 0.5. @@ -74,26 +76,18 @@ def forward(self, candidate_x: Tensor) -> Tensor: if isinstance(self.model, NeuralProcessModel): x_c, y_c, x_t, y_t = self.model.random_split_context_target( - self.model.train_X[:, 0], self.model.train_Y + self.model.train_X, + self.model.train_Y, + self.model.n_context ) - print(x_c.shape) - print(y_c.shape) - print(self.model.train_X) - print(self.model.train_X[:, 0]) - print(self.model.train_Y) - print(self.model.train_Y[:, 0]) - z_mu_context, z_logvar_context = self.model.data_to_z_params(x_c, y_c, xy_dim = -1) - print(z_mu_context) - print(z_logvar_context) + z_mu_context, z_logvar_context = self.model.data_to_z_params(x_c, y_c) for _ in range(self.num_samples): # Taking Samples/Predictions samples = self.model.sample_z(z_mu_context, z_logvar_context) y_pred = self.model.decoder(candidate_x.view(-1, D), samples) # Combining the data - combined_x = torch.cat( - [x_c, candidate_x.view(-1, D)], dim=0 - ).to(device) - combined_y = torch.cat([self.y_c, y_pred], dim=0).to(device) + combined_x = torch.cat([x_c, candidate_x.view(-1, D)], dim=0).to(device) + combined_y = torch.cat([y_c, y_pred], dim=0).to(device) # Computing posterior variables z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params( combined_x, combined_y From 046f609f8a9d0248bf39965226f291f26dba8ade Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Fri, 21 Mar 2025 05:47:57 -0700 Subject: [PATCH 19/35] NPR Updated Parameters --- botorch_community/models/np_regression.py | 49 +++-------------------- 1 file changed, 5 insertions(+), 44 deletions(-) diff --git a/botorch_community/models/np_regression.py b/botorch_community/models/np_regression.py index 0941397203..3c667010f9 100644 --- a/botorch_community/models/np_regression.py +++ b/botorch_community/models/np_regression.py @@ -291,18 +291,13 @@ def __init__( self.input_transform = input_transform def data_to_z_params( - self, - x: torch.Tensor, - y: torch.Tensor, - xy_dim: int = 1, - r_dim: int = 0, + self, x: torch.Tensor, y: torch.Tensor, r_dim: int = 0 ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Compute latent parameters from inputs as a latent distribution. Args: x: Input tensor y: Target tensor - xy_dim: Combined Input Dimension as int, defaults as 1 r_dim: Combined Target Dimension as int, defaults as 0. Returns: @@ -314,7 +309,7 @@ def data_to_z_params( """ x = x.to(device) y = y.to(device) - xy = torch.cat([x, y], dim=xy_dim).to(device).to(device) + xy = torch.cat([x, y], dim=-1).to(device).to(device) rs = self.r_encoder(xy) r_agg = rs.mean(dim=r_dim).to(device) return self.z_encoder(r_agg) @@ -463,46 +458,12 @@ def forward( self.z_mu_all, self.z_logvar_all = self.data_to_z_params( self.train_X, self.train_Y ) - self.z_mu_context, self.z_logvar_context = self.data_to_z_params( - x_c, y_c - ) + self.z_mu_context, self.z_logvar_context = self.data_to_z_params(x_c, y_c) x_t = self.transform_inputs(x_t) return self.posterior(x_t).distribution def random_split_context_target( - self, x: torch.Tensor, y: torch.Tensor, n_context: int, axis: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - r"""Helper function to split randomly into context and target. - - Args: - x: A `batch_shape x n x d` tensor of training features. - y: A `batch_shape x n x m` tensor of training observations. - n_context (int): Number of context points. - axis: Dimension axis as int - - Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - - x_c: Context input data. - - y_c: Context target data. - - x_t: Target input data. - - y_t: Target target data. - """ - self.n_context = n_context - mask = torch.randperm(x.shape[axis])[:n_context] - x_c = torch.from_numpy(x[mask]).to(device) - y_c = torch.from_numpy(y[mask]).to(device) - splitter = torch.zeros(x.shape[axis], dtype=torch.bool) - splitter[mask] = True - x_t = torch.from_numpy(x[~splitter]).to(device) - y_t = torch.from_numpy(y[~splitter]).to(device) - return x_c, y_c, x_t, y_t - - def random_split_context_target( - self, - x: torch.Tensor, - y: torch.Tensor, - n_context: int = 20, - axis: int = 0 + self, x: torch.Tensor, y: torch.Tensor, n_context, axis: int = 0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: r"""Helper function to split randomly into context and target. @@ -510,7 +471,7 @@ def random_split_context_target( x: A `batch_shape x n x d` tensor of training features. y: A `batch_shape x n x m` tensor of training observations. n_context (int): Number of context points. - axis: Dimension axis as int + axis: Dimension axis as int, defaults to 0. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: From 4c037c516678049c36209c793bda558662735ce2 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Fri, 21 Mar 2025 05:48:22 -0700 Subject: [PATCH 20/35] LIG Updated Tests --- test_community/acquisition/test_latent_information_gain.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_community/acquisition/test_latent_information_gain.py b/test_community/acquisition/test_latent_information_gain.py index 4961bd266e..34c65d4a7c 100644 --- a/test_community/acquisition/test_latent_information_gain.py +++ b/test_community/acquisition/test_latent_information_gain.py @@ -10,7 +10,7 @@ def setUp(self): self.x_dim = 2 self.y_dim = 1 self.r_dim = 8 - self.z_dim = 8 + self.z_dim = 3 self.r_hidden_dims = [16, 16] self.z_hidden_dims = [32, 32] self.decoder_hidden_dims = [16, 16] From e7c964d10f7c5d6d8b2379c5a328ee4f004d3384 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Fri, 21 Mar 2025 05:48:39 -0700 Subject: [PATCH 21/35] NPR Updated Tests --- test_community/models/test_np_regression.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test_community/models/test_np_regression.py b/test_community/models/test_np_regression.py index 4f60c07848..6cf1e26cc8 100644 --- a/test_community/models/test_np_regression.py +++ b/test_community/models/test_np_regression.py @@ -1,4 +1,5 @@ import unittest + import torch from botorch.posteriors import GPyTorchPosterior from botorch_community.models.np_regression import NeuralProcessModel @@ -15,6 +16,7 @@ def initialize(self): self.y_dim = 1 self.r_dim = 8 self.z_dim = 8 + self.n_context = 20 self.model = NeuralProcessModel( torch.rand(100, self.x_dim), torch.rand(100, self.y_dim), @@ -25,6 +27,7 @@ def initialize(self): self.y_dim, self.r_dim, self.z_dim, + self.n_context ) def test_r_encoder(self): @@ -71,10 +74,7 @@ def test_KLD_gaussian(self): def test_data_to_z_params(self): self.initialize() - mu, logvar = self.model.data_to_z_params( - self.model.train_X, - self.model.train_Y - ) + mu, logvar = self.model.data_to_z_params(self.model.train_X, self.model.train_Y) self.assertEqual(mu.shape, (self.z_dim,)) self.assertEqual(logvar.shape, (self.z_dim,)) self.assertTrue(torch.is_tensor(mu)) @@ -88,7 +88,7 @@ def test_forward(self): def test_random_split_context_target(self): self.initialize() x_c, y_c, x_t, y_t = self.model.random_split_context_target( - self.model.train_X[:, 0], self.model.train_Y + self.model.train_X[:, 0], self.model.train_Y, self.model.n_context ) self.assertEqual(x_c.shape[0], 20) self.assertEqual(y_c.shape[0], 20) From 96d2bb499664413da7c7d58096832d335cfa96ec Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Mon, 24 Mar 2025 06:21:55 -0700 Subject: [PATCH 22/35] Update Parameters --- botorch_community/acquisition/latent_information_gain.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/botorch_community/acquisition/latent_information_gain.py b/botorch_community/acquisition/latent_information_gain.py index 71fab3a9e0..6c298f121f 100644 --- a/botorch_community/acquisition/latent_information_gain.py +++ b/botorch_community/acquisition/latent_information_gain.py @@ -33,7 +33,7 @@ class LatentInformationGain(AcquisitionFunction): def __init__( self, - model: Type[Any] = NeuralProcessModel, + model: Type[Any], num_samples: int = 10, min_std: float = 0.01, scaler: float = 0.5, From c4f5f879471bdcb7ee401ae7f4475e0f3159a745 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Mon, 24 Mar 2025 06:22:27 -0700 Subject: [PATCH 23/35] Update Parameters --- botorch_community/models/np_regression.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/botorch_community/models/np_regression.py b/botorch_community/models/np_regression.py index 3c667010f9..3d25d13660 100644 --- a/botorch_community/models/np_regression.py +++ b/botorch_community/models/np_regression.py @@ -242,13 +242,16 @@ def __init__( Args: train_X: A `batch_shape x n x d` tensor of training features. train_Y: A `batch_shape x n x m` tensor of training observations. - r_hidden_dims: Hidden Dimensions/Layer list for REncoder - z_hidden_dims: Hidden Dimensions/Layer list for ZEncoder - decoder_hidden_dims: Hidden Dimensions/Layer for Decoder - x_dim: Int dimensionality of input data x. - y_dim: Int dimensionality of target data y. - r_dim: Int dimensionality of representation r. - z_dim: Int dimensionality of latent variable z. + r_hidden_dims: Hidden Dimensions/Layer list for REncoder, defaults to + [16, 16] + z_hidden_dims: Hidden Dimensions/Layer list for ZEncoder, defaults to + [32, 32] + decoder_hidden_dims: Hidden Dimensions/Layer for Decoder, defaults to + [16, 16] + x_dim: Int dimensionality of input data x, defaults to 2. + y_dim: Int dimensionality of target data y, defaults to 1. + r_dim: Int dimensionality of representation r, defaults to 64. + z_dim: Int dimensionality of latent variable z, defaults to 8. n_context (int): Number of context points, defaults to 20. activation: Activation function applied between layers, defaults to nn. Sigmoid. From 529e36c5c3892a46eef9ef6abd5b9dd0349b19c1 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Mon, 24 Mar 2025 06:22:55 -0700 Subject: [PATCH 24/35] Update Parameters --- test_community/acquisition/test_latent_information_gain.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test_community/acquisition/test_latent_information_gain.py b/test_community/acquisition/test_latent_information_gain.py index 34c65d4a7c..cfacf6cb37 100644 --- a/test_community/acquisition/test_latent_information_gain.py +++ b/test_community/acquisition/test_latent_information_gain.py @@ -25,9 +25,7 @@ def setUp(self): r_dim=self.r_dim, z_dim=self.z_dim, ) - self.acquisition_function = LatentInformationGain( - model=self.model, - ) + self.acquisition_function = LatentInformationGain(self.model) self.candidate_x = torch.rand(5, self.x_dim) def test_initialization(self): From 0e280776b862d430ac4e2e6ebfb4abf2200ad375 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Tue, 29 Apr 2025 06:34:08 -0700 Subject: [PATCH 25/35] April Updates --- botorch_community/models/np_regression.py | 104 +++++++++++----------- 1 file changed, 53 insertions(+), 51 deletions(-) diff --git a/botorch_community/models/np_regression.py b/botorch_community/models/np_regression.py index 3d25d13660..6100be5bd0 100644 --- a/botorch_community/models/np_regression.py +++ b/botorch_community/models/np_regression.py @@ -22,11 +22,9 @@ from gpytorch.distributions import MultivariateNormal from gpytorch.likelihoods import GaussianLikelihood from gpytorch.likelihoods.likelihood import Likelihood +from gpytorch.models.gp import GP from torch.nn import Module -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# Account for different acquisitions - # reference: https://chrisorm.github.io/NGP.html class MLP(nn.Module): @@ -56,21 +54,21 @@ def __init__( prev_dim = input_dim for hidden_dim in hidden_dims: - layer = nn.Linear(prev_dim, hidden_dim).to(device) + layer = nn.Linear(prev_dim, hidden_dim) if init_func is not None: init_func(layer.weight) layers.append(layer) layers.append(activation()) prev_dim = hidden_dim - final_layer = nn.Linear(prev_dim, output_dim).to(device) + final_layer = nn.Linear(prev_dim, output_dim) if init_func is not None: init_func(final_layer.weight) layers.append(final_layer) - self.model = nn.Sequential(*layers).to(device) + self.model = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.model(x.to(device)) + return self.model(x) class REncoder(nn.Module): @@ -101,7 +99,7 @@ def __init__( hidden_dims=hidden_dims, activation=activation, init_func=init_func, - ).to(device) + ) def forward(self, inputs: torch.Tensor) -> torch.Tensor: r"""Forward pass for representation encoder. @@ -112,7 +110,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Encoded representations """ - return self.mlp(inputs.to(device)) + return self.mlp(inputs) class ZEncoder(nn.Module): @@ -144,14 +142,14 @@ def __init__( hidden_dims=hidden_dims, activation=activation, init_func=init_func, - ).to(device) + ) self.logvar_net = MLP( input_dim=input_dim, output_dim=output_dim, hidden_dims=hidden_dims, activation=activation, init_func=init_func, - ).to(device) + ) def forward(self, inputs: torch.Tensor) -> torch.Tensor: r"""Forward pass for latent encoder. @@ -164,7 +162,6 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: - Mean of the latent Gaussian distribution. - Log variance of the latent Gaussian distribution. """ - inputs = inputs.to(device) return self.mean_net(inputs), self.logvar_net(inputs) @@ -197,7 +194,7 @@ def __init__( hidden_dims=hidden_dims, activation=activation, init_func=init_func, - ).to(device) + ) def forward(self, x_pred: torch.Tensor, z: torch.Tensor) -> torch.Tensor: r"""Forward pass for decoder. @@ -212,14 +209,17 @@ def forward(self, x_pred: torch.Tensor, z: torch.Tensor) -> torch.Tensor: torch.Tensor: Predicted target values of shape (n x z_dim), representing # of data points by z_dim. """ - z = z.to(device) - z_expanded = z.unsqueeze(0).expand(x_pred.size(0), -1).to(device) - x_pred = x_pred.to(device) + if z.dim() == 1: + z = z.unsqueeze(0) + if z.dim() == 3: + z = z.squeeze(0) + z_expanded = z.expand(x_pred.size(0), -1) + x_pred = x_pred xz = torch.cat([x_pred, z_expanded], dim=-1) return self.mlp(xz) -class NeuralProcessModel(Model): +class NeuralProcessModel(Model, GP): def __init__( self, train_X: torch.Tensor, @@ -262,25 +262,28 @@ def __init__( forward pass. """ super().__init__() + self.device = train_X.device + + # self._validate_tensor_args(X=train_X, Y=train_Y) self.r_encoder = REncoder( x_dim + y_dim, r_dim, r_hidden_dims, activation=activation, init_func=init_func, - ).to(device) + ).to(self.device) self.z_encoder = ZEncoder( r_dim, z_dim, z_hidden_dims, activation=activation, init_func=init_func - ).to(device) + ).to(self.device) self.decoder = Decoder( x_dim + z_dim, y_dim, decoder_hidden_dims, activation=activation, init_func=init_func, - ).to(device) - self.train_X = train_X.to(device) - self.train_Y = train_Y.to(device) + ).to(self.device) + self.train_X = train_X.to(self.device) + self.train_Y = train_Y.to(self.device) self.n_context = n_context self.z_dim = z_dim self.z_mu_all = None @@ -288,9 +291,9 @@ def __init__( self.z_mu_context = None self.z_logvar_context = None if likelihood is None: - self.likelihood = GaussianLikelihood().to(device) + self.likelihood = GaussianLikelihood().to(self.device) else: - self.likelihood = likelihood.to(device) + self.likelihood = likelihood.to(self.device) self.input_transform = input_transform def data_to_z_params( @@ -310,11 +313,11 @@ def data_to_z_params( - x_t: Target input data. - y_t: Target target data. """ - x = x.to(device) - y = y.to(device) - xy = torch.cat([x, y], dim=-1).to(device).to(device) + x = x.to(self.device) + y = y.to(self.device) + xy = torch.cat([x, y], dim=-1).to(self.device).to(self.device) rs = self.r_encoder(xy) - r_agg = rs.mean(dim=r_dim).to(device) + r_agg = rs.mean(dim=r_dim).to(self.device) return self.z_encoder(r_agg) def sample_z( @@ -344,11 +347,11 @@ def sample_z( shape = [n, self.z_dim] if n == 1: shape = shape[1:] - eps = torch.autograd.Variable(logvar.data.new(*shape).normal_()).to(device) + eps = torch.autograd.Variable(logvar.data.new(*shape).normal_()).to(self.device) std = min_std + scaler * torch.sigmoid(logvar) - std = std.to(device) - mu = mu.to(device) + std = std.to(self.device) + mu = mu.to(self.device) return mu + std * eps def KLD_gaussian(self, min_std: float = 0.01, scaler: float = 0.5) -> torch.Tensor: @@ -365,10 +368,10 @@ def KLD_gaussian(self, min_std: float = 0.01, scaler: float = 0.5) -> torch.Tens if min_std <= 0 or scaler <= 0: raise ValueError() - std_q = min_std + scaler * torch.sigmoid(self.z_logvar_all).to(device) - std_p = min_std + scaler * torch.sigmoid(self.z_logvar_context).to(device) - p = torch.distributions.Normal(self.z_mu_context.to(device), std_p) - q = torch.distributions.Normal(self.z_mu_all.to(device), std_q) + std_q = min_std + scaler * torch.sigmoid(self.z_logvar_all).to(self.device) + std_p = min_std + scaler * torch.sigmoid(self.z_logvar_context).to(self.device) + p = torch.distributions.Normal(self.z_mu_context.to(self.device), std_p) + q = torch.distributions.Normal(self.z_mu_all.to(self.device), std_q) return torch.distributions.kl_divergence(p, q).sum() def posterior( @@ -378,7 +381,7 @@ def posterior( observation_noise: bool = False, posterior_transform: PosteriorTransform | None = None, ) -> GPyTorchPosterior: - r"""Computes the model's posterior distribution for given input tensors. + r"""Computes the model's posterior for given input tensors. Args: X: Input Tensor @@ -391,20 +394,19 @@ def posterior( defaults to None. Returns: - GPyTorchPosterior: The posterior distribution object - utilizing MultivariateNormal. + GPyTorchPosterior: The posterior utilizing MultivariateNormal. """ X = self.transform_inputs(X) - X = X.to(device) + X = X.to(self.device) mean = self.decoder( - X.to(device), self.sample_z(self.z_mu_all, self.z_logvar_all) + X.to(self.device), self.sample_z(self.z_mu_all, self.z_logvar_all) ) z_var = torch.exp(self.z_logvar_all) - covariance = torch.eye(X.size(0)).to(device) * z_var.mean() + covariance = torch.eye(X.size(0)).to(self.device) * z_var.mean() if observation_noise: covariance = covariance + self.likelihood.noise * torch.eye( covariance.size(0) - ).to(device) + ).to(self.device) mvn = MultivariateNormal(mean, covariance) posterior = GPyTorchPosterior(mvn) if posterior_transform is not None: @@ -425,7 +427,7 @@ def transform_inputs( Returns: torch.Tensor: A tensor of transformed inputs """ - X = X.to(device) + X = X.to(self.device) if input_transform is not None: input_transform.to(X) return input_transform(X) @@ -454,10 +456,10 @@ def forward( x_c, y_c, x_t, y_t = self.random_split_context_target( train_X, train_Y, self.n_context, axis=axis ) - x_t = x_t.to(device) - x_c = x_c.to(device) - y_c = y_c.to(device) - y_t = y_t.to(device) + x_t = x_t.to(self.device) + x_c = x_c.to(self.device) + y_c = y_c.to(self.device) + y_t = y_t.to(self.device) self.z_mu_all, self.z_logvar_all = self.data_to_z_params( self.train_X, self.train_Y ) @@ -486,9 +488,9 @@ def random_split_context_target( self.n_context = n_context mask = torch.randperm(x.shape[axis])[:n_context] splitter = torch.zeros(x.shape[axis], dtype=torch.bool) - x_c = x[mask].to(device) - y_c = y[mask].to(device) + x_c = x[mask].to(self.device) + y_c = y[mask].to(self.device) splitter[mask] = True - x_t = x[~splitter].to(device) - y_t = y[~splitter].to(device) + x_t = x[~splitter].to(self.device) + y_t = y[~splitter].to(self.device) return x_c, y_c, x_t, y_t From 0ceb9ca566011f1ee145b5347bcf099baf6d2634 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Tue, 29 Apr 2025 06:34:49 -0700 Subject: [PATCH 26/35] April Updates --- .../acquisition/latent_information_gain.py | 98 +++++++++---------- 1 file changed, 48 insertions(+), 50 deletions(-) diff --git a/botorch_community/acquisition/latent_information_gain.py b/botorch_community/acquisition/latent_information_gain.py index 6c298f121f..421b9ac64f 100644 --- a/botorch_community/acquisition/latent_information_gain.py +++ b/botorch_community/acquisition/latent_information_gain.py @@ -27,8 +27,6 @@ from torch import Tensor # reference: https://arxiv.org/abs/2106.02770 -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - class LatentInformationGain(AcquisitionFunction): def __init__( @@ -56,58 +54,58 @@ def __init__( self.scaler = scaler def forward(self, candidate_x: Tensor) -> Tensor: - """ - Conduct the Latent Information Gain acquisition function using the model's - posterior. - - Args: - candidate_x: Candidate input points, as a Tensor. Ideally in the shape - (N, q, D). - - Returns: - torch.Tensor: The LIG scores of computed KLDs, in the shape (N, q). - """ + device = candidate_x.device candidate_x = candidate_x.to(device) - if candidate_x.dim() == 2: - candidate_x = candidate_x.unsqueeze(0) # Ensure (N, q, D) format N, q, D = candidate_x.shape - - kl = torch.zeros(N, q, device=device) - + kl = torch.zeros(N, device=device) if isinstance(self.model, NeuralProcessModel): - x_c, y_c, x_t, y_t = self.model.random_split_context_target( - self.model.train_X, - self.model.train_Y, - self.model.n_context + x_c, y_c, _, _ = self.model.random_split_context_target( + self.model.train_X, self.model.train_Y, self.model.n_context ) z_mu_context, z_logvar_context = self.model.data_to_z_params(x_c, y_c) - for _ in range(self.num_samples): - # Taking Samples/Predictions - samples = self.model.sample_z(z_mu_context, z_logvar_context) - y_pred = self.model.decoder(candidate_x.view(-1, D), samples) - # Combining the data - combined_x = torch.cat([x_c, candidate_x.view(-1, D)], dim=0).to(device) - combined_y = torch.cat([y_c, y_pred], dim=0).to(device) - # Computing posterior variables - z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params( - combined_x, combined_y - ) - std_prior = self.min_std + self.scaler * torch.sigmoid(z_logvar_context) - std_posterior = self.min_std + self.scaler * torch.sigmoid( - z_logvar_posterior - ) - p = torch.distributions.Normal(z_mu_posterior, std_posterior) - q = torch.distributions.Normal(z_mu_context, std_prior) - kl_divergence = torch.distributions.kl_divergence(p, q).sum(dim=-1) - kl += kl_divergence - else: - for _ in range(self.num_samples): - posterior_prior = self.model.posterior(self.model.train_X) - posterior_candidate = self.model.posterior(candidate_x.view(-1, D)) - kl_divergence = torch.distributions.kl_divergence( - posterior_candidate.mvn, posterior_prior.mvn - ).sum(dim=-1) - kl += kl_divergence + for i in range(N): + x_i = candidate_x[i] + kl_i = 0.0 - return kl / self.num_samples + for _ in range(self.num_samples): + sample_z = self.model.sample_z(z_mu_context, z_logvar_context) + if sample_z.dim() == 1: + sample_z = sample_z.unsqueeze(0) + + y_pred = self.model.decoder(x_i, sample_z) + + combined_x = torch.cat([x_c, x_i], dim=0) + combined_y = torch.cat([y_c, y_pred], dim=0) + + z_mu_post, z_logvar_post = self.model.data_to_z_params( + combined_x, combined_y + ) + + std_prior = self.min_std + self.scaler * torch.sigmoid( + z_logvar_context + ) + std_post = self.min_std + self.scaler * torch.sigmoid(z_logvar_post) + + p = torch.distributions.Normal(z_mu_post, std_post) + q = torch.distributions.Normal(z_mu_context, std_prior) + kl_sample = torch.distributions.kl_divergence(p, q).sum() + + kl_i += kl_sample + + kl[i] = kl_i / self.num_samples + + else: + for i in range(N): + x_i = candidate_x[i] + kl_i = 0.0 + for _ in range(self.num_samples): + posterior_prior = self.model.posterior(self.model.train_X) + posterior_candidate = self.model.posterior(x_i) + + kl_i += torch.distributions.kl_divergence( + posterior_candidate.mvn, posterior_prior.mvn + ).sum() + + kl[i] = kl_i / self.num_samples + return kl From 9949ff980d1c4b713c93a33164f7242903c13cf6 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Tue, 29 Apr 2025 06:36:24 -0700 Subject: [PATCH 27/35] April Tests --- test_community/models/test_np_regression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_community/models/test_np_regression.py b/test_community/models/test_np_regression.py index 6cf1e26cc8..4f5707e37b 100644 --- a/test_community/models/test_np_regression.py +++ b/test_community/models/test_np_regression.py @@ -27,7 +27,7 @@ def initialize(self): self.y_dim, self.r_dim, self.z_dim, - self.n_context + self.n_context, ) def test_r_encoder(self): From 2c9c95843dfbc79f65ee7b08db0a069c971ba5be Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Tue, 29 Apr 2025 06:36:44 -0700 Subject: [PATCH 28/35] April Tests --- .../test_latent_information_gain.py | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/test_community/acquisition/test_latent_information_gain.py b/test_community/acquisition/test_latent_information_gain.py index cfacf6cb37..5a76c00163 100644 --- a/test_community/acquisition/test_latent_information_gain.py +++ b/test_community/acquisition/test_latent_information_gain.py @@ -1,6 +1,7 @@ import unittest import torch +from botorch.optim.optimize import optimize_acqf from botorch_community.acquisition.latent_information_gain import LatentInformationGain from botorch_community.models.np_regression import NeuralProcessModel @@ -32,16 +33,22 @@ def test_initialization(self): self.assertEqual(self.acquisition_function.num_samples, 10) self.assertEqual(self.acquisition_function.model, self.model) - def test_acquisition_shape(self): - self.model(self.model.train_X, self.model.train_Y) - lig_score = self.acquisition_function.forward(candidate_x=self.candidate_x) - self.assertTrue(torch.is_tensor(lig_score)) - self.assertEqual(lig_score.shape, (1, 5)) - - def test_acquisition_kl(self): - self.model(self.model.train_X, self.model.train_Y) - lig_score = self.acquisition_function.forward(candidate_x=self.candidate_x) - self.assertGreaterEqual(lig_score.mean().item(), 0) + def test_acqf(self): + bounds = torch.tensor([[0.0] * self.x_dim, [1.0] * self.x_dim]) + q = 3 + raw_samples = 8 + num_restarts = 2 + + candidate = optimize_acqf( + acq_function=self.acquisition_function, + bounds=bounds, + q=q, + raw_samples=raw_samples, + num_restarts=num_restarts, + ) + self.assertTrue(isinstance(candidate, tuple)) + self.assertEqual(candidate[0].shape, (q, self.x_dim)) + self.assertTrue(torch.all(candidate[1] >= 0)) if __name__ == "__main__": From 918a4b4e5e7d57b6133e72ab2d6a614885fcc06b Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Fri, 16 May 2025 17:53:27 -0700 Subject: [PATCH 29/35] 5/16 Updates --- .../acquisition/latent_information_gain.py | 39 ++++++++++++++----- botorch_community/models/np_regression.py | 4 +- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/botorch_community/acquisition/latent_information_gain.py b/botorch_community/acquisition/latent_information_gain.py index 421b9ac64f..9b893df41b 100644 --- a/botorch_community/acquisition/latent_information_gain.py +++ b/botorch_community/acquisition/latent_information_gain.py @@ -54,10 +54,35 @@ def __init__( self.scaler = scaler def forward(self, candidate_x: Tensor) -> Tensor: + """ + Conduct the Latent Information Gain acquisition function for the inputs. + + Args: + candidate_x: Candidate input points, as a Tensor. Ideally in the shape + (N, q, D). + + Returns: + torch.Tensor: The LIG scores of computed KLDs, in the shape (N, q). + """ device = candidate_x.device candidate_x = candidate_x.to(device) N, q, D = candidate_x.shape - kl = torch.zeros(N, device=device) + kl = torch.zeros(N, device=device, dtype=torch.float32) + def normal_dist(mu, logvar, min_std, scaler): + r"""Helper function for creating the normal distributions. + + Args: + mu: Tensor representing the Gaussian distribution mean. + logvar: Tensor representing the log variance of the + Gaussian distribution. + min_std: Float representing the minimum standardized std. + scaler: Float scaling the std. + + Returns: + torch.distributions.Normal: The normal distribution. + """ + std = min_std + scaler * torch.sigmoid(logvar) + return torch.distributions.Normal(mu, std) if isinstance(self.model, NeuralProcessModel): x_c, y_c, _, _ = self.model.random_split_context_target( self.model.train_X, self.model.train_Y, self.model.n_context @@ -82,15 +107,11 @@ def forward(self, candidate_x: Tensor) -> Tensor: combined_x, combined_y ) - std_prior = self.min_std + self.scaler * torch.sigmoid( - z_logvar_context + p = normal_dist(z_mu_post, z_logvar_post, self.min_std, self.scaler) + q = normal_dist( + z_mu_context, z_logvar_context, self.min_std, self.scaler ) - std_post = self.min_std + self.scaler * torch.sigmoid(z_logvar_post) - - p = torch.distributions.Normal(z_mu_post, std_post) - q = torch.distributions.Normal(z_mu_context, std_prior) kl_sample = torch.distributions.kl_divergence(p, q).sum() - kl_i += kl_sample kl[i] = kl_i / self.num_samples @@ -108,4 +129,4 @@ def forward(self, candidate_x: Tensor) -> Tensor: ).sum() kl[i] = kl_i / self.num_samples - return kl + return kl \ No newline at end of file diff --git a/botorch_community/models/np_regression.py b/botorch_community/models/np_regression.py index 6100be5bd0..0371a76c8f 100644 --- a/botorch_community/models/np_regression.py +++ b/botorch_community/models/np_regression.py @@ -282,8 +282,8 @@ def __init__( activation=activation, init_func=init_func, ).to(self.device) - self.train_X = train_X.to(self.device) - self.train_Y = train_Y.to(self.device) + self.train_X = train_X + self.train_Y = train_Y self.n_context = n_context self.z_dim = z_dim self.z_mu_all = None From fe75f4327dede96def2e56b4be98a65d648c48e1 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Fri, 16 May 2025 18:32:05 -0700 Subject: [PATCH 30/35] Recent Fixes --- .../acquisition/latent_information_gain.py | 37 ++++++------------- botorch_community/models/np_regression.py | 13 +++---- 2 files changed, 16 insertions(+), 34 deletions(-) diff --git a/botorch_community/acquisition/latent_information_gain.py b/botorch_community/acquisition/latent_information_gain.py index 9b893df41b..5a2d8c38a8 100644 --- a/botorch_community/acquisition/latent_information_gain.py +++ b/botorch_community/acquisition/latent_information_gain.py @@ -68,33 +68,23 @@ def forward(self, candidate_x: Tensor) -> Tensor: candidate_x = candidate_x.to(device) N, q, D = candidate_x.shape kl = torch.zeros(N, device=device, dtype=torch.float32) - def normal_dist(mu, logvar, min_std, scaler): - r"""Helper function for creating the normal distributions. - - Args: - mu: Tensor representing the Gaussian distribution mean. - logvar: Tensor representing the log variance of the - Gaussian distribution. - min_std: Float representing the minimum standardized std. - scaler: Float scaling the std. - - Returns: - torch.distributions.Normal: The normal distribution. - """ - std = min_std + scaler * torch.sigmoid(logvar) - return torch.distributions.Normal(mu, std) + if isinstance(self.model, NeuralProcessModel): x_c, y_c, _, _ = self.model.random_split_context_target( self.model.train_X, self.model.train_Y, self.model.n_context ) - z_mu_context, z_logvar_context = self.model.data_to_z_params(x_c, y_c) + self.model.z_mu_context, self.model.z_logvar_context = ( + self.model.data_to_z_params(x_c, y_c) + ) for i in range(N): x_i = candidate_x[i] kl_i = 0.0 for _ in range(self.num_samples): - sample_z = self.model.sample_z(z_mu_context, z_logvar_context) + sample_z = self.model.sample_z( + self.model.z_mu_context, self.model.z_logvar_context + ) if sample_z.dim() == 1: sample_z = sample_z.unsqueeze(0) @@ -103,15 +93,10 @@ def normal_dist(mu, logvar, min_std, scaler): combined_x = torch.cat([x_c, x_i], dim=0) combined_y = torch.cat([y_c, y_pred], dim=0) - z_mu_post, z_logvar_post = self.model.data_to_z_params( - combined_x, combined_y - ) - - p = normal_dist(z_mu_post, z_logvar_post, self.min_std, self.scaler) - q = normal_dist( - z_mu_context, z_logvar_context, self.min_std, self.scaler + self.model.z_mu_all, self.model.z_logvar_all = ( + self.model.data_to_z_params(combined_x, combined_y) ) - kl_sample = torch.distributions.kl_divergence(p, q).sum() + kl_sample = self.model.KLD_gaussian(self.min_std, self.scaler) kl_i += kl_sample kl[i] = kl_i / self.num_samples @@ -129,4 +114,4 @@ def normal_dist(mu, logvar, min_std, scaler): ).sum() kl[i] = kl_i / self.num_samples - return kl \ No newline at end of file + return kl diff --git a/botorch_community/models/np_regression.py b/botorch_community/models/np_regression.py index 0371a76c8f..3b1af8d9b2 100644 --- a/botorch_community/models/np_regression.py +++ b/botorch_community/models/np_regression.py @@ -264,24 +264,23 @@ def __init__( super().__init__() self.device = train_X.device - # self._validate_tensor_args(X=train_X, Y=train_Y) self.r_encoder = REncoder( x_dim + y_dim, r_dim, r_hidden_dims, activation=activation, init_func=init_func, - ).to(self.device) + ) self.z_encoder = ZEncoder( r_dim, z_dim, z_hidden_dims, activation=activation, init_func=init_func - ).to(self.device) + ) self.decoder = Decoder( x_dim + z_dim, y_dim, decoder_hidden_dims, activation=activation, init_func=init_func, - ).to(self.device) + ) self.train_X = train_X self.train_Y = train_Y self.n_context = n_context @@ -290,11 +289,9 @@ def __init__( self.z_logvar_all = None self.z_mu_context = None self.z_logvar_context = None - if likelihood is None: - self.likelihood = GaussianLikelihood().to(self.device) - else: - self.likelihood = likelihood.to(self.device) + self.likelihood = likelihood if likelihood is not None else GaussianLikelihood() self.input_transform = input_transform + self.to(device=self.device) def data_to_z_params( self, x: torch.Tensor, y: torch.Tensor, r_dim: int = 0 From badba2092e20ee256b05ca744016e978c3e577cb Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Fri, 16 May 2025 18:42:05 -0700 Subject: [PATCH 31/35] Fixing merge conflicts for acquisition/models docs --- docs/acquisition.md | 385 ++++++++++++++++++++------------------------ docs/models.md | 377 +++++++++++++++++++++---------------------- 2 files changed, 365 insertions(+), 397 deletions(-) diff --git a/docs/acquisition.md b/docs/acquisition.md index dc126a2388..f88aa2bbe4 100644 --- a/docs/acquisition.md +++ b/docs/acquisition.md @@ -1,207 +1,178 @@ ---- -id: acquisition -title: Acquisition Functions ---- - -Acquisition functions are heuristics employed to evaluate the usefulness of one -of more design points for achieving the objective of maximizing the underlying -black box function. - -BoTorch supports both analytic as well as (quasi-) Monte-Carlo based acquisition -functions. It provides a generic -[`AcquisitionFunction`](../api/acquisition.html#acquisitionfunction) API that -abstracts away from the particular type, so that optimization can be performed -on the same objects. - - -## Monte Carlo Acquisition Functions - -Many common acquisition functions can be expressed as the expectation of some -real-valued function of the model output(s) at the design point(s): - -$$ -\alpha(X) = \mathbb{E}\bigl[ a(\xi) \mid - \xi \sim \mathbb{P}(f(X) \mid \mathcal{D}) \bigr] -$$ - -where $X = (x_1, \dotsc, x_q)$, and $\mathbb{P}(f(X) \mid \mathcal{D})$ is the -posterior distribution of the function $f$ at $X$ given the data $\mathcal{D}$ -observed so far. - -Evaluating the acquisition function thus requires evaluating an integral over -the posterior distribution. In most cases, this is analytically intractable. In -particular, analytic expressions generally do not exist for batch acquisition -functions that consider multiple design points jointly (i.e. $q > 1$). - -An alternative is to use Monte-Carlo (MC) sampling to approximate the integrals. -An MC approximation of $\alpha$ at $X$ using $N$ MC samples is - -$$ \alpha(X) \approx \frac{1}{N} \sum_{i=1}^N a(\xi_{i}) $$ - -where $\xi_i \sim \mathbb{P}(f(X) \mid \mathcal{D})$. - -For instance, for q-Expected Improvement (qEI), we have: - -$$ -\text{qEI}(X) \approx \frac{1}{N} \sum_{i=1}^N \max_{j=1,..., q} -\bigl\\{ \max(\xi_{ij} - f^\*, 0) \bigr\\}, -\qquad \xi_{i} \sim \mathbb{P}(f(X) \mid \mathcal{D}) -$$ - -where $f^\*$ is the best function value observed so far (assuming noiseless -observations). Using the reparameterization trick ([^KingmaWelling2014], -[^Rezende2014]), - -$$ -\text{qEI}(X) \approx \frac{1}{N} \sum_{i=1}^N \max_{j=1,..., q} -\bigl\\{ \max\bigl( \mu(X)\_j + (L(X) \epsilon_i)\_j - f^\*, 0 \bigr) \bigr\\}, -\qquad \epsilon_{i} \sim \mathcal{N}(0, I) -$$ - -where $\mu(X)$ is the posterior mean of $f$ at $X$, and $L(X)L(X)^T = \Sigma(X)$ -is a root decomposition of the posterior covariance matrix. - -All MC-based acquisition functions in BoTorch are derived from -[`MCAcquisitionFunction`](../api/acquisition.html#mcacquisitionfunction). - -Acquisition functions expect input tensors $X$ of shape -$\textit{batch_shape} \times q \times d$, where $d$ is the dimension of the -feature space, $q$ is the number of points considered jointly, and -$\textit{batch_shape}$ is the batch-shape of the input tensor. The output -$\alpha(X)$ will have shape $\textit{batch_shape}$, with each element -corresponding to the respective $q \times d$ batch tensor in the input $X$. -Note that for analytic acquisition functions, it must be that $q=1$. - -### MC, q-MC, and Fixed Base Samples - -BoTorch relies on the re-parameterization trick and (quasi)-Monte-Carlo sampling -for optimization and estimation of the batch acquisition functions [^Wilson2017]. -The results below show the reduced variance when estimating an expected -improvement (EI) acquisition function using base samples obtained via quasi-MC -sampling versus standard MC sampling. - -![MC_qMC](assets/EI_MC_qMC.png) - -In the plots above, the base samples used to estimate each point are resampled. -As discussed in the [Overview](./overview), a single set of base samples can be -used for optimization when the re-parameterization trick is employed. What are the -trade-offs between using a fixed set of base samples versus re-sampling on every -MC evaluation of the acquisition function? Below, we show that fixing base samples -produces functions that are potentially much easier to optimize, without resorting to -stochastic optimization methods. - -![resampling_fixed](assets/EI_resampling_fixed.png) - -If the base samples are fixed, the problem of optimizing the acquisition function -is deterministic, allowing for conventional quasi-second order methods such as -L-BFGS or sequential least-squares programming (SLSQP) to be used. These have -faster convergence rates than first-order methods and can speed up acquisition -function optimization significantly. - -One concern is that the approximated acquisition function is *biased* for any -fixed set of base samples, which may adversely affect the solution. However, we -find that in practice, both the optimal value and the optimal solution of these -biased problems for standard acquisition functions converge quite rapidly to -their true counterparts as more samples are used. Note that for evaluation of -the acquisition function we integrate over a $qo$-dimensional space (where -$q$ is the number of points in the q-batch and $o$ is the number of outputs -included in the objective). Therefore, the MC integration problem can be quite -low-dimensional even for models on high-dimensional feature spaces (large $d$). -Because using additional samples is relatively cheap computationally, -we default to 500 base samples in the MC acquisition functions. - -On the other hand, when re-sampling is used in conjunction with a stochastic -optimization algorithm, the kind of bias noted above is no longer a concern. -The trade-off here is that the optimization may be less effective, as discussed -above. - - -## Analytic Acquisition Functions - -BoTorch also provides implementations of analytic acquisition functions that -do not depend on MC sampling. These acquisition functions are subclasses of -[`AnalyticAcquisitionFunction`](../api/acquisition.html#analyticacquisitionfunction) -and only exist for the case of a single candidate point ($q = 1$). These -include classical acquisition functions such as Expected Improvement (EI), -Upper Confidence Bound (UCB), and Probability of Improvement (PI). An example -comparing [`ExpectedImprovement`](../api/acquisition.html#expectedimprovement), -the analytic version of EI, to it's MC counterpart -[`qExpectedImprovement`](../api/acquisition.html#qexpectedimprovement) -can be found in -[this tutorial](../tutorials/compare_mc_analytic_acquisition). - -Analytic acquisition functions allow for an explicit expression in terms of the -summary statistics of the posterior distribution at the evaluated point(s). -A popular acquisition function is Expected Improvement of a single point -for a Gaussian posterior, given by - -$$ \text{EI}(x) = \mathbb{E}\bigl[ -\max(y - f^\*, 0) \mid y\sim \mathcal{N}(\mu(x), \sigma^2(x)) -\bigr] $$ - -where $\mu(x)$ and $\sigma(x)$ are the posterior mean and variance of $f$ at the -point $x$, and $f^\*$ is again the best function value observed so far (assuming -noiseless observations). It can be shown that - -$$ \text{EI}(x) = \sigma(x) \bigl( z \Phi(z) + \varphi(z) \bigr)$$ - -where $z = \frac{\mu(x) - f_{\max}}{\sigma(x)}$ and $\Phi$ and $\varphi$ are -the cdf and pdf of the standard normal distribution, respectively. - -With some additional work, it is also possible to express the gradient of -the Expected Improvement with respect to the design $x$. Classic Bayesian -Optimization software will implement this gradient function explicitly, so that -it can be used for numerically optimizing the acquisition function. - -BoTorch, in contrast, harnesses PyTorch's automatic differentiation feature -("autograd") in order to obtain gradients of acquisition functions. This makes -implementing new acquisition functions much less cumbersome, as it does not -require to analytically derive gradients. All that is required is that the -operations performed in the acquisition function computation allow for the -back-propagation of gradient information through the posterior and the model. - - -[^KingmaWelling2014]: D. P. Kingma, M. Welling. Auto-Encoding Variational Bayes. -ICLR, 2013. - -[^Rezende2014]: D. J. Rezende, S. Mohamed, D. Wierstra. Stochastic -Backpropagation and Approximate Inference in Deep Generative Models. ICML, 2014. - -[^Wilson2017]: J. T. Wilson, R. Moriconi, F. Hutter, M. P. Deisenroth. -The Reparameterization Trick for Acquisition Functions. NeurIPS Workshop on -Bayesian Optimization, 2017. - -## Latent Information Gain - -In the high-dimensional spatiotemporal domain, Expected Information Gain becomes -less informative for useful observations, and it can be difficult to calculate -its parameters. To overcome these limitations, we propose a novel acquisition -function by computing the expected information gain in the latent space rather -than the observational space. To design this acquisition function, -we prove the equivalence between the expected information gain -in the observational space and the expected KL divergence in the -latent processes w.r.t. a candidate parameter 𝜃, as illustrated by the -following proposition. - -Proposition 1. The expected information gain (EIG) for Neural -Process is equivalent to the KL divergence between the prior and -posterior in the latent process, that is - -$$ \text{EIG}(\hat{x}_{1:T}, \theta) := \mathbb{E} \left[ H(\hat{x}_{1:T}) - -H(\hat{x}_{1:T} \mid z_{1:T}, \theta) \right] -= \mathbb{E}_{p(\hat{x}_{1:T} \mid \theta)} -\text{KL} \left( p(z_{1:T} \mid \hat{x}_{1:T}, \theta) \,\|\, p(z_{1:T}) \right) -$$ - - -Inspired by this fact, we propose a novel acquisition function computing the -expected KL divergence in the latent processes and name it LIG. Specifically, -the trained NP model produces a variational posterior given the current dataset. -For every parameter $$\theta$$ remained in the search space, we can predict -$$\hat{x}_{1:T}$$ with the decoder. We use $$\hat{x}_{1:T}$$ and $$\theta$$ -as input to the encoder to re-evaluate the posterior. LIG computes the -distributional difference with respect to the latent process. -[Wu2023arxiv]: - Wu, D., Niu, R., Chinazzi, M., Vespignani, A., Ma, Y.-A., & Yu, R. (2023). - Deep Bayesian Active Learning for Accelerating Stochastic Simulation. - arXiv preprint arXiv:2106.02770. Retrieved from https://arxiv.org/abs/2106.02770 +--- +id: acquisition +title: Acquisition Functions +--- + +Acquisition functions are heuristics employed to evaluate the usefulness of one +of more design points for achieving the objective of maximizing the underlying +black box function. + +BoTorch supports both analytic as well as (quasi-) Monte-Carlo based acquisition +functions. It provides a generic +[`AcquisitionFunction`](https://botorch.readthedocs.io/en/latest/acquisition.html#botorch.acquisition.acquisition.AcquisitionFunction) API that +abstracts away from the particular type, so that optimization can be performed +on the same objects. + + +## Monte Carlo Acquisition Functions + +Many common acquisition functions can be expressed as the expectation of some +real-valued function of the model output(s) at the design point(s): + +$$ +\alpha(X) = \mathbb{E}\bigl[ a(\xi) \mid + \xi \sim \mathbb{P}(f(X) \mid \mathcal{D}) \bigr] +$$ + +where $X = (x_1, \dotsc, x_q)$, and $\mathbb{P}(f(X) \mid \mathcal{D})$ is the +posterior distribution of the function $f$ at $X$ given the data $\mathcal{D}$ +observed so far. + +Evaluating the acquisition function thus requires evaluating an integral over +the posterior distribution. In most cases, this is analytically intractable. In +particular, analytic expressions generally do not exist for batch acquisition +functions that consider multiple design points jointly (i.e. $q > 1$). + +An alternative is to use Monte-Carlo (MC) sampling to approximate the integrals. +An MC approximation of $\alpha$ at $X$ using $N$ MC samples is + +$$ +\alpha(X) \approx \frac{1}{N} \sum_{i=1}^N a(\xi_{i}) +$$ + +where $\xi_i \sim \mathbb{P}(f(X) \mid \mathcal{D})$. + +For instance, for q-Expected Improvement (qEI), we have: + +$$ +\text{qEI}(X) \approx \frac{1}{N} \sum_{i=1}^N \max_{j=1,..., q} +\bigl\{ \max(\xi_{ij} - f^*, 0) \bigr\}, +\qquad \xi_{i} \sim \mathbb{P}(f(X) \mid \mathcal{D}) +$$ + +where $f^*$ is the best function value observed so far (assuming noiseless +observations). Using the reparameterization trick ([^KingmaWelling2014], +[^Rezende2014]), + +$$ +\text{qEI}(X) \approx \frac{1}{N} \sum_{i=1}^N \max_{j=1,..., q} +\bigl\{ \max\bigl( \mu(X)\_j + (L(X) \epsilon_i)\_j - f^*, 0 \bigr) \bigr\}, +\qquad \epsilon_{i} \sim \mathcal{N}(0, I) +$$ + +where $\mu(X)$ is the posterior mean of $f$ at $X$, and $L(X)L(X)^T = \Sigma(X)$ +is a root decomposition of the posterior covariance matrix. + +All MC-based acquisition functions in BoTorch are derived from +[`MCAcquisitionFunction`](https://botorch.readthedocs.io/en/latest/acquisition.html#botorch.acquisition.monte_carlo.MCAcquisitionFunction). + +Acquisition functions expect input tensors $X$ of shape +$\textit{batch\_shape} \times q \times d$, where $d$ is the dimension of the +feature space, $q$ is the number of points considered jointly, and +$\textit{batch\_shape}$ is the batch-shape of the input tensor. The output +$\alpha(X)$ will have shape $\textit{batch\_shape}$, with each element +corresponding to the respective $q \times d$ batch tensor in the input $X$. +Note that for analytic acquisition functions, it must be that $q=1$. + +### MC, q-MC, and Fixed Base Samples + +BoTorch relies on the re-parameterization trick and (quasi)-Monte-Carlo sampling +for optimization and estimation of the batch acquisition functions [^Wilson2017]. +The results below show the reduced variance when estimating an expected +improvement (EI) acquisition function using base samples obtained via quasi-MC +sampling versus standard MC sampling. + +![MC_qMC](assets/EI_MC_qMC.png) + +In the plots above, the base samples used to estimate each point are resampled. +As discussed in the [Overview](./overview), a single set of base samples can be +used for optimization when the re-parameterization trick is employed. What are the +trade-offs between using a fixed set of base samples versus re-sampling on every +MC evaluation of the acquisition function? Below, we show that fixing base samples +produces functions that are potentially much easier to optimize, without resorting to +stochastic optimization methods. + +![resampling_fixed](assets/EI_resampling_fixed.png) + +If the base samples are fixed, the problem of optimizing the acquisition function +is deterministic, allowing for conventional quasi-second order methods such as +L-BFGS or sequential least-squares programming (SLSQP) to be used. These have +faster convergence rates than first-order methods and can speed up acquisition +function optimization significantly. + +One concern is that the approximated acquisition function is *biased* for any +fixed set of base samples, which may adversely affect the solution. However, we +find that in practice, both the optimal value and the optimal solution of these +biased problems for standard acquisition functions converge quite rapidly to +their true counterparts as more samples are used. Note that for evaluation of +the acquisition function we integrate over a $qo$-dimensional space (where +$q$ is the number of points in the q-batch and $o$ is the number of outputs +included in the objective). Therefore, the MC integration problem can be quite +low-dimensional even for models on high-dimensional feature spaces (large $d$). +Because using additional samples is relatively cheap computationally, +we default to 500 base samples in the MC acquisition functions. + +On the other hand, when re-sampling is used in conjunction with a stochastic +optimization algorithm, the kind of bias noted above is no longer a concern. +The trade-off here is that the optimization may be less effective, as discussed +above. + + +## Analytic Acquisition Functions + +BoTorch also provides implementations of analytic acquisition functions that +do not depend on MC sampling. These acquisition functions are subclasses of +[`AnalyticAcquisitionFunction`](https://botorch.readthedocs.io/en/latest/acquisition.html#botorch.acquisition.analytic.AnalyticAcquisitionFunction) +and only exist for the case of a single candidate point ($q = 1$). These +include classical acquisition functions such as Expected Improvement (EI), +Upper Confidence Bound (UCB), and Probability of Improvement (PI). An example +comparing [`ExpectedImprovement`](https://botorch.readthedocs.io/en/latest/acquisition.html#botorch.acquisition.analytic.ExpectedImprovement), +the analytic version of EI, to it's MC counterpart +[`qExpectedImprovement`](https://botorch.readthedocs.io/en/latest/acquisition.html#botorch.acquisition.monte_carlo.qExpectedImprovement) +can be found in +[this tutorial](tutorials/compare_mc_analytic_acquisition). + +Analytic acquisition functions allow for an explicit expression in terms of the +summary statistics of the posterior distribution at the evaluated point(s). +A popular acquisition function is Expected Improvement of a single point +for a Gaussian posterior, given by + +$$ +\text{EI}(x) = \mathbb{E}\bigl[ +\max(y - f^*, 0) \mid y\sim \mathcal{N}(\mu(x), \sigma^2(x)) +\bigr] +$$ + +where $\mu(x)$ and $\sigma(x)$ are the posterior mean and variance of $f$ at the +point $x$, and $f^*$ is again the best function value observed so far (assuming +noiseless observations). It can be shown that + +$$ +\text{EI}(x) = \sigma(x) \bigl( z \Phi(z) + \varphi(z) \bigr) +$$ + +where $z = \frac{\mu(x) - f_{\max}}{\sigma(x)}$ and $\Phi$ and $\varphi$ are +the cdf and pdf of the standard normal distribution, respectively. + +With some additional work, it is also possible to express the gradient of +the Expected Improvement with respect to the design $x$. Classic Bayesian +Optimization software will implement this gradient function explicitly, so that +it can be used for numerically optimizing the acquisition function. + +BoTorch, in contrast, harnesses PyTorch's automatic differentiation feature +("autograd") in order to obtain gradients of acquisition functions. This makes +implementing new acquisition functions much less cumbersome, as it does not +require to analytically derive gradients. All that is required is that the +operations performed in the acquisition function computation allow for the +back-propagation of gradient information through the posterior and the model. + + +[^KingmaWelling2014]: D. P. Kingma, M. Welling. Auto-Encoding Variational Bayes. +ICLR, 2013. + +[^Rezende2014]: D. J. Rezende, S. Mohamed, D. Wierstra. Stochastic +Backpropagation and Approximate Inference in Deep Generative Models. ICML, 2014. + +[^Wilson2017]: J. T. Wilson, R. Moriconi, F. Hutter, M. P. Deisenroth. +The Reparameterization Trick for Acquisition Functions. NeurIPS Workshop on +Bayesian Optimization, 2017. diff --git a/docs/models.md b/docs/models.md index ea63ea0ea8..917855decb 100644 --- a/docs/models.md +++ b/docs/models.md @@ -1,190 +1,187 @@ ---- -id: models -title: Models ---- - -Models play an essential role in Bayesian Optimization (BO). A model is used as -a surrogate function for the actual underlying black box function to be -optimized. In BoTorch, a `Model` maps a set of design points to a posterior -probability distribution of its output(s) over the design points. - -In BO, the model used is traditionally a Gaussian Process (GP), in which case -the posterior distribution is a multivariate normal. While BoTorch supports many -GP models, **BoTorch makes no assumption on the model being a GP** or the -posterior being multivariate normal. With the exception of some of the analytic -acquisition functions in the -[`botorch.acquisition.analytic`](../api/acquisition.html#analytic-acquisition-function-api) -module, BoTorch’s Monte Carlo-based acquisition functions are compatible with -any model that conforms to the `Model` interface, whether user-implemented or -provided. - -Under the hood, BoTorch models are PyTorch `Modules` that implement the -light-weight [`Model`](../api/models.html#model-apis) interface. When working -with GPs, -[`GPyTorchModel`](../api/models.html#module-botorch.models.gp_regression) -provides a base class for conveniently wrapping GPyTorch models. - -Users can extend `Model` and `GPyTorchModel` to generate their own models. For -more on implementing your own models, see -[Implementing Custom Models](#implementing-custom-models) below. - -## Terminology - -### Multi-Output and Multi-Task - -A `Model` (as in the BoTorch object) may have multiple outputs, multiple inputs, -and may exploit correlation between different inputs. BoTorch uses the following -terminology to distinguish these model types: - -- _Multi-Output Model_: a `Model` with multiple outputs. Most BoTorch `Model`s - are multi-output. -- _Multi-Task Model_: a `Model` making use of a logical grouping of - inputs/observations (as in the underlying process). For example, there could - be multiple tasks where each task has a different fidelity. In a multi-task - model, the relationship between different outputs is modeled, with a joint - model across tasks. - -Note the following: - -- A multi-task (MT) model may or may not be a multi-output model. For example, - if a multi-task model uses different tasks for modeling but only outputs - predictions for one of those tasks, it is single-output. -- Conversely, a multi-output (MO) model may or may not be a multi-task model. - For example, multi-output `Model`s that model different outputs independently - rather than building a joint model are not multi-task. -- If a model is both, we refer to it as a multi-task-multi-output (MTMO) model. - -### Noise: Homoskedastic, fixed, and heteroskedastic - -Noise can be treated in several different ways: - -- _Homoskedastic_: Noise is not provided as an input and is inferred, with a - constant variance that does not depend on `X`. Many models, such as - `SingleTaskGP`, take this approach. Use these models if you know that your - observations are noisy, but not how noisy. - -- _Fixed_: Noise is provided as an input, `train_Yvar`, and is not fit. In - “fixed noise” models like `SingleTaskGP` with noise observations, noise cannot - be predicted out-of-sample because it has not been modeled. Use these models - if you have estimates of the noise in your observations (e.g. observations may - be averages over individual samples in which case you would provide the mean - as observation and the standard error of the mean as the noise estimate), or - if you know your observations are noiseless (by passing a zero noise level). - -- _Heteroskedastic_: Noise is provided as an input and is modeled to allow for - predicting noise out-of-sample. BoTorch does not implement a model that - supports this out of the box. - -## Standard BoTorch Models - -BoTorch provides several GPyTorch models to cover most standard BO use cases: - -### Single-Task GPs - -These models use the same training data for all outputs and assume conditional -independence of the outputs given the input. If different training data is -required for each output, use a -[`ModelListGP`](../api/models.html#module-botorch.models.model_list_gp_regression) -instead. - -- [`SingleTaskGP`](../api/models.html#botorch.models.gp_regression.SingleTaskGP): - a single-task exact GP that supports both inferred and observed noise. When - noise observations are not provided, it infers a homoskedastic noise level. -- [`MixedSingleTaskGP`](../api/models.html#botorch.models.gp_regression_mixed.MixedSingleTaskGP): - a single-task exact GP that supports mixed search spaces, which combine - discrete and continuous features. -- [`SaasFullyBayesianSingleTaskGP`](../api/models.html#botorch.models.fully_bayesian.SaasFullyBayesianSingleTaskGP): - a fully Bayesian single-task GP with the SAAS prior. This model is suitable - for sample-efficient high-dimensional Bayesian optimization. - -### Model List of Single-Task GPs - -- [`ModelListGP`](../api/models.html#module-botorch.models.model_list_gp_regression): - A multi-output model in which outcomes are modeled independently, given a list - of any type of single-task GP. This model should be used when the same - training data is not used for all outputs. - -### Multi-Task GPs - -- [`MultiTaskGP`](../api/models.html#module-botorch.models.multitask): a - Hadamard multi-task, multi-output GP using an ICM kernel. Supports both known - observation noise levels and inferring a homoskedastic noise level (when noise - observations are not provided). -- [`KroneckerMultiTaskGP`](../api/models.html#botorch.models.multitask.KroneckerMultiTaskGP): - A multi-task, multi-output GP using an ICM kernel, with Kronecker structure. - Useful for multi-fidelity optimization. -- [`SaasFullyBayesianMultiTaskGP`](../api/models.html#saasfullybayesianmultitaskgp): - a fully Bayesian multi-task GP using an ICM kernel. The data kernel uses the - SAAS prior to model high-dimensional parameter spaces. - -All of the above models use RBF kernels with Automatic Relevance Discovery -(ARD), and have reasonable priors on hyperparameters that make them work well in -settings where the **input features are normalized to the unit cube** and the -**observations are standardized** (zero mean, unit variance). The lengthscale -priors scale with the input dimension, which makes them adaptable to both low -and high dimensional problems. See -[this discussion](https://github.com/pytorch/botorch/discussions/2451) for -additional context on the default hyperparameters. - -## Other useful models - -- [`ModelList`](../api/models.html#botorch.models.model.ModelList): a - multi-output model container in which outcomes are modeled independently by - individual `Model`s (as in `ModelListGP`, but the component models do not all - need to be GPyTorch models). -- [`SingleTaskMultiFidelityGP`](../api/models.html#botorch.models.gp_regression_fidelity.SingleTaskMultiFidelityGP): - A GP model for multi-fidelity optimization. For more on Multi-Fidelity BO, see - the [tutorial](../tutorials/discrete_multi_fidelity_bo). -- [`HigherOrderGP`](../api/models.html#botorch.models.higher_order_gp.HigherOrderGP): - A GP model with matrix-valued predictions, such as images or grids of images. -- [`PairwiseGP`](../api/models.html#module-botorch.models.pairwise_gp): A - probit-likelihood GP that learns via pairwise comparison data, useful for - preference learning. -- [`ApproximateGPyTorchModel`](../api/models.html#botorch.models.approximate_gp.ApproximateGPyTorchModel): - for efficient computation when data is large or responses are non-Gaussian. -- [Deterministic models](../api/models.html#module-botorch.models.deterministic), - such as - [`AffineDeterministicModel`](../api/models.html#botorch.models.deterministic.AffineDeterministicModel), - [`AffineFidelityCostModel`](../api/models.html#botorch.models.cost.AffineFidelityCostModel), - [`GenericDeterministicModel`](../api/models.html#botorch.models.deterministic.GenericDeterministicModel), - and - [`PosteriorMeanModel`](../api/models.html#botorch.models.deterministic.PosteriorMeanModel) - express known input-output relationships; they conform to the BoTorch `Model` - API, so they can easily be used in conjunction with other BoTorch models. - Deterministic models are useful for multi-objective optimization with known - objective functions and for encoding cost functions for cost-aware - acquisition. -- [`SingleTaskVariationalGP`](../api/models.html#botorch.models.approximate_gp.SingleTaskVariationalGP): - an approximate model for faster computation when you have a lot of data or - your responses are non-Gaussian. -- [`NeuralProcessModel`](../api/models.html#botorch_community.models.np_regression.NeuralProcessModel): - A NP Model utilizing a novel acquisition function computing the expected KL - Divergence in the latent processes. - -## Implementing Custom Models - -The configurability of the above models is limited (for instance, it is not -straightforward to use a different kernel). Doing so is an intentional design -decision -- we believe that having a few simple and easy-to-understand models -for basic use cases is more valuable than having a highly complex and -configurable model class whose implementation is difficult to understand. - -Instead, we advocate that users implement their own models to cover more -specialized use cases. The light-weight nature of BoTorch's Model API makes this -easy to do. See the -[Using a custom BoTorch model in Ax](../tutorials/custom_botorch_model_in_ax) -tutorial for an example. - -The BoTorch `Model` interface is light-weight and easy to extend. The only -requirement for using BoTorch's Monte-Carlo based acquisition functions is that -the model has a `posterior` method. It takes in a Tensor `X` of design points, -and returns a Posterior object describing the (joint) probability distribution -of the model output(s) over the design points in `X`. The `Posterior` object -must implement an `rsample()` method for sampling from the posterior of the -model. If you wish to use gradient-based optimization algorithms, the model -should allow back-propagating gradients through the samples to the model input. - -If you happen to implement a model that would be useful for other researchers as -well (and involves more than just swapping out the RBF kernel for a Matérn -kernel), please consider [contributing](getting_started#contributing) this model -to BoTorch. +--- +id: models +title: Models +--- + +Models play an essential role in Bayesian Optimization (BO). A model is used as +a surrogate function for the actual underlying black box function to be +optimized. In BoTorch, a `Model` maps a set of design points to a posterior +probability distribution of its output(s) over the design points. + +In BO, the model used is traditionally a Gaussian Process (GP), in which case +the posterior distribution is a multivariate normal. While BoTorch supports many +GP models, **BoTorch makes no assumption on the model being a GP** or the +posterior being multivariate normal. With the exception of some of the analytic +acquisition functions in the +[`botorch.acquisition.analytic`](https://botorch.readthedocs.io/en/latest/acquisition.html#analytic-acquisition-function-api) +module, BoTorch’s Monte Carlo-based acquisition functions are compatible with +any model that conforms to the `Model` interface, whether user-implemented or +provided. + +Under the hood, BoTorch models are PyTorch `Modules` that implement the +light-weight [`Model`](https://botorch.readthedocs.io/en/latest/models.html#model-apis) interface. When working +with GPs, +[`GPyTorchModel`](https://botorch.readthedocs.io/en/latest/models.html#module-botorch.models.gp_regression) +provides a base class for conveniently wrapping GPyTorch models. + +Users can extend `Model` and `GPyTorchModel` to generate their own models. For +more on implementing your own models, see +[Implementing Custom Models](#implementing-custom-models) below. + +## Terminology + +### Multi-Output and Multi-Task + +A `Model` (as in the BoTorch object) may have multiple outputs, multiple inputs, +and may exploit correlation between different inputs. BoTorch uses the following +terminology to distinguish these model types: + +- _Multi-Output Model_: a `Model` with multiple outputs. Most BoTorch `Model`s + are multi-output. +- _Multi-Task Model_: a `Model` making use of a logical grouping of + inputs/observations (as in the underlying process). For example, there could + be multiple tasks where each task has a different fidelity. In a multi-task + model, the relationship between different outputs is modeled, with a joint + model across tasks. + +Note the following: + +- A multi-task (MT) model may or may not be a multi-output model. For example, + if a multi-task model uses different tasks for modeling but only outputs + predictions for one of those tasks, it is single-output. +- Conversely, a multi-output (MO) model may or may not be a multi-task model. + For example, multi-output `Model`s that model different outputs independently + rather than building a joint model are not multi-task. +- If a model is both, we refer to it as a multi-task-multi-output (MTMO) model. + +### Noise: Homoskedastic, fixed, and heteroskedastic + +Noise can be treated in several different ways: + +- _Homoskedastic_: Noise is not provided as an input and is inferred, with a + constant variance that does not depend on `X`. Many models, such as + `SingleTaskGP`, take this approach. Use these models if you know that your + observations are noisy, but not how noisy. + +- _Fixed_: Noise is provided as an input, `train_Yvar`, and is not fit. In + “fixed noise” models like `SingleTaskGP` with noise observations, noise cannot + be predicted out-of-sample because it has not been modeled. Use these models + if you have estimates of the noise in your observations (e.g. observations may + be averages over individual samples in which case you would provide the mean + as observation and the standard error of the mean as the noise estimate), or + if you know your observations are noiseless (by passing a zero noise level). + +- _Heteroskedastic_: Noise is provided as an input and is modeled to allow for + predicting noise out-of-sample. BoTorch does not implement a model that + supports this out of the box. + +## Standard BoTorch Models + +BoTorch provides several GPyTorch models to cover most standard BO use cases: + +### Single-Task GPs + +These models use the same training data for all outputs and assume conditional +independence of the outputs given the input. If different training data is +required for each output, use a +[`ModelListGP`](https://botorch.readthedocs.io/en/latest/models.html#module-botorch.models.model_list_gp_regression) +instead. + +- [`SingleTaskGP`](https://botorch.readthedocs.io/en/latest/models.html#botorch.models.gp_regression.SingleTaskGP): + a single-task exact GP that supports both inferred and observed noise. When + noise observations are not provided, it infers a homoskedastic noise level. +- [`MixedSingleTaskGP`](https://botorch.readthedocs.io/en/latest/models.html#botorch.models.gp_regression_mixed.MixedSingleTaskGP): + a single-task exact GP that supports mixed search spaces, which combine + discrete and continuous features. +- [`SaasFullyBayesianSingleTaskGP`](https://botorch.readthedocs.io/en/latest/models.html#botorch.models.fully_bayesian.SaasFullyBayesianSingleTaskGP): + a fully Bayesian single-task GP with the SAAS prior. This model is suitable + for sample-efficient high-dimensional Bayesian optimization. + +### Model List of Single-Task GPs + +- [`ModelListGP`](https://botorch.readthedocs.io/en/latest/models.html#module-botorch.models.model_list_gp_regression): + A multi-output model in which outcomes are modeled independently, given a list + of any type of single-task GP. This model should be used when the same + training data is not used for all outputs. + +### Multi-Task GPs + +- [`MultiTaskGP`](https://botorch.readthedocs.io/en/latest/models.html#module-botorch.models.multitask): a + Hadamard multi-task, multi-output GP using an ICM kernel. Supports both known + observation noise levels and inferring a homoskedastic noise level (when noise + observations are not provided). +- [`KroneckerMultiTaskGP`](https://botorch.readthedocs.io/en/latest/models.html#botorch.models.multitask.KroneckerMultiTaskGP): + A multi-task, multi-output GP using an ICM kernel, with Kronecker structure. + Useful for multi-fidelity optimization. +- [`SaasFullyBayesianMultiTaskGP`](https://botorch.readthedocs.io/en/latest/models.html#botorch.models.fully_bayesian_multitask.SaasFullyBayesianMultiTaskGP): + a fully Bayesian multi-task GP using an ICM kernel. The data kernel uses the + SAAS prior to model high-dimensional parameter spaces. + +All of the above models use RBF kernels with Automatic Relevance Discovery +(ARD), and have reasonable priors on hyperparameters that make them work well in +settings where the **input features are normalized to the unit cube** and the +**observations are standardized** (zero mean, unit variance). The lengthscale +priors scale with the input dimension, which makes them adaptable to both low +and high dimensional problems. See +[this discussion](https://github.com/pytorch/botorch/discussions/2451) for +additional context on the default hyperparameters. + +## Other useful models + +- [`ModelList`](https://botorch.readthedocs.io/en/latest/models.html#botorch.models.model.ModelList): a + multi-output model container in which outcomes are modeled independently by + individual `Model`s (as in `ModelListGP`, but the component models do not all + need to be GPyTorch models). +- [`SingleTaskMultiFidelityGP`](https://botorch.readthedocs.io/en/latest/models.html#botorch.models.gp_regression_fidelity.SingleTaskMultiFidelityGP): + A GP model for multi-fidelity optimization. For more on Multi-Fidelity BO, see + the [tutorial](tutorials/discrete_multi_fidelity_bo). +- [`HigherOrderGP`](https://botorch.readthedocs.io/en/latest/models.html#botorch.models.higher_order_gp.HigherOrderGP): + A GP model with matrix-valued predictions, such as images or grids of images. +- [`PairwiseGP`](https://botorch.readthedocs.io/en/latest/models.html#module-botorch.models.pairwise_gp): A + probit-likelihood GP that learns via pairwise comparison data, useful for + preference learning. +- [`ApproximateGPyTorchModel`](https://botorch.readthedocs.io/en/latest/models.html#botorch.models.approximate_gp.ApproximateGPyTorchModel): + for efficient computation when data is large or responses are non-Gaussian. +- [Deterministic models](https://botorch.readthedocs.io/en/latest/models.html#module-botorch.models.deterministic), + such as + [`AffineDeterministicModel`](https://botorch.readthedocs.io/en/latest/models.html#botorch.models.deterministic.AffineDeterministicModel), + [`AffineFidelityCostModel`](https://botorch.readthedocs.io/en/latest/models.html#botorch.models.cost.AffineFidelityCostModel), + [`GenericDeterministicModel`](https://botorch.readthedocs.io/en/latest/models.html#botorch.models.deterministic.GenericDeterministicModel), + and + [`PosteriorMeanModel`](https://botorch.readthedocs.io/en/latest/models.html#botorch.models.deterministic.PosteriorMeanModel) + express known input-output relationships; they conform to the BoTorch `Model` + API, so they can easily be used in conjunction with other BoTorch models. + Deterministic models are useful for multi-objective optimization with known + objective functions and for encoding cost functions for cost-aware + acquisition. +- [`SingleTaskVariationalGP`](https://botorch.readthedocs.io/en/latest/models.html#botorch.models.approximate_gp.SingleTaskVariationalGP): + an approximate model for faster computation when you have a lot of data or + your responses are non-Gaussian. + +## Implementing Custom Models + +The configurability of the above models is limited (for instance, it is not +straightforward to use a different kernel). Doing so is an intentional design +decision -- we believe that having a few simple and easy-to-understand models +for basic use cases is more valuable than having a highly complex and +configurable model class whose implementation is difficult to understand. + +Instead, we advocate that users implement their own models to cover more +specialized use cases. The light-weight nature of BoTorch's Model API makes this +easy to do. See Ax's +[Modular BoTorch tutorial](https://ax.dev/docs/tutorials/modular_botorch/) +tutorial for an example for this and how to use such a custom model in Ax. + +The BoTorch `Model` interface is light-weight and easy to extend. The only +requirement for using BoTorch's Monte-Carlo based acquisition functions is that +the model has a `posterior` method. It takes in a Tensor `X` of design points, +and returns a Posterior object describing the (joint) probability distribution +of the model output(s) over the design points in `X`. The `Posterior` object +must implement an `rsample()` method for sampling from the posterior of the +model. If you wish to use gradient-based optimization algorithms, the model +should allow back-propagating gradients through the samples to the model input. + +If you happen to implement a model that would be useful for other researchers as +well (and involves more than just swapping out the RBF kernel for a Matérn +kernel), please consider [contributing](getting_started#contributing) this model +to BoTorch. From ef43cc197bba59231a2c69bfa76416f095e27524 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Tue, 27 May 2025 19:16:55 -0700 Subject: [PATCH 32/35] Cleaned up code --- .../acquisition/latent_information_gain.py | 245 ++--- botorch_community/models/np_regression.py | 983 +++++++++--------- 2 files changed, 618 insertions(+), 610 deletions(-) diff --git a/botorch_community/acquisition/latent_information_gain.py b/botorch_community/acquisition/latent_information_gain.py index 5a2d8c38a8..920aa993c5 100644 --- a/botorch_community/acquisition/latent_information_gain.py +++ b/botorch_community/acquisition/latent_information_gain.py @@ -1,117 +1,128 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -r""" -Latent Information Gain Acquisition Function for Neural Process Models. - -References: - -.. [Wu2023arxiv] - Wu, D., Niu, R., Chinazzi, M., Vespignani, A., Ma, Y.-A., & Yu, R. (2023). - Deep Bayesian Active Learning for Accelerating Stochastic Simulation. - arXiv preprint arXiv:2106.02770. Retrieved from https://arxiv.org/abs/2106.02770 - -Contributor: eibarolle -""" - -from __future__ import annotations - -from typing import Any, Type - -import torch -from botorch.acquisition import AcquisitionFunction -from botorch_community.models.np_regression import NeuralProcessModel -from torch import Tensor -# reference: https://arxiv.org/abs/2106.02770 - - -class LatentInformationGain(AcquisitionFunction): - def __init__( - self, - model: Type[Any], - num_samples: int = 10, - min_std: float = 0.01, - scaler: float = 0.5, - ) -> None: - """ - Latent Information Gain (LIG) Acquisition Function. - Uses the model's built-in posterior function to generalize KL computation. - - Args: - model: The model class to be used, defaults to NeuralProcessModel. - num_samples: Int showing the # of samples for calculation, defaults to 10. - min_std: Float representing the minimum possible standardized std, - defaults to 0.01. - scaler: Float scaling the std, defaults to 0.5. - """ - super().__init__(model) - self.model = model - self.num_samples = num_samples - self.min_std = min_std - self.scaler = scaler - - def forward(self, candidate_x: Tensor) -> Tensor: - """ - Conduct the Latent Information Gain acquisition function for the inputs. - - Args: - candidate_x: Candidate input points, as a Tensor. Ideally in the shape - (N, q, D). - - Returns: - torch.Tensor: The LIG scores of computed KLDs, in the shape (N, q). - """ - device = candidate_x.device - candidate_x = candidate_x.to(device) - N, q, D = candidate_x.shape - kl = torch.zeros(N, device=device, dtype=torch.float32) - - if isinstance(self.model, NeuralProcessModel): - x_c, y_c, _, _ = self.model.random_split_context_target( - self.model.train_X, self.model.train_Y, self.model.n_context - ) - self.model.z_mu_context, self.model.z_logvar_context = ( - self.model.data_to_z_params(x_c, y_c) - ) - - for i in range(N): - x_i = candidate_x[i] - kl_i = 0.0 - - for _ in range(self.num_samples): - sample_z = self.model.sample_z( - self.model.z_mu_context, self.model.z_logvar_context - ) - if sample_z.dim() == 1: - sample_z = sample_z.unsqueeze(0) - - y_pred = self.model.decoder(x_i, sample_z) - - combined_x = torch.cat([x_c, x_i], dim=0) - combined_y = torch.cat([y_c, y_pred], dim=0) - - self.model.z_mu_all, self.model.z_logvar_all = ( - self.model.data_to_z_params(combined_x, combined_y) - ) - kl_sample = self.model.KLD_gaussian(self.min_std, self.scaler) - kl_i += kl_sample - - kl[i] = kl_i / self.num_samples - - else: - for i in range(N): - x_i = candidate_x[i] - kl_i = 0.0 - for _ in range(self.num_samples): - posterior_prior = self.model.posterior(self.model.train_X) - posterior_candidate = self.model.posterior(x_i) - - kl_i += torch.distributions.kl_divergence( - posterior_candidate.mvn, posterior_prior.mvn - ).sum() - - kl[i] = kl_i / self.num_samples - return kl +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +Latent Information Gain Acquisition Function for Neural Process Models. + +References: + +.. [Wu2023arxiv] + Wu, D., Niu, R., Chinazzi, M., Vespignani, A., Ma, Y.-A., & Yu, R. (2023). + Deep Bayesian Active Learning for Accelerating Stochastic Simulation. + arXiv preprint arXiv:2106.02770. Retrieved from https://arxiv.org/abs/2106.02770 + +Contributor: eibarolle +""" + +from __future__ import annotations + +from typing import Any, Type + +import torch +from botorch.acquisition import AcquisitionFunction +from botorch_community.models.np_regression import NeuralProcessModel +from torch import Tensor +# reference: https://arxiv.org/abs/2106.02770 + + +class LatentInformationGain(AcquisitionFunction): + def __init__( + self, + model: Type[Any], + num_samples: int = 10, + min_std: float = 0.01, + scaler: float = 0.5, + ) -> None: + """ + Latent Information Gain (LIG) Acquisition Function. + Uses the model's built-in posterior function to generalize KL computation. + + Args: + model: The model class to be used, defaults to NeuralProcessModel. + num_samples: Int showing the # of samples for calculation, defaults to 10. + min_std: Float representing the minimum possible standardized std, + defaults to 0.01. + scaler: Float scaling the std, defaults to 0.5. + """ + super().__init__(model) + self.model = model + self.num_samples = num_samples + self.min_std = min_std + self.scaler = scaler + + def forward(self, candidate_x: Tensor) -> Tensor: + """ + Conduct the Latent Information Gain acquisition function for the inputs. + + Args: + candidate_x: Candidate input points, as a Tensor. Ideally in the shape + (N, q, D). + + Returns: + torch.Tensor: The LIG scores of computed KLDs, in the shape (N, q). + """ + device = candidate_x.device + candidate_x = candidate_x.to(device) + N, q, D = candidate_x.shape + kl = torch.zeros(N, device=device, dtype=torch.float32) + + if isinstance(self.model, NeuralProcessModel): + x_c, y_c, _, _ = self.model.random_split_context_target( + self.model.train_X, self.model.train_Y, self.model.n_context + ) + self.model.z_mu_context, self.model.z_logvar_context = ( + self.model.data_to_z_params(x_c, y_c) + ) + + for i in range(N): + x_i = candidate_x[i] + kl_i = 0.0 + + for _ in range(self.num_samples): + sample_z = self.model.sample_z( + self.model.z_mu_context, self.model.z_logvar_context + ) + if sample_z.dim() == 1: + sample_z = sample_z.unsqueeze(0) + + y_pred = self.model.decoder(x_i, sample_z) + + combined_x = torch.cat([x_c, x_i], dim=0) + combined_y = torch.cat([y_c, y_pred], dim=0) + + self.model.z_mu_all, self.model.z_logvar_all = ( + self.model.data_to_z_params(combined_x, combined_y) + ) + kl_sample = self.model.KLD_gaussian(self.min_std, self.scaler) + kl_i += kl_sample + + kl[i] = kl_i / self.num_samples + + else: + for i in range(N): + x_i = candidate_x[i] + kl_i = 0.0 + for _ in range(self.num_samples): + posterior_prior = self.model.posterior(self.model.train_inputs[0]) + posterior_candidate = self.model.posterior(x_i) + + mean_prior = posterior_prior.mean.mean(dim=0) + cov_prior = posterior_prior.variance.mean(dim=0) + mvn_prior = torch.distributions.MultivariateNormal( + mean_prior, torch.diag(cov_prior) + ) + + mean_candidate = posterior_candidate.mean.mean(dim=0) + cov_candidate = posterior_candidate.variance.mean(dim=0) + mvn_candidate = torch.distributions.MultivariateNormal( + mean_candidate, torch.diag(cov_candidate) + ) + + kl_i += torch.distributions.kl_divergence(mvn_candidate, mvn_prior) + + kl[i] = kl_i / self.num_samples + + return kl diff --git a/botorch_community/models/np_regression.py b/botorch_community/models/np_regression.py index 3b1af8d9b2..444389968c 100644 --- a/botorch_community/models/np_regression.py +++ b/botorch_community/models/np_regression.py @@ -1,493 +1,490 @@ -r""" -Neural Process Regression models based on PyTorch models. - -References: - -.. [Wu2023arxiv] - Wu, D., Niu, R., Chinazzi, M., Vespignani, A., Ma, Y.-A., & Yu, R. (2023). - Deep Bayesian Active Learning for Accelerating Stochastic Simulation. - arXiv preprint arXiv:2106.02770. Retrieved from https://arxiv.org/abs/2106.02770 - -Contributor: eibarolle -""" - -from typing import Callable, List, Optional, Tuple - -import torch -import torch.nn as nn -from botorch.acquisition.objective import PosteriorTransform -from botorch.models.model import Model -from botorch.models.transforms.input import InputTransform -from botorch.posteriors import GPyTorchPosterior -from gpytorch.distributions import MultivariateNormal -from gpytorch.likelihoods import GaussianLikelihood -from gpytorch.likelihoods.likelihood import Likelihood -from gpytorch.models.gp import GP -from torch.nn import Module - - -# reference: https://chrisorm.github.io/NGP.html -class MLP(nn.Module): - def __init__( - self, - input_dim: int, - output_dim: int, - hidden_dims: List[int], - activation: Callable = nn.Sigmoid, - init_func: Optional[Callable] = nn.init.normal_, - ) -> None: - r""" - A modular implementation of a Multilayer Perceptron (MLP). - - Args: - input_dim: An int representing the total input dimensionality. - output_dim: An int representing the total encoded dimensionality. - hidden_dims: A list of integers representing the # of units in each hidden - dimension. - activation: Activation function applied between layers, defaults to - nn.Sigmoid. - init_func: A function initializing the weights, - defaults to nn.init.normal_. - """ - super().__init__() - layers = [] - prev_dim = input_dim - - for hidden_dim in hidden_dims: - layer = nn.Linear(prev_dim, hidden_dim) - if init_func is not None: - init_func(layer.weight) - layers.append(layer) - layers.append(activation()) - prev_dim = hidden_dim - - final_layer = nn.Linear(prev_dim, output_dim) - if init_func is not None: - init_func(final_layer.weight) - layers.append(final_layer) - self.model = nn.Sequential(*layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.model(x) - - -class REncoder(nn.Module): - def __init__( - self, - input_dim: int, - output_dim: int, - hidden_dims: List[int], - activation: Callable = nn.Sigmoid, - init_func: Optional[Callable] = nn.init.normal_, - ) -> None: - r"""Encodes inputs of the form (x_i,y_i) into representations, r_i. - - Args: - input_dim: An int representing the total input dimensionality. - output_dim: An int representing the total encoded dimensionality. - hidden_dims: A list of integers representing the # of units in each hidden - dimension. - activation: Activation function applied between layers, defaults to nn. - Sigmoid. - init_func: A function initializing the weights, - defaults to nn.init.normal_. - """ - super().__init__() - self.mlp = MLP( - input_dim=input_dim, - output_dim=output_dim, - hidden_dims=hidden_dims, - activation=activation, - init_func=init_func, - ) - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - r"""Forward pass for representation encoder. - - Args: - inputs: Input tensor - - Returns: - torch.Tensor: Encoded representations - """ - return self.mlp(inputs) - - -class ZEncoder(nn.Module): - def __init__( - self, - input_dim: int, - output_dim: int, - hidden_dims: List[int], - activation: Callable = nn.Sigmoid, - init_func: Optional[Callable] = nn.init.normal_, - ) -> None: - r"""Takes an r representation and produces the mean & standard - deviation of the normally distributed function encoding, z. - - Args: - input_dim: An int representing r's aggregated dimensionality. - output_dim: An int representing z's latent dimensionality. - hidden_dims: A list of integers representing the # of units in each hidden - dimension. - activation: Activation function applied between layers, defaults to nn. - Sigmoid. - init_func: A function initializing the weights, - defaults to nn.init.normal_. - """ - super().__init__() - self.mean_net = MLP( - input_dim=input_dim, - output_dim=output_dim, - hidden_dims=hidden_dims, - activation=activation, - init_func=init_func, - ) - self.logvar_net = MLP( - input_dim=input_dim, - output_dim=output_dim, - hidden_dims=hidden_dims, - activation=activation, - init_func=init_func, - ) - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - r"""Forward pass for latent encoder. - - Args: - inputs: Input tensor - - Returns: - Tuple[torch.Tensor, torch.Tensor]: - - Mean of the latent Gaussian distribution. - - Log variance of the latent Gaussian distribution. - """ - return self.mean_net(inputs), self.logvar_net(inputs) - - -class Decoder(torch.nn.Module): - def __init__( - self, - input_dim: int, - output_dim: int, - hidden_dims: List[int], - activation: Callable = nn.Sigmoid, - init_func: Optional[Callable] = nn.init.normal_, - ) -> None: - r"""Takes the x star points, along with a 'function encoding', z, and makes - predictions. - - Args: - input_dim: An int representing the total input dimensionality. - output_dim: An int representing the total encoded dimensionality. - hidden_dims: A list of integers representing the # of units in each hidden - dimension. - activation: Activation function applied between layers, defaults to - nn.Sigmoid. - init_func: A function initializing the weights, - defaults to nn.init.normal_. - """ - super().__init__() - self.mlp = MLP( - input_dim=input_dim, - output_dim=output_dim, - hidden_dims=hidden_dims, - activation=activation, - init_func=init_func, - ) - - def forward(self, x_pred: torch.Tensor, z: torch.Tensor) -> torch.Tensor: - r"""Forward pass for decoder. - - Args: - x_pred: Input points of shape (n x d_x), representing # of data points by - x_dim. - z: Latent encoding of shape (num_samples x d_z), representing # of samples - by z_dim. - - Returns: - torch.Tensor: Predicted target values of shape (n x z_dim), representing # - of data points by z_dim. - """ - if z.dim() == 1: - z = z.unsqueeze(0) - if z.dim() == 3: - z = z.squeeze(0) - z_expanded = z.expand(x_pred.size(0), -1) - x_pred = x_pred - xz = torch.cat([x_pred, z_expanded], dim=-1) - return self.mlp(xz) - - -class NeuralProcessModel(Model, GP): - def __init__( - self, - train_X: torch.Tensor, - train_Y: torch.Tensor, - r_hidden_dims: List[int] = [16, 16], - z_hidden_dims: List[int] = [32, 32], - decoder_hidden_dims: List[int] = [16, 16], - x_dim: int = 2, - y_dim: int = 1, - r_dim: int = 64, - z_dim: int = 8, - n_context: int = 20, - activation: Callable = nn.Sigmoid, - init_func: Optional[Callable] = torch.nn.init.normal_, - likelihood: Likelihood | None = None, - input_transform: InputTransform | None = None, - ) -> None: - r"""Diffusion Convolutional Recurrent Neural Network Model Implementation. - - Args: - train_X: A `batch_shape x n x d` tensor of training features. - train_Y: A `batch_shape x n x m` tensor of training observations. - r_hidden_dims: Hidden Dimensions/Layer list for REncoder, defaults to - [16, 16] - z_hidden_dims: Hidden Dimensions/Layer list for ZEncoder, defaults to - [32, 32] - decoder_hidden_dims: Hidden Dimensions/Layer for Decoder, defaults to - [16, 16] - x_dim: Int dimensionality of input data x, defaults to 2. - y_dim: Int dimensionality of target data y, defaults to 1. - r_dim: Int dimensionality of representation r, defaults to 64. - z_dim: Int dimensionality of latent variable z, defaults to 8. - n_context (int): Number of context points, defaults to 20. - activation: Activation function applied between layers, defaults to nn. - Sigmoid. - init_func: A function initializing the weights, - defaults to nn.init.normal_. - likelihood: A likelihood. If omitted, use a standard GaussianLikelihood. - input_transform: An input transform that is applied in the model's - forward pass. - """ - super().__init__() - self.device = train_X.device - - self.r_encoder = REncoder( - x_dim + y_dim, - r_dim, - r_hidden_dims, - activation=activation, - init_func=init_func, - ) - self.z_encoder = ZEncoder( - r_dim, z_dim, z_hidden_dims, activation=activation, init_func=init_func - ) - self.decoder = Decoder( - x_dim + z_dim, - y_dim, - decoder_hidden_dims, - activation=activation, - init_func=init_func, - ) - self.train_X = train_X - self.train_Y = train_Y - self.n_context = n_context - self.z_dim = z_dim - self.z_mu_all = None - self.z_logvar_all = None - self.z_mu_context = None - self.z_logvar_context = None - self.likelihood = likelihood if likelihood is not None else GaussianLikelihood() - self.input_transform = input_transform - self.to(device=self.device) - - def data_to_z_params( - self, x: torch.Tensor, y: torch.Tensor, r_dim: int = 0 - ) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Compute latent parameters from inputs as a latent distribution. - - Args: - x: Input tensor - y: Target tensor - r_dim: Combined Target Dimension as int, defaults as 0. - - Returns: - Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - - x_c: Context input data. - - y_c: Context target data. - - x_t: Target input data. - - y_t: Target target data. - """ - x = x.to(self.device) - y = y.to(self.device) - xy = torch.cat([x, y], dim=-1).to(self.device).to(self.device) - rs = self.r_encoder(xy) - r_agg = rs.mean(dim=r_dim).to(self.device) - return self.z_encoder(r_agg) - - def sample_z( - self, - mu: torch.Tensor, - logvar: torch.Tensor, - n: int = 1, - min_std: float = 0.01, - scaler: float = 0.5, - ) -> torch.Tensor: - r"""Reparameterization trick for z's latent distribution. - - Args: - mu: Tensor representing the Gaussian distribution mean. - logvar: Tensor representing the log variance of the Gaussian distribution. - n: Int representing the # of samples, defaults to 1. - min_std: Float representing the minimum possible standardized std, defaults - to 0.01. - scaler: Float scaling the std, defaults to 0.5. - - Returns: - torch.Tensor: Samples from the Gaussian distribution. - """ - if min_std <= 0 or scaler <= 0: - raise ValueError() - - shape = [n, self.z_dim] - if n == 1: - shape = shape[1:] - eps = torch.autograd.Variable(logvar.data.new(*shape).normal_()).to(self.device) - - std = min_std + scaler * torch.sigmoid(logvar) - std = std.to(self.device) - mu = mu.to(self.device) - return mu + std * eps - - def KLD_gaussian(self, min_std: float = 0.01, scaler: float = 0.5) -> torch.Tensor: - r"""Analytical KLD between 2 Gaussian Distributions. - - Args: - min_std: Float representing the minimum possible standardized std, defaults - to 0.01. - scaler: Float scaling the std, defaults to 0.5. - - Returns: - torch.Tensor: A tensor representing the KLD. - """ - - if min_std <= 0 or scaler <= 0: - raise ValueError() - std_q = min_std + scaler * torch.sigmoid(self.z_logvar_all).to(self.device) - std_p = min_std + scaler * torch.sigmoid(self.z_logvar_context).to(self.device) - p = torch.distributions.Normal(self.z_mu_context.to(self.device), std_p) - q = torch.distributions.Normal(self.z_mu_all.to(self.device), std_q) - return torch.distributions.kl_divergence(p, q).sum() - - def posterior( - self, - X: torch.Tensor, - output_indices: list[int] | None = None, - observation_noise: bool = False, - posterior_transform: PosteriorTransform | None = None, - ) -> GPyTorchPosterior: - r"""Computes the model's posterior for given input tensors. - - Args: - X: Input Tensor - covariance_multiplier: Float scaling the covariance. - observation_constant: Float representing the noise constant. - output_indices: Ignored (defined in parent Model, but not used here). - observation_noise: Adds observation noise to the covariance if True, - defaults to False. - posterior_transform: An optional posterior transformation, - defaults to None. - - Returns: - GPyTorchPosterior: The posterior utilizing MultivariateNormal. - """ - X = self.transform_inputs(X) - X = X.to(self.device) - mean = self.decoder( - X.to(self.device), self.sample_z(self.z_mu_all, self.z_logvar_all) - ) - z_var = torch.exp(self.z_logvar_all) - covariance = torch.eye(X.size(0)).to(self.device) * z_var.mean() - if observation_noise: - covariance = covariance + self.likelihood.noise * torch.eye( - covariance.size(0) - ).to(self.device) - mvn = MultivariateNormal(mean, covariance) - posterior = GPyTorchPosterior(mvn) - if posterior_transform is not None: - posterior = posterior_transform(posterior) - return posterior - - def transform_inputs( - self, - X: torch.Tensor, - input_transform: Optional[Module] = None, - ) -> torch.Tensor: - r"""Transform inputs. - - Args: - X: A tensor of inputs - input_transform: A Module that performs the input transformation. - - Returns: - torch.Tensor: A tensor of transformed inputs - """ - X = X.to(self.device) - if input_transform is not None: - input_transform.to(X) - return input_transform(X) - try: - return self.input_transform(X) - except (AttributeError, TypeError): - return X - - def forward( - self, - train_X: torch.Tensor, - train_Y: torch.Tensor, - axis: int = 0, - ) -> MultivariateNormal: - r"""Forward pass for the model. - - Args: - train_X: A `batch_shape x n x d` tensor of training features. - train_Y: A `batch_shape x n x m` tensor of training observations. - axis: Dimension axis as int, defaulted as 0. - - Returns: - MultivariateNormal: Predicted target distribution. - """ - train_X = self.transform_inputs(train_X) - x_c, y_c, x_t, y_t = self.random_split_context_target( - train_X, train_Y, self.n_context, axis=axis - ) - x_t = x_t.to(self.device) - x_c = x_c.to(self.device) - y_c = y_c.to(self.device) - y_t = y_t.to(self.device) - self.z_mu_all, self.z_logvar_all = self.data_to_z_params( - self.train_X, self.train_Y - ) - self.z_mu_context, self.z_logvar_context = self.data_to_z_params(x_c, y_c) - x_t = self.transform_inputs(x_t) - return self.posterior(x_t).distribution - - def random_split_context_target( - self, x: torch.Tensor, y: torch.Tensor, n_context, axis: int = 0 - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - r"""Helper function to split randomly into context and target. - - Args: - x: A `batch_shape x n x d` tensor of training features. - y: A `batch_shape x n x m` tensor of training observations. - n_context (int): Number of context points. - axis: Dimension axis as int, defaults to 0. - - Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - - x_c: Context input data. - - y_c: Context target data. - - x_t: Target input data. - - y_t: Target target data. - """ - self.n_context = n_context - mask = torch.randperm(x.shape[axis])[:n_context] - splitter = torch.zeros(x.shape[axis], dtype=torch.bool) - x_c = x[mask].to(self.device) - y_c = y[mask].to(self.device) - splitter[mask] = True - x_t = x[~splitter].to(self.device) - y_t = y[~splitter].to(self.device) - return x_c, y_c, x_t, y_t +r""" +Neural Process Regression models based on PyTorch models. + +References: + +.. [Wu2023arxiv] + Wu, D., Niu, R., Chinazzi, M., Vespignani, A., Ma, Y.-A., & Yu, R. (2023). + Deep Bayesian Active Learning for Accelerating Stochastic Simulation. + arXiv preprint arXiv:2106.02770. Retrieved from https://arxiv.org/abs/2106.02770 + +Contributor: eibarolle +""" + +from typing import Callable, List, Optional, Tuple + +import torch +import torch.nn as nn +from botorch.acquisition.objective import PosteriorTransform +from botorch.models.model import Model +from botorch.models.transforms.input import InputTransform +from botorch.posteriors import GPyTorchPosterior +from gpytorch.distributions import MultivariateNormal +from gpytorch.likelihoods import GaussianLikelihood +from gpytorch.likelihoods.likelihood import Likelihood +from gpytorch.models.gp import GP +from torch.nn import Module + + +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + output_dim: int, + hidden_dims: List[int], + activation: Callable = nn.Sigmoid, + init_func: Optional[Callable] = nn.init.normal_, + ) -> None: + r""" + A modular implementation of a Multilayer Perceptron (MLP). + + Args: + input_dim: An int representing the total input dimensionality. + output_dim: An int representing the total encoded dimensionality. + hidden_dims: A list of integers representing the # of units in each hidden + dimension. + activation: Activation function applied between layers, defaults to + nn.Sigmoid. + init_func: A function initializing the weights, + defaults to nn.init.normal_. + """ + super().__init__() + layers = [] + prev_dim = input_dim + + for hidden_dim in hidden_dims: + layer = nn.Linear(prev_dim, hidden_dim) + if init_func is not None: + init_func(layer.weight) + layers.append(layer) + layers.append(activation()) + prev_dim = hidden_dim + + final_layer = nn.Linear(prev_dim, output_dim) + if init_func is not None: + init_func(final_layer.weight) + layers.append(final_layer) + self.model = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.model(x) + + +class REncoder(nn.Module): + def __init__( + self, + input_dim: int, + output_dim: int, + hidden_dims: List[int], + activation: Callable = nn.Sigmoid, + init_func: Optional[Callable] = nn.init.normal_, + ) -> None: + r"""Encodes inputs of the form (x_i,y_i) into representations, r_i. + + Args: + input_dim: An int representing the total input dimensionality. + output_dim: An int representing the total encoded dimensionality. + hidden_dims: A list of integers representing the # of units in each hidden + dimension. + activation: Activation function applied between layers, defaults to nn. + Sigmoid. + init_func: A function initializing the weights, + defaults to nn.init.normal_. + """ + super().__init__() + self.mlp = MLP( + input_dim=input_dim, + output_dim=output_dim, + hidden_dims=hidden_dims, + activation=activation, + init_func=init_func, + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + r"""Forward pass for representation encoder. + + Args: + inputs: Input tensor + + Returns: + torch.Tensor: Encoded representations + """ + return self.mlp(inputs) + + +class ZEncoder(nn.Module): + def __init__( + self, + input_dim: int, + output_dim: int, + hidden_dims: List[int], + activation: Callable = nn.Sigmoid, + init_func: Optional[Callable] = nn.init.normal_, + ) -> None: + r"""Takes an r representation and produces the mean & standard + deviation of the normally distributed function encoding, z. + + Args: + input_dim: An int representing r's aggregated dimensionality. + output_dim: An int representing z's latent dimensionality. + hidden_dims: A list of integers representing the # of units in each hidden + dimension. + activation: Activation function applied between layers, defaults to nn. + Sigmoid. + init_func: A function initializing the weights, + defaults to nn.init.normal_. + """ + super().__init__() + self.mean_net = MLP( + input_dim=input_dim, + output_dim=output_dim, + hidden_dims=hidden_dims, + activation=activation, + init_func=init_func, + ) + self.logvar_net = MLP( + input_dim=input_dim, + output_dim=output_dim, + hidden_dims=hidden_dims, + activation=activation, + init_func=init_func, + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + r"""Forward pass for latent encoder. + + Args: + inputs: Input tensor + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - Mean of the latent Gaussian distribution. + - Log variance of the latent Gaussian distribution. + """ + return self.mean_net(inputs), self.logvar_net(inputs) + + +class Decoder(torch.nn.Module): + def __init__( + self, + input_dim: int, + output_dim: int, + hidden_dims: List[int], + activation: Callable = nn.Sigmoid, + init_func: Optional[Callable] = nn.init.normal_, + ) -> None: + r"""Takes the x star points, along with a 'function encoding', z, and makes + predictions. + + Args: + input_dim: An int representing the total input dimensionality. + output_dim: An int representing the total encoded dimensionality. + hidden_dims: A list of integers representing the # of units in each hidden + dimension. + activation: Activation function applied between layers, defaults to + nn.Sigmoid. + init_func: A function initializing the weights, + defaults to nn.init.normal_. + """ + super().__init__() + self.mlp = MLP( + input_dim=input_dim, + output_dim=output_dim, + hidden_dims=hidden_dims, + activation=activation, + init_func=init_func, + ) + + def forward(self, x_pred: torch.Tensor, z: torch.Tensor) -> torch.Tensor: + r"""Forward pass for decoder. + + Args: + x_pred: Input points of shape (n x d_x), representing # of data points by + x_dim. + z: Latent encoding of shape (num_samples x d_z), representing # of samples + by z_dim. + + Returns: + torch.Tensor: Predicted target values of shape (n x z_dim), representing # + of data points by z_dim. + """ + if z.dim() == 1: + z = z.unsqueeze(0) + z_expanded = z.expand(x_pred.size(0), -1) + x_pred = x_pred + xz = torch.cat([x_pred, z_expanded], dim=-1) + return self.mlp(xz) + + +class NeuralProcessModel(Model, GP): + def __init__( + self, + train_X: torch.Tensor, + train_Y: torch.Tensor, + r_hidden_dims: List[int] = [16, 16], + z_hidden_dims: List[int] = [32, 32], + decoder_hidden_dims: List[int] = [16, 16], + x_dim: int = 2, + y_dim: int = 1, + r_dim: int = 64, + z_dim: int = 8, + n_context: int = 20, + activation: Callable = nn.Sigmoid, + init_func: Optional[Callable] = torch.nn.init.normal_, + likelihood: Likelihood | None = None, + input_transform: InputTransform | None = None, + ) -> None: + r"""Diffusion Convolutional Recurrent Neural Network Model Implementation. + + Args: + train_X: A `batch_shape x n x d` tensor of training features. + train_Y: A `batch_shape x n x m` tensor of training observations. + r_hidden_dims: Hidden Dimensions/Layer list for REncoder, defaults to + [16, 16] + z_hidden_dims: Hidden Dimensions/Layer list for ZEncoder, defaults to + [32, 32] + decoder_hidden_dims: Hidden Dimensions/Layer for Decoder, defaults to + [16, 16] + x_dim: Int dimensionality of input data x, defaults to 2. + y_dim: Int dimensionality of target data y, defaults to 1. + r_dim: Int dimensionality of representation r, defaults to 64. + z_dim: Int dimensionality of latent variable z, defaults to 8. + n_context (int): Number of context points, defaults to 20. + activation: Activation function applied between layers, defaults to nn. + Sigmoid. + init_func: A function initializing the weights, + defaults to nn.init.normal_. + likelihood: A likelihood. If omitted, use a standard GaussianLikelihood. + input_transform: An input transform that is applied in the model's + forward pass. + """ + super().__init__() + self.device = train_X.device + + self.r_encoder = REncoder( + x_dim + y_dim, + r_dim, + r_hidden_dims, + activation=activation, + init_func=init_func, + ) + self.z_encoder = ZEncoder( + r_dim, z_dim, z_hidden_dims, activation=activation, init_func=init_func + ) + self.decoder = Decoder( + x_dim + z_dim, + y_dim, + decoder_hidden_dims, + activation=activation, + init_func=init_func, + ) + self.train_X = train_X + self.train_Y = train_Y + self.n_context = n_context + self.z_dim = z_dim + self.z_mu_all = None + self.z_logvar_all = None + self.z_mu_context = None + self.z_logvar_context = None + self.likelihood = likelihood if likelihood is not None else GaussianLikelihood() + self.input_transform = input_transform + self.to(device=self.device) + + def data_to_z_params( + self, x: torch.Tensor, y: torch.Tensor, r_dim: int = 0 + ) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Compute latent parameters from inputs as a latent distribution. + + Args: + x: Input tensor + y: Target tensor + r_dim: Combined Target Dimension as int, defaults as 0. + + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + - x_c: Context input data. + - y_c: Context target data. + - x_t: Target input data. + - y_t: Target target data. + """ + x = x.to(self.device) + y = y.to(self.device) + xy = torch.cat([x, y], dim=-1).to(self.device).to(self.device) + rs = self.r_encoder(xy) + r_agg = rs.mean(dim=r_dim).to(self.device) + return self.z_encoder(r_agg) + + def sample_z( + self, + mu: torch.Tensor, + logvar: torch.Tensor, + n: int = 1, + min_std: float = 0.01, + scaler: float = 0.5, + ) -> torch.Tensor: + r"""Reparameterization trick for z's latent distribution. + + Args: + mu: Tensor representing the Gaussian distribution mean. + logvar: Tensor representing the log variance of the Gaussian distribution. + n: Int representing the # of samples, defaults to 1. + min_std: Float representing the minimum possible standardized std, defaults + to 0.01. + scaler: Float scaling the std, defaults to 0.5. + + Returns: + torch.Tensor: Samples from the Gaussian distribution. + """ + if min_std <= 0 or scaler <= 0: + raise ValueError() + + shape = [n, self.z_dim] + if n == 1: + shape = shape[1:] + eps = torch.autograd.Variable(logvar.data.new(*shape).normal_()).to(self.device) + + std = min_std + scaler * torch.sigmoid(logvar) + std = std.to(self.device) + mu = mu.to(self.device) + return mu + std * eps + + def KLD_gaussian(self, min_std: float = 0.01, scaler: float = 0.5) -> torch.Tensor: + r"""Analytical KLD between 2 Gaussian Distributions. + + Args: + min_std: Float representing the minimum possible standardized std, defaults + to 0.01. + scaler: Float scaling the std, defaults to 0.5. + + Returns: + torch.Tensor: A tensor representing the KLD. + """ + + if min_std <= 0 or scaler <= 0: + raise ValueError() + std_q = min_std + scaler * torch.sigmoid(self.z_logvar_all).to(self.device) + std_p = min_std + scaler * torch.sigmoid(self.z_logvar_context).to(self.device) + p = torch.distributions.Normal(self.z_mu_context.to(self.device), std_p) + q = torch.distributions.Normal(self.z_mu_all.to(self.device), std_q) + return torch.distributions.kl_divergence(p, q).sum() + + def posterior( + self, + X: torch.Tensor, + output_indices: list[int] | None = None, + observation_noise: bool = False, + posterior_transform: PosteriorTransform | None = None, + ) -> GPyTorchPosterior: + r"""Computes the model's posterior for given input tensors. + + Args: + X: Input Tensor + covariance_multiplier: Float scaling the covariance. + observation_constant: Float representing the noise constant. + output_indices: Ignored (defined in parent Model, but not used here). + observation_noise: Adds observation noise to the covariance if True, + defaults to False. + posterior_transform: An optional posterior transformation, + defaults to None. + + Returns: + GPyTorchPosterior: The posterior utilizing MultivariateNormal. + """ + X = self.transform_inputs(X) + X = X.to(self.device) + mean = self.decoder( + X.to(self.device), self.sample_z(self.z_mu_all, self.z_logvar_all) + ) + z_var = torch.exp(self.z_logvar_all) + covariance = torch.eye(X.size(0)).to(self.device) * z_var.mean() + if observation_noise: + covariance = covariance + self.likelihood.noise * torch.eye( + covariance.size(0) + ).to(self.device) + mvn = MultivariateNormal(mean, covariance) + posterior = GPyTorchPosterior(mvn) + if posterior_transform is not None: + posterior = posterior_transform(posterior) + return posterior + + def transform_inputs( + self, + X: torch.Tensor, + input_transform: Optional[Module] = None, + ) -> torch.Tensor: + r"""Transform inputs. + + Args: + X: A tensor of inputs + input_transform: A Module that performs the input transformation. + + Returns: + torch.Tensor: A tensor of transformed inputs + """ + X = X.to(self.device) + if input_transform is not None: + input_transform.to(X) + return input_transform(X) + try: + return self.input_transform(X) + except (AttributeError, TypeError): + return X + + def forward( + self, + train_X: torch.Tensor, + train_Y: torch.Tensor, + axis: int = 0, + ) -> MultivariateNormal: + r"""Forward pass for the model. + + Args: + train_X: A `batch_shape x n x d` tensor of training features. + train_Y: A `batch_shape x n x m` tensor of training observations. + axis: Dimension axis as int, defaulted as 0. + + Returns: + MultivariateNormal: Predicted target distribution. + """ + train_X = self.transform_inputs(train_X) + x_c, y_c, x_t, y_t = self.random_split_context_target( + train_X, train_Y, self.n_context, axis=axis + ) + x_t = x_t.to(self.device) + x_c = x_c.to(self.device) + y_c = y_c.to(self.device) + y_t = y_t.to(self.device) + self.z_mu_all, self.z_logvar_all = self.data_to_z_params( + self.train_X, self.train_Y + ) + self.z_mu_context, self.z_logvar_context = self.data_to_z_params(x_c, y_c) + x_t = self.transform_inputs(x_t) + return self.posterior(x_t).distribution + + def random_split_context_target( + self, x: torch.Tensor, y: torch.Tensor, n_context, axis: int = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + r"""Helper function to split randomly into context and target. + + Args: + x: A `batch_shape x n x d` tensor of training features. + y: A `batch_shape x n x m` tensor of training observations. + n_context (int): Number of context points. + axis: Dimension axis as int, defaults to 0. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + - x_c: Context input data. + - y_c: Context target data. + - x_t: Target input data. + - y_t: Target target data. + """ + self.n_context = n_context + mask = torch.randperm(x.shape[axis])[:n_context] + splitter = torch.zeros(x.shape[axis], dtype=torch.bool) + x_c = x[mask].to(self.device) + y_c = y[mask].to(self.device) + splitter[mask] = True + x_t = x[~splitter].to(self.device) + y_t = y[~splitter].to(self.device) + return x_c, y_c, x_t, y_t From acc35e85f8fc2c0a328b44e34cb0ac090fcbf8bb Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Tue, 27 May 2025 19:17:58 -0700 Subject: [PATCH 33/35] Codecov Test Cases --- .../__pycache__/test_scorebo.cpython-311.pyc | Bin 0 -> 4218 bytes test_community/models/test_np_regression.py | 246 ++++++++++-------- 2 files changed, 133 insertions(+), 113 deletions(-) create mode 100644 test_community/acquisition/__pycache__/test_scorebo.cpython-311.pyc diff --git a/test_community/acquisition/__pycache__/test_scorebo.cpython-311.pyc b/test_community/acquisition/__pycache__/test_scorebo.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3a2d121429092176f4786b9b2b44df28b3980d5 GIT binary patch literal 4218 zcmc&0TWk~A^^V7n@jLM&B(X^XA*E*7&@81CmOvrN27;9A64M2fl{I)KVZzvx%!~;^ zvPY&=W)+D_ug~vw_dLc!MAX;ICsa7(7*7Y@!9Igiv%F|5rY_lM+w|pcp_~{SO|=h zJejs8tTo)q+tT)gy@uO(N194d1hSxU#Mtj3#=+VjkqAA8zwsrU3^9#j)CL|-#t1z+ zouYYKF8)IkevsM$;+F%8$gK@k=ZOKv}l%=+jZE41`pn29d ziVDP6H)ep6v*v6$d(NR)6@qb|03#Szjxu@PIACzF6l&JOA%%EI{GP};cjU5c;d05| zhG(EVs5ljuLMjfh#JysxYjQK5oM(G&iU+udl-b=a=DXH9on$fgz;$%{y)$jiRsqk!w@@N?TV`B-btnn**e{6|dq`Y)bnR?;}|M zanpueV0$kZA7k=*M=tBu*zCv!=Mng;Hdb!QU-+q+vfo^1c>sw)+So(EIzq1h%^uu)FYk1bh0pnOSkcnnwL-2uao@%T4k! zS|SK^8MGRdSa@;+UoG9vus66AyMc4`22P#2^BT+ZP+M3r{+aH*&L{oH&zRHQTxOOPlNo`N*bRrF z#8xYJx=+9*jOLSEW+5v}y1y|Nva&J3a9Cjqbg{C?rF7?|^tWNT5MzrGz7cEJw zD96Pt1W&p~FR)@Q$W|f4N;*mNydE}tn!G{tSynPk_88(9Xb703yMB#HW1=XCu@2oT z!BYT2e|?D-=Oo>SLoZyN#}9)WAdd^ai0_IR!G^en47@SY^i5DNhVhKE7HQbsv>3o_ zI1AT8rUz@P>iZ$l=>;D2wc_Y&20qEdp6CQ8MzKPYPA!5x@R&$ml5}ETx4{IU9YE~c z_Iq<-=`TDT+cdgVjD;N2_5k4kbmu6jOM_Xh6H z6hAH7%30NO1dpy150!tu8d5#KFleXBW2;A0&#}g6bag=W{Icq{`&?D`uycagLsU63lt?n}V{CI;fsdk)% zwmN6@`SBk@2dl^t_Gz8HxJ3rEy@$1EU)2qBHA2;cg56bbW0)VS22lS#)ziDad+$m@ zYwIs1R=cXCr3csnXeY;7gZyOGh1w!jM0kIq?HMTe)_eL_<_(!1VC%&)y*_PZqBY1* zflP1zH)9Vc%8RQpb>BF&r%Oh!?mnBJTn~3YI9&YAT4b;i8C(k=s)P?!QP?{{7_dNS z(F5h7O7!D`^M{UOT6eG3KcwwDj9brs2iTblFUn3eGz@L^OQUbeyho}!ZIM910mq}e_kmP8`n9s6taTl%biodg?bv(_ z3#YYcZ*ipLC=(A|Ur((}6~?{~b>5%)(^P5f;rZ30Plx_?ymENt*~qgC&qgYT&pzLG zb}e+S5;~{WTGfG-Y!2Xfyc$IQPqe@;EfUkBAFcPDSnujt=~lajpp|zSz1lehvESp{ zL;zl5P<1-|u1y5sB?fsCPMOpD2mkN<4{R&WQh#Odr_adBucp;g)6e%#uXV;No$)p5 zVuiY>QWv$hV1X=*7f;`t_@n=xKX2104;C%jR@|kf^5Vl=mHkJbl1~Fq$;$o{&pS`7 zQNtB#Sfz%)_qPAj+w*YJ#a+ph-keZYkh~bzMpB`{jiHc zFc|>sV$hc#sge!mngp2d1d?g4rSBc=_3PK~%)B#{a8CXEf$(_``JKmlzRmGnkN`Lt zv+7hbnW58cGO4?gxG-jUjC+#FMJO`{MLY@8;$8qR;E%*%45$XcFC*!v0B$#bRm+!z z4Lsi-if0<4F_+j6xW2odSO>Rc(!pl+FltnHgC{FUo z4Ag0mf$K@zv{XG`*NUX!7W6cWnr^MdAS3 Date: Sat, 21 Jun 2025 02:40:39 -0700 Subject: [PATCH 34/35] LIG Test Correct Placement --- .../test_latent_information_gain.py | 133 ++++++++++-------- 1 file changed, 78 insertions(+), 55 deletions(-) diff --git a/test_community/acquisition/test_latent_information_gain.py b/test_community/acquisition/test_latent_information_gain.py index 5a76c00163..48f57d2241 100644 --- a/test_community/acquisition/test_latent_information_gain.py +++ b/test_community/acquisition/test_latent_information_gain.py @@ -1,55 +1,78 @@ -import unittest - -import torch -from botorch.optim.optimize import optimize_acqf -from botorch_community.acquisition.latent_information_gain import LatentInformationGain -from botorch_community.models.np_regression import NeuralProcessModel - - -class TestLatentInformationGain(unittest.TestCase): - def setUp(self): - self.x_dim = 2 - self.y_dim = 1 - self.r_dim = 8 - self.z_dim = 3 - self.r_hidden_dims = [16, 16] - self.z_hidden_dims = [32, 32] - self.decoder_hidden_dims = [16, 16] - self.model = NeuralProcessModel( - torch.rand(10, self.x_dim), - torch.rand(10, self.y_dim), - r_hidden_dims=self.r_hidden_dims, - z_hidden_dims=self.z_hidden_dims, - decoder_hidden_dims=self.decoder_hidden_dims, - x_dim=self.x_dim, - y_dim=self.y_dim, - r_dim=self.r_dim, - z_dim=self.z_dim, - ) - self.acquisition_function = LatentInformationGain(self.model) - self.candidate_x = torch.rand(5, self.x_dim) - - def test_initialization(self): - self.assertEqual(self.acquisition_function.num_samples, 10) - self.assertEqual(self.acquisition_function.model, self.model) - - def test_acqf(self): - bounds = torch.tensor([[0.0] * self.x_dim, [1.0] * self.x_dim]) - q = 3 - raw_samples = 8 - num_restarts = 2 - - candidate = optimize_acqf( - acq_function=self.acquisition_function, - bounds=bounds, - q=q, - raw_samples=raw_samples, - num_restarts=num_restarts, - ) - self.assertTrue(isinstance(candidate, tuple)) - self.assertEqual(candidate[0].shape, (q, self.x_dim)) - self.assertTrue(torch.all(candidate[1] >= 0)) - - -if __name__ == "__main__": - unittest.main() +import unittest + +import torch +from botorch.models import SingleTaskGP +from botorch.optim.optimize import optimize_acqf +from botorch_community.acquisition.latent_information_gain import LatentInformationGain +from botorch_community.models.np_regression import NeuralProcessModel + + +class TestLatentInformationGain(unittest.TestCase): + def setUp(self): + self.x_dim = 2 + self.y_dim = 1 + self.r_dim = 8 + self.z_dim = 3 + self.r_hidden_dims = [16, 16] + self.z_hidden_dims = [32, 32] + self.decoder_hidden_dims = [16, 16] + self.model = NeuralProcessModel( + torch.rand(10, self.x_dim), + torch.rand(10, self.y_dim), + r_hidden_dims=self.r_hidden_dims, + z_hidden_dims=self.z_hidden_dims, + decoder_hidden_dims=self.decoder_hidden_dims, + x_dim=self.x_dim, + y_dim=self.y_dim, + r_dim=self.r_dim, + z_dim=self.z_dim, + ) + self.acquisition_function = LatentInformationGain(self.model) + self.candidate_x = torch.rand(5, self.x_dim) + + def test_initialization(self): + self.assertEqual(self.acquisition_function.num_samples, 10) + self.assertEqual(self.acquisition_function.model, self.model) + + def test_acqf(self): + bounds = torch.tensor([[0.0] * self.x_dim, [1.0] * self.x_dim]) + q = 3 + raw_samples = 8 + num_restarts = 2 + + candidate = optimize_acqf( + acq_function=self.acquisition_function, + bounds=bounds, + q=q, + raw_samples=raw_samples, + num_restarts=num_restarts, + ) + self.assertTrue(isinstance(candidate, tuple)) + self.assertEqual(candidate[0].shape, (q, self.x_dim)) + self.assertTrue(torch.all(candidate[1] >= 0)) + + def test_non_NPR(self): + self.model = SingleTaskGP( + torch.rand(10, self.x_dim, dtype=torch.float64), + torch.rand(10, self.y_dim, dtype=torch.float64), + ) + self.acquisition_function = LatentInformationGain(self.model) + bounds = torch.tensor([[0.0] * self.x_dim, [1.0] * self.x_dim]) + q = 3 + raw_samples = 8 + num_restarts = 2 + + candidate = optimize_acqf( + acq_function=self.acquisition_function, + bounds=bounds, + q=q, + raw_samples=raw_samples, + num_restarts=num_restarts, + ) + self.assertTrue(isinstance(candidate, tuple)) + self.assertEqual(candidate[0].shape, (q, self.x_dim)) + self.assertTrue(torch.all(candidate[1] >= 0)) + + +if __name__ == "__main__": + unittest.main() From 6d0de3ab16b20e45b1479c9d0c51ea8929b9d279 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Sat, 21 Jun 2025 03:30:28 -0700 Subject: [PATCH 35/35] Deleting Broken File --- .../__pycache__/test_scorebo.cpython-311.pyc | Bin 4218 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 test_community/acquisition/__pycache__/test_scorebo.cpython-311.pyc diff --git a/test_community/acquisition/__pycache__/test_scorebo.cpython-311.pyc b/test_community/acquisition/__pycache__/test_scorebo.cpython-311.pyc deleted file mode 100644 index c3a2d121429092176f4786b9b2b44df28b3980d5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4218 zcmc&0TWk~A^^V7n@jLM&B(X^XA*E*7&@81CmOvrN27;9A64M2fl{I)KVZzvx%!~;^ zvPY&=W)+D_ug~vw_dLc!MAX;ICsa7(7*7Y@!9Igiv%F|5rY_lM+w|pcp_~{SO|=h zJejs8tTo)q+tT)gy@uO(N194d1hSxU#Mtj3#=+VjkqAA8zwsrU3^9#j)CL|-#t1z+ zouYYKF8)IkevsM$;+F%8$gK@k=ZOKv}l%=+jZE41`pn29d ziVDP6H)ep6v*v6$d(NR)6@qb|03#Szjxu@PIACzF6l&JOA%%EI{GP};cjU5c;d05| zhG(EVs5ljuLMjfh#JysxYjQK5oM(G&iU+udl-b=a=DXH9on$fgz;$%{y)$jiRsqk!w@@N?TV`B-btnn**e{6|dq`Y)bnR?;}|M zanpueV0$kZA7k=*M=tBu*zCv!=Mng;Hdb!QU-+q+vfo^1c>sw)+So(EIzq1h%^uu)FYk1bh0pnOSkcnnwL-2uao@%T4k! zS|SK^8MGRdSa@;+UoG9vus66AyMc4`22P#2^BT+ZP+M3r{+aH*&L{oH&zRHQTxOOPlNo`N*bRrF z#8xYJx=+9*jOLSEW+5v}y1y|Nva&J3a9Cjqbg{C?rF7?|^tWNT5MzrGz7cEJw zD96Pt1W&p~FR)@Q$W|f4N;*mNydE}tn!G{tSynPk_88(9Xb703yMB#HW1=XCu@2oT z!BYT2e|?D-=Oo>SLoZyN#}9)WAdd^ai0_IR!G^en47@SY^i5DNhVhKE7HQbsv>3o_ zI1AT8rUz@P>iZ$l=>;D2wc_Y&20qEdp6CQ8MzKPYPA!5x@R&$ml5}ETx4{IU9YE~c z_Iq<-=`TDT+cdgVjD;N2_5k4kbmu6jOM_Xh6H z6hAH7%30NO1dpy150!tu8d5#KFleXBW2;A0&#}g6bag=W{Icq{`&?D`uycagLsU63lt?n}V{CI;fsdk)% zwmN6@`SBk@2dl^t_Gz8HxJ3rEy@$1EU)2qBHA2;cg56bbW0)VS22lS#)ziDad+$m@ zYwIs1R=cXCr3csnXeY;7gZyOGh1w!jM0kIq?HMTe)_eL_<_(!1VC%&)y*_PZqBY1* zflP1zH)9Vc%8RQpb>BF&r%Oh!?mnBJTn~3YI9&YAT4b;i8C(k=s)P?!QP?{{7_dNS z(F5h7O7!D`^M{UOT6eG3KcwwDj9brs2iTblFUn3eGz@L^OQUbeyho}!ZIM910mq}e_kmP8`n9s6taTl%biodg?bv(_ z3#YYcZ*ipLC=(A|Ur((}6~?{~b>5%)(^P5f;rZ30Plx_?ymENt*~qgC&qgYT&pzLG zb}e+S5;~{WTGfG-Y!2Xfyc$IQPqe@;EfUkBAFcPDSnujt=~lajpp|zSz1lehvESp{ zL;zl5P<1-|u1y5sB?fsCPMOpD2mkN<4{R&WQh#Odr_adBucp;g)6e%#uXV;No$)p5 zVuiY>QWv$hV1X=*7f;`t_@n=xKX2104;C%jR@|kf^5Vl=mHkJbl1~Fq$;$o{&pS`7 zQNtB#Sfz%)_qPAj+w*YJ#a+ph-keZYkh~bzMpB`{jiHc zFc|>sV$hc#sge!mngp2d1d?g4rSBc=_3PK~%)B#{a8CXEf$(_``JKmlzRmGnkN`Lt zv+7hbnW58cGO4?gxG-jUjC+#FMJO`{MLY@8;$8qR;E%*%45$XcFC*!v0B$#bRm+!z z4Lsi-if0<4F_+j6xW2odSO>Rc(!pl+FltnHgC{FUo z4Ag0mf$K@zv{XG`*NUX!7W6cWnr^MdAS3