@@ -279,7 +279,39 @@ def _arrowtable2df(
279
279
return df
280
280
281
281
282
- def _read_parquet_chunked (
282
+ def _pyarrow_chunk_generator (
283
+ pq_file : pyarrow .parquet .ParquetFile ,
284
+ chunked : Union [bool , int ],
285
+ columns : Optional [List [str ]],
286
+ use_threads_flag : bool ,
287
+ ) -> Iterator [pa .RecordBatch ]:
288
+ if chunked is True :
289
+ batch_size = 65_536
290
+ elif isinstance (chunked , int ) and chunked > 0 :
291
+ batch_size = chunked
292
+ else :
293
+ raise exceptions .InvalidArgument (f"chunked: { chunked } " )
294
+
295
+ chunks = pq_file .iter_batches (
296
+ batch_size = batch_size , columns = columns , use_threads = use_threads_flag , use_pandas_metadata = False
297
+ )
298
+
299
+ for chunk in chunks :
300
+ yield chunk
301
+
302
+
303
+ def _row_group_chunk_generator (
304
+ pq_file : pyarrow .parquet .ParquetFile ,
305
+ columns : Optional [List [str ]],
306
+ use_threads_flag : bool ,
307
+ num_row_groups : int ,
308
+ ) -> Iterator [pa .Table ]:
309
+ for i in range (num_row_groups ):
310
+ _logger .debug ("Reading Row Group %s..." , i )
311
+ yield pq_file .read_row_group (i = i , columns = columns , use_threads = use_threads_flag , use_pandas_metadata = False )
312
+
313
+
314
+ def _read_parquet_chunked ( # pylint: disable=too-many-branches
283
315
paths : List [str ],
284
316
chunked : Union [bool , int ],
285
317
validate_schema : bool ,
@@ -293,7 +325,7 @@ def _read_parquet_chunked(
293
325
path_root : Optional [str ],
294
326
s3_additional_kwargs : Optional [Dict [str , str ]],
295
327
use_threads : Union [bool , int ],
296
- ) -> Iterator [pd .DataFrame ]: # pylint: disable=too-many-branches
328
+ ) -> Iterator [pd .DataFrame ]:
297
329
next_slice : Optional [pd .DataFrame ] = None
298
330
last_schema : Optional [Dict [str , str ]] = None
299
331
last_path : str = ""
@@ -327,12 +359,19 @@ def _read_parquet_chunked(
327
359
num_row_groups : int = pq_file .num_row_groups
328
360
_logger .debug ("num_row_groups: %s" , num_row_groups )
329
361
use_threads_flag : bool = use_threads if isinstance (use_threads , bool ) else bool (use_threads > 1 )
330
- for i in range (num_row_groups ):
331
- _logger .debug ("Reading Row Group %s..." , i )
362
+ # iter_batches is only available for pyarrow >= 3.0.0
363
+ if callable (getattr (pq_file , "iter_batches" , None )):
364
+ chunk_generator = _pyarrow_chunk_generator (
365
+ pq_file = pq_file , chunked = chunked , columns = columns , use_threads_flag = use_threads_flag
366
+ )
367
+ else :
368
+ chunk_generator = _row_group_chunk_generator (
369
+ pq_file = pq_file , columns = columns , use_threads_flag = use_threads_flag , num_row_groups = num_row_groups
370
+ )
371
+
372
+ for chunk in chunk_generator :
332
373
df : pd .DataFrame = _arrowtable2df (
333
- table = pq_file .read_row_group (
334
- i = i , columns = columns , use_threads = use_threads_flag , use_pandas_metadata = False
335
- ),
374
+ table = chunk ,
336
375
categories = categories ,
337
376
safe = safe ,
338
377
map_types = map_types ,
0 commit comments