Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit b5eda6d

Browse files
shreydesaifacebook-github-bot
authored andcommitted
implemented separable convolutions (#830)
Summary: Pull Request resolved: #830 Implements depthwise separable convolutions. The depthwise convolution spatially convolves each input channel separately, then the pointwise convolution projects this result into a new channel space. Separable convolutions achieve similar performance to regular convolutions with a large reduction in the number of parameters. Grouped convolutions also have cuDNN support, so using them can give latency advantages as well. Differential Revision: D16466988 fbshipit-source-id: b5aec14c9c21816cd6d090ee045d56bb39e314ac
1 parent afdccfb commit b5eda6d

File tree

1 file changed

+47
-1
lines changed

1 file changed

+47
-1
lines changed

pytext/models/representations/deepcnn.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python3
22
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
33
import math
4+
from typing import Optional
45

56
import torch
67
import torch.nn as nn
@@ -25,6 +26,47 @@ def forward(self, x):
2526
return x[:, :, : -self.trim].contiguous()
2627

2728

29+
class SeparableConv1d(nn.Module):
30+
"""
31+
Implements a 1d depthwise separable convolutional layer. In regular convolutional
32+
layers, the input channels are mixed with each other to produce each output channel.
33+
Depthwise separable convolutions decompose this process into two smaller
34+
convolutions -- a depthwise and pointwise convolution.
35+
36+
The depthwise convolution spatially convolves each input channel separately,
37+
then the pointwise convolution projects thie result into a new channel space.
38+
This process reduces the number of FLOPS used to compute a convolution and also
39+
exhibits a regularization effect. The general behavior -- including the input
40+
parameters -- is equivalent to `nn.Conv1d`.
41+
42+
"""
43+
44+
def __init__(
45+
self,
46+
input_channels: int,
47+
output_channels: int,
48+
kernel_size: int,
49+
padding: Optional[int],
50+
dilation: Optional[int],
51+
):
52+
super(SeparableConv1d, self).__init__()
53+
54+
self.conv = nn.Sequential(
55+
nn.Conv1d(
56+
input_channels,
57+
input_channels,
58+
kernel_size,
59+
padding=padding,
60+
dilation=dilation,
61+
groups=input_channels,
62+
),
63+
nn.Conv1d(input_channels, output_channels, 1),
64+
)
65+
66+
def forward(self, x):
67+
return self.conv(x)
68+
69+
2870
class DeepCNNRepresentation(RepresentationBase):
2971
"""
3072
`DeepCNNRepresentation` implements CNN representation layer
@@ -42,6 +84,7 @@ class Config(RepresentationBase.Config):
4284
cnn: CNNParams = CNNParams()
4385
dropout: float = 0.3
4486
activation: Activation = Activation.GLU
87+
separable: bool = False
4588

4689
def __init__(self, config: Config, embed_dim: int) -> None:
4790
super().__init__(config)
@@ -51,7 +94,9 @@ def __init__(self, config: Config, embed_dim: int) -> None:
5194
weight_norm = config.cnn.weight_norm
5295
dilated = config.cnn.dilated
5396
causal = config.cnn.causal
97+
5498
activation = config.activation
99+
separable = config.separable
55100

56101
conv_layers = []
57102
trim_layers = []
@@ -70,8 +115,9 @@ def __init__(self, config: Config, embed_dim: int) -> None:
70115

71116
dilation = 2 ** i if dilated else 1
72117
padding = (k - 1) * dilation if causal else ((k - 1) // 2) * dilation
118+
conv_layer = SeparableConv1d if separable else nn.Conv1d
73119

74-
single_conv = nn.Conv1d(
120+
single_conv = conv_layer(
75121
in_channels,
76122
(out_channels * 2 if activation == Activation.GLU else out_channels),
77123
k,

0 commit comments

Comments
 (0)