Skip to content

Commit 38f6d75

Browse files
author
Ricardo
committed
Implement censored RVs logprob
1 parent 6e83ee6 commit 38f6d75

File tree

3 files changed

+318
-0
lines changed

3 files changed

+318
-0
lines changed

aeppl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
# isort: off
1313
# Add optimizations to the DBs
1414
import aeppl.mixture
15+
import aeppl.truncation
1516

1617
# isort: on

aeppl/truncation.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import warnings
2+
from typing import List, Optional
3+
4+
import aesara.tensor as at
5+
import numpy as np
6+
from aesara.assert_op import Assert
7+
from aesara.graph.basic import Node
8+
from aesara.graph.fg import FunctionGraph
9+
from aesara.graph.opt import local_optimizer
10+
from aesara.scalar.basic import Clip
11+
from aesara.scalar.basic import clip as scalar_clip
12+
from aesara.tensor.elemwise import Elemwise
13+
from aesara.tensor.random.op import RandomVariable
14+
from aesara.tensor.var import TensorConstant
15+
16+
from aeppl.abstract import MeasurableVariable
17+
from aeppl.logprob import _logcdf, _logprob
18+
from aeppl.opt import rv_sinking_db
19+
20+
21+
class CensoredRV(Elemwise):
22+
"""A placeholder used to specify a log-likelihood for a censored RV sub-graph."""
23+
24+
25+
MeasurableVariable.register(CensoredRV)
26+
27+
28+
@local_optimizer(tracks=[Elemwise])
29+
def find_censored_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[CensoredRV]]:
30+
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)
31+
32+
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
33+
if rv_map_feature is None:
34+
return None # pragma: no cover
35+
36+
if isinstance(node.op, CensoredRV):
37+
return None
38+
39+
if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Clip)):
40+
return None
41+
42+
clipped_var = node.outputs[0]
43+
if clipped_var not in rv_map_feature.rv_values:
44+
return None
45+
46+
base_var, lower_bound, upper_bound = node.inputs
47+
48+
if not (base_var.owner and isinstance(base_var.owner.op, RandomVariable)):
49+
return None
50+
51+
if base_var in rv_map_feature.rv_values:
52+
warnings.warn(
53+
f"Value variables were assigned to both the input ({base_var}) and "
54+
f"output ({clipped_var}) of a censored random variable."
55+
)
56+
return None
57+
58+
# Replace bounds by `+-inf` if `y = clip(x, x, ?)` or `y=clip(x, ?, x)`
59+
# This is used in `censor_logprob` to generate a more succint logprob graph
60+
# for one-sided censored random variables
61+
lower_bound = lower_bound if (lower_bound is not base_var) else at.constant(-np.inf)
62+
upper_bound = upper_bound if (upper_bound is not base_var) else at.constant(np.inf)
63+
64+
censored_op = CensoredRV(scalar_clip)
65+
censored_rv_node = censored_op.make_node(base_var, lower_bound, upper_bound)
66+
censored_rv = censored_rv_node.outputs[0]
67+
68+
censored_rv.name = clipped_var.name
69+
base_var.tag.ignore_logprob = True
70+
71+
return [censored_rv]
72+
73+
74+
rv_sinking_db.register("find_censored_rvs", find_censored_rvs, -5, "basic")
75+
76+
77+
@_logprob.register(CensoredRV)
78+
def censor_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
79+
(value,) = values
80+
81+
base_rv_op = base_rv.owner.op
82+
base_rv_inputs = base_rv.owner.inputs
83+
84+
logprob = _logprob(base_rv_op, (value,), *base_rv_inputs, **kwargs)
85+
logcdf = _logcdf(base_rv_op, value, *base_rv_inputs, **kwargs)
86+
87+
if base_rv_op.name:
88+
logprob.name = f"{base_rv_op}_logprob"
89+
logcdf.name = f"{base_rv_op}_logcdf"
90+
91+
is_lower_bounded, is_upper_bounded = False, False
92+
if not (
93+
isinstance(upper_bound, TensorConstant) and np.all(upper_bound.value == np.inf)
94+
):
95+
is_upper_bounded = True
96+
97+
logccdf = at.log(1 - at.exp(logcdf))
98+
# For right censored discrete RVs, we need to add an extra term
99+
# corresponding to the pmf at the upper bound
100+
if base_rv_op.dtype == "int64":
101+
logccdf = at.logaddexp(logccdf, logprob)
102+
103+
logprob = at.switch(
104+
at.eq(value, upper_bound),
105+
logccdf,
106+
at.switch(at.gt(value, upper_bound), -np.inf, logprob),
107+
)
108+
if not (
109+
isinstance(lower_bound, TensorConstant) and np.all(lower_bound.value == -np.inf)
110+
):
111+
is_lower_bounded = True
112+
logprob = at.switch(
113+
at.eq(value, lower_bound),
114+
logcdf,
115+
at.switch(at.lt(value, lower_bound), -np.inf, logprob),
116+
)
117+
118+
if is_lower_bounded and is_upper_bounded:
119+
logprob = Assert("lower_bound <= upper_bound")(
120+
logprob, at.all(at.le(lower_bound, upper_bound))
121+
)
122+
123+
return logprob

