Skip to content

Forecast exogenous vars bug fix #510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

Dekermanjian
Copy link
Contributor

@Dekermanjian Dekermanjian commented Jun 9, 2025

This PR addresses a bug in the Statespace module where erroneous forecasts were being produced when exogenous variables were present in the model. This PR addresses issue #491.

The issue stemmed from updating the model's data prior to running the model forward up to the forecast's initial time index as specified by the user. This caused the forecasts to be produced using incorrect x0 and P0 initializations at the forecast initial time index.

The fix ensures that the model is run forward up to the forecast time index and the states (filtered, predicted or smoothed) at that point are frozen using all the observed data. After which new data replacements are made to generate forecasts.

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look great! Left some ideas to simplify the code a bit

Comment on lines 2212 to 2231
if scenario is not None:
sub_dict = {
forecast_model[data_name]: pt.as_tensor_variable(
x=np.atleast_2d(self._exog_data_info[data_name]["value"].T).T,
name=data_name,
)
for data_name in self.data_names
}

# Will this always be named "data"?
sub_dict[forecast_model["data"]] = pt.as_tensor_variable(
np.atleast_2d(self._fit_data.T).T, name="data"
)
else:
# same here will it always be named data?
sub_dict = {
forecast_model["data"]: pt.as_tensor_variable(
np.atleast_2d(self._fit_data.T).T, name="data"
)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can simplify all this by using the model data_vars property. PyMC models store all the symbolic variables created with pm.Data here. They are pytensor SharedVariables, so the data lives inside them. You can get SharedVariable data with .get_value(). That will also ensure the shapes are correct.

So you can do something like:

sub_dict = {
data_var: pt.as_tensor_variable(data_var.get_value(), name="data") for data_var in forecast_model.data_vars
} 

We want free absolutely all data, so this should always do what we want. When there's no self.data_names, it should still find the main data (even if we change the name later). When there is, it will freeze that as well.

If you want, you can add a sanity check after that makes sure the names of the variables in the keys of sub_dict are in self.data_names + ['data']

)
}

