Skip to content

Commit d24b3fa

Browse files
committed
Fix indices and channels handling
1 parent 5934e82 commit d24b3fa

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

segmentation_models_pytorch/decoders/unet/model.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Union, List
1+
from typing import Optional, Union, List, Sequence, Callable
22

33
from segmentation_models_pytorch.encoders import get_encoder
44
from segmentation_models_pytorch.base import (
@@ -25,13 +25,13 @@ class Unet(SegmentationModel):
2525
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
2626
other pretrained weights (see table with available weights for each encoder_name)
2727
encoder_indices: The indices of the encoder features that will be used in the decoder.
28-
If **"first"**, only the first `encoder_depth` features will be used.
28+
If **"first"**, only the first `encoder_depth` features will be used.
2929
If **"last"**, only the last `encoder_depth` features will be used.
3030
If a list of integers, the indices of the encoder features that will be used in the decoder.
31-
Default is **"first"**
31+
If **None**, defaults to **"first"**.
3232
encoder_channels: A list of integers that specify the number of output channels for each encoder layer.
3333
If **None**, the number of encoder output channels stays the same as for specifier `encoder_name`.
34-
If a list of integers, the number of encoder output channels is equal to the provided list,
34+
If a list of integers, the number of encoder output channels is equal to the provided list,
3535
features are adjusted by 1x1 convolutions without non-linearity.
3636
decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
3737
Length of the list should be the same as **encoder_depth**
@@ -67,14 +67,14 @@ def __init__(
6767
encoder_name: str = "resnet34",
6868
encoder_depth: int = 5,
6969
encoder_weights: Optional[str] = "imagenet",
70-
encoder_indices: Union[str, List[int]] = "first",
70+
encoder_indices: Optional[Union[str, List[int]]] = None,
7171
encoder_channels: Optional[List[int]] = None,
7272
decoder_use_batchnorm: bool = True,
73-
decoder_channels: List[int] = (256, 128, 64, 32, 16),
73+
decoder_channels: Sequence[int] = (256, 128, 64, 32, 16),
7474
decoder_attention_type: Optional[str] = None,
7575
in_channels: int = 3,
7676
classes: int = 1,
77-
activation: Optional[Union[str, callable]] = None,
77+
activation: Optional[Union[str, Callable]] = None,
7878
aux_params: Optional[dict] = None,
7979
):
8080
super().__init__()

segmentation_models_pytorch/encoders/__init__.py

+17
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import timm
22
import functools
33
import torch.utils.model_zoo as model_zoo
4+
from loguru import logger
45

56
from .resnet import resnet_encoders
67
from .dpn import dpn_encoders
@@ -51,6 +52,10 @@
5152
def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
5253
if name.startswith("tu-"):
5354
name = name[3:]
55+
56+
if "encoder_indices" in kwargs and kwargs["encoder_indices"] is None:
57+
kwargs["encoder_indices"] = "first"
58+
5459
encoder = TimmUniversalEncoder(
5560
name=name,
5661
in_channels=in_channels,
@@ -61,6 +66,18 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **
6166
)
6267
return encoder
6368

69+
encoder_indices = kwargs.pop("encoder_indices", None)
70+
if encoder_indices is not None:
71+
logger.warning(
72+
"Argument `encoder_indices` is supported only for `tu-` encoders (Timm) and will be ignored."
73+
)
74+
75+
encoder_channels = kwargs.pop("encoder_channels", None)
76+
if encoder_channels is not None:
77+
logger.warning(
78+
"Argument `encoder_channels` is supported only for `tu-` encoders (Timm) and will be ignored."
79+
)
80+
6481
try:
6582
Encoder = encoders[name]["encoder"]
6683
except KeyError:

0 commit comments

Comments
 (0)