Skip to content

Commit 9dd9052

Browse files
author
Ricardo
committed
Implement RV censoring logprob and opt
1 parent e3930a2 commit 9dd9052

File tree

3 files changed

+291
-2
lines changed

3 files changed

+291
-2
lines changed

aeppl/opt.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from aesara.compile.mode import optdb
66
from aesara.graph.features import Feature
77
from aesara.graph.op import compute_test_value
8-
from aesara.graph.opt import EquilibriumOptimizer, local_optimizer
8+
from aesara.graph.opt import EquilibriumOptimizer, local_optimizer, out2in
99
from aesara.graph.optdb import SequenceDB
1010
from aesara.tensor.extra_ops import BroadcastTo
1111
from aesara.tensor.random.op import RandomVariable
@@ -21,6 +21,7 @@
2121
)
2222
from aesara.tensor.var import TensorVariable
2323

24+
from aeppl.truncation import censor_rvs
2425
from aeppl.utils import indices_from_subtensor
2526

2627
inc_subtensor_ops = (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)
@@ -182,4 +183,5 @@ def naive_bcast_rv_lift(fgraph, node):
182183

183184

184185
logprob_canonicalize.register("canonicalize", optdb["canonicalize"], -10, "basic")
185-
logprob_canonicalize.register("rvsinker", RVSinker(), -1, "basic")
186+
logprob_canonicalize.register("rvsinker", RVSinker(), -5, "basic")
187+
logprob_canonicalize.register("censor_rvs", out2in(censor_rvs), -1, "basic")

aeppl/truncation.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import warnings
2+
from typing import List, Optional
3+
4+
import aesara.tensor as at
5+
import numpy as np
6+
from aesara.assert_op import Assert
7+
from aesara.graph.basic import Node
8+
from aesara.graph.fg import FunctionGraph
9+
from aesara.graph.opt import local_optimizer
10+
from aesara.scalar.basic import Clip
11+
from aesara.tensor.elemwise import Elemwise
12+
from aesara.tensor.random.op import RandomVariable
13+
from aesara.tensor.var import TensorConstant
14+
15+
from aeppl.logprob import _logcdf, _logprob
16+
17+
18+
# TODO: Add interval transform
19+
class CensoredRV(RandomVariable):
20+
r"""A base class for censored `RandomVariable`\s."""
21+
22+
def __init__(self, *args, base_op, **kwargs):
23+
self.base_op = base_op
24+
super().__init__(*args, **kwargs)
25+
26+
27+
@_logprob.register(CensoredRV)
28+
def censor_logprob(op, value, *inputs, name=None, **kwargs):
29+
30+
*rv_params, lower_bound, upper_bound = inputs
31+
logprob = _logprob(op.base_op, value, *rv_params, name=name, **kwargs)
32+
logcdf = _logcdf(op.base_op, value, *rv_params, name=name, **kwargs)
33+
if op.base_op.name:
34+
logprob.name = f"{op.base_op.name}_logprob"
35+
logcdf.name = f"{op.base_op.name}_logcdf"
36+
37+
is_lower_bounded, is_upper_bounded = False, False
38+
if not (
39+
isinstance(upper_bound, TensorConstant) and np.all(upper_bound.value == np.inf)
40+
):
41+
is_upper_bounded = True
42+
logprob = at.switch(
43+
at.eq(value, upper_bound),
44+
at.log(1 - at.exp(logcdf)),
45+
at.switch(at.gt(value, upper_bound), -np.inf, logprob),
46+
)
47+
if not (
48+
isinstance(lower_bound, TensorConstant) and np.all(lower_bound.value == -np.inf)
49+
):
50+
is_lower_bounded = True
51+
logprob = at.switch(
52+
at.eq(value, lower_bound),
53+
logcdf,
54+
at.switch(at.lt(value, lower_bound), -np.inf, logprob),
55+
)
56+
57+
if is_lower_bounded and is_upper_bounded:
58+
logprob = Assert("lower_bound <= upper_bound")(
59+
logprob, at.all(at.le(lower_bound, upper_bound))
60+
)
61+
62+
return logprob
63+
64+
65+
@local_optimizer(tracks=[Elemwise])
66+
def censor_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
67+
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)
68+
69+
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
70+
if rv_map_feature is None:
71+
return
72+
73+
if not isinstance(node.op, Elemwise) or not isinstance(node.op.scalar_op, Clip):
74+
return
75+
76+
clipped_var = node.outputs[0]
77+
if clipped_var not in rv_map_feature.rv_values:
78+
return
79+
80+
base_var, lower_bound, upper_bound = node.inputs
81+
base_op = base_var.owner.op
82+
83+
if not isinstance(base_op, RandomVariable):
84+
return
85+
86+
if base_var in rv_map_feature.rv_values:
87+
warnings.warn(
88+
f"Value variables were assigned to both the input ({base_var}) and "
89+
f"output ({clipped_var}) of a censored random variable."
90+
)
91+
return
92+
93+
censored_rv = CensoredRV(
94+
"censored",
95+
base_op.ndim_supp,
96+
list(base_op.ndims_params) + [base_op.ndim_supp] * 2,
97+
base_op.dtype,
98+
inplace=False,
99+
base_op=base_op,
100+
)
101+
102+
# Replace bounds by +-inf if `y = clip(x, x, ?)` or `y=clip(x, ?, x)`
103+
# This is used in `censor_logprob` to generate a more succint logprob graph
104+
# for one-sided censored random variables
105+
# TODO: This will probably fail for multivariate variables
106+
lower_bound = lower_bound if (lower_bound is not base_var) else -np.inf
107+
upper_bound = upper_bound if (upper_bound is not base_var) else np.inf
108+
109+
censored_node = censored_rv.make_node(
110+
*base_var.owner.inputs,
111+
lower_bound,
112+
upper_bound,
113+
)
114+
115+
censored_node_out = censored_node.outputs[1]
116+
117+
if not censored_node_out.type == clipped_var.type:
118+
# TODO: issue warning?
119+
return
120+
121+
if clipped_var.name:
122+
censored_node_out.name = clipped_var.name
123+
elif base_var.name:
124+
censored_node_out.name = f"{base_var.name}_censored"
125+
126+
return [censored_node_out]

