Skip to content

Commit c00db4e

Browse files
Add new CSPNet preset and add manual padding. (#2212)
* csp stages zero padding * zero padding for all stages and maxpooling * add manual padding * remove activation for conv3 * add zero padding for stages * indentation fix * remove zero padding for avg_down * add more presets and timm conversion * fix configs and timm preset logic * change preset loader logic * preset loader fix
1 parent 38ca305 commit c00db4e

File tree

7 files changed

+209
-67
lines changed

7 files changed

+209
-67
lines changed

keras_hub/src/models/cspnet/cspnet_backbone.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class CSPNetBackbone(FeaturePyramidBackbone):
8181
8282
# Pretrained backbone
8383
model = keras_hub.models.CSPNetBackbone.from_preset(
84-
"cspdarknet53_ra_imagenet"
84+
"csp_darknet_53_ra_imagenet"
8585
)
8686
model(input_data)
8787
@@ -357,18 +357,6 @@ def apply(x):
357357
dtype=dtype,
358358
name=f"{name}_bottleneck_block_bn_3",
359359
)(x)
360-
if activation == "leaky_relu":
361-
x = layers.LeakyReLU(
362-
negative_slope=0.01,
363-
dtype=dtype,
364-
name=f"{name}_bottleneck_block_activation_3",
365-
)(x)
366-
else:
367-
x = layers.Activation(
368-
activation,
369-
dtype=dtype,
370-
name=f"{name}_bottleneck_block_activation_3",
371-
)(x)
372360

373361
x = layers.add(
374362
[x, shortcut], dtype=dtype, name=f"{name}_bottleneck_block_add"
@@ -673,6 +661,13 @@ def apply(x):
673661
name=f"{name}_csp_activation_1",
674662
)(x)
675663
else:
664+
if strides > 1:
665+
x = layers.ZeroPadding2D(
666+
1,
667+
data_format=data_format,
668+
dtype=dtype,
669+
name=f"{name}_csp_conv_pad_1",
670+
)(x)
676671
x = layers.Conv2D(
677672
filters=down_chs,
678673
kernel_size=3,
@@ -882,6 +877,13 @@ def apply(x):
882877
name=f"{name}_cs3_activation_1",
883878
)(x)
884879
else:
880+
if strides > 1:
881+
x = layers.ZeroPadding2D(
882+
1,
883+
data_format=data_format,
884+
dtype=dtype,
885+
name=f"{name}_cs3_conv_pad_1",
886+
)(x)
885887
x = layers.Conv2D(
886888
filters=down_chs,
887889
kernel_size=3,
@@ -1062,6 +1064,13 @@ def apply(x):
10621064
name=f"{name}_dark_activation_1",
10631065
)(x)
10641066
else:
1067+
if strides > 1:
1068+
x = layers.ZeroPadding2D(
1069+
1,
1070+
data_format=data_format,
1071+
dtype=dtype,
1072+
name=f"{name}_dark_conv_pad_1",
1073+
)(x)
10651074
x = layers.Conv2D(
10661075
filters=filters,
10671076
kernel_size=3,
@@ -1091,18 +1100,18 @@ def apply(x):
10911100
dtype=dtype,
10921101
name=f"{name}_dark_activation_1",
10931102
)(x)
1094-
for i in range(depth):
1095-
x = block_fn(
1096-
filters=block_channels,
1097-
dilation=dilation,
1098-
bottle_ratio=bottle_ratio,
1099-
groups=groups,
1100-
activation=activation,
1101-
data_format=data_format,
1102-
channel_axis=channel_axis,
1103-
dtype=dtype,
1104-
name=f"{name}_block_{i}",
1105-
)(x)
1103+
for i in range(depth):
1104+
x = block_fn(
1105+
filters=block_channels,
1106+
dilation=dilation,
1107+
bottle_ratio=bottle_ratio,
1108+
groups=groups,
1109+
activation=activation,
1110+
data_format=data_format,
1111+
channel_axis=channel_axis,
1112+
dtype=dtype,
1113+
name=f"{name}_block_{i}",
1114+
)(x)
11061115
return x
11071116

