Description
Is your feature request related to a problem? Please describe.
I think a big issue in using data lakes is not being able to ingest data without worrying about uniqueness of the data. Using deeply nested partitions will slow down querying the data, especially if tables would only consist of one record. There are heavy duty solutions to tackle this problem, but they involve JVM like solutions with much overhead/costs, for example Apache Hudi. For smaller cloud users, who like the ability to scale down to 0, this is not a feasible approach.
Describe the solution you'd like
As a daily user of both awswrangler and apache arrow, I wrote the following a pip package to deal with some of my troubles. As a part of the package, I am building a class which inherits the pq.ParquetData, which allows you to do upserts / removals of data (see the code below).
The code currently treats partitions in isolation, thus it splits any upserts into partitioned tables and then reads + upserts + saves the partition. I am still adding the following functionality:
- Add created_time & removed_time instead of removing records to allow time shift in reads
- Add sharding of files (hash bucketed files by unique constraint) to allow for easier upserts
- Add some blocking mechanism to avoid racing condition on same partition/shard
I would like to know if there is interest in integrating such functionality in awswrangler, as it has convenient methods for syncing with Glue. I saw there was already some work in a private file: (_merge_upsert_table.py). However this would not work for tables containing TBs of data. I can help write the pull request, if necessary.
import os, time, s3fs
import pyarrow as pa
import pyarrow.parquet as pq
from ops import drop_duplicates, head, split
class ParquetUniqueDataset(pq.ParquetDataset):
def __init__(self, *args, **kwargs):
self.args, self.kwargs = args, kwargs
# super().__init__(*args, **kwargs)
self.path = args[0]
self.tables = {}
self.load()
def load(self, verbose=False):
# Initiate the parent class
super().__init__(*self.args, **self.kwargs)
self.meta = self.pieces[0].get_metadata()
self.columns = [c['path_in_schema'] for c in self.meta.row_group(0).to_dict()['columns']]
# Partition information
self.partition_cols = [c[0] for c in self.pieces[0].partition_keys]
self.partitions_ = [p.partition_keys for p in self.pieces]
self.partitions_val = [tuple(v[1] for v in p) for p in self.partitions_]
if verbose:
print("Loaded all the data:", [p.path for p in self.pieces])
print("Column names:", self.columns)
def set_unique(self, columns):
self.unique_cols = [u for u in columns if u not in self.partition_cols]
return self
def partition_dict(self, partition_val):
return dict(zip(self.partition_cols, list(partition_val)))
def get_path(self, partition_val, name):
partition = self.partition_dict(partition_val)
return self.path + '/' + '/'.join(str(k) + '=' + str(v) for k, v in partition.items()) + '/' + name + '.parquet'
def get_idxs(self, partition_val):
return [idx for idx, p in enumerate(self.partitions_val) if p == partition_val]
# Cleaning tables
def concat(self, tables):
return pa.concat_tables([t.select(self.columns) for t in tables])
def sanitize(self, table):
# TODO: Add casting to default schema of class
return table.select(self.partition_cols + self.columns)
def cleanup(self):
for p in set(self.partitions_val):
print("Cleaning up:", p)
table = self.read_parts(p)
table_dedup = drop_duplicates(table, on=self.unique_cols, keep='last')
self.save(self.deduplicate(table), p)
# Reading / writing tables
def read_parts(self, partition_val=None):
# See what pieces we need to load
idxs = (self.get_idxs(partition_val) if partition_val else range(len(self.pieces)))
for i in idxs:
if self.pieces[i].path not in self.tables.keys():
print("Reading {} as it is not in cache".format(self.pieces[i].path))
self.tables[self.pieces[i].path] = self.pieces[i].read(columns=self.columns, partitions=self.partitions)
return self.concat([self.tables[self.pieces[i].path] for i in idxs])
def save(self, table, partition_val):
paths_old = [self.pieces[i].path for i in self.get_idxs(partition_val)]
paths_new = [self.get_path(partition_val, 'file0')]
# Delete old which are not written in new
for path in [p for p in paths_old if p not in paths_new]:
os.remove(path)
# Write new table
pq.write_table(table.select(self.columns), paths_new[0])
# Add table to caching
self.tables[paths_new[0]] = table
# Reload the Dataset if file names have changed
if set(paths_old) != set(paths_new):
return True
else:
return False
# Upsertion
def upsert_part(self, table, partition_val, partition_idxs, keep):
# If partition exists, gather original data, before deduplication. Else use new table
rows_b4 = None
if partition_val in self.partitions_val:
table_part = self.read_parts(partition_val)
rows_b4 = table_part.num_rows
table_new = self.concat([table_part, table.take(partition_idxs)])
else:
table_new = table.take(partition_idxs)
table_dedup = drop_duplicates(table_new, on=self.unique_cols, keep=keep)
print("Upserting data for partition {0}. Added {1} unique records".format(partition_val, table_dedup.num_rows - rows_b4))
return self.save(table_dedup, partition_val)
def upsert(self, table, keep='last'):
table = self.sanitize(table)
# bools = self.pool.map(lambda p: self.upsert_part(*p), [(table, val, idxs, keep) for val, idxs in split(table=table, columns=self.partition_cols)])
bools = [self.upsert_part(table, val, idxs, keep) for val, idxs in split(table=table, columns=self.partition_cols)]
if max(bools):
print("Reloading dataset!")
self.load()
# Deletion by table
def delete_part(self, table, partition_val, partition_idxs):
if partition_val in self.partitions_val:
table_part = self.read_parts(partition_val)
table_new = self.concat([table_part, table.take(partition_idxs)])
table_dedup = drop_duplicates(table_new, on=self.unique_cols, keep='drop')
print("Removing data for partition {0}. Removed {1} unique records".format(partition_val, table_part.num_rows - table_dedup.num_rows))
return self.save(table_dedup, partition_val)
else:
print("There does not data for partition:", self.partition_dict(partition_val))
return False
def delete(self, table):
table = self.sanitize(table)
# bools = self.pool.map(lambda p: self.delete_part(*p), [(table, val, idxs) for val, idxs in split(table=table, columns=self.partition_cols)])
bools = [self.delete_part(table, val, idxs) for val, idxs in split(table=table, columns=self.partition_cols)]
if max(bools):
self.load()