Skip to content

Commit 958ec44

Browse files
author
Ricardo
committed
Implement censored RVs logprob
1 parent 2df4f7d commit 958ec44

File tree

3 files changed

+317
-0
lines changed

3 files changed

+317
-0
lines changed

aeppl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414
import aeppl.cumsum
1515
import aeppl.mixture
1616
import aeppl.scan
17+
import aeppl.truncation
1718

1819
# isort: on

aeppl/truncation.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import warnings
2+
from typing import List, Optional
3+
4+
import aesara.tensor as at
5+
import numpy as np
6+
from aesara.assert_op import Assert
7+
from aesara.graph.basic import Node
8+
from aesara.graph.fg import FunctionGraph
9+
from aesara.graph.opt import local_optimizer
10+
from aesara.scalar.basic import Clip
11+
from aesara.scalar.basic import clip as scalar_clip
12+
from aesara.tensor.elemwise import Elemwise
13+
from aesara.tensor.var import TensorConstant
14+
15+
from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs
16+
from aeppl.logprob import _logcdf, _logprob
17+
from aeppl.opt import rv_sinking_db
18+
19+
20+
class CensoredRV(Elemwise):
21+
"""A placeholder used to specify a log-likelihood for a censored RV sub-graph."""
22+
23+
24+
MeasurableVariable.register(CensoredRV)
25+
26+
27+
@local_optimizer(tracks=[Elemwise])
28+
def find_censored_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[CensoredRV]]:
29+
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)
30+
31+
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
32+
if rv_map_feature is None:
33+
return None # pragma: no cover
34+
35+
if isinstance(node.op, CensoredRV):
36+
return None # pragma: no cover
37+
38+
if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Clip)):
39+
return None
40+
41+
clipped_var = node.outputs[0]
42+
if clipped_var not in rv_map_feature.rv_values:
43+
return None
44+
45+
base_var, lower_bound, upper_bound = node.inputs
46+
47+
if not (base_var.owner and isinstance(base_var.owner.op, MeasurableVariable)):
48+
return None
49+
50+
if base_var in rv_map_feature.rv_values:
51+
warnings.warn(
52+
f"Value variables were assigned to both the input ({base_var}) and "
53+
f"output ({clipped_var}) of a censored random variable."
54+
)
55+
return None
56+
57+
# Replace bounds by `+-inf` if `y = clip(x, x, ?)` or `y=clip(x, ?, x)`
58+
# This is used in `censor_logprob` to generate a more succint logprob graph
59+
# for one-sided censored random variables
60+
lower_bound = lower_bound if (lower_bound is not base_var) else at.constant(-np.inf)
61+
upper_bound = upper_bound if (upper_bound is not base_var) else at.constant(np.inf)
62+
63+
censored_op = CensoredRV(scalar_clip)
64+
# Make base_var unmeasurable
65+
unmeasurable_base_var = assign_custom_measurable_outputs(base_var.owner)
66+
censored_rv_node = censored_op.make_node(
67+
unmeasurable_base_var, lower_bound, upper_bound
68+
)
69+
censored_rv = censored_rv_node.outputs[0]
70+
71+
censored_rv.name = clipped_var.name
72+
73+
return [censored_rv]
74+
75+
76+
rv_sinking_db.register("find_censored_rvs", find_censored_rvs, -5, "basic")
77+
78+
79+
@_logprob.register(CensoredRV)
80+
def censor_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
81+
(value,) = values
82+
83+
base_rv_op = base_rv.owner.op
84+
base_rv_inputs = base_rv.owner.inputs
85+
86+
logprob = _logprob(base_rv_op, (value,), *base_rv_inputs, **kwargs)
87+
logcdf = _logcdf(base_rv_op, value, *base_rv_inputs, **kwargs)
88+
89+
if base_rv_op.name:
90+
logprob.name = f"{base_rv_op}_logprob"
91+
logcdf.name = f"{base_rv_op}_logcdf"
92+
93+
is_lower_bounded, is_upper_bounded = False, False
94+
if not (
95+
isinstance(upper_bound, TensorConstant) and np.all(upper_bound.value == np.inf)
96+
):
97+
is_upper_bounded = True
98+
99+
logccdf = at.log(1 - at.exp(logcdf))
100+
# For right censored discrete RVs, we need to add an extra term
101+
# corresponding to the pmf at the upper bound
102+
if base_rv_op.dtype == "int64":
103+
logccdf = at.logaddexp(logccdf, logprob)
104+
105+
logprob = at.switch(
106+
at.eq(value, upper_bound),
107+
logccdf,
108+
at.switch(at.gt(value, upper_bound), -np.inf, logprob),
109+
)
110+
if not (
111+
isinstance(lower_bound, TensorConstant) and np.all(lower_bound.value == -np.inf)
112+
):
113+
is_lower_bounded = True
114+
logprob = at.switch(
115+
at.eq(value, lower_bound),
116+
logcdf,
117+
at.switch(at.lt(value, lower_bound), -np.inf, logprob),
118+
)
119+
120+
if is_lower_bounded and is_upper_bounded:
121+
logprob = Assert("lower_bound <= upper_bound")(
122+
logprob, at.all(at.le(lower_bound, upper_bound))
123+
)
124+
125+
return logprob

