Skip to content

Commit b7aa31d

Browse files
Example of an autoguide
1 parent a870c7c commit b7aa31d

File tree

2 files changed

+232
-0
lines changed

2 files changed

+232
-0
lines changed

pymc/variational/autoguide.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import pytensor.tensor as pt
17+
18+
from pytensor import Variable, graph_replace
19+
from pytensor.graph import vectorize_graph
20+
21+
import pymc as pm
22+
23+
from pymc.model.core import Model
24+
25+
ModelVariable = Variable | str
26+
27+
28+
def AutoDiagonalNormal(model):
29+
coords = model.coords
30+
free_rvs = model.free_RVs
31+
draws = pt.tensor("draws", shape=(), dtype="int64")
32+
33+
with Model(coords=coords) as guide_model:
34+
for rv in free_rvs:
35+
loc = pt.tensor(f"{rv.name}_loc", shape=rv.type.shape)
36+
scale = pt.tensor(f"{rv.name}_scale", shape=rv.type.shape)
37+
z = pm.Normal(
38+
f"{rv.name}_z",
39+
mu=0,
40+
sigma=1,
41+
shape=(draws, *rv.type.shape),
42+
transform=model.rvs_to_transforms[rv],
43+
)
44+
pm.Deterministic(
45+
rv.name, loc + scale * z, dims=model.named_vars_to_dims.get(rv.name, None)
46+
)
47+
48+
return guide_model
49+
50+
51+
def AutoFullRankNormal(model):
52+
# TODO: Broken
53+
54+
coords = model.coords
55+
free_rvs = model.free_RVs
56+
draws = pt.tensor("draws", shape=(), dtype="int64")
57+
58+
rv_sizes = [np.prod(rv.type.shape) for rv in free_rvs]
59+
total_size = np.sum(rv_sizes)
60+
tril_size = total_size * (total_size + 1) // 2
61+
62+
locs = [pt.tensor(f"{rv.name}_loc", shape=rv.type.shape) for rv in free_rvs]
63+
packed_L = pt.tensor("L", shape=(tril_size,), dtype="float64")
64+
L = pm.expand_packed_triangular(packed_L)
65+
66+
with Model(coords=coords) as guide_model:
67+
z = pm.MvNormal(
68+
"z", mu=np.zeros(total_size), cov=np.eye(total_size), size=(draws, total_size)
69+
)
70+
params = pt.concatenate([loc.ravel() for loc in locs]) + L @ z
71+
72+
cursor = 0
73+
74+
for rv, size in zip(free_rvs, rv_sizes):
75+
pm.Deterministic(
76+
rv.name,
77+
params[cursor : cursor + size].reshape(rv.type.shape),
78+
dims=model.named_vars_to_dims.get(rv.name, None),
79+
)
80+
cursor += size
81+
82+
return guide_model
83+
84+
85+
def get_logp_logq(model, guide_model):
86+
inputs_to_guide_rvs = {
87+
model_value_var: guide_model[rv.name]
88+
for rv, model_value_var in model.rvs_to_values.items()
89+
if rv not in model.observed_RVs
90+
}
91+
92+
logp = vectorize_graph(model.logp(), inputs_to_guide_rvs)
93+
logq = graph_replace(guide_model.logp(), guide_model.values_to_rvs)
94+
95+
return logp, logq