11081117
return apply
@@ -1135,6 +1144,13 @@ def apply(x):
11351144
or (i == last_idx and strides > 2 and not pooling)
11361145
else 1
11371146
)
1147+
if conv_strides > 1:
1148+
x = layers.ZeroPadding2D(
1149+
(kernel_size - 1) // 2,
1150+
data_format=data_format,
1151+
dtype=dtype,
1152+
name=f"csp_stem_pad_{i}",
1153+
)(x)
11381154
x = layers.Conv2D(
11391155
filters=chs,
11401156
kernel_size=kernel_size,
@@ -1167,10 +1183,19 @@ def apply(x):
11671183

11681184
if pooling == "max":
11691185
assert strides > 2
1186+
# Use manual padding to handle edge case scenario to ignore zero's
1187+
# as max value instead consider negative values from Leaky Relu type
1188+
# of activations.
1189+
pad_width = [[1, 1], [1, 1]]
1190+
if data_format == "channels_last":
1191+
pad_width += [[0, 0]]
1192+
else:
1193+
pad_width = [[0, 0]] + pad_width
1194+
pad_width = [[0, 0]] + pad_width
1195+
x = ops.pad(x, pad_width=pad_width, constant_values=float("-inf"))
11701196
x = layers.MaxPooling2D(
11711197
pool_size=3,
11721198
strides=2,
1173-
padding="same",
11741199
data_format=data_format,
11751200
dtype=dtype,
11761201
name="csp_stem_pool",

keras_hub/src/models/cspnet/cspnet_backbone_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def setUp(self):
2222
"expand_ratio": (2.0,) + (1.0,),
2323
"block_type": "dark_block",
2424
"stage_type": "csp",
25+
"stem_padding": "same",
2526
}
2627
self.input_size = 64
2728
self.input_data = ops.ones((2, self.input_size, self.input_size, 3))
@@ -38,9 +39,9 @@ def test_backbone_basics(self, stage_type, block_type):
3839
"stage_type": stage_type,
3940
},
4041
input_data=self.input_data,
41-
expected_output_shape=(2, 6, 6, 48),
42+
expected_output_shape=(2, 8, 8, 48),
4243
expected_pyramid_output_keys=["P2", "P3", "P4"],
43-
expected_pyramid_image_sizes=[(30, 30), (14, 14), (6, 6)],
44+
expected_pyramid_image_sizes=[(32, 32), (16, 16), (8, 8)],
4445
)
4546

4647
@pytest.mark.large

keras_hub/src/models/cspnet/cspnet_presets.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,46 @@
66
"description": (
77
"A CSP-DarkNet (Cross-Stage-Partial) image classification model"
88
" pre-trained on the Randomly Augmented ImageNet 1k dataset at "
9-
"a 224x224 resolution."
9+
"a 256x256 resolution."
1010
),
11-
"params": 26652512,
11+
"params": 27642184,
1212
"path": "cspnet",
1313
},
14-
"kaggle_handle": "kaggle://keras/cspdarknet/keras/csp_darknet_53_ra_imagenet/1",
14+
"kaggle_handle": "kaggle://keras/cspdarknet/keras/csp_darknet_53_ra_imagenet/2",
15+
},
16+
"csp_resnext_50_ra_imagenet": {
17+
"metadata": {
18+
"description": (
19+
"A CSP-ResNeXt (Cross-Stage-Partial) image classification model"
20+
" pre-trained on the Randomly Augmented ImageNet 1k dataset at "
21+
"a 256x256 resolution."
22+
),
23+
"params": 20569896,
24+
"path": "cspnet",
25+
},
26+
"kaggle_handle": "kaggle://keras/cspdarknet/keras/csp_resnext_50_ra_imagenet/1",
27+
},
28+
"csp_resnet_50_ra_imagenet": {
29+
"metadata": {
30+
"description": (
31+
"A CSP-ResNet (Cross-Stage-Partial) image classification model"
32+
" pre-trained on the Randomly Augmented ImageNet 1k dataset at "
33+
"a 256x256 resolution."
34+
),
35+
"params": 21616168,
36+
"path": "cspnet",
37+
},
38+
"kaggle_handle": "kaggle://keras/cspdarknet/keras/csp_resnet_50_ra_imagenet/1",
39+
},
40+
"darknet_53_imagenet": {
41+
"metadata": {
42+
"description": (
43+
"A DarkNet image classification model pre-trained on the"
44+
"ImageNet 1k dataset at a 256x256 resolution."
45+
),
46+
"params": 41609928,
47+
"path": "cspnet",
48+
},
49+
"kaggle_handle": "kaggle://keras/cspdarknet/keras/darknet_53_imagenet/1",
1550
},
1651
}

