Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit 694941f

Browse files
shreydesaifacebook-github-bot
authored andcommitted
implemented causal convolutions for deep cnn representation
Differential Revision: D16403554 fbshipit-source-id: 85ffe203f12cee7fb113c821673e7a9529a81623
1 parent 70df2c8 commit 694941f

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

pytext/config/module_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class CNNParams(ConfigBase):
2626
weight_norm: bool = False
2727
# Enables dilated convolutions
2828
dilated: bool = False
29+
# Enables causal convolutions
30+
causal: bool = False
2931

3032

3133
class PoolingType(Enum):

pytext/models/representations/deepcnn.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,22 @@
88
from pytext.models.representations.representation_base import RepresentationBase
99

1010

11+
class Trim1d(nn.Module):
12+
"""
13+
Trims a 1d convolutional output. Used to implement history-padding
14+
by removing excess padding from the right.
15+
16+
"""
17+
18+
def __init__(self, trim):
19+
super(Trim1d, self).__init__()
20+
21+
self.trim = trim
22+
23+
def forward(self, x):
24+
return x[:, :, : -self.trim].contiguous()
25+
26+
1127
class DeepCNNRepresentation(RepresentationBase):
1228
"""
1329
`DeepCNNRepresentation` implements CNN representation layer
@@ -32,8 +48,10 @@ def __init__(self, config: Config, embed_dim: int) -> None:
3248
kernel_sizes = config.cnn.kernel_sizes
3349
weight_norm = config.cnn.weight_norm
3450
dilated = config.cnn.dilated
51+
causal = config.cnn.causal
3552

3653
conv_layers = []
54+
trim_layers = []
3755
linear_layers = []
3856
in_channels = embed_dim
3957

@@ -48,7 +66,7 @@ def __init__(self, config: Config, embed_dim: int) -> None:
4866
linear_layers.append(proj)
4967

5068
dilation = 2 ** i if dilated else 1
51-
padding = (k - 1) // 2
69+
padding = (k - 1) * dilation if causal else ((k - 1) // 2) * dilation
5270

5371
single_conv = nn.Conv1d(
5472
in_channels, 2 * out_channels, k, padding=padding, dilation=dilation
@@ -58,6 +76,9 @@ def __init__(self, config: Config, embed_dim: int) -> None:
5876
)
5977
conv_layers.append(single_conv)
6078

79+
trim = Trim1d(padding) if causal else None
80+
trim_layers.append(trim)
81+
6182
in_channels = out_channels
6283

6384
self.convs = nn.ModuleList(conv_layers)
@@ -71,13 +92,15 @@ def forward(self, inputs: torch.Tensor, *args) -> torch.Tensor:
7192
inputs = self.dropout(inputs)
7293
# bsz * seq_len * embed_dim -> bsz * embed_dim * seq_len
7394
words = inputs.permute(0, 2, 1)
74-
for conv, proj in zip(self.convs, self.projections):
75-
if proj is None:
76-
residual = words
77-
else:
95+
for conv, trim, proj in zip(self.convs, self.trims, self.projections):
96+
if proj:
7897
tranposed = words.permute(0, 2, 1)
7998
residual = proj(tranposed).permute(0, 2, 1)
99+
else:
100+
residual = words
80101
words = conv(words)
102+
if trim:
103+
words = trim(words)
81104
words = self.glu(words)
82105
words = (words + residual) * math.sqrt(0.5)
83106
return words.permute(0, 2, 1)

0 commit comments

Comments
 (0)