1
1
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ import os
2
3
import numpy
4
+ from io import BytesIO
3
5
from matplotlib import pyplot
4
6
7
+ import requests
5
8
import torch
6
9
7
10
from PIL import Image
11
14
12
15
if __name__ == "__main__" :
13
16
# load config from file and command-line arguments
14
- cfg .merge_from_file ("../configs/caffe2/e2e_mask_rcnn_R_50_FPN_1x_caffe2.yaml" )
17
+
18
+ project_dir = os .path .dirname (os .path .dirname (os .path .abspath (__file__ )))
19
+ cfg .merge_from_file (
20
+ os .path .join (project_dir ,
21
+ "configs/caffe2/e2e_mask_rcnn_R_50_FPN_1x_caffe2.yaml" ))
15
22
cfg .merge_from_list (["MODEL.DEVICE" , "cpu" ])
16
23
cfg .freeze ()
17
24
@@ -131,9 +138,14 @@ def process_image_with_traced_model(image):
131
138
result_image = combine_masks (original_image , labels , masks , scores , boxes , 0.5 , 1 , rectangle = True )
132
139
return result_image
133
140
141
+ def fetch_image (url ):
142
+ response = requests .get (url )
143
+ return Image .open (BytesIO (response .content )).convert ("RGB" )
134
144
135
145
if __name__ == "__main__" :
136
- pil_image = Image .open ("3915380994_2e611b1779_z.jpg" ).convert ("RGB" )
146
+ pil_image = fetch_image (
147
+ url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg" )
148
+
137
149
# convert to BGR format
138
150
image = torch .from_numpy (numpy .array (pil_image )[:, :, [2 , 1 , 0 ]])
139
151
original_image = image
@@ -159,7 +171,8 @@ def end_to_end_model(image):
159
171
pyplot .show ()
160
172
161
173
# second image
162
- image2 = Image .open ('17790319373_bd19b24cfc_k.jpg' ).convert ("RGB" )
174
+ image2 = fetch_image (
175
+ url = 'http://farm4.staticflickr.com/3153/2970773875_164f0c0b83_z.jpg' )
163
176
image2 = image2 .resize ((640 , 480 ), Image .BILINEAR )
164
177
image2 = torch .from_numpy (numpy .array (image2 )[:, :, [2 , 1 , 0 ]])
165
178
result_image2 = process_image_with_traced_model (image2 )
0 commit comments