Skip to content

Does this repo's implementation of maskrcnn work with negative samples? #80

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
rbavery opened this issue Oct 15, 2019 · 20 comments
Closed

Comments

@rbavery
Copy link

rbavery commented Oct 15, 2019

❓ Questions and Help

I'm working with satellite imagery where some image tiles have over 100 objects to segment and some images have none (these are negative samples). The negative samples still provide information on what is not a segmentation (in my case, types of agriculture land cover). This issue on another facebookresearch repo seems to indicate that negative samples can be very useful but that they are not always supported out of the box: facebookresearch/maskrcnn-benchmark#169

Does this repo support negative samples? Does it depend on the maskrcnn model I pretrain from in the model zoo?

@ppwwyyxx
Copy link
Contributor

We allow input images to have no objects, if that's what you mean.

However, the default data loader filter such images at

dataset_dicts = get_detection_dataset_dicts(
cfg.DATASETS.TRAIN,
filter_empty=True,
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
if cfg.MODEL.KEYPOINT_ON
else 0,
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
)

If you need to keep them, you need to write a custom data loader

@rbavery
Copy link
Author

rbavery commented Oct 15, 2019

So if I wrote a data loader that accomplishes the same thing as setting filter_empty=False, would there be any other issues with training with negative samples? Or is this unknown? Thanks for the advice!

@rbavery
Copy link
Author

rbavery commented Oct 15, 2019

I guess I'm wondering what the rationale is to filter out empty annotations by default, is this because some segmentation models like MaskRCNN can't make use of negative samples?

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Oct 16, 2019

would there be any other issues with training with negative samples

There shouldn't be. But we don't normally do this, so there might be things we are not aware of. If you found any issues please report to us.

what the rationale is to filter out empty annotations by default,

Mainly because that's the tradition I think.

@ppwwyyxx ppwwyyxx added the usage label Oct 16, 2019
@rbgirshick
Copy link
Contributor

IIRC, the filtering was first added in detectron (v1) when we implemented keypoint prediction in Mask R-CNN. In that case only images that contain people annotated with keypoints are useful for training and this leaves a large number of images without positive examples. We found that training was faster and keypoint AP was no worse if the images without positives were filtered out. Since it had little impact on the usual COCO training data, filtering of these images was left on as the default. Detectron2 has inherited this choice, but in retrospect the design decision is a bit overfit to considerations of the COCO dataset.

@ppwwyyxx -- perhaps we should consider making this filtering configurable?

@rbavery
Copy link
Author

rbavery commented Oct 16, 2019

Thanks both of you for the explanations. IMO then, the option should probably be configurable by default, since it's a bit of an unexpected and hidden behavior for folks looking to use Detectron2 on their own custom datasets. I'm happy to make a PR to make empty filtering a cfg attribute if the maintainers agree this would be valuable. I can also test this out on my own datasets to see if their are any errors that pop up and if including negative/empty samples improves segmentation accuracy.

@rbgirshick
Copy link
Contributor

@rbavery -- sounds good, a PR is welcome!

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Oct 20, 2019

added in 62522b6.
Now, set cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS=False to include negative samples

@edwardchaos
Copy link

would there be any other issues with training with negative samples

There shouldn't be. But we don't normally do this, so there might be things we are not aware of. If you found any issues please report to us.

what the rationale is to filter out empty annotations by default,

Mainly because that's the tradition I think.

Hi, I'm using a custom dataset with torch DataLoader.
The model I'm using is ./model_final_a3ec72.pkl
Introducing negative samples fails in fast_rcnn.py(173):

File "/home/edwardchaos/detectron2_repo/detectron2/modeling/roi_heads/fast_rcnn.py", line 173, in _log_accuracy
pred_classes = self.pred_class_logits.argmax(dim=1)
RuntimeError: cannot perform reduction function argmax on a tensor with no elements because the operation does not have an identity

If I add a check to skip empty prediction in _log_accuracy function of fast_rcnn.py:

