Skip to content

Commit 8516865

Browse files
author
Benjamin Gorman
authored
Reenable checks and fix errors (#717)
* Set set_type_checking_flag for autodoc-typehints Fixes 11 out of the 17 warnings when running 'make html' WARNING: Cannot resolve forward reference in type annotations of "{}": name 'ignite' is not defined * Reenable and fix flake8-bugbear errors B007 Loop control variable '{}' not used within the loop body. If this is intended, start the name with an underscore. B008 See #727 * Reenable and fix flake8-comprehensions errors C402 Unnecessary generator - rewrite as a dict comprehension. C407 Unnecessary list comprehension - '{}' can take a generator. C416 Unnecessary list comprehension - rewrite using list(). * Reenable and fix flake8 unused import errors F401 '{}' imported but unused * Remove flake8-mypy configurations See #550 * Reenable and fix mypy-torch errors - Argument "{}" to "{}" has incompatible type "{}"; expected "{}" - Incompatible types in assignment (expression has type "{}", variable has type "{}") - Module "{}" is not valid as a type - Module has no attribute "{}" - Name '{}' is not defined - No overload variant of "{}" of "{}" matches argument types * Update the issue number of signature incompatible See #729
1 parent 6bd70bb commit 8516865

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+96
-124
lines changed

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def setup(app):
9595
add_module_names = True
9696
autosectionlabel_prefix_document = True
9797
napoleon_use_param = True
98+
set_type_checking_flag = True
9899

99100
# Add any paths that contain templates here, relative to this directory.
100101
templates_path = ["_templates"]

monai/apps/datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def randomize(self) -> None:
8989
self.rann = self.R.random()
9090

9191
def _generate_data_list(self, dataset_dir: str):
92-
class_names = sorted([x for x in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, x))])
92+
class_names = sorted((x for x in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, x))))
9393
num_class = len(class_names)
9494
image_files = [
9595
[

monai/config/deviceconfig.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,4 @@ def get_torch_version_tuple():
116116
Returns:
117117
tuple of ints represents the pytorch major/minor version.
118118
"""
119-
return tuple([int(x) for x in torch.__version__.split(".")[:2]])
119+
return tuple((int(x) for x in torch.__version__.split(".")[:2]))

monai/data/csv_saver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[dict]
7777
self._data_index += 1
7878
if torch.is_tensor(data):
7979
data = data.detach().cpu().numpy()
80+
assert isinstance(data, np.ndarray)
8081
self._cache_dict[save_key] = data.astype(np.float32)
8182

8283
def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[dict] = None) -> None:

monai/data/dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(
6363
timeout: float = 0.0,
6464
multiprocessing_context: Optional[Callable] = None,
6565
) -> None:
66-
super().__init__(
66+
super().__init__( # type: ignore # No overload variant matches argument types
6767
dataset=dataset,
6868
batch_size=batch_size,
6969
shuffle=shuffle,

monai/data/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def __init__(self, datasets, transform: Optional[Callable] = None) -> None:
336336
super().__init__(list(datasets), transform=transform)
337337

338338
def __len__(self) -> int:
339-
return min([len(dataset) for dataset in self.data])
339+
return min((len(dataset) for dataset in self.data))
340340

341341
def __getitem__(self, index: int):
342342
def to_list(x):

monai/data/synthetic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def create_test_image_2d(
3838
image = np.zeros((width, height))
3939
rs = np.random if random_state is None else random_state
4040

41-
for i in range(num_objs):
41+
for _ in range(num_objs):
4242
x = rs.randint(rad_max, width - rad_max)
4343
y = rs.randint(rad_max, height - rad_max)
4444
rad = rs.randint(5, rad_max)
@@ -85,7 +85,7 @@ def create_test_image_3d(
8585
image = np.zeros((width, height, depth))
8686
rs = np.random if random_state is None else random_state
8787

88-
for i in range(num_objs):
88+
for _ in range(num_objs):
8989
x = rs.randint(rad_max, width - rad_max)
9090
y = rs.randint(rad_max, height - rad_max)
9191
z = rs.randint(rad_max, depth - rad_max)

monai/data/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def compute_importance_map(
452452
patch_size,
453453
mode: Union[BlendMode, str] = BlendMode.CONSTANT,
454454
sigma_scale: float = 0.125,
455-
device: Optional[Union[torch.device, str]] = None,
455+
device: Optional[torch.device] = None,
456456
):
457457
"""Get importance map for different weight modes.
458458

monai/engines/multi_gpu_supervised_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import ignite.metrics
1616

1717
import torch
18+
from torch.optim.optimizer import Optimizer
1819

1920
from monai.utils import exact_version, optional_import
2021
from monai.engines.utils import get_devices_spec
@@ -34,7 +35,7 @@ def _default_eval_transform(x, y, y_pred):
3435

3536
def create_multigpu_supervised_trainer(
3637
net: torch.nn.Module,
37-
optimizer: torch.optim.Optimizer,
38+
optimizer: Optimizer,
3839
loss_fn,
3940
devices=None,
4041
non_blocking: bool = False,

monai/engines/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import ignite.metrics
1717

1818
import torch
19+
from torch.optim.optimizer import Optimizer
1920
from torch.utils.data import DataLoader
2021

2122
from monai.inferers import Inferer, SimpleInferer
@@ -78,7 +79,7 @@ def __init__(
7879
max_epochs: int,
7980
train_data_loader: DataLoader,
8081
network,
81-
optimizer: torch.optim.Optimizer,
82+
optimizer: Optimizer,
8283
loss_function,
8384
prepare_batch: Callable = default_prepare_batch,
8485
iteration_update: Optional[Callable] = None,

monai/engines/workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919
from torch.utils.data import DataLoader
2020

21-
from monai.transforms import apply_transform, Transform
21+
from monai.transforms import apply_transform
2222
from monai.utils import exact_version, optional_import, ensure_tuple
2323
from monai.engines.utils import default_prepare_batch
2424

monai/handlers/lr_schedule_handler.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Callable, Optional, TYPE_CHECKING
12+
from typing import Callable, Optional, Union, TYPE_CHECKING
1313

1414
if TYPE_CHECKING:
1515
import ignite.engine
1616

17-
import torch
18-
1917
import logging
2018

19+
from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau
20+
2121
from monai.utils import ensure_tuple, exact_version, optional_import
2222

2323
Events, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Events")
@@ -30,7 +30,7 @@ class LrScheduleHandler:
3030

3131
def __init__(
3232
self,
33-
lr_scheduler: torch.optim.lr_scheduler,
33+
lr_scheduler: Union[_LRScheduler, ReduceLROnPlateau],
3434
print_lr: bool = True,
3535
name: Optional[str] = None,
3636
epoch_level: bool = True,
@@ -73,4 +73,6 @@ def __call__(self, engine: "ignite.engine.Engine") -> None:
7373
args = ensure_tuple(self.step_transform(engine))
7474
self.lr_scheduler.step(*args)
7575
if self.print_lr:
76-
self.logger.info(f"Current learning rate: {self.lr_scheduler._last_lr[0]}")
76+
self.logger.info(
77+
f"Current learning rate: {self.lr_scheduler._last_lr[0]}" # type: ignore # Module has no attribute
78+
)

monai/handlers/mean_dice.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Callable, Optional, Sequence, Union
12+
from typing import Callable, Optional, Sequence
1313

1414
import torch
1515

@@ -72,13 +72,13 @@ def reset(self) -> None:
7272
self._num_examples = 0
7373

7474
@reinit__is_reduced
75-
def update(self, output: Sequence[Union[torch.Tensor, dict]]) -> None:
75+
def update(self, output: Sequence[torch.Tensor]) -> None:
7676
if not len(output) == 2:
7777
raise ValueError("MeanDice metric can only support y_pred and y.")
7878
y_pred, y = output
7979
score = self.dice(y_pred, y)
8080
assert self.dice.not_nans is not None
81-
not_nans = self.dice.not_nans.item()
81+
not_nans = int(self.dice.not_nans.item())
8282

8383
# add all items in current batch
8484
self._sum += score.item() * not_nans

monai/handlers/roc_auc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def __init__(
6363
self.average: Average = Average(average)
6464

6565
def reset(self) -> None:
66-
self._predictions: List[float] = []
67-
self._targets: List[float] = []
66+
self._predictions: List[torch.Tensor] = []
67+
self._targets: List[torch.Tensor] = []
6868

6969
def update(self, output: Sequence[torch.Tensor]) -> None:
7070
y_pred, y = output

monai/losses/dice.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
ValueError: sigmoid=True and softmax=True are not compatible.
6767
6868
"""
69-
super().__init__(reduction=LossReduction(reduction))
69+
super().__init__(reduction=LossReduction(reduction).value)
7070

7171
if sigmoid and softmax:
7272
raise ValueError("sigmoid=True and softmax=True are not compatible.")
@@ -133,11 +133,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e-
133133

134134
f = 1.0 - (2.0 * intersection + smooth) / (denominator + smooth)
135135

136-
if self.reduction == LossReduction.MEAN:
136+
if self.reduction == LossReduction.MEAN.value:
137137
f = torch.mean(f) # the batch and channel average
138-
elif self.reduction == LossReduction.SUM:
138+
elif self.reduction == LossReduction.SUM.value:
139139
f = torch.sum(f) # sum over the batch and channel dims
140-
elif self.reduction == LossReduction.NONE:
140+
elif self.reduction == LossReduction.NONE.value:
141141
pass # returns [N, n_classes] losses
142142
else:
143143
raise ValueError(f"reduction={self.reduction} is invalid.")
@@ -219,7 +219,7 @@ def __init__(
219219
ValueError: sigmoid=True and softmax=True are not compatible.
220220
221221
"""
222-
super().__init__(reduction=LossReduction(reduction))
222+
super().__init__(reduction=LossReduction(reduction).value)
223223

224224
self.include_background = include_background
225225
self.to_onehot_y = to_onehot_y
@@ -286,11 +286,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e-
286286

287287
f = 1.0 - (2.0 * (intersection * w).sum(1) + smooth) / ((denominator * w).sum(1) + smooth)
288288

289-
if self.reduction == LossReduction.MEAN:
289+
if self.reduction == LossReduction.MEAN.value:
290290
f = torch.mean(f) # the batch and channel average
291-
elif self.reduction == LossReduction.SUM:
291+
elif self.reduction == LossReduction.SUM.value:
292292
f = torch.sum(f) # sum over the batch and channel dims
293-
elif self.reduction == LossReduction.NONE:
293+
elif self.reduction == LossReduction.NONE.value:
294294
pass # returns [N, n_classes] losses
295295
else:
296296
raise ValueError(f"reduction={self.reduction} is invalid.")

monai/losses/focal_loss.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(
5757
fl(pred, grnd)
5858
5959
"""
60-
super(FocalLoss, self).__init__(weight=weight, reduction=LossReduction(reduction))
60+
super(FocalLoss, self).__init__(weight=weight, reduction=LossReduction(reduction).value)
6161
self.gamma = gamma
6262
self.weight: Optional[torch.Tensor]
6363

@@ -129,10 +129,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor):
129129
else:
130130
loss = torch.mean(-weight * t * logpt, dim=-1) # N,C
131131

132-
if self.reduction == LossReduction.SUM:
132+
if self.reduction == LossReduction.SUM.value:
133133
return loss.sum()
134-
if self.reduction == LossReduction.NONE:
134+
if self.reduction == LossReduction.NONE.value:
135135
return loss
136-
if self.reduction == LossReduction.MEAN:
136+
if self.reduction == LossReduction.MEAN.value:
137137
return loss.mean()
138138
raise ValueError(f"reduction={self.reduction} is invalid.")

monai/losses/tversky.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(
6363
6464
"""
6565

66-
super().__init__(reduction=LossReduction(reduction))
66+
super().__init__(reduction=LossReduction(reduction).value)
6767
self.include_background = include_background
6868
self.to_onehot_y = to_onehot_y
6969

@@ -125,10 +125,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e-
125125

126126
score = 1.0 - numerator / denominator
127127

128-
if self.reduction == LossReduction.SUM:
128+
if self.reduction == LossReduction.SUM.value:
129129
return score.sum() # sum over the batch and channel dims
130-
if self.reduction == LossReduction.NONE:
130+
if self.reduction == LossReduction.NONE.value:
131131
return score # returns [N, n_classes] losses
132-
if self.reduction == LossReduction.MEAN:
132+
if self.reduction == LossReduction.MEAN.value:
133133
return score.mean()
134134
raise ValueError(f"reduction={self.reduction} is invalid.")

monai/metrics/rocauc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Optional, Union
12+
from typing import cast, Union
1313

1414
import warnings
1515

@@ -34,7 +34,7 @@ def _calculate(y: torch.Tensor, y_pred: torch.Tensor):
3434
nneg = auc = tmp_pos = tmp_neg = 0.0
3535

3636
for i in range(n):
37-
y_i = y[i]
37+
y_i = cast(float, y[i])
3838
if i + 1 < n and y_pred[i] == y_pred[i + 1]:
3939
tmp_pos += y_i
4040
tmp_neg += 1 - y_i

monai/networks/blocks/aspp.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Sequence, Tuple
13-
1412
import torch
1513
import torch.nn as nn
1614

monai/networks/blocks/convolutions.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Tuple, Union
13-
1412
import numpy as np
1513
import torch.nn as nn
1614

monai/networks/blocks/upsample.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
align_corners: set the align_corners parameter of `torch.nn.Upsample`. Defaults to True.
4949
"""
5050
super().__init__()
51-
scale_factor = ensure_tuple_rep(scale_factor, spatial_dims)
51+
scale_factor_ = ensure_tuple_rep(scale_factor, spatial_dims)
5252
if not out_channels:
5353
out_channels = in_channels
5454
if not with_conv:
@@ -58,11 +58,11 @@ def __init__(
5858
mode = linear_mode[spatial_dims - 1]
5959
self.upsample = nn.Sequential(
6060
Conv[Conv.CONV, spatial_dims](in_channels=in_channels, out_channels=out_channels, kernel_size=1),
61-
nn.Upsample(scale_factor=scale_factor, mode=mode.value, align_corners=align_corners),
61+
nn.Upsample(scale_factor=scale_factor_, mode=mode.value, align_corners=align_corners),
6262
)
6363
else:
6464
self.upsample = Conv[Conv.CONVTRANS, spatial_dims](
65-
in_channels=in_channels, out_channels=out_channels, kernel_size=scale_factor, stride=scale_factor
65+
in_channels=in_channels, out_channels=out_channels, kernel_size=scale_factor_, stride=scale_factor_
6666
)
6767

6868
def forward(self, x: torch.Tensor) -> torch.Tensor:

monai/networks/layers/factories.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,8 @@ def batch_factory(dim: int):
204204
Act.add_factory_callable("gelu", lambda: nn.modules.GELU)
205205
Act.add_factory_callable("sigmoid", lambda: nn.modules.Sigmoid)
206206
Act.add_factory_callable("tanh", lambda: nn.modules.Tanh)
207-
Act.add_factory_callable("softmax", lambda: nn.modules.SoftMax)
208-
Act.add_factory_callable("logsoftmax", lambda: nn.modules.LogSoftMax)
207+
Act.add_factory_callable("softmax", lambda: nn.modules.Softmax)
208+
Act.add_factory_callable("logsoftmax", lambda: nn.modules.LogSoftmax)
209209

210210

211211
@Conv.factory_function("conv")

monai/networks/layers/spatial_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def forward(self, src, theta, spatial_size: Optional[Union[Sequence[int], int]]
147147
)
148148
)
149149

150-
grid = nn.functional.affine_grid(theta=theta[:, :sr], size=dst_size, align_corners=self.align_corners)
150+
grid = nn.functional.affine_grid(theta=theta[:, :sr], size=list(dst_size), align_corners=self.align_corners)
151151
dst = nn.functional.grid_sample(
152152
input=src.contiguous(),
153153
grid=grid,

monai/networks/nets/classifier.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Optional, Sequence, Tuple, Union
12+
from typing import Optional, Sequence, Union
1313

1414
import torch.nn as nn
1515
from monai.networks.layers.factories import Norm, Act, split_args
1616
from monai.networks.nets.regressor import Regressor
17-
from monai.utils import ensure_tuple
1817

1918

2019
class Classifier(Regressor):

monai/networks/nets/densenet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# limitations under the License.
1111

1212
from collections import OrderedDict
13-
from typing import Callable, Sequence, Tuple
13+
from typing import Callable, Sequence
1414

1515
import torch
1616
import torch.nn as nn
@@ -144,7 +144,7 @@ def __init__(
144144
[
145145
("relu", nn.ReLU(inplace=True)),
146146
("norm", avg_pool_type(1)),
147-
("flatten", nn.Flatten(1)),
147+
("flatten", nn.Flatten(1)), # type: ignore # Module has no attribute
148148
("class", nn.Linear(in_channels, out_channels)),
149149
]
150150
)

0 commit comments

Comments
 (0)