Skip to content

Commit b1c79e0

Browse files
authored
Make use of pyarrow iter_batches (#661)
* Make use of pyarrow iter_batches
1 parent 7b9bc57 commit b1c79e0

File tree

1 file changed

+46
-7
lines changed

1 file changed

+46
-7
lines changed

awswrangler/s3/_read_parquet.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,39 @@ def _arrowtable2df(
279279
return df
280280

281281

282-
def _read_parquet_chunked(
282+
def _pyarrow_chunk_generator(
283+
pq_file: pyarrow.parquet.ParquetFile,
284+
chunked: Union[bool, int],
285+
columns: Optional[List[str]],
286+
use_threads_flag: bool,
287+
) -> Iterator[pa.RecordBatch]:
288+
if chunked is True:
289+
batch_size = 65_536
290+
elif isinstance(chunked, int) and chunked > 0:
291+
batch_size = chunked
292+
else:
293+
raise exceptions.InvalidArgument(f"chunked: {chunked}")
294+
295+
chunks = pq_file.iter_batches(
296+
batch_size=batch_size, columns=columns, use_threads=use_threads_flag, use_pandas_metadata=False
297+
)
298+
299+
for chunk in chunks:
300+
yield chunk
301+
302+
303+
def _row_group_chunk_generator(
304+
pq_file: pyarrow.parquet.ParquetFile,
305+
columns: Optional[List[str]],
306+
use_threads_flag: bool,
307+
num_row_groups: int,
308+
) -> Iterator[pa.Table]:
309+
for i in range(num_row_groups):
310+
_logger.debug("Reading Row Group %s...", i)
311+
yield pq_file.read_row_group(i=i, columns=columns, use_threads=use_threads_flag, use_pandas_metadata=False)
312+
313+
314+
def _read_parquet_chunked( # pylint: disable=too-many-branches
283315
paths: List[str],
284316
chunked: Union[bool, int],
285317
validate_schema: bool,
@@ -293,7 +325,7 @@ def _read_parquet_chunked(
293325
path_root: Optional[str],
294326
s3_additional_kwargs: Optional[Dict[str, str]],
295327
use_threads: Union[bool, int],
296-
) -> Iterator[pd.DataFrame]: # pylint: disable=too-many-branches
328+
) -> Iterator[pd.DataFrame]:
297329
next_slice: Optional[pd.DataFrame] = None
298330
last_schema: Optional[Dict[str, str]] = None
299331
last_path: str = ""
@@ -327,12 +359,19 @@ def _read_parquet_chunked(
327359
num_row_groups: int = pq_file.num_row_groups
328360
_logger.debug("num_row_groups: %s", num_row_groups)
329361
use_threads_flag: bool = use_threads if isinstance(use_threads, bool) else bool(use_threads > 1)
330-
for i in range(num_row_groups):
331-
_logger.debug("Reading Row Group %s...", i)
362+
# iter_batches is only available for pyarrow >= 3.0.0
363+
if callable(getattr(pq_file, "iter_batches", None)):
364+
chunk_generator = _pyarrow_chunk_generator(
365+
pq_file=pq_file, chunked=chunked, columns=columns, use_threads_flag=use_threads_flag
366+
)
367+
else:
368+
chunk_generator = _row_group_chunk_generator(
369+
pq_file=pq_file, columns=columns, use_threads_flag=use_threads_flag, num_row_groups=num_row_groups
370+
)
371+
372+
for chunk in chunk_generator:
332373
df: pd.DataFrame = _arrowtable2df(
333-
table=pq_file.read_row_group(
334-
i=i, columns=columns, use_threads=use_threads_flag, use_pandas_metadata=False
335-
),
374+
table=chunk,
336375
categories=categories,
337376
safe=safe,
338377
map_types=map_types,

0 commit comments

Comments
 (0)