-
Notifications
You must be signed in to change notification settings - Fork 944
Feature/transformer refactorisation #1915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
084541f
492c3e5
30ccd78
fd13830
35e2f5b
9eecccd
297afb9
fa8a152
84e8d4a
fa65ed2
a136e4c
87e0a44
3bcbb30
a370a5e
634294d
0bfde7c
0aec5f6
e04f1bf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,8 @@ | |
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.nn import Transformer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe import just There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After reading up a little and checking out the implementation, it turns out that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. WDYT @dennisbader ? |
||
|
||
from darts.logging import get_logger, raise_if, raise_if_not, raise_log | ||
from darts.models.components import glu_variants, layer_norm_variants | ||
|
@@ -188,6 +190,7 @@ def __init__( | |
self.target_size = output_size | ||
self.nr_params = nr_params | ||
self.target_length = self.output_chunk_length | ||
self.d_model = d_model | ||
|
||
self.encoder = nn.Linear(input_size, d_model) | ||
self.positional_encoding = _PositionalEncoding( | ||
|
@@ -281,48 +284,105 @@ def __init__( | |
custom_decoder=custom_decoder, | ||
) | ||
|
||
self.decoder = nn.Linear( | ||
d_model, self.target_length * self.target_size * self.nr_params | ||
) | ||
|
||
def _create_transformer_inputs(self, data): | ||
# '_TimeSeriesSequentialDataset' stores time series in the | ||
# (batch_size, input_chunk_length, input_size) format. PyTorch's nn.Transformer | ||
# module needs it the (input_chunk_length, batch_size, input_size) format. | ||
# Therefore, the first two dimensions need to be swapped. | ||
src = data.permute(1, 0, 2) | ||
tgt = src[-1:, :, :] | ||
|
||
return src, tgt | ||
self.decoder = nn.Linear(d_model, self.target_size * self.nr_params) | ||
|
||
@io_processor | ||
def forward(self, x_in: Tuple): | ||
data, _ = x_in | ||
# Here we create 'src' and 'tgt', the inputs for the encoder and decoder | ||
# side of the Transformer architecture | ||
src, tgt = self._create_transformer_inputs(data) | ||
""" | ||
During training (teacher forcing) x_in = tuple(past_target + past_covariates, static_covariates, future_targets) | ||
During inference x_in = tuple(past_target + past_covariates, static_covariates) | ||
|
||
'_TimeSeriesSequentialDataset' stores time series in the | ||
(batch_size, input_chunk_length, input_size) format. PyTorch's nn.Transformer | ||
module needs it the (input_chunk_length, batch_size, input_size) format. | ||
Therefore, the first two dimensions need to be swapped. | ||
""" | ||
src = x_in[0].permute(1, 0, 2) | ||
pad_size = (0, self.input_size - self.target_size) | ||
|
||
# start token consists only of target series, past covariates are substituted with 0 padding | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why would you not include the past covariates in the start token? |
||
start_token = src[-1:, :, : self.target_size] | ||
start_token_padded = F.pad(start_token, pad_size) | ||
|
||
if len(x_in) == 3: | ||
tgt = x_in[-1].permute(1, 0, 2) | ||
tgt = F.pad(tgt, pad_size) | ||
tgt = torch.cat([start_token_padded, tgt], dim=0) | ||
return self._prediction_step(src, tgt)[:, :-1, :, :] | ||
|
||
tgt = start_token_padded | ||
|
||
predictions = [] | ||
for _ in range(self.target_length): | ||
pred = self._prediction_step(src, tgt)[:, -1, :, :] | ||
predictions.append(pred) | ||
tgt = torch.cat( | ||
[tgt, F.pad(pred.mean(dim=2).unsqueeze(dim=0), pad_size)], | ||
dim=0, | ||
) # take average of samples | ||
return torch.stack(predictions, dim=1) | ||
|
||
def _prediction_step(self, src: torch.Tensor, tgt: torch.Tensor): | ||
target_length = tgt.shape[0] | ||
device, tensor_type = src.device, src.dtype | ||
# "math.sqrt(self.input_size)" is a normalization factor | ||
# see section 3.2.1 in 'Attention is All you Need' by Vaswani et al. (2017) | ||
src = self.encoder(src) * math.sqrt(self.input_size) | ||
src = self.positional_encoding(src) | ||
src = self.encoder(src) * math.sqrt(self.d_model) | ||
tgt = self.encoder(tgt) * math.sqrt(self.d_model) | ||
|
||
tgt = self.encoder(tgt) * math.sqrt(self.input_size) | ||
src = self.positional_encoding(src) | ||
tgt = self.positional_encoding(tgt) | ||
|
||
x = self.transformer(src=src, tgt=tgt) | ||
tgt_mask = Transformer.generate_square_subsequent_mask( | ||
target_length, device | ||
).to(dtype=tensor_type) | ||
|
||
x = self.transformer(src=src, tgt=tgt, tgt_mask=tgt_mask) | ||
out = self.decoder(x) | ||
|
||
# Here we change the data format | ||
# from (1, batch_size, output_chunk_length * output_size) | ||
# to (batch_size, output_chunk_length, output_size, nr_params) | ||
predictions = out[0, :, :] | ||
predictions = out.permute(1, 0, 2) | ||
predictions = predictions.view( | ||
-1, self.target_length, self.target_size, self.nr_params | ||
-1, target_length, self.target_size, self.nr_params | ||
) | ||
|
||
return predictions | ||
|
||
def training_step(self, train_batch, batch_idx) -> torch.Tensor: | ||
"""performs the training step""" | ||
train_batch = list(train_batch) | ||
future_targets = train_batch[-1] | ||
train_batch.append(future_targets) | ||
return super().training_step(train_batch, batch_idx) | ||
|
||
def _produce_train_output(self, input_batch: Tuple): | ||
""" | ||
Feeds PastCovariatesTorchModel with input and output chunks of a PastCovariatesSequentialDataset for | ||
training. | ||
|
||
Parameters: | ||
input_batch | ||
``(past_target, past_covariates, static_covariates, future_target)`` during training | ||
|
||
``(past_target, past_covariates, static_covariates)`` during validation (not teacher forced) | ||
""" | ||
|
||
past_target, past_covariates, static_covariates = input_batch[:3] | ||
# Currently all our PastCovariates models require past target and covariates concatenated | ||
inpt = [ | ||
torch.cat([past_target, past_covariates], dim=2) | ||
if past_covariates is not None | ||
else past_target, | ||
static_covariates, | ||
] | ||
|
||
# add future targets when training (teacher forcing) | ||
if len(input_batch) == 4: | ||
inpt.append(input_batch[-1]) | ||
return self(inpt) | ||
|
||
|
||
class TransformerModel(PastCovariatesTorchModel): | ||
def __init__( | ||
|
@@ -351,7 +411,7 @@ def __init__( | |
The multi-head attention mechanism is highly parallelizable, which makes the transformer architecture | ||
very suitable to be trained with GPUs. | ||
|
||
The transformer architecture implemented here is based on [1]_. | ||
The transformer architecture implemented here is based on [1]_ and uses teacher forcing [4]_. | ||
|
||
This model supports past covariates (known for `input_chunk_length` points before prediction time). | ||
|
||
|
@@ -388,7 +448,7 @@ def __init__( | |
Fraction of neurons affected by Dropout (default=0.1). | ||
activation | ||
The activation function of encoder/decoder intermediate layer, (default='relu'). | ||
can be one of the glu variant's FeedForward Network (FFN)[2]. A feedforward network is a | ||
can be one of the glu variant's FeedForward Network (FFN)[3]. A feedforward network is a | ||
fully-connected layer with an activation. The glu variant's FeedForward Network are a series | ||
of FFNs designed to work better with Transformer based models. ["GLU", "Bilinear", "ReGLU", "GEGLU", | ||
"SwiGLU", "ReLU", "GELU"] or one the pytorch internal activations ["relu", "gelu"] | ||
|
@@ -541,17 +601,7 @@ def encode_year(idx): | |
.. [2] Shazeer, Noam, "GLU Variants Improve Transformer", 2020. arVix https://arxiv.org/abs/2002.05202. | ||
.. [3] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against | ||
Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p | ||
|
||
Notes | ||
----- | ||
Disclaimer: | ||
This current implementation is fully functional and can already produce some good predictions. However, | ||
it is still limited in how it uses the Transformer architecture because the `tgt` input of | ||
`torch.nn.Transformer` is not utilized to its full extent. Currently, we simply pass the last value of the | ||
`src` input to `tgt`. To get closer to the way the Transformer is usually used in language models, we | ||
should allow the model to consume its own output as part of the `tgt` argument, such that when predicting | ||
sequences of values, the input to the `tgt` argument would grow as outputs of the transformer model would be | ||
added to it. Of course, the training of the model would have to be adapted accordingly. | ||
.. [4] Teacher Forcing PyTorch tutorial: https://github.com/pytorch/examples/tree/main/word_language_model | ||
|
||
Examples | ||
-------- | ||
|
Uh oh!
There was an error while loading. Please reload this page.