1
- from typing import Optional , Union , List
1
+ from typing import Optional , Union , List , Sequence , Callable
2
2
3
3
from segmentation_models_pytorch .encoders import get_encoder
4
4
from segmentation_models_pytorch .base import (
@@ -25,13 +25,13 @@ class Unet(SegmentationModel):
25
25
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
26
26
other pretrained weights (see table with available weights for each encoder_name)
27
27
encoder_indices: The indices of the encoder features that will be used in the decoder.
28
- If **"first"**, only the first `encoder_depth` features will be used.
28
+ If **"first"**, only the first `encoder_depth` features will be used.
29
29
If **"last"**, only the last `encoder_depth` features will be used.
30
30
If a list of integers, the indices of the encoder features that will be used in the decoder.
31
- Default is **"first"**
31
+ If **None**, defaults to **"first"**.
32
32
encoder_channels: A list of integers that specify the number of output channels for each encoder layer.
33
33
If **None**, the number of encoder output channels stays the same as for specifier `encoder_name`.
34
- If a list of integers, the number of encoder output channels is equal to the provided list,
34
+ If a list of integers, the number of encoder output channels is equal to the provided list,
35
35
features are adjusted by 1x1 convolutions without non-linearity.
36
36
decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
37
37
Length of the list should be the same as **encoder_depth**
@@ -67,14 +67,14 @@ def __init__(
67
67
encoder_name : str = "resnet34" ,
68
68
encoder_depth : int = 5 ,
69
69
encoder_weights : Optional [str ] = "imagenet" ,
70
- encoder_indices : Union [str , List [int ]] = "first" ,
70
+ encoder_indices : Optional [ Union [str , List [int ]]] = None ,
71
71
encoder_channels : Optional [List [int ]] = None ,
72
72
decoder_use_batchnorm : bool = True ,
73
- decoder_channels : List [int ] = (256 , 128 , 64 , 32 , 16 ),
73
+ decoder_channels : Sequence [int ] = (256 , 128 , 64 , 32 , 16 ),
74
74
decoder_attention_type : Optional [str ] = None ,
75
75
in_channels : int = 3 ,
76
76
classes : int = 1 ,
77
- activation : Optional [Union [str , callable ]] = None ,
77
+ activation : Optional [Union [str , Callable ]] = None ,
78
78
aux_params : Optional [dict ] = None ,
79
79
):
80
80
super ().__init__ ()
0 commit comments