Skip to content

Commit 6e83ee6

Browse files
author
Ricardo
committed
Add logcdf methods
1 parent 6b9f696 commit 6e83ee6

File tree

2 files changed

+150
-4
lines changed

2 files changed

+150
-4
lines changed

aeppl/logprob.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,18 @@ def logprob(rv_var, *rv_values, **kwargs):
3838
return logprob
3939

4040

41+
def logcdf(rv_var, rv_value, **kwargs):
42+
"""Create a graph for the logcdf of a ``RandomVariable``."""
43+
logcdf = _logcdf(
44+
rv_var.owner.op, rv_value, *rv_var.owner.inputs, name=rv_var.name, **kwargs
45+
)
46+
47+
if rv_var.name:
48+
logcdf.name = f"{rv_var.name}_logcdf"
49+
50+
return logcdf
51+
52+
4153
@singledispatch
4254
def _logprob(
4355
op: Op,
@@ -52,7 +64,23 @@ def _logprob(
5264
for a ``RandomVariable``, register a new function on this dispatcher.
5365
5466
"""
55-
raise NotImplementedError()
67+
raise NotImplementedError(f"Logprob method not implemented for {op}")
68+
69+
70+
@singledispatch
71+
def _logcdf(
72+
op: Op,
73+
value: TensorVariable,
74+
*inputs: TensorVariable,
75+
**kwargs,
76+
):
77+
"""Create a graph for the logcdf of a ``RandomVariable``.
78+
79+
This function dispatches on the type of ``op``, which should be a subclass
80+
of ``RandomVariable``. If you want to implement new logcdf graphs
81+
for a ``RandomVariable``, register a new function on this dispatcher.
82+
"""
83+
raise NotImplementedError(f"Logcdf method not implemented for {op}")
5684

5785

5886
@_logprob.register(arb.UniformRV)
@@ -66,6 +94,24 @@ def uniform_logprob(op, values, *inputs, **kwargs):
6694
)
6795

6896

97+
@_logcdf.register(arb.UniformRV)
98+
def uniform_logcdf(op, value, *inputs, **kwargs):
99+
lower, upper = inputs[3:]
100+
101+
res = at.switch(
102+
at.lt(value, lower),
103+
-np.inf,
104+
at.switch(
105+
at.lt(value, upper),
106+
at.log(value - lower) - at.log(upper - lower),
107+
0,
108+
),
109+
)
110+
111+
res = Assert("lower <= upper")(res, at.all(at.le(lower, upper)))
112+
return res
113+
114+
69115
@_logprob.register(arb.NormalRV)
70116
def normal_logprob(op, values, *inputs, **kwargs):
71117
(value,) = values
@@ -79,6 +125,21 @@ def normal_logprob(op, values, *inputs, **kwargs):
79125
return res
80126

81127

128+
@_logcdf.register(arb.NormalRV)
129+
def normal_logcdf(op, value, *inputs, **kwargs):
130+
mu, sigma = inputs[3:]
131+
132+
z = (value - mu) / sigma
133+
res = at.switch(
134+
at.lt(z, -1.0),
135+
at.log(at.erfcx(-z / at.sqrt(2.0)) / 2.0) - at.sqr(z) / 2.0,
136+
at.log1p(-at.erfc(z / at.sqrt(2.0)) / 2.0),
137+
)
138+
139+
res = Assert("sigma > 0")(res, at.all(at.gt(sigma, 0.0)))
140+
return res
141+
142+
82143
@_logprob.register(arb.HalfNormalRV)
83144
def halfnormal_logprob(op, values, *inputs, **kwargs):
84145
(value,) = values
@@ -346,6 +407,16 @@ def poisson_logprob(op, values, *inputs, **kwargs):
346407
return res
347408

348409

410+
@_logcdf.register(arb.PoissonRV)
411+
def poisson_logcdf(op, value, *inputs, **kwargs):
412+
(mu,) = inputs[3:]
413+
value = at.floor(value)
414+
res = at.log(at.gammaincc(value + 1, mu))
415+
res = at.switch(at.le(0, value), res, -np.inf)
416+
res = Assert("0 <= mu")(res, at.all(at.le(0.0, mu)))
417+
return res
418+
419+
349420
@_logprob.register(arb.NegBinomialRV)
350421
def nbinom_logprob(op, values, *inputs, **kwargs):
351422
(value,) = values

tests/test_logprob.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
import scipy.stats as stats
77

8-
from aeppl.logprob import logprob
8+
from aeppl.logprob import logcdf, logprob
99

1010
# @pytest.fixture(scope="module", autouse=True)
1111
# def set_aesara_flags():
@@ -33,7 +33,7 @@ def create_aesara_params(dist_params, obs, size):
3333

3434

3535
def scipy_logprob_tester(
36-
rv_var, obs, dist_params, test_fn=None, check_broadcastable=True
36+
rv_var, obs, dist_params, test_fn=None, check_broadcastable=True, test_logcdf=False
3737
):
3838
"""Test for correspondence between `RandomVariable` and NumPy shape and
3939
broadcast dimensions.
@@ -46,7 +46,10 @@ def scipy_logprob_tester(
4646

4747
test_fn = getattr(stats, name)
4848

49-
aesara_res = logprob(rv_var, at.as_tensor(obs))
49+
if not test_logcdf:
50+
aesara_res = logprob(rv_var, at.as_tensor(obs))
51+
else:
52+
aesara_res = logcdf(rv_var, at.as_tensor(obs))
5053
aesara_res_val = aesara_res.eval(dist_params)
5154

5255
numpy_res = np.asarray(test_fn(obs, *dist_params.values()))
@@ -83,6 +86,26 @@ def scipy_logprob(obs, l, u):
8386
scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logprob)
8487

8588

89+
@pytest.mark.parametrize(
90+
"dist_params, obs, size",
91+
[
92+
((0, 1), np.array([-1, 0, 0.5, 1, 2], dtype=np.float64), ()),
93+
((-2, -1), np.array([-3, -2, -0.5, -1, 0], dtype=np.float64), ()),
94+
],
95+
)
96+
def test_uniform_logcdf(dist_params, obs, size):
97+
98+
dist_params_at, obs_at, size_at = create_aesara_params(dist_params, obs, size)
99+
dist_params = dict(zip(dist_params_at, dist_params))
100+
101+
x = at.random.uniform(*dist_params_at, size=size_at)
102+
103+
def scipy_logcdf(obs, l, u):
104+
return stats.uniform.logcdf(obs, loc=l, scale=u - l)
105+
106+
scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logcdf, test_logcdf=True)
107+
108+
86109
@pytest.mark.parametrize(
87110
"dist_params, obs, size",
88111
[
@@ -101,6 +124,26 @@ def test_normal_logprob(dist_params, obs, size):
101124
scipy_logprob_tester(x, obs, dist_params, test_fn=stats.norm.logpdf)
102125

103126

127+
@pytest.mark.parametrize(
128+
"dist_params, obs, size",
129+
[
130+
((0, 1), np.array([0, 0.5, 1, -1], dtype=np.float64), ()),
131+
((-1, 20), np.array([0, 0.5, 1, -1], dtype=np.float64), ()),
132+
((-1, 20), np.array([0, 0.5, 1, -1], dtype=np.float64), (2, 3)),
133+
],
134+
)
135+
def test_normal_logcdf(dist_params, obs, size):
136+
137+
dist_params_at, obs_at, size_at = create_aesara_params(dist_params, obs, size)
138+
dist_params = dict(zip(dist_params_at, dist_params))
139+
140+
x = at.random.normal(*dist_params_at, size=size_at)
141+
142+
scipy_logprob_tester(
143+
x, obs, dist_params, test_fn=stats.norm.logcdf, test_logcdf=True
144+
)
145+
146+
104147
@pytest.mark.parametrize(
105148
"dist_params, obs, size",
106149
[
@@ -620,6 +663,38 @@ def scipy_logprob(obs, mu):
620663
scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logprob)
621664

622665

666+
@pytest.mark.parametrize(
667+
"dist_params, obs, size, error",
668+
[
669+
((-1,), np.array([-1, 0, 1, 100, 10000], dtype=np.int64), (), True),
670+
((1.0,), np.array([-1, 0, 1, 100, 10000], dtype=np.int64), (), False),
671+
((0.5,), np.array([-1, 0, 1, 100, 10000], dtype=np.int64), (3, 2), False),
672+
(
673+
(np.array([0.01, 0.2, 200]),),
674+
np.array([-1, 1, 84], dtype=np.int64),
675+
(),
676+
False,
677+
),
678+
],
679+
)
680+
def test_poisson_logcdf(dist_params, obs, size, error):
681+
682+
dist_params_at, obs_at, size_at = create_aesara_params(dist_params, obs, size)
683+
dist_params = dict(zip(dist_params_at, dist_params))
684+
685+
x = at.random.poisson(*dist_params_at, size=size_at)
686+
687+
cm = contextlib.suppress() if not error else pytest.raises(AssertionError)
688+
689+
def scipy_logcdf(obs, mu):
690+
return stats.poisson.logcdf(obs, mu)
691+
692+
with cm:
693+
scipy_logprob_tester(
694+
x, obs, dist_params, test_fn=scipy_logcdf, test_logcdf=True
695+
)
696+
697+
623698
@pytest.mark.parametrize(
624699
"dist_params, obs, size, error",
625700
[

0 commit comments

Comments
 (0)