8
8
from pytext .models .representations .representation_base import RepresentationBase
9
9
10
10
11
+ class Trim1d (nn .Module ):
12
+ """
13
+ Trims a 1d convolutional output. Used to implement history-padding
14
+ by removing excess padding from the right.
15
+
16
+ """
17
+
18
+ def __init__ (self , trim ):
19
+ super (Trim1d , self ).__init__ ()
20
+
21
+ self .trim = trim
22
+
23
+ def forward (self , x ):
24
+ return x [:, :, : - self .trim ].contiguous ()
25
+
26
+
11
27
class DeepCNNRepresentation (RepresentationBase ):
12
28
"""
13
29
`DeepCNNRepresentation` implements CNN representation layer
@@ -32,8 +48,10 @@ def __init__(self, config: Config, embed_dim: int) -> None:
32
48
kernel_sizes = config .cnn .kernel_sizes
33
49
weight_norm = config .cnn .weight_norm
34
50
dilated = config .cnn .dilated
51
+ causal = config .cnn .causal
35
52
36
53
conv_layers = []
54
+ trim_layers = []
37
55
linear_layers = []
38
56
in_channels = embed_dim
39
57
@@ -48,7 +66,7 @@ def __init__(self, config: Config, embed_dim: int) -> None:
48
66
linear_layers .append (proj )
49
67
50
68
dilation = 2 ** i if dilated else 1
51
- padding = (k - 1 ) // 2
69
+ padding = (k - 1 ) * dilation if causal else (( k - 1 ) // 2 ) * dilation
52
70
53
71
single_conv = nn .Conv1d (
54
72
in_channels , 2 * out_channels , k , padding = padding , dilation = dilation
@@ -58,6 +76,9 @@ def __init__(self, config: Config, embed_dim: int) -> None:
58
76
)
59
77
conv_layers .append (single_conv )
60
78
79
+ trim = Trim1d (padding ) if causal else None
80
+ trim_layers .append (trim )
81
+
61
82
in_channels = out_channels
62
83
63
84
self .convs = nn .ModuleList (conv_layers )
@@ -71,13 +92,15 @@ def forward(self, inputs: torch.Tensor, *args) -> torch.Tensor:
71
92
inputs = self .dropout (inputs )
72
93
# bsz * seq_len * embed_dim -> bsz * embed_dim * seq_len
73
94
words = inputs .permute (0 , 2 , 1 )
74
- for conv , proj in zip (self .convs , self .projections ):
75
- if proj is None :
76
- residual = words
77
- else :
95
+ for conv , trim , proj in zip (self .convs , self .trims , self .projections ):
96
+ if proj :
78
97
tranposed = words .permute (0 , 2 , 1 )
79
98
residual = proj (tranposed ).permute (0 , 2 , 1 )
99
+ else :
100
+ residual = words
80
101
words = conv (words )
102
+ if trim :
103
+ words = trim (words )
81
104
words = self .glu (words )
82
105
words = (words + residual ) * math .sqrt (0.5 )
83
106
return words .permute (0 , 2 , 1 )
0 commit comments