num_instances = self.gt_classes.numel()
if self.pred_class_logits.nelement() == 0:
return
pred_classes = self.pred_class_logits.argmax(dim=1)
bg_class_ind = self.pred_class_logits.shape[1] - 1

It then fails at:
File "/home/edwardchaos/anaconda3/envs/detectron2/lib/python3.7/site-packages/torch/nn/functional.py", line 1838, in nll_loss
ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: invalid argument 2: non-empty vector or matrix expected at /opt/conda/conda-bld/pytorch_1570711283072/work/aten/src/THCUNN/generic/ClassNLLCriterion.cu:31

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Nov 7, 2019

negative samples should not make self.pred_class_logits empty. Please provide exact steps to let others reproduce the issue if you think it is a detectron2 issue.

@edwardchaos
Copy link

I found this comment in fast_rcnn.py:

proposals (list[Instances]): A list of N Instances, where Instances i stores the
proposals for image i, in the field "proposal_boxes".
When training, each Instances must have ground-truth labels
stored in the field "gt_classes" and "gt_boxes".

Maybe mask rcnn accepts negative samples, i'll just use sample with gt labels for my purposes.

@MLaurenceFournier
Copy link

I'm working on project based on Torchvision where negative samples also cause problems. Since Torchvision currently don't support them, I'm wondering if it would be worth switching to Detectron2. As of today, what is the answer to this issue's title: "Does this repo's implementation of maskrcnn work with negative samples?" What about faster-rcnn? I understand a data loader option has been added, but it's still unclear if the models support negative samples.

@ppwwyyxx
Copy link
Contributor

Images without annotations are supported in R-CNN.

@Hvorost
Copy link

Hvorost commented Dec 10, 2019

Hi, maybe someone can help me. I try to implement this process https://github.com/TannerGilbert/Microcontroller-Detection-with Detectron2/blob/master/Detectron2_Detect_Microcontrollers.ipynb with some modification, in which I try to train model without filtering empty annotations, for that, I modify the value of this variable DATALOADER.FILTER_EMPTY_ANNOTATIONS = False , my dataset looks like
image and my code
`def get_watermark_dicts(csv_file, img_dir):
df = pd.read_csv(csv_file)
df['filename'] = df['filename'].map(lambda x: img_dir+x)

classes = ['watermark', 'not']

df['class_int'] = df['class'].map(lambda x: classes.index(x))

dataset_dicts = []
for filename in df['filename'].unique().tolist():

    record = {}
    
    height, width = cv2.imread(filename).shape[:2]
    
    record["file_name"] = filename
    record["height"] = height
    record["width"] = width

    objs = []
    for index, row in df[(df['filename']==filename)].iterrows():
        obj= {
            'bbox': [row['xmin'], row['ymin'], row['xmax'], row['ymax']],
            'bbox_mode': BoxMode.XYXY_ABS,
            'category_id': row['class_int'],
            "iscrowd": 0
        }
        objs.append(obj)
    record["annotations"] = objs
    dataset_dicts.append(record)
return dataset_dicts

cfg = get_cfg()
cfg.merge_from_file("detectron2_repo/configs/COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
cfg.DATASETS.TRAIN = ('/home/ubuntu/watermark/train',)
cfg.DATASETS.TEST = () # no metrics implemented for this dataset
cfg.DATALOADER.NUM_WORKERS = 4
cfg.MODEL.WEIGHTS = "detectron2://COCO-Detection/faster_rcnn_R_50_FPN_3x/137849458/model_final_280758.pkl" # initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.MAX_ITER = 2000
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS = False

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()`

but my training process crushed with:
AssertionError: Caught AssertionError in DataLoader worker process 1.
Original Traceback (most recent call last):
File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
data = fetcher.fetch(index)
File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/ubuntu/detectron2_repo/detectron2/data/common.py", line 39, in getitem
data = self._map_func(self._dataset[cur_idx])
File "/home/ubuntu/detectron2_repo/detectron2/utils/serialize.py", line 23, in call
return self._obj(*args, **kwargs)
File "/home/ubuntu/detectron2_repo/detectron2/data/dataset_mapper.py", line 131, in call
annos, image_shape, mask_format=self.mask_format
File "/home/ubuntu/detectron2_repo/detectron2/data/detection_utils.py", line 234, in annotations_to_instances
boxes.clip(image_size)
File "/home/ubuntu/detectron2_repo/detectron2/structures/boxes.py", line 130, in clip
assert torch.isfinite(self.tensor).all()
AssertionError
Maybe someone can help I would be very grateful .

