Skip to content

Commit 4c9fd2d

Browse files
author
Benjamin Gorman
committed
Reenable and fix mypy-torch errors
- Argument "{}" to "{}" has incompatible type "{}"; expected "{}" - Incompatible types in assignment (expression has type "{}", variable has type "{}") - Module "{}" is not valid as a type - Module has no attribute "{}" - Name '{}' is not defined - No overload variant of "{}" of "{}" matches argument types
1 parent a77770b commit 4c9fd2d

File tree

22 files changed

+60
-62
lines changed

22 files changed

+60
-62
lines changed

monai/data/csv_saver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[dict]
7777
self._data_index += 1
7878
if torch.is_tensor(data):
7979
data = data.detach().cpu().numpy()
80+
assert isinstance(data, np.ndarray)
8081
self._cache_dict[save_key] = data.astype(np.float32)
8182

8283
def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[dict] = None) -> None:

monai/data/dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(
6363
timeout: float = 0.0,
6464
multiprocessing_context: Optional[Callable] = None,
6565
) -> None:
66-
super().__init__(
66+
super().__init__( # type: ignore # No overload variant matches argument types
6767
dataset=dataset,
6868
batch_size=batch_size,
6969
shuffle=shuffle,

monai/data/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def compute_importance_map(
452452
patch_size,
453453
mode: Union[BlendMode, str] = BlendMode.CONSTANT,
454454
sigma_scale: float = 0.125,
455-
device: Optional[Union[torch.device, str]] = None,
455+
device: Optional[torch.device] = None,
456456
):
457457
"""Get importance map for different weight modes.
458458

monai/engines/multi_gpu_supervised_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import ignite.metrics
1616

1717
import torch
18+
from torch.optim.optimizer import Optimizer
1819

1920
from monai.utils import exact_version, optional_import
2021
from monai.engines.utils import get_devices_spec
@@ -34,7 +35,7 @@ def _default_eval_transform(x, y, y_pred):
3435

3536
def create_multigpu_supervised_trainer(
3637
net: torch.nn.Module,
37-
optimizer: torch.optim.Optimizer,
38+
optimizer: Optimizer,
3839
loss_fn,
3940
devices=None,
4041
non_blocking: bool = False,

monai/engines/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import ignite.metrics
1717

1818
import torch
19+
from torch.optim.optimizer import Optimizer
1920
from torch.utils.data import DataLoader
2021

2122
from monai.inferers import Inferer, SimpleInferer
@@ -78,7 +79,7 @@ def __init__(
7879
max_epochs: int,
7980
train_data_loader: DataLoader,
8081
network,
81-
optimizer: torch.optim.Optimizer,
82+
optimizer: Optimizer,
8283
loss_function,
8384
prepare_batch: Callable = default_prepare_batch,
8485
iteration_update: Optional[Callable] = None,

monai/handlers/lr_schedule_handler.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Callable, Optional, TYPE_CHECKING
12+
from typing import Callable, Optional, Union, TYPE_CHECKING
1313

1414
if TYPE_CHECKING:
1515
import ignite.engine
1616

17-
import torch
18-
1917
import logging
2018

19+
from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau
20+
2121
from monai.utils import ensure_tuple, exact_version, optional_import
2222

2323
Events, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Events")
@@ -30,7 +30,7 @@ class LrScheduleHandler:
3030

3131
def __init__(
3232
self,
33-
lr_scheduler: torch.optim.lr_scheduler,
33+
lr_scheduler: Union[_LRScheduler, ReduceLROnPlateau],
3434
print_lr: bool = True,
3535
name: Optional[str] = None,
3636
epoch_level: bool = True,
@@ -73,4 +73,6 @@ def __call__(self, engine: "ignite.engine.Engine") -> None:
7373
args = ensure_tuple(self.step_transform(engine))
7474
self.lr_scheduler.step(*args)
7575
if self.print_lr:
76-
self.logger.info(f"Current learning rate: {self.lr_scheduler._last_lr[0]}")
76+
self.logger.info(
77+
f"Current learning rate: {self.lr_scheduler._last_lr[0]}" # type: ignore # Module has no attribute
78+
)

monai/handlers/mean_dice.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Callable, Optional, Sequence, Union
12+
from typing import Callable, Optional, Sequence
1313

1414
import torch
1515

@@ -72,13 +72,13 @@ def reset(self) -> None:
7272
self._num_examples = 0
7373

7474
@reinit__is_reduced
75-
def update(self, output: Sequence[Union[torch.Tensor, dict]]) -> None:
75+
def update(self, output: Sequence[torch.Tensor]) -> None:
7676
if not len(output) == 2:
7777
raise ValueError("MeanDice metric can only support y_pred and y.")
7878
y_pred, y = output
7979
score = self.dice(y_pred, y)
8080
assert self.dice.not_nans is not None
81-
not_nans = self.dice.not_nans.item()
81+
not_nans = int(self.dice.not_nans.item())
8282

8383
# add all items in current batch
8484
self._sum += score.item() * not_nans

monai/handlers/roc_auc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def __init__(
6363
self.average: Average = Average(average)
6464

6565
def reset(self) -> None:
66-
self._predictions: List[float] = []
67-
self._targets: List[float] = []
66+
self._predictions: List[torch.Tensor] = []
67+
self._targets: List[torch.Tensor] = []
6868

6969
def update(self, output: Sequence[torch.Tensor]) -> None:
7070
y_pred, y = output

monai/losses/dice.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
ValueError: sigmoid=True and softmax=True are not compatible.
6767
6868
"""
69-
super().__init__(reduction=LossReduction(reduction))
69+
super().__init__(reduction=LossReduction(reduction).value)
7070

7171
if sigmoid and softmax:
7272
raise ValueError("sigmoid=True and softmax=True are not compatible.")
@@ -133,11 +133,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e-
133133

134134
f = 1.0 - (2.0 * intersection + smooth) / (denominator + smooth)
135135

136-
if self.reduction == LossReduction.MEAN:
136+
if self.reduction == LossReduction.MEAN.value:
137137
f = torch.mean(f) # the batch and channel average
138-
elif self.reduction == LossReduction.SUM:
138+
elif self.reduction == LossReduction.SUM.value:
139139
f = torch.sum(f) # sum over the batch and channel dims
140-
elif self.reduction == LossReduction.NONE:
140+
elif self.reduction == LossReduction.NONE.value:
141141
pass # returns [N, n_classes] losses
142142
else:
143143
raise ValueError(f"reduction={self.reduction} is invalid.")
@@ -219,7 +219,7 @@ def __init__(
219219
ValueError: sigmoid=True and softmax=True are not compatible.
220220
221221
"""
222-
super().__init__(reduction=LossReduction(reduction))
222+
super().__init__(reduction=LossReduction(reduction).value)
223223

224224
self.include_background = include_background
225225
self.to_onehot_y = to_onehot_y
@@ -286,11 +286,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e-
286286

287287
f = 1.0 - (2.0 * (intersection * w).sum(1) + smooth) / ((denominator * w).sum(1) + smooth)
288288

289-
if self.reduction == LossReduction.MEAN:
289+
if self.reduction == LossReduction.MEAN.value:
290290
f = torch.mean(f) # the batch and channel average
291-
elif self.reduction == LossReduction.SUM:
291+
elif self.reduction == LossReduction.SUM.value:
292292
f = torch.sum(f) # sum over the batch and channel dims
293-
elif self.reduction == LossReduction.NONE:
293+
elif self.reduction == LossReduction.NONE.value:
294294
pass # returns [N, n_classes] losses
295295
else:
296296
raise ValueError(f"reduction={self.reduction} is invalid.")

monai/losses/focal_loss.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(
5757
fl(pred, grnd)
5858
5959
"""
60-
super(FocalLoss, self).__init__(weight=weight, reduction=LossReduction(reduction))
60+
super(FocalLoss, self).__init__(weight=weight, reduction=LossReduction(reduction).value)
6161
self.gamma = gamma
6262
self.weight: Optional[torch.Tensor]
6363

@@ -129,10 +129,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor):
129129
else:
130130
loss = torch.mean(-weight * t * logpt, dim=-1) # N,C
131131

132-
if self.reduction == LossReduction.SUM:
132+
if self.reduction == LossReduction.SUM.value:
133133
return loss.sum()
134-
if self.reduction == LossReduction.NONE:
134+
if self.reduction == LossReduction.NONE.value:
135135
return loss
136-
if self.reduction == LossReduction.MEAN:
136+
if self.reduction == LossReduction.MEAN.value:
137137
return loss.mean()
138138
raise ValueError(f"reduction={self.reduction} is invalid.")

monai/losses/tversky.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(
6363
6464
"""
6565

66-
super().__init__(reduction=LossReduction(reduction))
66+
super().__init__(reduction=LossReduction(reduction).value)
6767
self.include_background = include_background
6868
self.to_onehot_y = to_onehot_y
6969

@@ -125,10 +125,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e-
125125

126126
score = 1.0 - numerator / denominator
127127

128-
if self.reduction == LossReduction.SUM:
128+
if self.reduction == LossReduction.SUM.value:
129129
return score.sum() # sum over the batch and channel dims
130-
if self.reduction == LossReduction.NONE:
130+
if self.reduction == LossReduction.NONE.value:
131131
return score # returns [N, n_classes] losses
132-
if self.reduction == LossReduction.MEAN:
132+
if self.reduction == LossReduction.MEAN.value:
133133
return score.mean()
134134
raise ValueError(f"reduction={self.reduction} is invalid.")

monai/metrics/rocauc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Union
12+
from typing import cast, Union
1313

1414
import warnings
1515

@@ -34,7 +34,7 @@ def _calculate(y: torch.Tensor, y_pred: torch.Tensor):
3434
nneg = auc = tmp_pos = tmp_neg = 0.0
3535

3636
for i in range(n):
37-
y_i = y[i]
37+
y_i = cast(float, y[i])
3838
if i + 1 < n and y_pred[i] == y_pred[i + 1]:
3939
tmp_pos += y_i
4040
tmp_neg += 1 - y_i

monai/networks/blocks/upsample.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
align_corners: set the align_corners parameter of `torch.nn.Upsample`. Defaults to True.
4949
"""
5050
super().__init__()
51-
scale_factor = ensure_tuple_rep(scale_factor, spatial_dims)
51+
scale_factor_ = ensure_tuple_rep(scale_factor, spatial_dims)
5252
if not out_channels:
5353
out_channels = in_channels
5454
if not with_conv:
@@ -58,11 +58,11 @@ def __init__(
5858
mode = linear_mode[spatial_dims - 1]
5959
self.upsample = nn.Sequential(
6060
Conv[Conv.CONV, spatial_dims](in_channels=in_channels, out_channels=out_channels, kernel_size=1),
61-
nn.Upsample(scale_factor=scale_factor, mode=mode.value, align_corners=align_corners),
61+
nn.Upsample(scale_factor=scale_factor_, mode=mode.value, align_corners=align_corners),
6262
)
6363
else:
6464
self.upsample = Conv[Conv.CONVTRANS, spatial_dims](
65-
in_channels=in_channels, out_channels=out_channels, kernel_size=scale_factor, stride=scale_factor
65+
in_channels=in_channels, out_channels=out_channels, kernel_size=scale_factor_, stride=scale_factor_
6666
)
6767

6868
def forward(self, x: torch.Tensor) -> torch.Tensor:

monai/networks/layers/factories.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,8 @@ def batch_factory(dim: int):
204204
Act.add_factory_callable("gelu", lambda: nn.modules.GELU)
205205
Act.add_factory_callable("sigmoid", lambda: nn.modules.Sigmoid)
206206
Act.add_factory_callable("tanh", lambda: nn.modules.Tanh)
207-
Act.add_factory_callable("softmax", lambda: nn.modules.SoftMax)
208-
Act.add_factory_callable("logsoftmax", lambda: nn.modules.LogSoftMax)
207+
Act.add_factory_callable("softmax", lambda: nn.modules.Softmax)
208+
Act.add_factory_callable("logsoftmax", lambda: nn.modules.LogSoftmax)
209209

210210

211211
@Conv.factory_function("conv")

monai/networks/layers/spatial_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def forward(self, src, theta, spatial_size: Optional[Union[Sequence[int], int]]
147147
)
148148
)
149149

150-
grid = nn.functional.affine_grid(theta=theta[:, :sr], size=dst_size, align_corners=self.align_corners)
150+
grid = nn.functional.affine_grid(theta=theta[:, :sr], size=list(dst_size), align_corners=self.align_corners)
151151
dst = nn.functional.grid_sample(
152152
input=src.contiguous(),
153153
grid=grid,

monai/networks/nets/densenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def __init__(
144144
[
145145
("relu", nn.ReLU(inplace=True)),
146146
("norm", avg_pool_type(1)),
147-
("flatten", nn.Flatten(1)),
147+
("flatten", nn.Flatten(1)), # type: ignore # Module has no attribute
148148
("class", nn.Linear(in_channels, out_channels)),
149149
]
150150
)

monai/networks/nets/generator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(
8181
self.dropout = dropout
8282
self.bias = bias
8383

84-
self.flatten = nn.Flatten()
84+
self.flatten = nn.Flatten() # type: ignore # Module has no attribute
8585
self.linear = nn.Linear(int(np.prod(self.latent_shape)), int(np.prod(start_shape)))
8686
self.reshape = Reshape(*start_shape)
8787
self.conv = nn.Sequential()
@@ -102,6 +102,8 @@ def _get_layer(self, in_channels: int, out_channels: int, strides, is_last: bool
102102
is True this is the final layer and is not expected to include activation and normalization layers.
103103
"""
104104

105+
layer: Union[Convolution, nn.Sequential]
106+
105107
layer = Convolution(
106108
in_channels=in_channels,
107109
strides=strides,

monai/networks/nets/regressor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def _get_layer(self, in_channels: int, out_channels: int, strides, is_last: bool
9696
is True this is the final layer and is not expected to include activation and normalization layers.
9797
"""
9898

99+
layer: Union[ResidualUnit, Convolution]
100+
99101
if self.num_res_units > 0:
100102
layer = ResidualUnit(
101103
subunits=self.num_res_units,

monai/transforms/post/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,5 +344,5 @@ def __call__(self, img: torch.Tensor):
344344
else:
345345
raise RuntimeError("the dimensions of img should be 4 or 5.")
346346

347-
torch.clamp_(contour_img, min=0.0, max=1.0)
347+
contour_img.clamp_(min=0.0, max=1.0)
348348
return contour_img

monai/transforms/spatial/array.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design
1414
"""
1515

16-
from typing import List, Optional, Sequence, Tuple, Union
16+
from typing import Callable, List, Optional, Sequence, Tuple, Union
1717

1818
import warnings
1919

@@ -47,6 +47,8 @@
4747

4848
nib, _ = optional_import("nibabel")
4949

50+
_torch_interp: Callable
51+
5052
if get_torch_version_tuple() >= (1, 5):
5153
# additional argument since torch 1.5 (to avoid warnings)
5254
def _torch_interp(**kwargs):

monai/utils/misc.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,9 @@ def set_determinism(
197197
"""
198198
if seed is None:
199199
# cast to 32 bit seed for CUDA
200-
seed_ = torch.default_generator.seed() % (np.iinfo(np.int32).max + 1)
201-
if not torch.cuda._is_in_bad_fork():
202-
torch.cuda.manual_seed_all(seed_)
200+
seed_ = torch.default_generator.seed() % (np.iinfo(np.int32).max + 1) # type: ignore # Module has no attribute
201+
if not torch.cuda._is_in_bad_fork(): # type: ignore # Module has no attribute
202+
torch.cuda.manual_seed_all(seed_) # type: ignore # Module has no attribute
203203
else:
204204
torch.manual_seed(seed)
205205

@@ -214,7 +214,7 @@ def set_determinism(
214214
func(seed)
215215

216216
if seed is not None:
217-
torch.backends.cudnn.deterministic = True
218-
torch.backends.cudnn.benchmark = False
217+
torch.backends.cudnn.deterministic = True # type: ignore # Module has no attribute
218+
torch.backends.cudnn.benchmark = False # type: ignore # Module has no attribute
219219
else:
220-
torch.backends.cudnn.deterministic = False
220+
torch.backends.cudnn.deterministic = False # type: ignore # Module has no attribute

setup.cfg

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -91,20 +91,7 @@ warn_no_return=True
9191
# lint-style cleanliness for typing needs to be disabled; returns more errors
9292
# than the full run.
9393
warn_redundant_casts=True
94-
warn_unused_ignores=False
95-
96-
97-
[mypy-torch.*]
98-
# --- Temporary disabling to allow smaller PRs
99-
# Subsequent fixes necessary to make this tool pass are not yet
100-
# merged into the master branch, but requested as separate
101-
# commits. Can not enable until all outstanding typehint
102-
# PR's are approved
103-
## Temporarily disable all mypy warnings
104-
ignore_errors = True
105-
# ^^^^^^ Temporary disabling to allow smaller PRs
106-
follow_imports = skip
107-
follow_imports_for_stubs = True
94+
warn_unused_ignores=True
10895

10996
[mypy-versioneer]
11097
# Always ignore any type issues in the versioneer.py file

0 commit comments

Comments
 (0)