Skip to content

Commit f5b2c17

Browse files
styusuffacebook-github-bot
authored andcommitted
Adding a new Module for unsupported layer. Adding test for unsupported layers. Simple logging for unsupported layers (#1505)
Summary: We are adding test for unsupported gradient layers. Open to ideas if there is a better way to structure the test. A bit uncomfortable with removing pyre type validations as we allow anything to be passed into the GradientUnsupportedLayerOutput class. Differential Revision: D69792994
1 parent 6f0f748 commit f5b2c17

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

captum/testing/helpers/basic_models.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -428,8 +428,27 @@ def forward(self, arg1: Tensor, arg2: Tensor) -> Tuple[Tensor, Tensor]:
428428
return (self.relu1(arg1), self.relu2(arg2))
429429

430430

431+
class GradientUnsupportedLayerOutput(nn.Module):
432+
# pyre-fixme[2]: Parameter must be annotated.
433+
def __init__(self, unsupported_layer_output) -> None:
434+
super().__init__()
435+
# pyre-fixme[4]: Attribute must be annotated.
436+
self.unsupported_layer_output = unsupported_layer_output
437+
438+
@no_type_check
439+
# pyre-fixme[3]: Return type must be annotated.
440+
def forward(self):
441+
return self.unsupported_layer_output
442+
443+
431444
class BasicModel_MultiLayer(nn.Module):
432-
def __init__(self, inplace: bool = False, multi_input_module: bool = False) -> None:
445+
def __init__(
446+
self,
447+
inplace: bool = False,
448+
multi_input_module: bool = False,
449+
# pyre-fixme[2]: Parameter must be annotated.
450+
unsupported_layer_output=None,
451+
) -> None:
433452
super().__init__()
434453
# Linear 0 is simply identity transform
435454
self.multi_input_module = multi_input_module
@@ -445,6 +464,7 @@ def __init__(self, inplace: bool = False, multi_input_module: bool = False) -> N
445464
self.linear1_alt.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0]))
446465
self.multi_relu = MultiRelu(inplace=inplace)
447466
self.relu = nn.ReLU(inplace=inplace)
467+
self.unsupportedLayer = GradientUnsupportedLayerOutput(unsupported_layer_output)
448468

449469
self.linear2 = nn.Linear(4, 2)
450470
self.linear2.weight = nn.Parameter(torch.ones(2, 4))
@@ -461,6 +481,9 @@ def forward(
461481
input = x if add_input is None else x + add_input
462482
lin0_out = self.linear0(input)
463483
lin1_out = self.linear1(lin0_out)
484+
if self.unsupportedLayer is not None:
485+
self.unsupportedLayer()
486+
## unsupportedLayer is unused intentionally for testing
464487
if self.multi_input_module:
465488
relu_out1, relu_out2 = self.multi_relu(lin1_out, self.linear1_alt(input))
466489
relu_out = relu_out1 + relu_out2

0 commit comments

Comments
 (0)