Skip to content

Commit ccbfb22

Browse files
author
Ricardo
committed
Implement censored RVs logprob
1 parent b98d51c commit ccbfb22

File tree

3 files changed

+335
-0
lines changed

3 files changed

+335
-0
lines changed

aeppl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@
1313
# Add optimizations to the DBs
1414
import aeppl.mixture
1515
import aeppl.scan
16+
import aeppl.truncation
1617

1718
# isort: on

aeppl/truncation.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import warnings
2+
from typing import List, Optional
3+
4+
import aesara.tensor as at
5+
import numpy as np
6+
from aesara.assert_op import Assert
7+
from aesara.graph.basic import Node
8+
from aesara.graph.fg import FunctionGraph
9+
from aesara.graph.opt import local_optimizer
10+
from aesara.scalar.basic import Clip
11+
from aesara.scalar.basic import clip as scalar_clip
12+
from aesara.tensor.elemwise import Elemwise
13+
from aesara.tensor.random.op import RandomVariable
14+
from aesara.tensor.var import TensorConstant
15+
16+
from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs
17+
from aeppl.logprob import _logcdf, _logprob
18+
from aeppl.opt import rv_sinking_db
19+
20+
21+
class CensoredRV(Elemwise):
22+
"""A placeholder used to specify a log-likelihood for a censored RV sub-graph."""
23+
24+
25+
MeasurableVariable.register(CensoredRV)
26+
27+
28+
@local_optimizer(tracks=[Elemwise])
29+
def find_censored_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[CensoredRV]]:
30+
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)
31+
32+
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
33+
if rv_map_feature is None:
34+
return None # pragma: no cover
35+
36+
if isinstance(node.op, CensoredRV):
37+
return None
38+
39+
if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Clip)):
40+
return None
41+
42+
clipped_var = node.outputs[0]
43+
if clipped_var not in rv_map_feature.rv_values:
44+
return None
45+
46+
base_var, lower_bound, upper_bound = node.inputs
47+
48+
if not (base_var.owner and isinstance(base_var.owner.op, RandomVariable)):
49+
return None
50+
51+
if base_var in rv_map_feature.rv_values:
52+
warnings.warn(
53+
f"Value variables were assigned to both the input ({base_var}) and "
54+
f"output ({clipped_var}) of a censored random variable."
55+
)
56+
return None
57+
58+
# Replace bounds by `+-inf` if `y = clip(x, x, ?)` or `y=clip(x, ?, x)`
59+
# This is used in `censor_logprob` to generate a more succint logprob graph
60+
# for one-sided censored random variables
61+
lower_bound = lower_bound if (lower_bound is not base_var) else at.constant(-np.inf)
62+
upper_bound = upper_bound if (upper_bound is not base_var) else at.constant(np.inf)
63+
64+
censored_op = CensoredRV(scalar_clip)
65+
# Make base_var unmeasurable
66+
unmeasurable_base_var = assign_custom_measurable_outputs(base_var.owner)
67+
censored_rv_node = censored_op.make_node(
68+
unmeasurable_base_var, lower_bound, upper_bound
69+
)
70+
censored_rv = censored_rv_node.outputs[0]
71+
72+
censored_rv.name = clipped_var.name
73+
74+
return [censored_rv]
75+
76+
77+
rv_sinking_db.register("find_censored_rvs", find_censored_rvs, -5, "basic")
78+
79+
80+
@_logprob.register(CensoredRV)
81+
def censor_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
82+
(value,) = values
83+
84+
base_rv_op = base_rv.owner.op
85+
base_rv_inputs = base_rv.owner.inputs
86+
87+
logprob = _logprob(base_rv_op, (value,), *base_rv_inputs, **kwargs)
88+
logcdf = _logcdf(base_rv_op, value, *base_rv_inputs, **kwargs)
89+
90+
if base_rv_op.name:
91+
logprob.name = f"{base_rv_op}_logprob"
92+
logcdf.name = f"{base_rv_op}_logcdf"
93+
94+
is_lower_bounded, is_upper_bounded = False, False
95+
if not (
96+
isinstance(upper_bound, TensorConstant) and np.all(upper_bound.value == np.inf)
97+
):
98+
is_upper_bounded = True
99+
100+
logccdf = at.log(1 - at.exp(logcdf))
101+
# For right censored discrete RVs, we need to add an extra term
102+
# corresponding to the pmf at the upper bound
103+
if base_rv_op.dtype == "int64":
104+
logccdf = at.logaddexp(logccdf, logprob)
105+
106+
logprob = at.switch(
107+
at.eq(value, upper_bound),
108+
logccdf,
109+
at.switch(at.gt(value, upper_bound), -np.inf, logprob),
110+
)
111+
if not (
112+
isinstance(lower_bound, TensorConstant) and np.all(lower_bound.value == -np.inf)
113+
):
114+
is_lower_bounded = True
115+
logprob = at.switch(
116+
at.eq(value, lower_bound),
117+
logcdf,
118+
at.switch(at.lt(value, lower_bound), -np.inf, logprob),
119+
)
120+
121+
if is_lower_bounded and is_upper_bounded:
122+
logprob = Assert("lower_bound <= upper_bound")(
123+
logprob, at.all(at.le(lower_bound, upper_bound))
124+
)
125+
126+
return logprob