tests/test_truncation.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import aesara
2+
import aesara.tensor as at
3+
import numpy as np
4+
import pytest
5+
import scipy as sp
6+
import scipy.stats as st
7+
8+
from aeppl import factorized_joint_logprob, joint_logprob
9+
from aeppl.transforms import LogTransform, TransformValuesOpt
10+
from tests.utils import assert_no_rvs
11+
12+
13+
@aesara.config.change_flags(compute_test_value="raise")
14+
def test_continuous_rv_censoring():
15+
x_rv = at.random.normal(0.5, 1)
16+
cens_x_rv = at.clip(x_rv, -2, 2)
17+
18+
cens_x_vv = cens_x_rv.clone()
19+
cens_x_vv.tag.test_value = 0
20+
21+
logp = joint_logprob({cens_x_rv: cens_x_vv})
22+
assert_no_rvs(logp)
23+
24+
logp_fn = aesara.function([cens_x_vv], logp)
25+
ref_scipy = st.norm(0.5, 1)
26+
27+
assert logp_fn(-3) == -np.inf
28+
assert logp_fn(3) == -np.inf
29+
30+
assert np.isclose(logp_fn(-2), ref_scipy.logcdf(-2))
31+
assert np.isclose(logp_fn(2), ref_scipy.logsf(2))
32+
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0))
33+
34+
35+
def test_discrete_rv_censoring():
36+
x_rv = at.random.poisson(2)
37+
cens_x_rv = at.clip(x_rv, 1, 4)
38+
39+
cens_x_vv = cens_x_rv.clone()
40+
41+
logp = joint_logprob({cens_x_rv: cens_x_vv})
42+
assert_no_rvs(logp)
43+
44+
logp_fn = aesara.function([cens_x_vv], logp)
45+
ref_scipy = st.poisson(2)
46+
47+
assert logp_fn(0) == -np.inf
48+
assert logp_fn(5) == -np.inf
49+
50+
assert np.isclose(logp_fn(1), ref_scipy.logcdf(1))
51+
assert np.isclose(logp_fn(4), np.logaddexp(ref_scipy.logsf(4), ref_scipy.logpmf(4)))
52+
assert np.isclose(logp_fn(2), ref_scipy.logpmf(2))
53+
54+
55+
def test_one_sided_censoring():
56+
x_rv = at.random.normal(0, 1)
57+
lb_cens_x_rv = at.clip(x_rv, -1, x_rv)
58+
ub_cens_x_rv = at.clip(x_rv, x_rv, 1)
59+
60+
lb_cens_x_vv = lb_cens_x_rv.clone()
61+
ub_cens_x_vv = ub_cens_x_rv.clone()
62+
63+
lb_logp = joint_logprob({lb_cens_x_rv: lb_cens_x_vv})
64+
ub_logp = joint_logprob({ub_cens_x_rv: ub_cens_x_vv})
65+
assert_no_rvs(lb_logp)
66+
assert_no_rvs(ub_logp)
67+
68+
logp_fn = aesara.function([lb_cens_x_vv, ub_cens_x_vv], [lb_logp, ub_logp])
69+
ref_scipy = st.norm(0, 1)
70+
71+
assert np.all(np.array(logp_fn(-2, 2)) == -np.inf)
72+
assert np.all(np.array(logp_fn(2, -2)) != -np.inf)
73+
np.testing.assert_almost_equal(logp_fn(-1, 1), ref_scipy.logcdf(-1))
74+
np.testing.assert_almost_equal(logp_fn(1, -1), ref_scipy.logpdf(-1))
75+
76+
77+
def test_useless_censoring():
78+
x_rv = at.random.normal(0.5, 1, size=3)
79+
cens_x_rv = at.clip(x_rv, x_rv, x_rv)
80+
81+
cens_x_vv = cens_x_rv.clone()
82+
83+
logp = joint_logprob({cens_x_rv: cens_x_vv}, sum=False)
84+
assert_no_rvs(logp)
85+
86+
logp_fn = aesara.function([cens_x_vv], logp)
87+
ref_scipy = st.norm(0.5, 1)
88+
89+
np.testing.assert_allclose(logp_fn([-2, 0, 2]), ref_scipy.logpdf([-2, 0, 2]))
90+
91+
92+
def test_random_censoring():
93+
lb_rv = at.random.normal(0, 1, size=2)
94+
x_rv = at.random.normal(0, 2)
95+
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])
96+
97+
lb_vv = lb_rv.clone()
98+
cens_x_vv = cens_x_rv.clone()
99+
logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv}, sum=False)
100+
assert_no_rvs(logp)
101+
102+
logp_fn = aesara.function([lb_vv, cens_x_vv], logp)
103+
res = logp_fn([0, -1], [-1, -1])
104+
assert res[0] == -np.inf
105+
assert res[1] != -np.inf
106+
107+
108+
def test_broadcasted_censoring_constant():
109+
lb_rv = at.random.uniform(0, 1)
110+
x_rv = at.random.normal(0, 2)
111+
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])
112+
113+
lb_vv = lb_rv.clone()
114+
cens_x_vv = cens_x_rv.clone()
115+
116+
logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
117+
assert_no_rvs(logp)
118+
119+
120+
def test_broadcasted_censoring_random():
121+
lb_rv = at.random.normal(0, 1)
122+
x_rv = at.random.normal(0, 2, size=2)
123+
cens_x_rv = at.clip(x_rv, lb_rv, 1)
124+
125+
lb_vv = lb_rv.clone()
126+
cens_x_vv = cens_x_rv.clone()
127+
128+
logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
129+
assert_no_rvs(logp)
130+
131+
132+
def test_fail_base_and_censored_have_values():
133+
"""Test failure when both base_rv and clipped_rv are given value vars"""
134+
x_rv = at.random.normal(0, 1)
135+
cens_x_rv = at.clip(x_rv, x_rv, 1)
136+
cens_x_rv.name = "cens_x"
137+
138+
x_vv = x_rv.clone()
139+
cens_x_vv = cens_x_rv.clone()
140+
logp_terms = factorized_joint_logprob({cens_x_rv: cens_x_vv, x_rv: x_vv})
141+
assert cens_x_vv not in logp_terms
142+
143+
144+
def test_fail_multiple_censored_single_base():
145+
"""Test failure when multiple clipped_rvs share a single base_rv"""
146+
base_rv = at.random.normal(0, 1)
147+
cens_rv1 = at.clip(base_rv, -1, 1)
148+
cens_rv1.name = "cens1"
149+
cens_rv2 = at.clip(base_rv, -1, 1)
150+
cens_rv2.name = "cens2"
151+
152+
cens_vv1 = cens_rv1.clone()
153+
cens_vv2 = cens_rv2.clone()
154+
logp_terms = factorized_joint_logprob({cens_rv1: cens_vv1, cens_rv2: cens_vv2})
155+
assert cens_rv2 not in logp_terms
156+
157+
158+
def test_deterministic_clipping():
159+
x_rv = at.random.normal(0, 1)
160+
clip = at.clip(x_rv, 0, 0)
161+
y_rv = at.random.normal(clip, 1)
162+
163+
x_vv = x_rv.clone()
164+
y_vv = y_rv.clone()
165+
logp = joint_logprob({x_rv: x_vv, y_rv: y_vv})
166+
assert_no_rvs(logp)
167+
168+
logp_fn = aesara.function([x_vv, y_vv], logp)
169+
assert np.isclose(
170+
logp_fn(-1, 1),
171+
st.norm(0, 1).logpdf(-1) + st.norm(0, 1).logpdf(1),
172+
)
173+
174+
175+
@pytest.mark.xfail(reason="Transform does not work with Elemwise ops, see #60")
176+
def test_censored_transform():
177+
x_rv = at.random.normal(0.5, 1)
178+
cens_x_rv = at.clip(x_rv, 0, x_rv)
179+
180+
cens_x_vv = cens_x_rv.clone()
181+
182+
transform = TransformValuesOpt({cens_x_vv: LogTransform()})
183+
logp = joint_logprob({cens_x_rv: cens_x_vv}, extra_rewrites=transform)
184+
185+
cens_x_vv_testval = -1
186+
obs_logp = logp.eval({cens_x_vv: cens_x_vv_testval})
187+
exp_logp = (
188+
sp.stats.norm(0.5, 1).logpdf(np.exp(cens_x_vv_testval)) + cens_x_vv_testval
189+
)
190+
191+
assert np.isclose(obs_logp, exp_logp)

0 commit comments

Comments
 (0)