keras_hub/src/utils/timm/convert_cspnet.py

Lines changed: 94 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,69 @@ def convert_backbone_config(timm_config):
1717
bottle_ratio = (0.5,) + (1.0,)
1818
block_ratio = (1.0,) + (0.5,)
1919
expand_ratio = (2.0,) + (1.0,)
20+
stem_padding = "same"
21+
stem_pooling = None
2022
stage_type = "csp"
23+
groups = 1
2124
block_type = "dark_block"
2225
down_growth = True
23-
stackwise_strides = 2
26+
stackwise_strides = [2, 2, 2, 2, 2]
27+
avg_down = False
28+
cross_linear = False
29+
elif timm_architecture == "cspresnet50":
30+
stem_filters = 64
31+
stem_kernel_size = 7
32+
stem_strides = 4
33+
stackwise_depth = [3, 3, 5, 2]
34+
stackwise_strides = [1, 2, 2, 2]
35+
stackwise_num_filters = [128, 256, 512, 1024]
36+
block_type = "bottleneck_block"
37+
stage_type = "csp"
38+
bottle_ratio = [0.5]
39+
block_ratio = [1.0]
40+
expand_ratio = [2.0]
41+
stem_padding = "valid"
42+
stem_pooling = "max"
43+
avg_down = False
44+
groups = 1
45+
down_growth = False
46+
cross_linear = True
47+
elif timm_architecture == "cspresnext50":
48+
stem_filters = 64
49+
stem_kernel_size = 7
50+
stem_strides = 4
51+
stackwise_depth = [3, 3, 5, 2]
52+
stackwise_num_filters = [256, 512, 1024, 2048]
53+
bottle_ratio = [1.0]
54+
block_ratio = [0.5]
55+
expand_ratio = [1.0]
56+
stage_type = "csp"
57+
block_type = "bottleneck_block"
58+
stem_pooling = "max"
59+
stackwise_strides = [1, 2, 2, 2]
60+
groups = 32
61+
stem_padding = "valid"
62+
avg_down = False
63+
down_growth = False
64+
cross_linear = True
65+
elif timm_architecture == "darknet53":
66+
stem_filters = 32
67+
stem_kernel_size = 3
68+
stem_strides = 1
69+
stackwise_depth = [1, 2, 8, 8, 4]
70+
stackwise_num_filters = [64, 128, 256, 512, 1024]
71+
bottle_ratio = [0.5]
72+
block_ratio = [1.0]
73+
groups = 1
74+
expand_ratio = [1.0]
75+
stage_type = "dark"
76+
block_type = "dark_block"
77+
stem_pooling = None
78+
stackwise_strides = [2, 2, 2, 2, 2]
79+
stem_padding = "same"
80+
avg_down = False
81+
down_growth = False
82+
cross_linear = False
2483
else:
2584
raise ValueError(
2685
f"Currently, the architecture {timm_architecture} is not supported."
@@ -38,6 +97,11 @@ def convert_backbone_config(timm_config):
3897
block_type=block_type,
3998
stackwise_strides=stackwise_strides,
4099
down_growth=down_growth,
100+
stem_pooling=stem_pooling,
101+
stem_padding=stem_padding,
102+
avg_down=avg_down,
103+
cross_linear=cross_linear,
104+
groups=groups,
41105
)
42106

43107

@@ -81,21 +145,36 @@ def port_batch_normalization(hf_weight_prefix, keras_layer_name):
81145
stackwise_depth = backbone.stackwise_depth
82146
stage_type = backbone.stage_type
83147
block_type = backbone.block_type
148+
strides = backbone.stackwise_strides
84149

