Skip to content

Commit e9defb6

Browse files
committed
Reformat with ruff
1 parent 8fcc1a3 commit e9defb6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+923
-439
lines changed

docs/conf.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@
1414
# import sys
1515
# sys.path.insert(0, os.path.abspath('.'))
1616

17-
import os
18-
import re
1917
import sys
2018
import datetime
19+
import sphinx_rtd_theme
2120

2221
sys.path.append("..")
2322

@@ -68,14 +67,11 @@ def get_version():
6867
# a list of builtin themes.
6968
#
7069

71-
import sphinx_rtd_theme
72-
7370
html_theme = "sphinx_rtd_theme"
7471
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
7572

7673
# import karma_sphinx_theme
7774
# html_theme = "karma_sphinx_theme"
78-
import faculty_sphinx_theme
7975

8076
html_theme = "faculty_sphinx_theme"
8177

misc/generate_table.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,17 @@
44

55

66
WIDTH = 32
7-
COLUMNS = [
8-
"Encoder",
9-
"Weights",
10-
"Params, M",
11-
]
7+
COLUMNS = ["Encoder", "Weights", "Params, M"]
128

139

1410
def wrap_row(r):
1511
return "|{}|".format(r)
1612

1713

1814
header = "|".join([column.ljust(WIDTH, " ") for column in COLUMNS])
19-
separator = "|".join(["-" * WIDTH] + [":" + "-" * (WIDTH - 2) + ":"] * (len(COLUMNS) - 1))
15+
separator = "|".join(
16+
["-" * WIDTH] + [":" + "-" * (WIDTH - 2) + ":"] * (len(COLUMNS) - 1)
17+
)
2018

2119
print(wrap_row(header))
2220
print(wrap_row(separator))

misc/generate_table_timm.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,29 @@ def make_table(data):
2424

2525
l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+\n"
2626
l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+\n"
27-
top = "| " + "Encoder name".ljust(max_len1 - 2) + " | " + "Support dilation".center(max_len2 - 2) + " |\n"
27+
top = (
28+
"| "
29+
+ "Encoder name".ljust(max_len1 - 2)
30+
+ " | "
31+
+ "Support dilation".center(max_len2 - 2)
32+
+ " |\n"
33+
)
2834

2935
table = l1 + top + l2
3036

3137
for k in sorted(data.keys()):
32-
support = "✅".center(max_len2 - 3) if data[k]["has_dilation"] else " ".center(max_len2 - 2)
38+
support = (
39+
"✅".center(max_len2 - 3)
40+
if data[k]["has_dilation"]
41+
else " ".center(max_len2 - 2)
42+
)
3343
table += "| " + k.ljust(max_len1 - 2) + " | " + support + " |\n"
3444
table += l1
3545

3646
return table
3747

3848

3949
if __name__ == "__main__":
40-
4150
supported_models = {}
4251

4352
with tqdm(timm.list_models()) as names:

