Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit eb6d392

Browse files
shreydesaifacebook-github-bot
authored andcommitted
created bilstm dropout condition
Summary: Only drops out embedded_tokens if there is a non-zero probability in the dropout layer. This prevents recomputing a dropout mask even when it isn't required. Differential Revision: D16190467 fbshipit-source-id: c6468c916ba381ad0eb3614ba0898cc5a03fd867
1 parent 2a8f365 commit eb6d392

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

pytext/models/representations/bilstm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def __init__(
5757
super().__init__(config)
5858

5959
self.padding_value: float = padding_value
60-
self.dropout = nn.Dropout(config.dropout)
60+
if config.dropout > 0.0:
61+
self.dropout = nn.Dropout(config.dropout)
6162
self.lstm = nn.LSTM(
6263
embed_dim,
6364
config.lstm_dim,
@@ -95,7 +96,8 @@ def forward(
9596
Shape of each state is (bsize x num_layers * num_directions x nhid).
9697
9798
"""
98-
embedded_tokens = self.dropout(embedded_tokens)
99+
if self.dropout.p > 0.0:
100+
embedded_tokens = self.dropout(embedded_tokens)
99101

100102
if states is not None:
101103
# convert (h0, c0) from (bsz x num_layers*num_directions x nhid) to

0 commit comments

Comments
 (0)