mu, cov = graph_replace([mu, cov], replace=sub_dict, strict=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe give new names, like frozen_mu, or mu_given_training_data? Makes it clear what we're doing here.

if scenario is not None:
for name in self.data_names:
if name in scenario.keys():
pm.set_data({name: scenario[name]})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pm.set_data can take a dictionary of multiple things, so no need to loop here. Just directly do pm.set_data(scenario). I think the scenario has already been validated by this point (by self._validate_scenario_data), so there's no need to make sure the keys are in self.data_names

I'm pretty sure you can do something like:

if scenario is not None:
    pm.set_data(scenario | {'data': dummy_obs_data},
                coords={"data_time": np.arange(len(forecast_index))})

@Dekermanjian
Copy link
Contributor Author

Thank you, Jesse! Your suggestion cleaned the code up nicely. I added in a sanity check for the graph replacements let me know if that looks alright to you.

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really great!

I guess the last thing I want to ask for is a test that makes sure this works. Basically we want to make sure that if we:

  1. Make a really simple model (LevelTrend + Regression)
  2. Fit it
  3. Run a forecast with two scenarios (it doesn't matter what the are)
  4. Check that the the mean of all the data before the forecast starts is all the same (you can seed the forecast rng to avoid the issue of random draws)

I think this will test the issue we've been having?

If you have a more clear idea, please go for it. I just want some kind of regression test to make sure this issue is dead once and for all.

@Dekermanjian
Copy link
Contributor Author

Yes, I will update the test test_forecast_with_exog_data in test_statespace.py to add a check on the exogenous variable.

When I write the test, is it okay to use lower level components? Otherwise if I just use the ss_mod.forecast() method how would I get the data before the forecast starts?

@jessegrabowski
Copy link
Member

jessegrabowski commented Jun 9, 2025

Don't update the existing test, make a new one called like test_forecast_consistency_under_exog_scenario or something.

If you think the refactor makes sense, we can actually just call the test test_build_forecast_model

Good point about not being able to just pull out what we need from the ss_mod.forecast method. We might want to refactor the actual model construction out of forecast and into another method called _build_forecast_model. Then forecast can do all the validation and what-not, then make the model with that function, then sample.

We could then just use _build_forecast_model in the test.

@jessegrabowski
Copy link
Member

jessegrabowski commented Jun 9, 2025

I did the refactor, could you see if it's enough to get a test going?

It's a bit sloppy -- feel free to adjust

@jessegrabowski
Copy link
Member

jessegrabowski commented Jun 11, 2025

The test is a good start. I think you're right that we probably only need one scenario, so let's do that just to simplify matters.

You're going to need to test a few extra things:

  1. Make sure that there are no non-random generator SharedVariable among the inputs to node on the compute graphs of x0_slice and P0_slice of test_forecast_model. You're looking for SharedVariable, because that's what a pm.Data is under the hood. Docs here if you're not familiar with them. You can do this by using graph_inputs, which returns a generator of all, well, graph inputs:
from pytensor.graph.basic import graph_inputs
from pytensor.compile import SharedVariable

frozen_shared_inputs = [inpt for inpt in
                        graph_inputs([test_forecast_model.x0_slice, test_forecast_model.P0_slice])
                        if isinstance(inpt, SharedVariable) and not isinstance(inpt.get_value(),
                                                                               np.random.Generator)]


assert len(frozen_shared_inputs) == 0
  1. Make sure there are SharedVariables in the final forecast output, and make sure that it's the data_exog, and make sure the values are correctly set by pm.set_data:
unfrozen_shared_inputs = [inpt for inpt in graph_inputs([test_forecast_model.forecast_combined])
                          if isinstance(inpt, SharedVariable) and not isinstance(inpt.get_value(),
                                                                                 np.random.Generator)]
assert len(unfrozen_shared_inputs) == 1
assert unfrozen_shared_inputs[0].name == 'data_exog'

with test_forecast_model:
    dummy_obs_data = np.zeros((len(forecast_index), exog_ss_mod.k_endog))
    pm.set_data(
        {'exog_data': scenario} | {"data": dummy_obs_data},
        coords={"data_time": np.arange(len(forecast_index))},
    )

# TODO: Why is the reshape necessary? 
np.testing.assert_allclose(unfrozen_shared_inputs[0].get_value(), scenario['x1'].values.reshape((-1, 1)))

Another thought I had with multiple scenarios was to vary the starting value of the scenario, and make sure that x0_slice and P0_slice end up being what they're supposed to be. That is, if you call pm.sample_posterior_predictive(idata_exog, var_names=['x0_slice', 'P0_slice']) in the test, you can check that the "sampled" values of x0_slice and P0_slice in the posterior predictive are just copies of the values of smoothed_state and smoothed_covariance in idata_exog at time=t

To be clear, I had the idea of several scenarios because I wanted to vary the starting time t for this last check.

Some other notes:

  1. To make the tests less miserably slow, use the new pymc sampling mocker. Add this to the top of the file:
from pymc.testing import mock_sample_setup_and_teardown
mock_pymc_sample = pytest.fixture(scope="session")(mock_sample_setup_and_teardown)

Then to every fixture or test that calls pm.sample add mock_pymc_sample to the function signature. For example:

@pytest.fixture(scope="session")
def idata(pymc_mod, rng, mock_pymc_sample):
    with pymc_mod:
        idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
        idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)

    idata.extend(idata_prior)
    return idata

This will monkeypatch pm.sample with pm.sample_prior_predictive, and rejigger the outputs to look like real sample outputs. This should increase the speed of the tests by a factor of roughly infinity.

  1. A super definitive test of our forecast would be to test it against statsmodels. I'm not 100% how their API works for setting exogenous variable, nor do I necessarily think we should do it in this PR. But it's something to keep in mind. Maybe i'll make a separate issue for it.

@Dekermanjian
Copy link
Contributor Author

Hey Jesse, I tried adding from pymc.testing import mock_sample_setup_and_teardown but I was getting import errors. I tried syncing up with main but still got the import error so I rolled back the sync because I don't think it's a great idea to merge main into a topic branch.

@jessegrabowski
Copy link
Member

Is your pymc up to date? It was released in 5.23.0 (check the new features)

@Dekermanjian
Copy link
Contributor Author

oh, duh! I tried updating the wrong module. Okay, updated the correct module and added in the mock sampling.

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! I think we've finally squashed this bug. As one last test, can you confirm that your hurricane notebook works on this PR without the coord hack, and that the forecasts are correct?

Once you do, please feel free to squash and merge :)

@jessegrabowski jessegrabowski merged commit c099fc4 into pymc-devs:main Jun 12, 2025
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants