|
1 | 1 | from io import BytesIO
|
2 | 2 | from pathlib import Path
|
3 |
| -from typing import TYPE_CHECKING, Callable, Iterable, Iterator |
| 3 | +from typing import ( |
| 4 | + TYPE_CHECKING, |
| 5 | + Callable, |
| 6 | + Iterable, |
| 7 | + Iterator, |
| 8 | + Literal, |
| 9 | + TypeVar, |
| 10 | + cast, |
| 11 | + overload, |
| 12 | +) |
4 | 13 |
|
5 | 14 | import srsly
|
6 | 15 | from docling.datamodel.base_models import DocumentStream
|
|
18 | 27 | from pandas import DataFrame
|
19 | 28 | from spacy.language import Language
|
20 | 29 |
|
| 30 | +# Type variable for contexts piped with documents |
| 31 | +_AnyContext = TypeVar("_AnyContext") |
21 | 32 |
|
22 | 33 | TABLE_PLACEHOLDER = "TABLE"
|
23 | 34 | TABLE_ITEM_LABELS = [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX]
|
@@ -76,12 +87,42 @@ def __call__(self, source: str | Path | bytes | DoclingDocument) -> Doc:
|
76 | 87 | result = self.converter.convert(self._get_source(source)).document
|
77 | 88 | return self._result_to_doc(result)
|
78 | 89 |
|
79 |
| - def pipe(self, sources: Iterable[str | Path | bytes]) -> Iterator[Doc]: |
| 90 | + @overload |
| 91 | + def pipe( |
| 92 | + self, |
| 93 | + sources: Iterable[str | Path | bytes], |
| 94 | + as_tuples: Literal[False] = ..., |
| 95 | + ) -> Iterator[Doc]: ... |
| 96 | + |
| 97 | + @overload |
| 98 | + def pipe( |
| 99 | + self, |
| 100 | + sources: Iterable[tuple[str | Path | bytes, _AnyContext]], |
| 101 | + as_tuples: Literal[True] = ..., |
| 102 | + ) -> Iterator[tuple[Doc, _AnyContext]]: ... |
| 103 | + |
| 104 | + def pipe( |
| 105 | + self, |
| 106 | + sources: ( |
| 107 | + Iterable[str | Path | bytes] |
| 108 | + | Iterable[tuple[str | Path | bytes, _AnyContext]] |
| 109 | + ), |
| 110 | + as_tuples: bool = False, |
| 111 | + ) -> Iterator[Doc] | Iterator[tuple[Doc, _AnyContext]]: |
80 | 112 | """Process multiple documents and create spaCy Doc objects."""
|
81 |
| - data = (self._get_source(source) for source in sources) |
82 |
| - results = self.converter.convert_all(data) |
83 |
| - for result in results: |
84 |
| - yield self._result_to_doc(result.document) |
| 113 | + if as_tuples: |
| 114 | + sources = cast(Iterable[tuple[str | Path | bytes, _AnyContext]], sources) |
| 115 | + data = (self._get_source(source) for source, _ in sources) |
| 116 | + contexts = (context for _, context in sources) |
| 117 | + results = self.converter.convert_all(data) |
| 118 | + for result, context in zip(results, contexts): |
| 119 | + yield (self._result_to_doc(result.document), context) |
| 120 | + else: |
| 121 | + sources = cast(Iterable[str | Path | bytes], sources) |
| 122 | + data = (self._get_source(source) for source in sources) |
| 123 | + results = self.converter.convert_all(data) |
| 124 | + for result in results: |
| 125 | + yield self._result_to_doc(result.document) |
85 | 126 |
|
86 | 127 | def _get_source(self, source: str | Path | bytes) -> str | Path | DocumentStream:
|
87 | 128 | if isinstance(source, (str, Path)):
|
|
0 commit comments