segmentation_models_pytorch/__init__.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ def create_model(
5050
except KeyError:
5151
raise KeyError(
5252
"Wrong architecture type `{}`. Available options are: {}".format(
53-
arch,
54-
list(archs_dict.keys()),
53+
arch, list(archs_dict.keys())
5554
)
5655
)
5756
return model_class(
@@ -61,3 +60,24 @@ def create_model(
6160
classes=classes,
6261
**kwargs,
6362
)
63+
64+
65+
__all__ = [
66+
"datasets",
67+
"encoders",
68+
"decoders",
69+
"losses",
70+
"metrics",
71+
"Unet",
72+
"UnetPlusPlus",
73+
"MAnet",
74+
"Linknet",
75+
"FPN",
76+
"PSPNet",
77+
"DeepLabV3",
78+
"DeepLabV3Plus",
79+
"PAN",
80+
"from_pretrained",
81+
"create_model",
82+
"__version__",
83+
]
+10-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from .model import SegmentationModel
22

3-
from .modules import (
4-
Conv2dReLU,
5-
Attention,
6-
)
3+
from .modules import Conv2dReLU, Attention
74

8-
from .heads import (
9-
SegmentationHead,
10-
ClassificationHead,
11-
)
5+
from .heads import SegmentationHead, ClassificationHead
6+
7+
__all__ = [
8+
"SegmentationModel",
9+
"Conv2dReLU",
10+
"Attention",
11+
"SegmentationHead",
12+
"ClassificationHead",
13+
]

segmentation_models_pytorch/base/heads.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,29 @@
33

44

55
class SegmentationHead(nn.Sequential):
6-
def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1):
7-
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
8-
upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
6+
def __init__(
7+
self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1
8+
):
9+
conv2d = nn.Conv2d(
10+
in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2
11+
)
12+
upsampling = (
13+
nn.UpsamplingBilinear2d(scale_factor=upsampling)
14+
if upsampling > 1
15+
else nn.Identity()
16+
)
917
activation = Activation(activation)
1018
super().__init__(conv2d, upsampling, activation)
1119

1220

1321
class ClassificationHead(nn.Sequential):
14-
def __init__(self, in_channels, classes, pooling="avg", dropout=0.2, activation=None):
22+
def __init__(
23+
self, in_channels, classes, pooling="avg", dropout=0.2, activation=None
24+
):
1525
if pooling not in ("max", "avg"):
16-
raise ValueError("Pooling should be one of ('max', 'avg'), got {}.".format(pooling))
26+
raise ValueError(
27+
"Pooling should be one of ('max', 'avg'), got {}.".format(pooling)
28+
)
1729
pool = nn.AdaptiveAvgPool2d(1) if pooling == "avg" else nn.AdaptiveMaxPool2d(1)
1830
flatten = nn.Flatten()
1931
dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity()

segmentation_models_pytorch/base/hub_mixin.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
from pathlib import Path
33
from typing import Optional, Union
44
from functools import wraps
5-
from huggingface_hub import PyTorchModelHubMixin, ModelCard, ModelCardData, hf_hub_download
5+
from huggingface_hub import (
6+
PyTorchModelHubMixin,
7+
ModelCard,
8+
ModelCardData,
9+
hf_hub_download,
10+
)
611

712

813
MODEL_CARD = """
@@ -45,15 +50,17 @@
4550

4651
def _format_parameters(parameters: dict):
4752
params = {k: v for k, v in parameters.items() if not k.startswith("_")}
48-
params = [f'"{k}": {v}' if not isinstance(v, str) else f'"{k}": "{v}"' for k, v in params.items()]
53+
params = [
54+
f'"{k}": {v}' if not isinstance(v, str) else f'"{k}": "{v}"'
55+
for k, v in params.items()
56+
]
4957
params = ",\n".join([f" {param}" for param in params])
5058
params = "{\n" + f"{params}" + "\n}"
5159
return params
5260

5361

5462
class SMPHubMixin(PyTorchModelHubMixin):
5563
def generate_model_card(self, *args, **kwargs) -> ModelCard:
56-
5764
model_parameters_json = _format_parameters(self._hub_mixin_config)
5865
directory = self._save_directory if hasattr(self, "_save_directory") else None
5966
repo_id = self._repo_id if hasattr(self, "_repo_id") else None
@@ -97,8 +104,9 @@ def _del_attrs(self, attrs):
97104
delattr(self, f"_{attr}")
98105

99106
@wraps(PyTorchModelHubMixin.save_pretrained)
100-
def save_pretrained(self, save_directory: Union[str, Path], *args, **kwargs) -> Optional[str]:
101-
107+
def save_pretrained(
108+
self, save_directory: Union[str, Path], *args, **kwargs
109+
) -> Optional[str]:
102110
# set additional attributes to be used in generate_model_card
103111
self._save_directory = save_directory
104112
self._set_attrs_from_kwargs(["metrics", "dataset"], kwargs)
@@ -132,7 +140,9 @@ def config(self):
132140
@wraps(PyTorchModelHubMixin.from_pretrained)
133141
def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs):
134142
config_path = hf_hub_download(
135-
pretrained_model_name_or_path, filename="config.json", revision=kwargs.get("revision", None)
143+
pretrained_model_name_or_path,
144+
filename="config.json",
145+
revision=kwargs.get("revision", None),
136146
)
137147
with open(config_path, "r") as f:
138148
config = json.load(f)

segmentation_models_pytorch/base/initialization.py

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
def initialize_decoder(module):
55
for m in module.modules():
6-
76
if isinstance(m, nn.Conv2d):
87
nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
98
if m.bias is not None:

segmentation_models_pytorch/base/model.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,27 @@
44
from .hub_mixin import SMPHubMixin
55

66

7-
class SegmentationModel(
8-
torch.nn.Module,
9-
SMPHubMixin,
10-
):
7+
class SegmentationModel(torch.nn.Module, SMPHubMixin):
118
def initialize(self):
129
init.initialize_decoder(self.decoder)
1310
init.initialize_head(self.segmentation_head)
1411
if self.classification_head is not None:
1512
init.initialize_head(self.classification_head)
1613

1714
def check_input_shape(self, x):
18-
1915
h, w = x.shape[-2:]
2016
output_stride = self.encoder.output_stride
2117
if h % output_stride != 0 or w % output_stride != 0:
22-
new_h = (h // output_stride + 1) * output_stride if h % output_stride != 0 else h
23-
new_w = (w // output_stride + 1) * output_stride if w % output_stride != 0 else w
18+
new_h = (
19+
(h // output_stride + 1) * output_stride
20+
if h % output_stride != 0
21+
else h
22+
)
23+
new_w = (
24+
(w // output_stride + 1) * output_stride
25+
if w % output_stride != 0
26+
else w
27+
)
2428
raise RuntimeError(
2529
f"Wrong input shape height={h}, width={w}. Expected image height and width "
2630
f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w})."

segmentation_models_pytorch/base/modules.py

-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def __init__(
1717
stride=1,
1818
use_batchnorm=True,
1919
):
20-
2120
if use_batchnorm == "inplace" and InPlaceABN is None:
2221
raise RuntimeError(
2322
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
@@ -83,7 +82,6 @@ def forward(self, x):
8382

8483
class Activation(nn.Module):
8584
def __init__(self, name, **params):
86-
8785
super().__init__()
8886

8987
if name is None or name == "identity":
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
from .oxford_pet import OxfordPetDataset, SimpleOxfordPetDataset
2+
3+
__all__ = ["OxfordPetDataset", "SimpleOxfordPetDataset"]

segmentation_models_pytorch/datasets/oxford_pet.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
class OxfordPetDataset(torch.utils.data.Dataset):
1212
def __init__(self, root, mode="train", transform=None):
13-
1413
assert mode in {"train", "valid", "test"}
1514

1615
self.root = root
@@ -26,7 +25,6 @@ def __len__(self):
2625
return len(self.filenames)
2726

2827
def __getitem__(self, idx):
29-
3028
filename = self.filenames[idx]
3129
image_path = os.path.join(self.images_directory, filename + ".jpg")
3230
mask_path = os.path.join(self.masks_directory, filename + ".png")
@@ -63,7 +61,6 @@ def _read_split(self):
6361

6462
@staticmethod
6563
def download(root):
66-
6764
# load images
6865
filepath = os.path.join(root, "images.tar.gz")
6966
download_url(
@@ -83,13 +80,18 @@ def download(root):
8380

8481
class SimpleOxfordPetDataset(OxfordPetDataset):
8582
def __getitem__(self, *args, **kwargs):
86-
8783
sample = super().__getitem__(*args, **kwargs)
8884

8985
# resize images
90-
image = np.array(Image.fromarray(sample["image"]).resize((256, 256), Image.BILINEAR))
91-
mask = np.array(Image.fromarray(sample["mask"]).resize((256, 256), Image.NEAREST))
92-
trimap = np.array(Image.fromarray(sample["trimap"]).resize((256, 256), Image.NEAREST))
86+
image = np.array(
87+
Image.fromarray(sample["image"]).resize((256, 256), Image.BILINEAR)
88+
)
89+
mask = np.array(
90+
Image.fromarray(sample["mask"]).resize((256, 256), Image.NEAREST)
91+
)
92+
trimap = np.array(
93+
Image.fromarray(sample["trimap"]).resize((256, 256), Image.NEAREST)
94+
)
9395

9496
# convert to other format HWC -> CHW
9597
sample["image"] = np.moveaxis(image, -1, 0)
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
from .model import DeepLabV3, DeepLabV3Plus
2+
3+
__all__ = ["DeepLabV3", "DeepLabV3Plus"]

segmentation_models_pytorch/decoders/deeplabv3/decoder.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,18 @@ def __init__(
6161
):
6262
super().__init__()
6363
if output_stride not in {8, 16}:
64-
raise ValueError("Output stride should be 8 or 16, got {}.".format(output_stride))
64+
raise ValueError(
65+
"Output stride should be 8 or 16, got {}.".format(output_stride)
66+
)
6567

6668
self.out_channels = out_channels
6769
self.output_stride = output_stride
6870

6971
self.aspp = nn.Sequential(
7072
ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True),
71-
SeparableConv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
73+
SeparableConv2d(
74+
out_channels, out_channels, kernel_size=3, padding=1, bias=False
75+
),
7276
nn.BatchNorm2d(out_channels),
7377
nn.ReLU(),
7478
)
@@ -79,7 +83,9 @@ def __init__(
7983
highres_in_channels = encoder_channels[-4]
8084
highres_out_channels = 48 # proposed by authors of paper
8185
self.block1 = nn.Sequential(
82-
nn.Conv2d(highres_in_channels, highres_out_channels, kernel_size=1, bias=False),
86+
nn.Conv2d(
87+
highres_in_channels, highres_out_channels, kernel_size=1, bias=False
88+
),
8389
nn.BatchNorm2d(highres_out_channels),
8490
nn.ReLU(),
8591
)
@@ -210,10 +216,5 @@ def __init__(
210216
groups=in_channels,
211217
bias=False,
212218
)
213-
pointwise_conv = nn.Conv2d(
214-
in_channels,
215-
out_channels,
216-
kernel_size=1,
217-
bias=bias,
218-
)
219+
pointwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)
219220
super().__init__(dephtwise_conv, pointwise_conv)

0 commit comments

Comments
 (0)