Skip to content

Commit 4ef9aea

Browse files
authored
Merge pull request facebookresearch#2 from aaronlelevier/trace_model
Two small improvements for trace_model.py
2 parents b1b03ae + 4905a65 commit 4ef9aea

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

demo/trace_model.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2+
import os
23
import numpy
4+
from io import BytesIO
35
from matplotlib import pyplot
46

7+
import requests
58
import torch
69

710
from PIL import Image
@@ -11,7 +14,11 @@
1114

1215
if __name__ == "__main__":
1316
# load config from file and command-line arguments
14-
cfg.merge_from_file("../configs/caffe2/e2e_mask_rcnn_R_50_FPN_1x_caffe2.yaml")
17+
18+
project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
19+
cfg.merge_from_file(
20+
os.path.join(project_dir,
21+
"configs/caffe2/e2e_mask_rcnn_R_50_FPN_1x_caffe2.yaml"))
1522
cfg.merge_from_list(["MODEL.DEVICE", "cpu"])
1623
cfg.freeze()
1724

@@ -131,9 +138,14 @@ def process_image_with_traced_model(image):
131138
result_image = combine_masks(original_image, labels, masks, scores, boxes, 0.5, 1, rectangle=True)
132139
return result_image
133140

141+
def fetch_image(url):
142+
response = requests.get(url)
143+
return Image.open(BytesIO(response.content)).convert("RGB")
134144

135145
if __name__ == "__main__":
136-
pil_image = Image.open("3915380994_2e611b1779_z.jpg").convert("RGB")
146+
pil_image = fetch_image(
147+
url="http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg")
148+
137149
# convert to BGR format
138150
image = torch.from_numpy(numpy.array(pil_image)[:, :, [2, 1, 0]])
139151
original_image = image
@@ -159,7 +171,8 @@ def end_to_end_model(image):
159171
pyplot.show()
160172

161173
# second image
162-
image2 = Image.open('17790319373_bd19b24cfc_k.jpg').convert("RGB")
174+
image2 = fetch_image(
175+
url='http://farm4.staticflickr.com/3153/2970773875_164f0c0b83_z.jpg')
163176
image2 = image2.resize((640, 480), Image.BILINEAR)
164177
image2 = torch.from_numpy(numpy.array(image2)[:, :, [2, 1, 0]])
165178
result_image2 = process_image_with_traced_model(image2)

0 commit comments

Comments
 (0)