16
16
17
17
H = NewType ("H" , int )
18
18
W = NewType ("W" , int )
19
- RgbBuf = NewType ("RgbBuf" , np .ndarray )
20
19
21
20
22
21
@dataclass
@@ -84,8 +83,13 @@ def apply(
84
83
) -> list [SegResult ]:
85
84
CLASS_ID = 0
86
85
assert image .mode == "RGB"
87
- buf = RgbBuf (np .array (image ))
88
- raw = self ._run (buf , threshold .confidence , threshold .iou )
86
+ assert image .width > 0
87
+ assert image .height > 0
88
+
89
+ assert 0.0 <= threshold .iou <= 1.0
90
+ assert 0.0 <= threshold .confidence <= 1.0
91
+
92
+ raw = self ._run (image , threshold .confidence , iou_threshold = threshold .iou )
89
93
if raw is None :
90
94
return []
91
95
@@ -105,17 +109,16 @@ def apply(
105
109
return results
106
110
107
111
def _run (
108
- self , img : RgbBuf , conf_threshold : float , iou_threshold : float
112
+ self , image : PILImage , conf_threshold : float , iou_threshold : float
109
113
) -> RawResult | None :
110
114
NM = 32
111
- ih , iw , _ = img .shape
112
-
113
- blob , ratio , (pad_w , pad_h ) = self .preprocess (img )
115
+ assert image .mode == "RGB"
116
+ blob , ratio , (pad_w , pad_h ) = self .preprocess (image )
114
117
assert blob .ndim == 4
115
118
preds = self .session .run (None , {self .input_ .name : blob })
116
119
return self .postprocess (
117
120
preds ,
118
- img_size = (ih , iw ),
121
+ img_size = (H ( image . height ), W ( image . width ) ),
119
122
ratio = ratio ,
120
123
pad_w = pad_w ,
121
124
pad_h = pad_h ,
@@ -125,15 +128,20 @@ def _run(
125
128
)
126
129
127
130
def preprocess (
128
- self , img_buf : RgbBuf
131
+ self , image : PILImage
129
132
) -> tuple [np .ndarray , float , tuple [float , float ]]:
130
133
BORDER_COLOR = (114 , 114 , 114 )
131
134
EPS = 0.1
132
- img = np .array (img_buf )
135
+
136
+ assert image .mode == "RGB"
137
+ img = np .array (image )
138
+
133
139
ih , iw , _ = img .shape
134
140
oh , ow = self .model_height , self .model_width
135
141
r = min (oh / ih , ow / iw )
136
142
rw , rh = round (iw * r ), round (ih * r )
143
+ rw = max (1 , rw )
144
+ rh = max (1 , rh )
137
145
138
146
pad_w , pad_h = [
139
147
(ow - rw ) / 2 ,
@@ -167,16 +175,16 @@ def postprocess(
167
175
B = 1
168
176
NM , MH , MW = (nm , 160 , 160 )
169
177
NUM_CLASSES = 1
170
- C = 4 + NUM_CLASSES + NM
171
178
172
179
x , protos = preds
173
180
assert len (x ) == len (protos ) == B
174
181
protos = protos [0 ]
175
182
x = x [0 ].T
176
183
assert protos .shape == (NM , MH , MW ), protos .shape
177
- assert x .shape == (len (x ), C )
184
+ assert x .shape == (len (x ), 4 + NUM_CLASSES + NM )
178
185
179
186
likely = x [:, 4 : 4 + NUM_CLASSES ].max (axis = 1 ) > conf_threshold
187
+ assert likely .ndim == 1
180
188
x = x [likely ]
181
189
182
190
scores = x [:, 4 : 4 + NUM_CLASSES ].max (axis = 1 )
@@ -335,14 +343,15 @@ def with_border(
335
343
bottom : int ,
336
344
left : int ,
337
345
right : int ,
338
- color : tuple [ int , int , int ] ,
346
+ color : Color ,
339
347
) -> np .ndarray :
340
- import cv2
341
-
342
348
assert img .ndim == 3
343
- return cv2 .copyMakeBorder (
344
- img , top , bottom , left , right , cv2 .BORDER_CONSTANT , value = color
345
- )
349
+ pil_img = Image .fromarray (img )
350
+ ow = pil_img .width + left + right
351
+ oh = pil_img .height + top + bottom
352
+ out = Image .new ("RGB" , (ow , oh ), color )
353
+ out .paste (pil_img , (left , top ))
354
+ return np .array (out ).astype (img .dtype )
346
355
347
356
348
357
def resize (buf : np .ndarray , size : tuple [W , H ]) -> np .ndarray :
0 commit comments