tests/test_truncation.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import aesara
2+
import aesara.tensor as at
3+
import numpy as np
4+
import pytest
5+
import scipy.stats as st
6+
7+
from aeppl import joint_logprob
8+
from tests.utils import assert_no_rvs
9+
10+
11+
def test_uniform_censoring():
12+
x_rv = at.random.uniform(-4, 5)
13+
cens_x_rv = at.clip(x_rv, -1, 1)
14+
cens_x_rv.name = "cens_x_rv"
15+
16+
cens_x = cens_x_rv.type()
17+
18+
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x})
19+
assert_no_rvs(logp)
20+
21+
logp_fn = aesara.function([cens_x], logp)
22+
ref_scipy = st.uniform(-4, 9)
23+
24+
assert logp_fn(-5) == -np.inf
25+
assert logp_fn(6) == -np.inf
26+
27+
assert np.isclose(logp_fn(-1), ref_scipy.logcdf(-1))
28+
assert np.isclose(logp_fn(5), ref_scipy.logsf(5))
29+
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0))
30+
31+
32+
def test_normal_censoring():
33+
x_rv = at.random.normal(0.5, 1, name="x_rv")
34+
cens_x_rv = at.clip(x_rv, -2, 2)
35+
36+
cens_x = cens_x_rv.type()
37+
38+
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x})
39+
assert_no_rvs(logp)
40+
41+
logp_fn = aesara.function([cens_x], logp)
42+
ref_scipy = st.norm(0.5, 1)
43+
44+
assert logp_fn(-3) == -np.inf
45+
assert logp_fn(3) == -np.inf
46+
47+
assert np.isclose(logp_fn(-2), ref_scipy.logcdf(-2))
48+
assert np.isclose(logp_fn(2), ref_scipy.logsf(2))
49+
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0))
50+
51+
52+
def test_one_sided_censoring():
53+
x_rv = at.random.normal(0, 1)
54+
lb_cens_x_rv = at.clip(x_rv, -1, x_rv)
55+
ub_cens_x_rv = at.clip(x_rv, x_rv, 1)
56+
57+
lb_cens_x = lb_cens_x_rv.type()
58+
ub_cens_x = ub_cens_x_rv.type()
59+
60+
lb_logp = joint_logprob(lb_cens_x_rv, {lb_cens_x_rv: lb_cens_x})
61+
ub_logp = joint_logprob(ub_cens_x_rv, {ub_cens_x_rv: ub_cens_x})
62+
assert_no_rvs(lb_logp)
63+
assert_no_rvs(ub_logp)
64+
65+
logp_fn = aesara.function([lb_cens_x, ub_cens_x], [lb_logp, ub_logp])
66+
ref_scipy = st.norm(0, 1)
67+
68+
assert np.all(np.array(logp_fn(-2, 2)) == -np.inf)
69+
assert np.all(np.array(logp_fn(2, -2)) != -np.inf)
70+
np.testing.assert_almost_equal(logp_fn(-1, 1), ref_scipy.logcdf(-1))
71+
np.testing.assert_almost_equal(logp_fn(1, -1), ref_scipy.logpdf(-1))
72+
73+
74+
def test_useless_censoring():
75+
x_rv = at.random.normal(0.5, 1, size=3)
76+
cens_x_rv = at.clip(x_rv, x_rv, x_rv)
77+
78+
cens_x = cens_x_rv.type()
79+
80+
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x})
81+
assert_no_rvs(logp)
82+
83+
logp_fn = aesara.function([cens_x], logp)
84+
ref_scipy = st.norm(0.5, 1)
85+
86+
np.testing.assert_allclose(logp_fn([-2, 0, 2]), ref_scipy.logpdf([-2, 0, 2]))
87+
88+
89+
def test_random_censoring():
90+
lb_rv = at.random.normal(0, 1, size=2)
91+
x_rv = at.random.normal(0, 2)
92+
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])
93+
94+
lb = lb_rv.type()
95+
cens_x = cens_x_rv.type()
96+
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb})
97+
assert_no_rvs(logp)
98+
99+
logp_fn = aesara.function([lb, cens_x], logp)
100+
res = logp_fn([0, -1], [-1, -1])
101+
assert res[0] == -np.inf
102+
assert res[1] != -np.inf
103+
104+
105+
@pytest.mark.xfail(reason="Broadcasting not properly handled yet")
106+
def test_broadcasted_censoring_constant():
107+
lb_rv = at.random.uniform(0, 1, name="lb_rv")
108+
x_rv = at.random.normal(0, 2, name="x_rv")
109+
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])
110+
111+
lb = lb_rv.type()
112+
lb.name = "lb"
113+
cens_x = cens_x_rv.type()
114+
cens_x.name = "cens_x"
115+
116+
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb})
117+
assert_no_rvs(logp)
118+
119+
120+
@pytest.mark.xfail(reason="Broadcasting not properly handled yet")
121+
def test_broadcasted_censoring_random():
122+
lb_rv = at.random.normal(0, 1, name="lb_rv")
123+
x_rv = at.random.normal(0, 2, size=2, name="x_rv")
124+
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])
125+
126+
lb = lb_rv.type()
127+
lb.name = "lb"
128+
cens_x = cens_x_rv.type()
129+
cens_x.name = "cens_x"
130+
131+
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb})
132+
assert_no_rvs(logp)
133+
134+
135+
def test_failed_censoring():
136+
# Test that `joint_logprob` fails when both base_rv and clipped_rv are given
137+
# value vars
138+
x_rv = at.random.normal(0, 1)
139+
cens_x_rv = at.clip(x_rv, x_rv, 1)
140+
141+
x = x_rv.type()
142+
cens_x = cens_x_rv.type()
143+
with pytest.raises(NotImplementedError):
144+
joint_logprob(cens_x_rv, {cens_x_rv: cens_x, x_rv: x})
145+
146+
147+
def test_deterministic_clipping():
148+
x_rv = at.random.normal(0, 1)
149+
clip = at.clip(x_rv, 0, 0)
150+
y_rv = at.random.normal(clip, 1)
151+
152+
x = x_rv.type()
153+
y = y_rv.type()
154+
logp = joint_logprob(y_rv, {x_rv: x, y_rv: y})
155+
assert_no_rvs(logp)
156+
157+
logp_fn = aesara.function([x, y], logp)
158+
assert np.isclose(
159+
logp_fn(-1, 1),
160+
st.norm(0, 1).logpdf(-1) + st.norm(0, 1).logpdf(1),
161+
)

0 commit comments

Comments
 (0)