tests/test_truncation.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
import aesara
2+
import aesara.tensor as at
3+
import numpy as np
4+
import pytest
5+
import scipy as sp
6+
import scipy.stats as st
7+
8+
from aeppl import joint_logprob
9+
from aeppl.transforms import LogTransform, TransformValuesOpt
10+
from tests.utils import assert_no_rvs
11+
12+
13+
def test_continuous_rv_censoring():
14+
x_rv = at.random.normal(0.5, 1, name="x_rv")
15+
cens_x_rv = at.clip(x_rv, -2, 2)
16+
17+
cens_x = cens_x_rv.type()
18+
19+
logp = joint_logprob({cens_x_rv: cens_x})
20+
assert_no_rvs(logp)
21+
22+
logp_fn = aesara.function([cens_x], logp)
23+
ref_scipy = st.norm(0.5, 1)
24+
25+
assert logp_fn(-3) == -np.inf
26+
assert logp_fn(3) == -np.inf
27+
28+
assert np.isclose(logp_fn(-2), ref_scipy.logcdf(-2))
29+
assert np.isclose(logp_fn(2), ref_scipy.logsf(2))
30+
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0))
31+
32+
33+
def test_discrete_rv_censoring():
34+
x_rv = at.random.poisson(2)
35+
cens_x_rv = at.clip(x_rv, 1, 4)
36+
cens_x_rv.name = "cens_x_rv"
37+
38+
cens_x = cens_x_rv.type()
39+
40+
logp = joint_logprob({cens_x_rv: cens_x})
41+
assert_no_rvs(logp)
42+
43+
logp_fn = aesara.function([cens_x], logp)
44+
ref_scipy = st.poisson(2)
45+
46+
assert logp_fn(0) == -np.inf
47+
assert logp_fn(5) == -np.inf
48+
49+
assert np.isclose(logp_fn(1), ref_scipy.logcdf(1))
50+
assert np.isclose(logp_fn(4), np.logaddexp(ref_scipy.logsf(4), ref_scipy.logpmf(4)))
51+
assert np.isclose(logp_fn(2), ref_scipy.logpmf(2))
52+
53+
54+
def test_one_sided_censoring():
55+
x_rv = at.random.normal(0, 1)
56+
lb_cens_x_rv = at.clip(x_rv, -1, x_rv)
57+
ub_cens_x_rv = at.clip(x_rv, x_rv, 1)
58+
59+
lb_cens_x = lb_cens_x_rv.type()
60+
ub_cens_x = ub_cens_x_rv.type()
61+
62+
lb_logp = joint_logprob({lb_cens_x_rv: lb_cens_x})
63+
ub_logp = joint_logprob({ub_cens_x_rv: ub_cens_x})
64+
assert_no_rvs(lb_logp)
65+
assert_no_rvs(ub_logp)
66+
67+
logp_fn = aesara.function([lb_cens_x, ub_cens_x], [lb_logp, ub_logp])
68+
ref_scipy = st.norm(0, 1)
69+
70+
assert np.all(np.array(logp_fn(-2, 2)) == -np.inf)
71+
assert np.all(np.array(logp_fn(2, -2)) != -np.inf)
72+
np.testing.assert_almost_equal(logp_fn(-1, 1), ref_scipy.logcdf(-1))
73+
np.testing.assert_almost_equal(logp_fn(1, -1), ref_scipy.logpdf(-1))
74+
75+
76+
def test_useless_censoring():
77+
x_rv = at.random.normal(0.5, 1, size=3)
78+
cens_x_rv = at.clip(x_rv, x_rv, x_rv)
79+
80+
cens_x = cens_x_rv.type()
81+
82+
logp = joint_logprob({cens_x_rv: cens_x}, sum=False)
83+
assert_no_rvs(logp)
84+
85+
logp_fn = aesara.function([cens_x], logp)
86+
ref_scipy = st.norm(0.5, 1)
87+
88+
np.testing.assert_allclose(logp_fn([-2, 0, 2]), ref_scipy.logpdf([-2, 0, 2]))
89+
90+
91+
def test_random_censoring():
92+
lb_rv = at.random.normal(0, 1, size=2)
93+
x_rv = at.random.normal(0, 2)
94+
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])
95+
96+
lb = lb_rv.type()
97+
cens_x = cens_x_rv.type()
98+
logp = joint_logprob({cens_x_rv: cens_x, lb_rv: lb}, sum=False)
99+
assert_no_rvs(logp)
100+
101+
logp_fn = aesara.function([lb, cens_x], logp)
102+
res = logp_fn([0, -1], [-1, -1])
103+
assert res[0] == -np.inf
104+
assert res[1] != -np.inf
105+
106+
107+
def test_broadcasted_censoring_constant():
108+
lb_rv = at.random.uniform(0, 1, name="lb_rv")
109+
x_rv = at.random.normal(0, 2, name="x_rv")
110+
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])
111+
112+
lb = lb_rv.type()
113+
lb.name = "lb"
114+
cens_x = cens_x_rv.type()
115+
cens_x.name = "cens_x"
116+
117+
logp = joint_logprob({cens_x_rv: cens_x, lb_rv: lb})
118+
assert_no_rvs(logp)
119+
120+
121+
def test_broadcasted_censoring_random():
122+
lb_rv = at.random.normal(0, 1, name="lb_rv")
123+
x_rv = at.random.normal(0, 2, size=2, name="x_rv")
124+
cens_x_rv = at.clip(x_rv, lb_rv, 1)
125+
126+
lb = lb_rv.type()
127+
lb.name = "lb"
128+
cens_x = cens_x_rv.type()
129+
cens_x.name = "cens_x"
130+
131+
logp = joint_logprob({cens_x_rv: cens_x, lb_rv: lb})
132+
assert_no_rvs(logp)
133+
134+
135+
def test_fail_base_and_censored_have_values():
136+
"""Test failure when both base_rv and clipped_rv are given value vars"""
137+
x_rv = at.random.normal(0, 1)
138+
cens_x_rv = at.clip(x_rv, x_rv, 1)
139+
140+
x = x_rv.type()
141+
cens_x = cens_x_rv.type()
142+
with pytest.raises(RuntimeError):
143+
joint_logprob({cens_x_rv: cens_x, x_rv: x})
144+
145+
# Test failure when base is not a RandomVariable
146+
x = at.vector("x")
147+
clipped_x_rv = at.clip(x, x, 1)
148+
149+
clipped_x = clipped_x_rv.type()
150+
with pytest.raises(RuntimeError):
151+
joint_logprob({clipped_x_rv: clipped_x})
152+
153+
154+
def test_fail_multiple_censored_single_base():
155+
"""Test failure when multiple clipped_rvs share a single base_rv"""
156+
base_rv = at.random.normal(0, 1)
157+
cens_rv1 = at.clip(base_rv, -1, 1)
158+
cens_rv1.name = "cens1"
159+
cens_rv2 = at.clip(base_rv, -1, 1)
160+
cens_rv2.name = "cens2"
161+
162+
cens_vv1 = cens_rv1.clone()
163+
cens_vv2 = cens_rv2.clone()
164+
with pytest.raises(RuntimeError):
165+
joint_logprob({cens_rv1: cens_vv1, cens_rv2: cens_vv2})
166+
167+
168+
def test_deterministic_clipping():
169+
x_rv = at.random.normal(0, 1)
170+
clip = at.clip(x_rv, 0, 0)
171+
y_rv = at.random.normal(clip, 1)
172+
173+
x = x_rv.type()
174+
y = y_rv.type()
175+
logp = joint_logprob({x_rv: x, y_rv: y})
176+
assert_no_rvs(logp)
177+
178+
logp_fn = aesara.function([x, y], logp)
179+
assert np.isclose(
180+
logp_fn(-1, 1),
181+
st.norm(0, 1).logpdf(-1) + st.norm(0, 1).logpdf(1),
182+
)
183+
184+
185+
@aesara.config.change_flags(compute_test_value="raise")
186+
def test_censored_test_value():
187+
x_rv = at.random.normal(0, 1)
188+
cens_x_rv = at.clip(x_rv, -1, 1)
189+
cens_x = cens_x_rv.type()
190+
cens_x.tag.test_value = 0
191+
joint_logprob({cens_x_rv: cens_x})
192+
193+
194+
@pytest.mark.xfail(reason="Transform does not work with Elemwise ops, see #60")
195+
def test_censored_transform():
196+
x_rv = at.random.normal(0.5, 1, name="x_rv")
197+
cens_x_rv = at.clip(x_rv, 0, x_rv)
198+
199+
cens_x = cens_x_rv.type()
200+
201+
transform = TransformValuesOpt({cens_x: LogTransform()})
202+
logp = joint_logprob({cens_x_rv: cens_x}, extra_rewrites=transform)
203+
204+
cens_x_val = -1
205+
obs_logp = logp.eval({cens_x: cens_x_val})
206+
exp_logp = sp.stats.norm(0.5, 1).logpdf(np.exp(cens_x_val)) + cens_x_val
207+
208+
assert np.isclose(obs_logp, exp_logp)

0 commit comments

Comments
 (0)