5
5
import pytest
6
6
import scipy .stats as stats
7
7
8
- from aeppl .logprob import logprob
8
+ from aeppl .logprob import logcdf , logprob
9
9
10
10
# @pytest.fixture(scope="module", autouse=True)
11
11
# def set_aesara_flags():
@@ -33,7 +33,7 @@ def create_aesara_params(dist_params, obs, size):
33
33
34
34
35
35
def scipy_logprob_tester (
36
- rv_var , obs , dist_params , test_fn = None , check_broadcastable = True
36
+ rv_var , obs , dist_params , test_fn = None , check_broadcastable = True , test_logcdf = False
37
37
):
38
38
"""Test for correspondence between `RandomVariable` and NumPy shape and
39
39
broadcast dimensions.
@@ -46,7 +46,10 @@ def scipy_logprob_tester(
46
46
47
47
test_fn = getattr (stats , name )
48
48
49
- aesara_res = logprob (rv_var , at .as_tensor (obs ))
49
+ if not test_logcdf :
50
+ aesara_res = logprob (rv_var , at .as_tensor (obs ))
51
+ else :
52
+ aesara_res = logcdf (rv_var , at .as_tensor (obs ))
50
53
aesara_res_val = aesara_res .eval (dist_params )
51
54
52
55
numpy_res = np .asarray (test_fn (obs , * dist_params .values ()))
@@ -83,6 +86,26 @@ def scipy_logprob(obs, l, u):
83
86
scipy_logprob_tester (x , obs , dist_params , test_fn = scipy_logprob )
84
87
85
88
89
+ @pytest .mark .parametrize (
90
+ "dist_params, obs, size" ,
91
+ [
92
+ ((0 , 1 ), np .array ([- 1 , 0 , 0.5 , 1 , 2 ], dtype = np .float64 ), ()),
93
+ ((- 2 , - 1 ), np .array ([- 3 , - 2 , - 0.5 , - 1 , 0 ], dtype = np .float64 ), ()),
94
+ ],
95
+ )
96
+ def test_uniform_logcdf (dist_params , obs , size ):
97
+
98
+ dist_params_at , obs_at , size_at = create_aesara_params (dist_params , obs , size )
99
+ dist_params = dict (zip (dist_params_at , dist_params ))
100
+
101
+ x = at .random .uniform (* dist_params_at , size = size_at )
102
+
103
+ def scipy_logcdf (obs , l , u ):
104
+ return stats .uniform .logcdf (obs , loc = l , scale = u - l )
105
+
106
+ scipy_logprob_tester (x , obs , dist_params , test_fn = scipy_logcdf , test_logcdf = True )
107
+
108
+
86
109
@pytest .mark .parametrize (
87
110
"dist_params, obs, size" ,
88
111
[
@@ -101,6 +124,26 @@ def test_normal_logprob(dist_params, obs, size):
101
124
scipy_logprob_tester (x , obs , dist_params , test_fn = stats .norm .logpdf )
102
125
103
126
127
+ @pytest .mark .parametrize (
128
+ "dist_params, obs, size" ,
129
+ [
130
+ ((0 , 1 ), np .array ([0 , 0.5 , 1 , - 1 ], dtype = np .float64 ), ()),
131
+ ((- 1 , 20 ), np .array ([0 , 0.5 , 1 , - 1 ], dtype = np .float64 ), ()),
132
+ ((- 1 , 20 ), np .array ([0 , 0.5 , 1 , - 1 ], dtype = np .float64 ), (2 , 3 )),
133
+ ],
134
+ )
135
+ def test_normal_logcdf (dist_params , obs , size ):
136
+
137
+ dist_params_at , obs_at , size_at = create_aesara_params (dist_params , obs , size )
138
+ dist_params = dict (zip (dist_params_at , dist_params ))
139
+
140
+ x = at .random .normal (* dist_params_at , size = size_at )
141
+
142
+ scipy_logprob_tester (
143
+ x , obs , dist_params , test_fn = stats .norm .logcdf , test_logcdf = True
144
+ )
145
+
146
+
104
147
@pytest .mark .parametrize (
105
148
"dist_params, obs, size" ,
106
149
[
@@ -620,6 +663,38 @@ def scipy_logprob(obs, mu):
620
663
scipy_logprob_tester (x , obs , dist_params , test_fn = scipy_logprob )
621
664
622
665
666
+ @pytest .mark .parametrize (
667
+ "dist_params, obs, size, error" ,
668
+ [
669
+ ((- 1 ,), np .array ([- 1 , 0 , 1 , 100 , 10000 ], dtype = np .int64 ), (), True ),
670
+ ((1.0 ,), np .array ([- 1 , 0 , 1 , 100 , 10000 ], dtype = np .int64 ), (), False ),
671
+ ((0.5 ,), np .array ([- 1 , 0 , 1 , 100 , 10000 ], dtype = np .int64 ), (3 , 2 ), False ),
672
+ (
673
+ (np .array ([0.01 , 0.2 , 200 ]),),
674
+ np .array ([- 1 , 1 , 84 ], dtype = np .int64 ),
675
+ (),
676
+ False ,
677
+ ),
678
+ ],
679
+ )
680
+ def test_poisson_logcdf (dist_params , obs , size , error ):
681
+
682
+ dist_params_at , obs_at , size_at = create_aesara_params (dist_params , obs , size )
683
+ dist_params = dict (zip (dist_params_at , dist_params ))
684
+
685
+ x = at .random .poisson (* dist_params_at , size = size_at )
686
+
687
+ cm = contextlib .suppress () if not error else pytest .raises (AssertionError )
688
+
689
+ def scipy_logcdf (obs , mu ):
690
+ return stats .poisson .logcdf (obs , mu )
691
+
692
+ with cm :
693
+ scipy_logprob_tester (
694
+ x , obs , dist_params , test_fn = scipy_logcdf , test_logcdf = True
695
+ )
696
+
697
+
623
698
@pytest .mark .parametrize (
624
699
"dist_params, obs, size, error" ,
625
700
[
0 commit comments