@ppwwyyxx
Copy link
Contributor

It does not allow boxes with NaNs.

@Hvorost
Copy link

Hvorost commented Dec 11, 2019

It does not allow boxes with NaNs.

I understand it, but how I can train the model with negative samples? How can I use the images which haven't boxes? If you can explain I would be very grateful.

@ppwwyyxx
Copy link
Contributor

You can train with an image whose "annotations" is an empty list.

@Hvorost
Copy link

Hvorost commented Dec 11, 2019

Thanks, I'll try.

@Hvorost
Copy link

Hvorost commented Dec 13, 2019

I still got the same issue. As I understand it, I don't need to clip boxes in boxes.clip(image_size) function, because I haven't coordinates to clip, but the process still the same as in case when I have nan in bbox list. Now I use an empty list for negative sample images. The question is open, please help.

@Hvorost
Copy link

Hvorost commented Dec 13, 2019

full traceback

AssertionError Traceback (most recent call last)
in
16 trainer = DefaultTrainer(cfg)
17 trainer.resume_or_load(resume=False)
---> 18 trainer.train()

~/detectron2_repo/detectron2/engine/defaults.py in train(self)
367 OrderedDict of results, if evaluation is enabled. Otherwise None.
368 """
--> 369 super().train(self.start_iter, self.max_iter)
370 if hasattr(self, "_last_eval_results") and comm.is_main_process():
371 verify_results(self.cfg, self._last_eval_results)

~/detectron2_repo/detectron2/engine/train_loop.py in train(self, start_iter, max_iter)
130 for self.iter in range(start_iter, max_iter):
131 self.before_step()
--> 132 self.run_step()
133 self.after_step()
134 finally:

~/detectron2_repo/detectron2/engine/train_loop.py in run_step(self)
204 If your want to do something with the data, you can wrap the dataloader.
205 """
--> 206 data = next(self._data_loader_iter)
207 data_time = time.perf_counter() - start
208

~/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py in next(self)
799 if len(self._task_info[self._rcvd_idx]) == 2:
800 data = self._task_info.pop(self._rcvd_idx)[1]
--> 801 return self._process_data(data)
802
803 assert not self._shutdown and self._tasks_outstanding > 0

~/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
844 self._try_put_index()
845 if isinstance(data, ExceptionWrapper):
--> 846 data.reraise()
847 return data
848

~/.local/lib/python3.6/site-packages/torch/_utils.py in reraise(self)
383 # (https://bugs.python.org/issue2651), so we work around it.
384 msg = KeyErrorMessage(msg)
--> 385 raise self.exc_type(msg)

AssertionError: Caught AssertionError in DataLoader worker process 1.
Original Traceback (most recent call last):
File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
data = fetcher.fetch(index)
File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/ubuntu/detectron2_repo/detectron2/data/common.py", line 39, in getitem
data = self._map_func(self._dataset[cur_idx])
File "/home/ubuntu/detectron2_repo/detectron2/utils/serialize.py", line 23, in call
return self._obj(*args, **kwargs)
File "/home/ubuntu/detectron2_repo/detectron2/data/dataset_mapper.py", line 131, in call
annos, image_shape, mask_format=self.mask_format
File "/home/ubuntu/detectron2_repo/detectron2/data/detection_utils.py", line 234, in annotations_to_instances
boxes.clip(image_size)
File "/home/ubuntu/detectron2_repo/detectron2/structures/boxes.py", line 130, in clip
assert torch.isfinite(self.tensor).all()
AssertionError

@github-actions github-actions bot locked as resolved and limited conversation to collaborators Apr 7, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants