Skip to content

Forecast exogenous vars bug fix #510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
130 changes: 78 additions & 52 deletions pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2047,6 +2047,69 @@ def _finalize_scenario_initialization(

return scenario

def _build_forecast_model(
self, time_index, t0, forecast_index, scenario, filter_output, mvn_method
):
filter_time_dim = TIME_DIM
temp_coords = self._fit_coords.copy()

dims = None
if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]):
dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]

t0_idx = np.flatnonzero(time_index == t0)[0]

temp_coords["data_time"] = time_index
temp_coords[TIME_DIM] = forecast_index

mu_dims, cov_dims = None, None
if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]):
mu_dims = ["data_time", ALL_STATE_DIM]
cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM]

with pm.Model(coords=temp_coords) as forecast_model:
(_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
data_dims=["data_time", OBS_STATE_DIM],
)

group_idx = FILTER_OUTPUT_TYPES.index(filter_output)
mu, cov = grouped_outputs[group_idx]

sub_dict = {
data_var: pt.as_tensor_variable(data_var.get_value(), name="data")
for data_var in forecast_model.data_vars
}

missing_data_vars = np.setdiff1d(
ar1=[*self.data_names, "data"], ar2=[k.name for k, _ in sub_dict.items()]
)
if missing_data_vars.size > 0:
raise ValueError(f"{missing_data_vars} data used for fitting not found!")

mu_frozen, cov_frozen = graph_replace([mu, cov], replace=sub_dict, strict=True)