tests/variational/test_autoguide.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pytensor.tensor as pt
16+
import pytest
17+
18+
import pymc as pm
19+
20+
from pymc.variational.autoguide import AutoDiagonalNormal, AutoFullRankNormal, get_logp_logq
21+
22+
Parameter = pt.tensor
23+
24+
25+
@pytest.fixture(scope="module")
26+
def X_y_params():
27+
"""Generate synthetic data for testing."""
28+
29+
rng = np.random.default_rng(sum(map(ord, "autoguide_test")))
30+
31+
alpha = rng.normal(loc=100, scale=10)
32+
beta = rng.normal(loc=0, scale=1, size=(10,))
33+
34+
true_params = {
35+
"alpha": alpha,
36+
"beta": beta,
37+
}
38+
39+
X_data = rng.normal(size=(100, 10))
40+
y_data = alpha + X_data @ beta
41+
42+
return X_data, y_data, true_params
43+
44+
45+
@pytest.fixture(scope="module")
46+
def model(X_y_params):
47+
X_data, y_data, _ = X_y_params
48+
49+
with pm.Model() as model:
50+
X = pm.Data("X", X_data)
51+
alpha = pm.Normal("alpha", 100, 10)
52+
beta = pm.Normal("beta", 0, 5, size=(10,))
53+
54+
mu = alpha + X @ beta
55+
sigma = pm.Exponential("sigma", 1)
56+
y = pm.Normal("y", mu=mu, sigma=sigma, observed=y_data)
57+
58+
return model
59+
60+
61+
@pytest.fixture(scope="module")
62+
def target_guide_model(X_y_params):
63+
X_data, *_ = X_y_params
64+
65+
draws = pt.tensor("draws", shape=(), dtype="int64")
66+
67+
with pm.Model() as guide_model:
68+
X = pm.Data("X", X_data)
69+
70+
alpha_loc = Parameter("alpha_loc", shape=())
71+
alpha_scale = Parameter("alpha_scale", shape=())
72+
alpha_z = pm.Normal("alpha_z", mu=0, sigma=1, shape=(draws,))
73+
alpha = pm.Deterministic("alpha", alpha_loc + alpha_scale * alpha_z)
74+
75+
beta_loc = Parameter("beta_loc", shape=(10,))
76+
beta_scale = Parameter("beta_scale", shape=(10,))
77+
beta_z = pm.Normal("beta_z", mu=0, sigma=1, shape=(draws, 10))
78+
beta = pm.Deterministic("beta", beta_loc + beta_scale * beta_z)
79+
80+
sigma_loc = Parameter("sigma_loc", shape=())
81+
sigma_scale = Parameter("sigma_scale", shape=())
82+
sigma_z = pm.Normal(
83+
"sigma_z", 0, 1, shape=(draws,), transform=pm.distributions.transforms.log
84+
)
85+
sigma = pm.Deterministic("sigma", sigma_loc + sigma_scale * sigma_z)
86+
87+
return guide_model
88+
89+
90+
def test_diagonal_normal_autoguide(model, target_guide_model, X_y_params):
91+
guide_model = AutoDiagonalNormal(model)
92+
93+
logp, logq = get_logp_logq(model, guide_model)
94+
logp_target, logq_target = get_logp_logq(model, target_guide_model)
95+
96+
inputs = pm.inputvars(logp)
97+
target_inputs = pm.inputvars(logp_target)
98+
99+
expected_locs = [f"{var}_loc" for var in ["alpha", "beta", "sigma"]]
100+
expected_scales = [f"{var}_scale" for var in ["alpha", "beta", "sigma"]]
101+
102+
expected_inputs = expected_locs + expected_scales + ["draws"]
103+
name_to_input = {input.name: input for input in inputs}
104+
name_to_target_input = {input.name: input for input in target_inputs}
105+
106+
assert all(input.name in expected_inputs for input in inputs), (
107+
"Guide inputs do not match expected inputs"
108+
)
109+
110+
negative_elbo = (logq - logp).mean()
111+
negative_elbo_target = (logq_target - logp_target).mean()
112+
113+
fn = pm.compile(
114+
[name_to_input[input] for input in expected_inputs], negative_elbo, random_seed=69420
115+
)
116+
fn_target = pm.compile(
117+
[name_to_target_input[input] for input in expected_inputs],
118+
negative_elbo_target,
119+
random_seed=69420,
120+
)
121+
122+
test_inputs = {
123+
"alpha_loc": np.zeros(()),
124+
"alpha_scale": np.ones(()),
125+
"beta_loc": np.zeros(10),
126+
"beta_scale": np.ones(10),
127+
"sigma_loc": np.zeros(()),
128+
"sigma_scale": np.ones(()),
129+
"draws": 100,
130+
}
131+
132+
np.testing.assert_allclose(fn(**test_inputs), fn_target(**test_inputs))
133+
134+
135+
def test_full_mv_normal_guide(model, X_y_params):
136+
guide_model = AutoFullRankNormal(model)
137+
logp, logq = get_logp_logq(model, guide_model)

0 commit comments

Comments
 (0)