-
-
Notifications
You must be signed in to change notification settings - Fork 67
Forecast exogenous vars bug fix #510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
439e980
c2cf547
027de41
d196409
d41a109
8635274
4119a0d
252f70a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2203,28 +2203,56 @@ def forecast( | |
|
||
with pm.Model(coords=temp_coords) as forecast_model: | ||
(_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph( | ||
scenario=scenario, | ||
data_dims=["data_time", OBS_STATE_DIM], | ||
) | ||
|
||
for name in self.data_names: | ||
if name in scenario.keys(): | ||
pm.set_data( | ||
{"data": np.zeros((len(forecast_index), self.k_endog))}, | ||
coords={"data_time": np.arange(len(forecast_index))}, | ||
) | ||
break | ||
|
||
group_idx = FILTER_OUTPUT_TYPES.index(filter_output) | ||
mu, cov = grouped_outputs[group_idx] | ||
|
||
if scenario is not None: | ||
sub_dict = { | ||
forecast_model[data_name]: pt.as_tensor_variable( | ||
x=np.atleast_2d(self._exog_data_info[data_name]["value"].T).T, | ||
name=data_name, | ||
) | ||
for data_name in self.data_names | ||
} | ||
|
||
# Will this always be named "data"? | ||
sub_dict[forecast_model["data"]] = pt.as_tensor_variable( | ||
np.atleast_2d(self._fit_data.T).T, name="data" | ||
) | ||
else: | ||
# same here will it always be named data? | ||
sub_dict = { | ||
forecast_model["data"]: pt.as_tensor_variable( | ||
np.atleast_2d(self._fit_data.T).T, name="data" | ||
) | ||
} | ||
|
||
mu, cov = graph_replace([mu, cov], replace=sub_dict, strict=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe give new names, like |
||
|
||
x0 = pm.Deterministic( | ||
"x0_slice", mu[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None | ||
) | ||
P0 = pm.Deterministic( | ||
"P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None | ||
) | ||
|
||
if scenario is not None: | ||
for name in self.data_names: | ||
if name in scenario.keys(): | ||
pm.set_data({name: scenario[name]}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'm pretty sure you can do something like: if scenario is not None:
pm.set_data(scenario | {'data': dummy_obs_data},
coords={"data_time": np.arange(len(forecast_index))}) |
||
|
||
for name in self.data_names: | ||
if name in scenario.keys(): | ||
# same here will it always be named data? | ||
pm.set_data( | ||
{"data": np.zeros((len(forecast_index), self.k_endog))}, | ||
coords={"data_time": np.arange(len(forecast_index))}, | ||
) | ||
break | ||
|
||
_ = LinearGaussianStateSpace( | ||
"forecast", | ||
x0, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can simplify all this by using the model
data_vars
property. PyMC models store all the symbolic variables created withpm.Data
here. They are pytensorSharedVariables
, so the data lives inside them. You can getSharedVariable
data with.get_value()
. That will also ensure the shapes are correct.So you can do something like:
We want free absolutely all data, so this should always do what we want. When there's no
self.data_names
, it should still find the main data (even if we change the name later). When there is, it will freeze that as well.If you want, you can add a sanity check after that makes sure the names of the variables in the keys of
sub_dict
are inself.data_names + ['data']