x0 = pm.Deterministic(
"x0_slice", mu_frozen[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None
)
P0 = pm.Deterministic(
"P0_slice", cov_frozen[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
)

_ = LinearGaussianStateSpace(
"forecast",
x0,
P0,
*matrices,
steps=len(forecast_index),
dims=dims,
sequence_names=self.kalman_filter.seq_names,
k_endog=self.k_endog,
append_x0=False,
method=mvn_method,
)

return forecast_model

def forecast(
self,
idata: InferenceData,
Expand Down Expand Up @@ -2139,8 +2202,6 @@ def forecast(
the latent state trajectories: `y[t] = Z @ x[t] + nu[t]`, where `nu ~ N(0, H)`.

"""
filter_time_dim = TIME_DIM

_validate_filter_arg(filter_output)

compile_kwargs = kwargs.pop("compile_kwargs", {})
Expand Down Expand Up @@ -2185,58 +2246,23 @@ def forecast(
use_scenario_index=use_scenario_index,
)
scenario = self._finalize_scenario_initialization(scenario, forecast_index)
temp_coords = self._fit_coords.copy()

dims = None
if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]):
dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]

t0_idx = np.flatnonzero(time_index == t0)[0]

temp_coords["data_time"] = time_index
temp_coords[TIME_DIM] = forecast_index

mu_dims, cov_dims = None, None
if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]):
mu_dims = ["data_time", ALL_STATE_DIM]
cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM]

with pm.Model(coords=temp_coords) as forecast_model:
(_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
scenario=scenario,
data_dims=["data_time", OBS_STATE_DIM],
)

for name in self.data_names:
if name in scenario.keys():
pm.set_data(
{"data": np.zeros((len(forecast_index), self.k_endog))},
coords={"data_time": np.arange(len(forecast_index))},
)
break

group_idx = FILTER_OUTPUT_TYPES.index(filter_output)
mu, cov = grouped_outputs[group_idx]

x0 = pm.Deterministic(
"x0_slice", mu[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None
)
P0 = pm.Deterministic(
"P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
)
forecast_model = self._build_forecast_model(
time_index=time_index,
t0=t0,
forecast_index=forecast_index,
scenario=scenario,
filter_output=filter_output,
mvn_method=mvn_method,
)

_ = LinearGaussianStateSpace(
"forecast",
x0,
P0,
*matrices,
steps=len(forecast_index),
dims=dims,
sequence_names=self.kalman_filter.seq_names,
k_endog=self.k_endog,
append_x0=False,
method=mvn_method,
)
with forecast_model:
if scenario is not None:
dummy_obs_data = np.zeros((len(forecast_index), self.k_endog))
pm.set_data(
scenario | {"data": dummy_obs_data},
coords={"data_time": np.arange(len(forecast_index))},
)

forecast_model.rvs_to_initial_values = {
k: None for k in forecast_model.rvs_to_initial_values.keys()
Expand Down
101 changes: 96 additions & 5 deletions tests/statespace/core/test_statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import pytest

from numpy.testing import assert_allclose
from pymc.testing import mock_sample_setup_and_teardown
from pytensor.compile import SharedVariable
from pytensor.graph.basic import graph_inputs

from pymc_extras.statespace.core.statespace import FILTER_FACTORY, PyMCStateSpace
from pymc_extras.statespace.models import structural as st
Expand All @@ -30,6 +33,7 @@
floatX = pytensor.config.floatX
nile = load_nile_test_data()
ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES
mock_pymc_sample = pytest.fixture(scope="session")(mock_sample_setup_and_teardown)


def make_statespace_mod(k_endog, k_states, k_posdef, filter_type, verbose=False, data_info=None):
Expand Down Expand Up @@ -170,7 +174,7 @@ def exog_pymc_mod(exog_ss_mod, exog_data):
)
beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=["exog_state"])

exog_ss_mod.build_statespace_graph(exog_data["y"])
exog_ss_mod.build_statespace_graph(exog_data["y"], save_kalman_filter_outputs_in_idata=True)

return struct_model

Expand Down Expand Up @@ -212,7 +216,7 @@ def pymc_mod_no_exog_dt(ss_mod_no_exog_dt, rng):


@pytest.fixture(scope="session")
def idata(pymc_mod, rng):
def idata(pymc_mod, rng, mock_pymc_sample):
with pymc_mod:
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
Expand All @@ -222,7 +226,7 @@ def idata(pymc_mod, rng):


@pytest.fixture(scope="session")
def idata_exog(exog_pymc_mod, rng):
def idata_exog(exog_pymc_mod, rng, mock_pymc_sample):
with exog_pymc_mod:
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
Expand All @@ -231,7 +235,7 @@ def idata_exog(exog_pymc_mod, rng):


@pytest.fixture(scope="session")
def idata_no_exog(pymc_mod_no_exog, rng):
def idata_no_exog(pymc_mod_no_exog, rng, mock_pymc_sample):
with pymc_mod_no_exog:
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
Expand All @@ -240,7 +244,7 @@ def idata_no_exog(pymc_mod_no_exog, rng):


@pytest.fixture(scope="session")
def idata_no_exog_dt(pymc_mod_no_exog_dt, rng):
def idata_no_exog_dt(pymc_mod_no_exog_dt, rng, mock_pymc_sample):
with pymc_mod_no_exog_dt:
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
Expand Down Expand Up @@ -895,6 +899,93 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
assert_allclose(regression_effect, regression_effect_expected)


@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op")
@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.")
def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data, idata_exog):
data_before_build_forecast_model = {d.name: d.get_value() for d in exog_pymc_mod.data_vars}

scenario = pd.DataFrame(
{
"date": pd.date_range(start="2023-05-11", end="2023-05-20", freq="D"),
"x1": rng.choice(2, size=10, replace=True).astype(float),
}
)
scenario.set_index("date", inplace=True)

time_index = exog_ss_mod._get_fit_time_index()
t0, forecast_index = exog_ss_mod._build_forecast_index(
time_index=time_index,
start=exog_data.index[-1],
end=scenario.index[-1],
scenario=scenario,
)

test_forecast_model = exog_ss_mod._build_forecast_model(
time_index=time_index,
t0=t0,
forecast_index=forecast_index,
scenario=scenario,
filter_output="predicted",
mvn_method="svd",
)

frozen_shared_inputs = [
inpt
for inpt in graph_inputs([test_forecast_model.x0_slice, test_forecast_model.P0_slice])
if isinstance(inpt, SharedVariable)
and not isinstance(inpt.get_value(), np.random.Generator)
]

assert (
len(frozen_shared_inputs) == 0
) # check there are no non-random generator SharedVariables in the frozen inputs

unfrozen_shared_inputs = [
inpt
for inpt in graph_inputs([test_forecast_model.forecast_combined])
if isinstance(inpt, SharedVariable)
and not isinstance(inpt.get_value(), np.random.Generator)
]

# Check that there is one (in this case) unfrozen shared input and it corresponds to the exogenous data
assert len(unfrozen_shared_inputs) == 1
assert unfrozen_shared_inputs[0].name == "data_exog"

data_after_build_forecast_model = {d.name: d.get_value() for d in test_forecast_model.data_vars}

with test_forecast_model:
dummy_obs_data = np.zeros((len(forecast_index), exog_ss_mod.k_endog))
pm.set_data(
{"data_exog": scenario} | {"data": dummy_obs_data},
coords={"data_time": np.arange(len(forecast_index))},
)
idata_forecast = pm.sample_posterior_predictive(
idata_exog, var_names=["x0_slice", "P0_slice"]
)

np.testing.assert_allclose(
unfrozen_shared_inputs[0].get_value(), scenario["x1"].values.reshape((-1, 1))
) # ensure the replaced data matches the exogenous data

for k in data_before_build_forecast_model.keys():
assert ( # check that the data needed to init the forecasts doesn't change
data_before_build_forecast_model[k].mean() == data_after_build_forecast_model[k].mean()
)

# Check that the frozen states and covariances correctly match the sliced index
np.testing.assert_allclose(
idata_exog.posterior["predicted_covariance"].sel(time=t0).mean(("chain", "draw")).values,
idata_forecast.posterior_predictive["P0_slice"].mean(("chain", "draw")).values,
)
np.testing.assert_allclose(
idata_exog.posterior["predicted_state"].sel(time=t0).mean(("chain", "draw")).values,
idata_forecast.posterior_predictive["x0_slice"].mean(("chain", "draw")).values,
)


@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
Expand Down