Skip to content

Commit 770665f

Browse files
authored
torchvision tutorial: update deprecated pretrained=True to weights="DEFAULT" (#1998)
1 parent d5f7a40 commit 770665f

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

intermediate_source/torchvision_tutorial.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ way of doing it:
221221
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
222222
223223
# load a model pre-trained on COCO
224-
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
224+
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
225225
226226
# replace the classifier with a new one, that has
227227
# num_classes which is user-defined
@@ -242,7 +242,7 @@ way of doing it:
242242
243243
# load a pre-trained model for classification and return
244244
# only the features
245-
backbone = torchvision.models.mobilenet_v2(pretrained=True).features
245+
backbone = torchvision.models.mobilenet_v2(weights="DEFAULT").features
246246
# FasterRCNN needs to know the number of
247247
# output channels in a backbone. For mobilenet_v2, it's 1280
248248
# so we need to add it here
@@ -291,7 +291,7 @@ be using Mask R-CNN:
291291
292292
def get_model_instance_segmentation(num_classes):
293293
# load an instance segmentation model pre-trained on COCO
294-
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
294+
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")
295295
296296
# get number of input features for the classifier
297297
in_features = model.roi_heads.box_predictor.cls_score.in_features
@@ -344,7 +344,7 @@ expects during training and inference time on sample data.
344344

345345
.. code:: python
346346
347-
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
347+
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
348348
dataset = PennFudanDataset('PennFudanPed', get_transform(train=True))
349349
data_loader = torch.utils.data.DataLoader(
350350
dataset, batch_size=2, shuffle=True, num_workers=4,

0 commit comments

Comments
 (0)