Skip to content

Commit be77d9d

Browse files
Fix bug in TimeSeasonality when pop_state = False (#407)
Rename `pop_state` to `remove_first_state` for clarity. Add test for `remove_first_state = False`
1 parent dcc2bec commit be77d9d

File tree

2 files changed

+31
-9
lines changed

2 files changed

+31
-9
lines changed

pymc_extras/statespace/models/structural.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,12 @@ class TimeSeasonality(Component):
10711071
10721072
If None, states will be numbered ``[State_0, ..., State_s]``
10731073
1074+
remove_first_state: bool, default True
1075+
If True, the first state will be removed from the model. This is done because there are only n-1 degrees of
1076+
freedom in the seasonal component, and one state is not identified. If False, the first state will be
1077+
included in the model, but it will not be identified -- you will need to handle this in the priors (e.g. with
1078+
ZeroSumNormal).
1079+
10741080
Notes
10751081
-----
10761082
A seasonal effect is any pattern that repeats every fixed interval. Although there are many possible ways to
@@ -1163,7 +1169,7 @@ def __init__(
11631169
innovations: bool = True,
11641170
name: str | None = None,
11651171
state_names: list | None = None,
1166-
pop_state: bool = True,
1172+
remove_first_state: bool = True,
11671173
):
11681174
if name is None:
11691175
name = f"Seasonal[s={season_length}]"
@@ -1176,14 +1182,15 @@ def __init__(
11761182
)
11771183
state_names = state_names.copy()
11781184
self.innovations = innovations
1179-
self.pop_state = pop_state
1185+
self.remove_first_state = remove_first_state
11801186

1181-
if self.pop_state:
1187+
if self.remove_first_state:
11821188
# In traditional models, the first state isn't identified, so we can help out the user by automatically
11831189
# discarding it.
11841190
# TODO: Can this be stashed and reconstructed automatically somehow?
11851191
state_names.pop(0)
1186-
k_states = season_length - 1
1192+
1193+
k_states = season_length - int(self.remove_first_state)
11871194

11881195
super().__init__(
11891196
name=name,
@@ -1218,8 +1225,16 @@ def populate_component_properties(self):
12181225
self.shock_names = [f"{self.name}"]
12191226

12201227
def make_symbolic_graph(self) -> None:
1221-
T = np.eye(self.k_states, k=-1)
1222-
T[0, :] = -1
1228+
if self.remove_first_state:
1229+
# In this case, parameters are normalized to sum to zero, so the current state is the negative sum of
1230+
# all previous states.
1231+
T = np.eye(self.k_states, k=-1)
1232+
T[0, :] = -1
1233+
else:
1234+
# In this case we assume the user to be responsible for ensuring the states sum to zero, so T is just a
1235+
# circulant matrix that cycles between the states.
1236+
T = np.eye(self.k_states, k=1)
1237+
T[-1, 0] = 1
12231238

12241239
self.ssm["transition", :, :] = T
12251240
self.ssm["design", 0, 0] = 1

tests/statespace/test_structural.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33

44
from collections import defaultdict
5+
from copyreg import remove_extension
56
from typing import Optional
67

78
import numpy as np
@@ -592,13 +593,18 @@ def test_autoregressive_model(order, rng):
592593

593594
@pytest.mark.parametrize("s", [10, 25, 50])
594595
@pytest.mark.parametrize("innovations", [True, False])
595-
def test_time_seasonality(s, innovations, rng):
596+
@pytest.mark.parametrize("remove_first_state", [True, False])
597+
def test_time_seasonality(s, innovations, remove_first_state, rng):
596598
def random_word(rng):
597599
return "".join(rng.choice(list("abcdefghijklmnopqrstuvwxyz")) for _ in range(5))
598600

599601
state_names = [random_word(rng) for _ in range(s)]
600602
mod = st.TimeSeasonality(
601-
season_length=s, innovations=innovations, name="season", state_names=state_names
603+
season_length=s,
604+
innovations=innovations,
605+
name="season",
606+
state_names=state_names,
607+
remove_first_state=remove_first_state,
602608
)
603609
x0 = np.zeros(mod.k_states, dtype=floatX)
604610
x0[0] = 1
@@ -615,7 +621,8 @@ def random_word(rng):
615621
# Check coords
616622
mod.build(verbose=False)
617623
_assert_basic_coords_correct(mod)
618-
assert mod.coords["season_state"] == state_names[1:]
624+
test_slice = slice(1, None) if remove_first_state else slice(None)
625+
assert mod.coords["season_state"] == state_names[test_slice]
619626

620627

621628
def get_shift_factor(s):

0 commit comments

Comments
 (0)