85150
for idx, block in enumerate(stackwise_depth):
86-
port_conv2d(
87-
f"stages.{idx}.conv_down.conv",
88-
f"stage_{idx}_{stage_type}_conv_down_1",
89-
)
90-
port_batch_normalization(
91-
f"stages.{idx}.conv_down.bn", f"stage_{idx}_{stage_type}_bn_1"
92-
)
93-
port_conv2d(
94-
f"stages.{idx}.conv_exp.conv", f"stage_{idx}_{stage_type}_conv_exp"
95-
)
96-
port_batch_normalization(
97-
f"stages.{idx}.conv_exp.bn", f"stage_{idx}_{stage_type}_bn_2"
98-
)
151+
if strides[idx] != 1 or stage_type == "dark":
152+
if strides[idx] == 2 and backbone.avg_down:
153+
port_conv2d(
154+
f"stages.{idx}.conv_down.1.conv",
155+
f"stage_{idx}_{stage_type}_conv_down_1",
156+
)
157+
port_batch_normalization(
158+
f"stages.{idx}.conv_down.1.bn",
159+
f"stage_{idx}_{stage_type}_bn_1",
160+
)
161+
else:
162+
port_conv2d(
163+
f"stages.{idx}.conv_down.conv",
164+
f"stage_{idx}_{stage_type}_conv_down_1",
165+
)
166+
port_batch_normalization(
167+
f"stages.{idx}.conv_down.bn",
168+
f"stage_{idx}_{stage_type}_bn_1",
169+
)
170+
if stage_type != "dark":
171+
port_conv2d(
172+
f"stages.{idx}.conv_exp.conv",
173+
f"stage_{idx}_{stage_type}_conv_exp",
174+
)
175+
port_batch_normalization(
176+
f"stages.{idx}.conv_exp.bn", f"stage_{idx}_{stage_type}_bn_2"
177+
)
99178

100179
for i in range(block):
101180
port_conv2d(
@@ -133,16 +212,8 @@ def port_batch_normalization(hf_weight_prefix, keras_layer_name):
133212
f"stages.{idx}.conv_transition_b.bn",
134213
f"stage_{idx}_{stage_type}_transition_b_bn",
135214
)
136-
port_conv2d(
137-
f"stages.{idx}.conv_transition.conv",
138-
f"stage_{idx}_{stage_type}_conv_transition",
139-
)
140-
port_batch_normalization(
141-
f"stages.{idx}.conv_transition.bn",
142-
f"stage_{idx}_{stage_type}_transition_bn",
143-
)
144215

145-
else:
216+
if stage_type != "dark":
146217
port_conv2d(
147218
f"stages.{idx}.conv_transition.conv",
148219
f"stage_{idx}_{stage_type}_conv_transition",

keras_hub/src/utils/timm/convert_cspnet_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
from keras_hub.src.tests.test_case import TestCase
77

88

9-
class TimmDenseNetBackboneTest(TestCase):
9+
class TimmCSPNetBackboneTest(TestCase):
1010
@pytest.mark.large
11-
def test_convert_densenet_backbone(self):
11+
def test_convert_cspnet_backbone(self):
1212
model = Backbone.from_preset("hf://timm/cspdarknet53.ra_in1k")
13-
outputs = model.predict(ops.ones((1, 224, 224, 3)))
14-
self.assertEqual(outputs.shape, (1, 5, 5, 1024))
13+
outputs = model.predict(ops.ones((1, 256, 256, 3)))
14+
self.assertEqual(outputs.shape, (1, 8, 8, 1024))
1515

1616
@pytest.mark.large
17-
def test_convert_densenet_classifier(self):
17+
def test_convert_cspnet_classifier(self):
1818
model = ImageClassifier.from_preset("hf://timm/cspdarknet53.ra_in1k")
1919
outputs = model.predict(ops.ones((1, 512, 512, 3)))
2020
self.assertEqual(outputs.shape, (1, 1000))

0 commit comments

Comments
 (0)