Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.7
rev: v0.11.5
hooks:
- id: ruff
args: ["--fix"]
Expand Down
6 changes: 3 additions & 3 deletions examples/bayes_llama3/llama3/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def __init__(self, config: FrozenConfigDict):
config["pretrained_model_name_or_path"]
)
else:
assert os.path.isdir(
config["checkpoints_folder"]
), "Provided checkpoints is not a path to a folder"
assert os.path.isdir(config["checkpoints_folder"]), (
"Provided checkpoints is not a path to a folder"
)
checkpoints = [
os.path.join(config["checkpoints_folder"], path)
for path in os.listdir(config["checkpoints_folder"])
Expand Down
103 changes: 76 additions & 27 deletions examples/continual_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,14 @@
"outputs": [],
"source": [
"episode_x_boundaries = torch.linspace(0, n_episodes, n_episodes + 1)\n",
"xs = torch.stack([torch.linspace(episode_x_boundaries[i], episode_x_boundaries[i + 1], samps_per_episode) for i in range(n_episodes)])\n",
"xs = torch.stack(\n",
" [\n",
" torch.linspace(\n",
" episode_x_boundaries[i], episode_x_boundaries[i + 1], samps_per_episode\n",
" )\n",
" for i in range(n_episodes)\n",
" ]\n",
")\n",
"ys = torch.stack([true_f(x) + y_sd * torch.randn_like(x) for x in xs])"
]
},
Expand All @@ -85,18 +92,20 @@
"source": [
"plt_linsp = torch.linspace(-1, episode_x_boundaries[-1] + 1, 1000)\n",
"\n",
"\n",
"def plot_data(ax, up_to_episode=None):\n",
" if up_to_episode is None:\n",
" up_to_episode = n_episodes\n",
" \n",
" ax.plot(xs.flatten(), ys.flatten(), 'o', color='gray', alpha=0.2)\n",
"\n",
" ax.plot(xs.flatten(), ys.flatten(), \"o\", color=\"gray\", alpha=0.2)\n",
" for i in range(up_to_episode):\n",
" ax.plot(xs[i], ys[i], 'o', color='orange')\n",
" \n",
" ax.plot(xs[i], ys[i], \"o\", color=\"orange\")\n",
"\n",
" for v in episode_x_boundaries:\n",
" ax.axvline(v, color='gray', linestyle='--', alpha=0.75)\n",
" ax.plot(plt_linsp, true_f(plt_linsp), color='green', zorder=10)\n",
" ax.set_ylim(-2., 2.5)\n",
" ax.axvline(v, color=\"gray\", linestyle=\"--\", alpha=0.75)\n",
" ax.plot(plt_linsp, true_f(plt_linsp), color=\"green\", zorder=10)\n",
" ax.set_ylim(-2.0, 2.5)\n",
"\n",
"\n",
"fig, ax = plt.subplots()\n",
"plot_data(ax)"
Expand Down Expand Up @@ -166,11 +175,21 @@
"outputs": [],
"source": [
"def log_prior(p, prior_mean, prior_sd: float):\n",
" all_vals = tree_map(lambda p, m, sd: torch.distributions.Normal(m, sd, validate_args=False).log_prob(p).sum(), p, prior_mean, prior_sd)\n",
" all_vals = tree_map(\n",
" lambda p, m, sd: torch.distributions.Normal(m, sd, validate_args=False)\n",
" .log_prob(p)\n",
" .sum(),\n",
" p,\n",
" prior_mean,\n",
" prior_sd,\n",
" )\n",
" return tree_reduce(torch.add, all_vals)\n",
" \n",
"\n",
"\n",
"def log_likelihood(y_pred, y):\n",
" return torch.distributions.Normal(y_pred, y_sd, validate_args=False).log_prob(y).mean()"
" return (\n",
" torch.distributions.Normal(y_pred, y_sd, validate_args=False).log_prob(y).mean()\n",
" )"
]
},
{
Expand All @@ -182,7 +201,10 @@
"def log_posterior(params, batch, prior_mean, prior_sd):\n",
" x, y = batch\n",
" y_pred = mlp_functional(params, x)\n",
" log_post = log_likelihood(y_pred, y) + log_prior(params, prior_mean, prior_sd) / samps_per_episode\n",
" log_post = (\n",
" log_likelihood(y_pred, y)\n",
" + log_prior(params, prior_mean, prior_sd) / samps_per_episode\n",
" )\n",
" return log_post, y_pred"
]
},
Expand Down Expand Up @@ -213,7 +235,13 @@
"outputs": [],
"source": [
"batch_size = 3\n",
"dataloaders = [torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x.unsqueeze(-1), y.unsqueeze(-1)), batch_size=batch_size) for x, y in zip(xs, ys)]"
"dataloaders = [\n",
" torch.utils.data.DataLoader(\n",
" torch.utils.data.TensorDataset(x.unsqueeze(-1), y.unsqueeze(-1)),\n",
" batch_size=batch_size,\n",
" )\n",
" for x, y in zip(xs, ys)\n",
"]"
]
},
{
Expand All @@ -227,7 +255,9 @@
" for _ in range(n_epochs):\n",
" for batch in dataloader:\n",
" opt.zero_grad()\n",
" loss = -log_posterior(dict(mlp.named_parameters()), batch, prior_mean, prior_sd)[0]\n",
" loss = -log_posterior(\n",
" dict(mlp.named_parameters()), batch, prior_mean, prior_sd\n",
" )[0]\n",
" loss.backward()\n",
" opt.step()"
]
Expand All @@ -252,10 +282,10 @@
"metadata": {},
"outputs": [],
"source": [
"def plot_predictions(params, ax, x, sd=y_sd, alpha=1.):\n",
"def plot_predictions(params, ax, x, sd=y_sd, alpha=1.0):\n",
" preds = mlp_functional(params, x.unsqueeze(-1)).detach().numpy().squeeze()\n",
" ax.plot(x, preds, color='blue', alpha=alpha)\n",
" ax.fill_between(x, preds - sd, preds + sd, color='blue', alpha=0.2)"
" ax.plot(x, preds, color=\"blue\", alpha=alpha)\n",
" ax.fill_between(x, preds - sd, preds + sd, color=\"blue\", alpha=0.2)"
]
},
{
Expand All @@ -275,12 +305,14 @@
}
],
"source": [
"fig, axes = plt.subplots(1, n_episodes, figsize=(n_episodes * 4, 4), sharex=True, sharey=True)\n",
"fig, axes = plt.subplots(\n",
" 1, n_episodes, figsize=(n_episodes * 4, 4), sharex=True, sharey=True\n",
")\n",
"\n",
"for i, ax in enumerate(axes):\n",
" plot_data(ax, up_to_episode=i+1)\n",
" plot_data(ax, up_to_episode=i + 1)\n",
" plot_predictions(trained_params[i], ax, plt_linsp)\n",
" ax.set_title(f\"After Episode {i+1}\")"
" ax.set_title(f\"After Episode {i + 1}\")"
]
},
{
Expand Down Expand Up @@ -318,7 +350,13 @@
"def train_for_vi(dataloader, prior_mean, prior_sd, n_epochs=200, init_log_sds=None):\n",
" seq_log_post = partial(log_posterior, prior_mean=prior_mean, prior_sd=prior_sd)\n",
" optimizer = torchopt.adam(lr=2e-3)\n",
" transform = posteriors.vi.diag.build(seq_log_post, optimizer, temperature=1/samps_per_episode, init_log_sds=init_log_sds, stl=False)\n",
" transform = posteriors.vi.diag.build(\n",
" seq_log_post,\n",
" optimizer,\n",
" temperature=1 / samps_per_episode,\n",
" init_log_sds=init_log_sds,\n",
" stl=False,\n",
" )\n",
" state = transform.init(dict(mlp.named_parameters()))\n",
" nelbos = []\n",
" for _ in range(n_epochs):\n",
Expand Down Expand Up @@ -346,9 +384,18 @@
"nelbos = []\n",
"for i in range(n_episodes):\n",
" seq_prior_mean = prior_mean if i == 0 else vi_states[i - 1].params\n",
" seq_prior_sd = prior_sd if i == 0 else tree_map(lambda lsd: torch.sqrt(torch.exp(lsd) ** 2 + transition_sd ** 2), vi_states[i - 1].log_sd_diag)\n",
" seq_log_sds = tree_map(lambda x: torch.zeros_like(x) - 6., mlp.state_dict())\n",
" state, nelbos_i = train_for_vi(dataloaders[i], seq_prior_mean, seq_prior_sd, init_log_sds=seq_log_sds)\n",
" seq_prior_sd = (\n",
" prior_sd\n",
" if i == 0\n",
" else tree_map(\n",
" lambda lsd: torch.sqrt(torch.exp(lsd) ** 2 + transition_sd**2),\n",
" vi_states[i - 1].log_sd_diag,\n",
" )\n",
" )\n",
" seq_log_sds = tree_map(lambda x: torch.zeros_like(x) - 6.0, mlp.state_dict())\n",
" state, nelbos_i = train_for_vi(\n",
" dataloaders[i], seq_prior_mean, seq_prior_sd, init_log_sds=seq_log_sds\n",
" )\n",
" vi_states += [state]\n",
" nelbos += [nelbos_i]\n",
" mlp.load_state_dict(vi_states[i].params)"
Expand All @@ -371,16 +418,18 @@
}
],
"source": [
"fig, axes = plt.subplots(1, n_episodes, figsize=(n_episodes * 4, 4), sharex=True, sharey=True)\n",
"fig, axes = plt.subplots(\n",
" 1, n_episodes, figsize=(n_episodes * 4, 4), sharex=True, sharey=True\n",
")\n",
"\n",
"n_samples = 20\n",
"\n",
"for i, ax in enumerate(axes):\n",
" for _ in range(n_samples):\n",
" sample = posteriors.vi.diag.sample(vi_states[i])\n",
" plot_predictions(sample, ax, plt_linsp, sd=y_sd, alpha=0.2)\n",
" plot_data(ax, up_to_episode=i+1)\n",
" ax.set_title(f\"After Episode {i+1}\")"
" plot_data(ax, up_to_episode=i + 1)\n",
" ax.set_title(f\"After Episode {i + 1}\")"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions examples/pyro_pima_indians_sghmc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@
" samples[:, i] = torch.stack([state.params for state in states])\n",
" if i > N_warmup:\n",
" j = i - N_warmup\n",
" gelman_rubin[j] = pyro.ops.stats.gelman_rubin(log_posts[:, N_warmup:i + 1])\n"
" gelman_rubin[j] = pyro.ops.stats.gelman_rubin(log_posts[:, N_warmup : i + 1])"
]
},
{
Expand Down Expand Up @@ -469,7 +469,7 @@
"for ind, ax in enumerate(axes.flatten()):\n",
" ax.hist(samples[:, N_warmup:, ind].flatten(), bins=50, density=True)\n",
" ax.set_title(column_names[ind])\n",
"fig.tight_layout()\n"
"fig.tight_layout()"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "posteriors"
version = "0.1.0"
version = "0.1.1"
description = "Uncertainty quantification with PyTorch"
readme = "README.md"
requires-python =">=3.9"
Expand Down
3 changes: 2 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def test_model_to_function():

func_output2 = func_lm(dict(lm.named_parameters()), input_ids, attention_mask)

assert type(output) == type(func_output1) == type(func_output2)
assert type(output) is type(func_output1)
assert type(output) is type(func_output2)
assert torch.allclose(output["logits"], func_output1["logits"])
assert torch.allclose(output["logits"], func_output2["logits"])

Expand Down