tests/test_truncation.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
import aesara
2+
import aesara.tensor as at
3+
import numpy as np
4+
import pytest
5+
import scipy as sp
6+
import scipy.stats as st
7+
8+
from aeppl import joint_logprob
9+
from aeppl.transforms import LogTransform, TransformValuesOpt
10+
from tests.utils import assert_no_rvs
11+
12+
13+
def test_continuous_rv_censoring():
14+
x_rv = at.random.normal(0.5, 1, name="x_rv")
15+
cens_x_rv = at.clip(x_rv, -2, 2)
16+
17+
cens_x = cens_x_rv.type()
18+
19+
logp = joint_logprob({cens_x_rv: cens_x})
20+
assert_no_rvs(logp)
21+
22+
logp_fn = aesara.function([cens_x], logp)
23+
ref_scipy = st.norm(0.5, 1)
24+
25+
assert logp_fn(-3) == -np.inf
26+
assert logp_fn(3) == -np.inf
27+
28+
assert np.isclose(logp_fn(-2), ref_scipy.logcdf(-2))
29+
assert np.isclose(logp_fn(2), ref_scipy.logsf(2))
30+
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0))
31+
32+
33+
def test_discrete_rv_censoring():
34+
x_rv = at.random.poisson(2)
35+
cens_x_rv = at.clip(x_rv, 1, 4)
36+
cens_x_rv.name = "cens_x_rv"
37+
38+
cens_x = cens_x_rv.type()
39+
40+
logp = joint_logprob({cens_x_rv: cens_x})
41+
assert_no_rvs(logp)
42+
43+
logp_fn = aesara.function([cens_x], logp)
44+
ref_scipy = st.poisson(2)
45+
46+
assert logp_fn(0) == -np.inf
47+
assert logp_fn(5) == -np.inf
48+
49+
assert np.isclose(logp_fn(1), ref_scipy.logcdf(1))
50+
assert np.isclose(logp_fn(4), np.logaddexp(ref_scipy.logsf(4), ref_scipy.logpmf(4)))
51+
assert np.isclose(logp_fn(2), ref_scipy.logpmf(2))
52+
53+
54+
def test_one_sided_censoring():
55+
x_rv = at.random.normal(0, 1)
56+
lb_cens_x_rv = at.clip(x_rv, -1, x_rv)
57+
ub_cens_x_rv = at.clip(x_rv, x_rv, 1)
58+
59+
lb_cens_x = lb_cens_x_rv.type()
60+
ub_cens_x = ub_cens_x_rv.type()
61+
62+
lb_logp = joint_logprob({lb_cens_x_rv: lb_cens_x})
63+
ub_logp = joint_logprob({ub_cens_x_rv: ub_cens_x})
64+
assert_no_rvs(lb_logp)
65+
assert_no_rvs(ub_logp)
66+
67+
logp_fn = aesara.function([lb_cens_x, ub_cens_x], [lb_logp, ub_logp])
68+
ref_scipy = st.norm(0, 1)
69+
70+
assert np.all(np.array(logp_fn(-2, 2)) == -np.inf)
71+
assert np.all(np.array(logp_fn(2, -2)) != -np.inf)
72+
np.testing.assert_almost_equal(logp_fn(-1, 1), ref_scipy.logcdf(-1))
73+
np.testing.assert_almost_equal(logp_fn(1, -1), ref_scipy.logpdf(-1))
74+
75+
76+
def test_useless_censoring():
77+
x_rv = at.random.normal(0.5, 1, size=3)
78+
cens_x_rv = at.clip(x_rv, x_rv, x_rv)
79+
80+
cens_x = cens_x_rv.type()
81+
82+
logp = joint_logprob({cens_x_rv: cens_x}, sum=False)
83+
assert_no_rvs(logp)
84+
85+
logp_fn = aesara.function([cens_x], logp)
86+
ref_scipy = st.norm(0.5, 1)
87+
88+
np.testing.assert_allclose(logp_fn([-2, 0, 2]), ref_scipy.logpdf([-2, 0, 2]))
89+
90+
91+
def test_random_censoring():
92+
lb_rv = at.random.normal(0, 1, size=2)
93+
x_rv = at.random.normal(0, 2)
94+
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])
95+
96+
lb = lb_rv.type()
97+
cens_x = cens_x_rv.type()
98+
logp = joint_logprob({cens_x_rv: cens_x, lb_rv: lb}, sum=False)
99+
assert_no_rvs(logp)
100+
101+
logp_fn = aesara.function([lb, cens_x], logp)
102+
res = logp_fn([0, -1], [-1, -1])
103+
assert res[0] == -np.inf
104+
assert res[1] != -np.inf
105+
106+
107+
def test_broadcasted_censoring_constant():
108+
lb_rv = at.random.uniform(0, 1, name="lb_rv")
109+
x_rv = at.random.normal(0, 2, name="x_rv")
110+
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])
111+
112+
lb = lb_rv.type()
113+
lb.name = "lb"
114+
cens_x = cens_x_rv.type()
115+
cens_x.name = "cens_x"
116+
117+
logp = joint_logprob({cens_x_rv: cens_x, lb_rv: lb})
118+
assert_no_rvs(logp)
119+
120+
121+
def test_broadcasted_censoring_random():
122+
lb_rv = at.random.normal(0, 1, name="lb_rv")
123+
x_rv = at.random.normal(0, 2, size=2, name="x_rv")
124+
cens_x_rv = at.clip(x_rv, lb_rv, 1)
125+
126+
lb = lb_rv.type()
127+
lb.name = "lb"
128+
cens_x = cens_x_rv.type()
129+
cens_x.name = "cens_x"
130+
131+
logp = joint_logprob({cens_x_rv: cens_x, lb_rv: lb})
132+
assert_no_rvs(logp)
133+
134+
135+
def test_failed_censoring():
136+
# Test failure when both base_rv and clipped_rv are given value vars
137+
x_rv = at.random.normal(0, 1)
138+
cens_x_rv = at.clip(x_rv, x_rv, 1)
139+
140+
x = x_rv.type()
141+
cens_x = cens_x_rv.type()
142+
with pytest.raises(NotImplementedError):
143+
joint_logprob({cens_x_rv: cens_x, x_rv: x})
144+
145+
# Test failure when base is not a RandomVariable
146+
x = at.vector("x")
147+
clipped_x_rv = at.clip(x, x, 1)
148+
149+
clipped_x = clipped_x_rv.type()
150+
with pytest.raises(NotImplementedError):
151+
joint_logprob({clipped_x_rv: clipped_x})
152+
153+
154+
def test_deterministic_clipping():
155+
x_rv = at.random.normal(0, 1)
156+
clip = at.clip(x_rv, 0, 0)
157+
y_rv = at.random.normal(clip, 1)
158+
159+
x = x_rv.type()
160+
y = y_rv.type()
161+
logp = joint_logprob({x_rv: x, y_rv: y})
162+
assert_no_rvs(logp)
163+
164+
logp_fn = aesara.function([x, y], logp)
165+
assert np.isclose(
166+
logp_fn(-1, 1),
167+
st.norm(0, 1).logpdf(-1) + st.norm(0, 1).logpdf(1),
168+
)
169+
170+
171+
@aesara.config.change_flags(compute_test_value="raise")
172+
def test_censored_test_value():
173+
x_rv = at.random.normal(0, 1)
174+
cens_x_rv = at.clip(x_rv, -1, 1)
175+
cens_x = cens_x_rv.type()
176+
cens_x.tag.test_value = 0
177+
joint_logprob({cens_x_rv: cens_x})
178+
179+
180+
@pytest.mark.xfail(reason="Transform does not work with Elemwise ops, see #60")
181+
def test_censored_transform():
182+
x_rv = at.random.normal(0.5, 1, name="x_rv")
183+
cens_x_rv = at.clip(x_rv, 0, x_rv)
184+
185+
cens_x = cens_x_rv.type()
186+
187+
transform = TransformValuesOpt({cens_x: LogTransform()})
188+
logp = joint_logprob({cens_x_rv: cens_x}, extra_rewrites=transform)
189+
190+
cens_x_val = -1
191+
obs_logp = logp.eval({cens_x: cens_x_val})
192+
exp_logp = sp.stats.norm(0.5, 1).logpdf(np.exp(cens_x_val)) + cens_x_val
193+
194+
assert np.isclose(obs_logp, exp_logp)

0 commit comments

Comments
 (0)