Skip to content

Commit 5402948

Browse files
authored
[datasets] COCO-Text V2 integration (#1888)
1 parent a8a81bf commit 5402948

File tree

8 files changed

+330
-55
lines changed

8 files changed

+330
-55
lines changed

docs/source/index.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ Supported datasets
6868
* MJSynth from `"Synthetic Data and Artificial Neural Networks for Natural Scene Text Recognition" <https://www.robots.ox.ac.uk/~vgg/data/text/>`_.
6969
* IIITHWS from `"Generating Synthetic Data for Text Recognition" <https://github.com/kris314/hwnet>`_.
7070
* WILDRECEIPT from `"Spatial Dual-Modality Graph Reasoning for Key Information Extraction" <https://arxiv.org/pdf/2103.14470v1.pdf>`_.
71-
71+
* COCO-Text dataset from `"COCO-Text: Dataset and Benchmark for Text Detection and Recognition in Natural Images" <https://arxiv.org/pdf/1601.07140v2>`_.
7272

7373
.. toctree::
7474
:maxdepth: 2

docs/source/modules/datasets.rst

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ doctr.datasets
3636

3737
.. autoclass:: WILDRECEIPT
3838

39+
.. autoclass:: COCOTEXT
40+
3941
Synthetic dataset generator
4042
---------------------------
4143

docs/source/using_doctr/using_datasets.rst

+58-54
Large diffs are not rendered by default.

doctr/datasets/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from doctr.file_utils import is_tf_available
22

33
from .generator import *
4+
from .coco_text import *
45
from .cord import *
56
from .detection import *
67
from .doc_artefacts import *

doctr/datasets/coco_text.py

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright (C) 2021-2025, Mindee.
2+
3+
# This program is licensed under the Apache License 2.0.
4+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5+
6+
import json
7+
import os
8+
from pathlib import Path
9+
from typing import Any
10+
11+
import numpy as np
12+
from tqdm import tqdm
13+
14+
from .datasets import AbstractDataset
15+
from .utils import convert_target_to_relative, crop_bboxes_from_image
16+
17+
__all__ = ["COCOTEXT"]
18+
19+
20+
class COCOTEXT(AbstractDataset):
21+
"""
22+
COCO-Text dataset from `"COCO-Text: Dataset and Benchmark for Text Detection and Recognition in Natural Images"
23+
<https://arxiv.org/pdf/1601.07140v2>`_ |
24+
`"homepage" <https://bgshih.github.io/cocotext/>`_.
25+
26+
>>> # NOTE: You need to download the dataset first.
27+
>>> from doctr.datasets import COCOTEXT
28+
>>> train_set = COCOTEXT(train=True, img_folder="/path/to/coco_text/train2014/",
29+
>>> label_path="/path/to/coco_text/cocotext.v2.json")
30+
>>> img, target = train_set[0]
31+
>>> test_set = COCOTEXT(train=False, img_folder="/path/to/coco_text/train2014/",
32+
>>> label_path = "/path/to/coco_text/cocotext.v2.json")
33+
>>> img, target = test_set[0]
34+
35+
Args:
36+
img_folder: folder with all the images of the dataset
37+
label_path: path to the annotations file of the dataset
38+
train: whether the subset should be the training one
39+
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
40+
recognition_task: whether the dataset should be used for recognition task
41+
detection_task: whether the dataset should be used for detection task
42+
**kwargs: keyword arguments from `AbstractDataset`.
43+
"""
44+
45+
def __init__(
46+
self,
47+
img_folder: str,
48+
label_path: str,
49+
train: bool = True,
50+
use_polygons: bool = False,
51+
recognition_task: bool = False,
52+
detection_task: bool = False,
53+
**kwargs: Any,
54+
) -> None:
55+
super().__init__(
56+
img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs
57+
)
58+
# Task check
59+
if recognition_task and detection_task:
60+
raise ValueError(
61+
" 'recognition' and 'detection task' cannot be set to True simultaneously. "
62+
+ " To get the whole dataset with boxes and labels leave both parameters to False "
63+
)
64+
65+
# File existence check
66+
if not os.path.exists(label_path) or not os.path.exists(img_folder):
67+
raise FileNotFoundError(f"unable to find {label_path if not os.path.exists(label_path) else img_folder}")
68+
69+
tmp_root = img_folder
70+
self.train = train
71+
np_dtype = np.float32
72+
self.data: list[tuple[str | Path | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
73+
74+
with open(label_path, "r") as file:
75+
data = json.load(file)
76+
77+
# Filter images based on the set
78+
img_items = [img for img in data["imgs"].items() if (img[1]["set"] == "train") == train]
79+
80+
for img_id, img_info in tqdm(img_items, desc="Preparing and Loading COCOTEXT", total=len(img_items)):
81+
img_path = os.path.join(img_folder, img_info["file_name"])
82+
83+
if not os.path.exists(img_path):
84+
raise FileNotFoundError(f"Unable to locate {img_path}")
85+
86+
# Get annotations for the current image (only legible text)
87+
annotations = [
88+
ann
89+
for ann in data["anns"].values()
90+
if ann["image_id"] == int(img_id) and ann["legibility"] == "legible"
91+
]
92+
93+
if not annotations: # Some images have no annotations with readable text
94+
continue
95+
96+
_targets = []
97+
98+
for annotation in annotations:
99+
x, y, w, h = annotation["bbox"]
100+
if use_polygons:
101+
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
102+
box = np.array(
103+
[
104+
[x, y],
105+
[x + w, y],
106+
[x + w, y + h],
107+
[x, y + h],
108+
],
109+
dtype=np_dtype,
110+
)
111+
else:
112+
# (xmin, ymin, xmax, ymax) coordinates
113+
box = [x, y, x + w, y + h]
114+
_targets.append((annotation["utf8_string"], box))
115+
text_targets, box_targets = zip(*_targets)
116+
117+
if recognition_task:
118+
crops = crop_bboxes_from_image(
119+
img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=int).clip(min=0)
120+
)
121+
for crop, label in zip(crops, list(text_targets)):
122+
if label and " " not in label:
123+
self.data.append((crop, label))
124+
125+
elif detection_task:
126+
self.data.append((img_path, np.asarray(box_targets, dtype=int).clip(min=0)))
127+
else:
128+
self.data.append((
129+
img_path,
130+
dict(boxes=np.asarray(box_targets, dtype=int).clip(min=0), labels=list(text_targets)),
131+
))
132+
133+
self.root = tmp_root
134+
135+
def extra_repr(self) -> str:
136+
return f"train={self.train}"

tests/conftest.py

+70
Original file line numberDiff line numberDiff line change
@@ -711,3 +711,73 @@ def mock_wildreceipt_dataset(tmpdir_factory, mock_image_stream):
711711
with open(fn_i, "wb") as f:
712712
f.write(file.getbuffer())
713713
return str(image_folder), str(annotation_file)
714+
715+
716+
@pytest.fixture(scope="session")
717+
def mock_cocotext_dataset(tmpdir_factory, mock_image_stream):
718+
file = BytesIO(mock_image_stream)
719+
root = tmpdir_factory.mktemp("datasets")
720+
cocotext_root = root.mkdir("cocotext")
721+
annotations_folder = cocotext_root
722+
image_folder = cocotext_root.mkdir("train2014")
723+
724+
filenames = [
725+
"COCO_train2014_000000353709.jpg",
726+
"COCO_train2014_000000077346.jpg",
727+
"COCO_train2014_000000437996.jpg",
728+
]
729+
labels = {
730+
"cats": {},
731+
"anns": {
732+
"1": {
733+
"mask": [286.1, 215.5, 285.2, 221.5, 304.6, 222.0, 304.6, 216.9],
734+
"class": "machine printed",
735+
"bbox": [285.2, 215.5, 19.4, 6.5],
736+
"image_id": 367969,
737+
"id": 108418,
738+
"language": "english",
739+
"area": 105.6,
740+
"utf8_string": "GATO",
741+
"legibility": "legible",
742+
},
743+
"2": {
744+
"mask": [310.4, 304.6, 319.4, 302.1, 323.2, 318.1, 307.2, 318.1],
745+
"class": "machine printed",
746+
"bbox": [307.2, 302.1, 16.0, 16.0],
747+
"image_id": 77346,
748+
"id": 196817,
749+
"language": "english",
750+
"area": 184.75,
751+
"utf8_string": "6",
752+
"legibility": "legible",
753+
},
754+
"3": {
755+
"mask": [212.6, 245.8, 210.1, 248.6, 212.0, 262.8, 221.9, 260.9, 227.4, 244.6],
756+
"class": "machine printed",
757+
"bbox": [210.1, 244.6, 17.3, 18.2],
758+
"image_id": 437996,
759+
"id": 134765,
760+
"language": "english",
761+
"area": 221.31,
762+
"utf8_string": "17",
763+
"legibility": "legible",
764+
},
765+
},
766+
"imgs": {
767+
"367969": {"id": 367969, "set": "train", "width": 640, "file_name": f"{filenames[0]}", "height": 427},
768+
"77346": {"id": 77346, "set": "train", "width": 640, "file_name": f"{filenames[1]}", "height": 427},
769+
"437996": {"id": 437996, "set": "train", "width": 640, "file_name": f"{filenames[2]}", "height": 427},
770+
},
771+
"imgToAnns": {},
772+
"info": {},
773+
}
774+
775+
annotation_file = annotations_folder.join("cocotext.v2.json")
776+
with open(annotation_file, "w") as f:
777+
json.dump(labels, f)
778+
file = BytesIO(mock_image_stream)
779+
for img_name in filenames:
780+
fn = image_folder.join(f"{img_name}")
781+
with open(fn, "wb") as f:
782+
f.write(file.getbuffer())
783+
return str(image_folder), str(annotation_file)

tests/pytorch/test_datasets_pt.py

+31
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,37 @@ def test_wildreceipt_dataset(input_size, num_samples, rotate, recognition, detec
760760
datasets.WILDRECEIPT(*mock_wildreceipt_dataset, train=True, recognition_task=True, detection_task=True)
761761

762762

763+
@pytest.mark.parametrize("rotate", [True, False])
764+
@pytest.mark.parametrize(
765+
"input_size, num_samples, recognition, detection",
766+
[
767+
[[512, 512], 3, False, False], # Actual set has 13880 training samples and 3261 test samples
768+
[[32, 128], 3, True, False], # recognition
769+
[[512, 512], 3, False, True], # detection
770+
],
771+
)
772+
def test_cocotext_dataset(input_size, num_samples, rotate, recognition, detection, mock_cocotext_dataset):
773+
ds = datasets.COCOTEXT(
774+
*mock_cocotext_dataset,
775+
train=True,
776+
img_transforms=Resize(input_size),
777+
use_polygons=rotate,
778+
recognition_task=recognition,
779+
detection_task=detection,
780+
)
781+
assert len(ds) == num_samples
782+
assert repr(ds) == f"COCOTEXT(train={True})"
783+
if recognition:
784+
_validate_dataset_recognition_part(ds, input_size)
785+
elif detection:
786+
_validate_dataset_detection_part(ds, input_size, is_polygons=rotate)
787+
else:
788+
_validate_dataset(ds, input_size, is_polygons=rotate)
789+
790+
with pytest.raises(ValueError):
791+
datasets.COCOTEXT(*mock_cocotext_dataset, train=True, recognition_task=True, detection_task=True)
792+
793+
763794
# NOTE: following datasets are only for recognition task
764795

765796

tests/tensorflow/test_datasets_tf.py

+31
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,37 @@ def test_wildreceipt_dataset(input_size, num_samples, rotate, recognition, detec
733733
datasets.WILDRECEIPT(*mock_wildreceipt_dataset, train=True, recognition_task=True, detection_task=True)
734734

735735

736+
@pytest.mark.parametrize("rotate", [True, False])
737+
@pytest.mark.parametrize(
738+
"input_size, num_samples, recognition, detection",
739+
[
740+
[[512, 512], 3, False, False], # Actual set has 1268 training samples and 3261 test samples
741+
[[32, 128], 3, True, False], # recognition
742+
[[512, 512], 3, False, True], # detection
743+
],
744+
)
745+
def test_cocotext_dataset(input_size, num_samples, rotate, recognition, detection, mock_cocotext_dataset):
746+
ds = datasets.COCOTEXT(
747+
*mock_cocotext_dataset,
748+
train=True,
749+
img_transforms=Resize(input_size),
750+
use_polygons=rotate,
751+
recognition_task=recognition,
752+
detection_task=detection,
753+
)
754+
assert len(ds) == num_samples
755+
assert repr(ds) == f"COCOTEXT(train={True})"
756+
if recognition:
757+
_validate_dataset_recognition_part(ds, input_size)
758+
elif detection:
759+
_validate_dataset_detection_part(ds, input_size, is_polygons=rotate)
760+
else:
761+
_validate_dataset(ds, input_size, is_polygons=rotate)
762+
763+
with pytest.raises(ValueError):
764+
datasets.COCOTEXT(*mock_cocotext_dataset, train=True, recognition_task=True, detection_task=True)
765+
766+
736767
# NOTE: following datasets are only for recognition task
737768

738769

0 commit comments

Comments
 (0)