Skip to content

Commit 905746f

Browse files
fix: cross account catalog_id glue client function calls (#370)
Co-authored-by: nicor88 <[email protected]>
1 parent 23478ef commit 905746f

File tree

2 files changed

+69
-10
lines changed

2 files changed

+69
-10
lines changed

dbt/adapters/athena/impl.py

+47-10
Original file line numberDiff line numberDiff line change
@@ -223,11 +223,15 @@ def get_glue_table(self, relation: AthenaRelation) -> Optional[GetTableResponseT
223223
"""
224224
conn = self.connections.get_thread_connection()
225225
client = conn.handle
226+
227+
data_catalog = self._get_data_catalog(relation.database)
228+
catalog_id = get_catalog_id(data_catalog)
229+
226230
with boto3_client_lock:
227231
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())
228232

229233
try:
230-
table = glue_client.get_table(DatabaseName=relation.schema, Name=relation.identifier)
234+
table = glue_client.get_table(CatalogId=catalog_id, DatabaseName=relation.schema, Name=relation.identifier)
231235
except ClientError as e:
232236
if e.response["Error"]["Code"] == "EntityNotFoundException":
233237
LOGGER.debug(f"Table {relation.render()} does not exists - Ignoring")
@@ -596,16 +600,25 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati
596600
conn = self.connections.get_thread_connection()
597601
client = conn.handle
598602

603+
data_catalog = self._get_data_catalog(src_relation.database)
604+
src_catalog_id = get_catalog_id(data_catalog)
605+
599606
with boto3_client_lock:
600607
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())
601608

602-
src_table = glue_client.get_table(DatabaseName=src_relation.schema, Name=src_relation.identifier).get("Table")
609+
src_table = glue_client.get_table(
610+
CatalogId=src_catalog_id, DatabaseName=src_relation.schema, Name=src_relation.identifier
611+
).get("Table")
612+
603613
src_table_partitions = glue_client.get_partitions(
604-
DatabaseName=src_relation.schema, TableName=src_relation.identifier
614+
CatalogId=src_catalog_id, DatabaseName=src_relation.schema, TableName=src_relation.identifier
605615
).get("Partitions")
606616

617+
data_catalog = self._get_data_catalog(src_relation.database)
618+
target_catalog_id = get_catalog_id(data_catalog)
619+
607620
target_table_partitions = glue_client.get_partitions(
608-
DatabaseName=target_relation.schema, TableName=target_relation.identifier
621+
CatalogId=target_catalog_id, DatabaseName=target_relation.schema, TableName=target_relation.identifier
609622
).get("Partitions")
610623

611624
target_table_version = {
@@ -618,7 +631,9 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati
618631
}
619632

620633
# perform a table swap
621-
glue_client.update_table(DatabaseName=target_relation.schema, TableInput=target_table_version)
634+
glue_client.update_table(
635+
CatalogId=target_catalog_id, DatabaseName=target_relation.schema, TableInput=target_table_version
636+
)
622637
LOGGER.debug(f"Table {target_relation.render()} swapped with the content of {src_relation.render()}")
623638

624639
# we delete the target table partitions in any case
@@ -627,6 +642,7 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati
627642
if target_table_partitions:
628643
for partition_batch in get_chunks(target_table_partitions, AthenaAdapter.BATCH_DELETE_PARTITION_API_LIMIT):
629644
glue_client.batch_delete_partition(
645+
CatalogId=target_catalog_id,
630646
DatabaseName=target_relation.schema,
631647
TableName=target_relation.identifier,
632648
PartitionsToDelete=[{"Values": partition["Values"]} for partition in partition_batch],
@@ -635,6 +651,7 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati
635651
if src_table_partitions:
636652
for partition_batch in get_chunks(src_table_partitions, AthenaAdapter.BATCH_CREATE_PARTITION_API_LIMIT):
637653
glue_client.batch_create_partition(
654+
CatalogId=target_catalog_id,
638655
DatabaseName=target_relation.schema,
639656
TableName=target_relation.identifier,
640657
PartitionInputList=[
@@ -676,6 +693,9 @@ def expire_glue_table_versions(
676693
conn = self.connections.get_thread_connection()
677694
client = conn.handle
678695

696+
data_catalog = self._get_data_catalog(relation.database)
697+
catalog_id = get_catalog_id(data_catalog)
698+
679699
with boto3_client_lock:
680700
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())
681701

@@ -688,7 +708,10 @@ def expire_glue_table_versions(
688708
location = v["Table"]["StorageDescriptor"]["Location"]
689709
try:
690710
glue_client.delete_table_version(
691-
DatabaseName=relation.schema, TableName=relation.identifier, VersionId=str(version)
711+
CatalogId=catalog_id,
712+
DatabaseName=relation.schema,
713+
TableName=relation.identifier,
714+
VersionId=str(version),
692715
)
693716
deleted_versions.append(version)
694717
LOGGER.debug(f"Deleted version {version} of table {relation.render()} ")
@@ -720,13 +743,16 @@ def persist_docs_to_glue(
720743
conn = self.connections.get_thread_connection()
721744
client = conn.handle
722745

746+
data_catalog = self._get_data_catalog(relation.database)
747+
catalog_id = get_catalog_id(data_catalog)
748+
723749
with boto3_client_lock:
724750
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())
725751

726752
# By default, there is no need to update Glue Table
727753
need_udpate_table = False
728754
# Get Table from Glue
729-
table = glue_client.get_table(DatabaseName=relation.schema, Name=relation.name)["Table"]
755+
table = glue_client.get_table(CatalogId=catalog_id, DatabaseName=relation.schema, Name=relation.name)["Table"]
730756
# Prepare new version of Glue Table picking up significant fields
731757
updated_table = self._get_table_input(table)
732758
# Update table description
@@ -766,7 +792,10 @@ def persist_docs_to_glue(
766792
# It prevents redundant schema version creating after incremental runs.
767793
if need_udpate_table:
768794
glue_client.update_table(
769-
DatabaseName=relation.schema, TableInput=updated_table, SkipArchive=skip_archive_table_version
795+
CatalogId=catalog_id,
796+
DatabaseName=relation.schema,
797+
TableInput=updated_table,
798+
SkipArchive=skip_archive_table_version,
770799
)
771800

772801
@available
@@ -797,11 +826,16 @@ def get_columns_in_relation(self, relation: AthenaRelation) -> List[AthenaColumn
797826
conn = self.connections.get_thread_connection()
798827
client = conn.handle
799828

829+
data_catalog = self._get_data_catalog(relation.database)
830+
catalog_id = get_catalog_id(data_catalog)
831+
800832
with boto3_client_lock:
801833
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())
802834

803835
try:
804-
table = glue_client.get_table(DatabaseName=relation.schema, Name=relation.identifier)["Table"]
836+
table = glue_client.get_table(CatalogId=catalog_id, DatabaseName=relation.schema, Name=relation.identifier)[
837+
"Table"
838+
]
805839
except ClientError as e:
806840
if e.response["Error"]["Code"] == "EntityNotFoundException":
807841
LOGGER.debug("table not exist, catching the error")
@@ -829,11 +863,14 @@ def delete_from_glue_catalog(self, relation: AthenaRelation) -> None:
829863
conn = self.connections.get_thread_connection()
830864
client = conn.handle
831865

866+
data_catalog = self._get_data_catalog(relation.database)
867+
catalog_id = get_catalog_id(data_catalog)
868+
832869
with boto3_client_lock:
833870
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())
834871

835872
try:
836-
glue_client.delete_table(DatabaseName=schema_name, Name=table_name)
873+
glue_client.delete_table(CatalogId=catalog_id, DatabaseName=schema_name, Name=table_name)
837874
LOGGER.debug(f"Deleted table from glue catalog: {relation.render()}")
838875
except ClientError as e:
839876
if e.response["Error"]["Code"] == "EntityNotFoundException":

tests/unit/test_adapter.py

+22
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ def test_generate_s3_location(
401401
@mock_glue
402402
@mock_s3
403403
@mock_athena
404+
@mock_sts
404405
def test_get_table_location(self, dbt_debug_caplog, mock_aws_service):
405406
table_name = "test_table"
406407
self.adapter.acquire_connection("dummy")
@@ -417,6 +418,7 @@ def test_get_table_location(self, dbt_debug_caplog, mock_aws_service):
417418
@mock_glue
418419
@mock_s3
419420
@mock_athena
421+
@mock_sts
420422
def test_get_table_location_raise_s3_location_exception(self, dbt_debug_caplog, mock_aws_service):
421423
table_name = "test_table"
422424
self.adapter.acquire_connection("dummy")
@@ -438,6 +440,7 @@ def test_get_table_location_raise_s3_location_exception(self, dbt_debug_caplog,
438440
@mock_glue
439441
@mock_s3
440442
@mock_athena
443+
@mock_sts
441444
def test_get_table_location_for_view(self, dbt_debug_caplog, mock_aws_service):
442445
view_name = "view"
443446
self.adapter.acquire_connection("dummy")
@@ -452,6 +455,7 @@ def test_get_table_location_for_view(self, dbt_debug_caplog, mock_aws_service):
452455
@mock_glue
453456
@mock_s3
454457
@mock_athena
458+
@mock_sts
455459
def test_get_table_location_with_failure(self, dbt_debug_caplog, mock_aws_service):
456460
table_name = "test_table"
457461
self.adapter.acquire_connection("dummy")
@@ -500,6 +504,7 @@ def test_clean_up_partitions_will_work(self, dbt_debug_caplog, mock_aws_service)
500504

501505
@mock_glue
502506
@mock_athena
507+
@mock_sts
503508
def test_clean_up_table_table_does_not_exist(self, dbt_debug_caplog, mock_aws_service):
504509
mock_aws_service.create_data_catalog()
505510
mock_aws_service.create_database()
@@ -517,6 +522,7 @@ def test_clean_up_table_table_does_not_exist(self, dbt_debug_caplog, mock_aws_se
517522

518523
@mock_glue
519524
@mock_athena
525+
@mock_sts
520526
def test_clean_up_table_view(self, dbt_debug_caplog, mock_aws_service):
521527
mock_aws_service.create_data_catalog()
522528
mock_aws_service.create_database()
@@ -534,6 +540,7 @@ def test_clean_up_table_view(self, dbt_debug_caplog, mock_aws_service):
534540
@mock_glue
535541
@mock_s3
536542
@mock_athena
543+
@mock_sts
537544
def test_clean_up_table_delete_table(self, dbt_debug_caplog, mock_aws_service):
538545
mock_aws_service.create_data_catalog()
539546
mock_aws_service.create_database()
@@ -844,6 +851,7 @@ def test_parse_s3_path(self, s3_path, expected):
844851
@mock_athena
845852
@mock_glue
846853
@mock_s3
854+
@mock_sts
847855
def test_swap_table_with_partitions(self, mock_aws_service):
848856
mock_aws_service.create_data_catalog()
849857
mock_aws_service.create_database()
@@ -870,6 +878,7 @@ def test_swap_table_with_partitions(self, mock_aws_service):
870878
@mock_athena
871879
@mock_glue
872880
@mock_s3
881+
@mock_sts
873882
def test_swap_table_without_partitions(self, mock_aws_service):
874883
mock_aws_service.create_data_catalog()
875884
mock_aws_service.create_database()
@@ -894,6 +903,7 @@ def test_swap_table_without_partitions(self, mock_aws_service):
894903
@mock_athena
895904
@mock_glue
896905
@mock_s3
906+
@mock_sts
897907
def test_swap_table_with_partitions_to_one_without(self, mock_aws_service):
898908
mock_aws_service.create_data_catalog()
899909
mock_aws_service.create_database()
@@ -931,6 +941,7 @@ def test_swap_table_with_partitions_to_one_without(self, mock_aws_service):
931941
@mock_athena
932942
@mock_glue
933943
@mock_s3
944+
@mock_sts
934945
def test_swap_table_with_no_partitions_to_one_with(self, mock_aws_service):
935946
mock_aws_service.create_data_catalog()
936947
mock_aws_service.create_database()
@@ -990,6 +1001,7 @@ def test__get_glue_table_versions_to_expire(self, mock_aws_service, dbt_debug_ca
9901001
@mock_athena
9911002
@mock_glue
9921003
@mock_s3
1004+
@mock_sts
9931005
def test_expire_glue_table_versions(self, mock_aws_service):
9941006
mock_aws_service.create_data_catalog()
9951007
mock_aws_service.create_database()
@@ -1101,6 +1113,7 @@ def test_get_work_group_output_location_not_enforced(self, mock_aws_service):
11011113
@mock_athena
11021114
@mock_glue
11031115
@mock_s3
1116+
@mock_sts
11041117
def test_persist_docs_to_glue_no_comment(self, mock_aws_service):
11051118
mock_aws_service.create_data_catalog()
11061119
mock_aws_service.create_database()
@@ -1142,6 +1155,7 @@ def test_persist_docs_to_glue_no_comment(self, mock_aws_service):
11421155
@mock_athena
11431156
@mock_glue
11441157
@mock_s3
1158+
@mock_sts
11451159
def test_persist_docs_to_glue_comment(self, mock_aws_service):
11461160
mock_aws_service.create_data_catalog()
11471161
mock_aws_service.create_database()
@@ -1194,6 +1208,7 @@ def test_list_schemas(self, mock_aws_service):
11941208

11951209
@mock_athena
11961210
@mock_glue
1211+
@mock_sts
11971212
def test_get_columns_in_relation(self, mock_aws_service):
11981213
mock_aws_service.create_data_catalog()
11991214
mock_aws_service.create_database()
@@ -1214,6 +1229,7 @@ def test_get_columns_in_relation(self, mock_aws_service):
12141229

12151230
@mock_athena
12161231
@mock_glue
1232+
@mock_sts
12171233
def test_get_columns_in_relation_not_found_table(self, mock_aws_service):
12181234
mock_aws_service.create_data_catalog()
12191235
mock_aws_service.create_database()
@@ -1229,6 +1245,7 @@ def test_get_columns_in_relation_not_found_table(self, mock_aws_service):
12291245

12301246
@mock_athena
12311247
@mock_glue
1248+
@mock_sts
12321249
def test_delete_from_glue_catalog(self, mock_aws_service):
12331250
mock_aws_service.create_data_catalog()
12341251
mock_aws_service.create_database()
@@ -1242,6 +1259,7 @@ def test_delete_from_glue_catalog(self, mock_aws_service):
12421259

12431260
@mock_athena
12441261
@mock_glue
1262+
@mock_sts
12451263
def test_delete_from_glue_catalog_not_found_table(self, dbt_debug_caplog, mock_aws_service):
12461264
mock_aws_service.create_data_catalog()
12471265
mock_aws_service.create_database()
@@ -1258,6 +1276,7 @@ def test_delete_from_glue_catalog_not_found_table(self, dbt_debug_caplog, mock_a
12581276
@mock_glue
12591277
@mock_s3
12601278
@mock_athena
1279+
@mock_sts
12611280
def test__get_relation_type_table(self, dbt_debug_caplog, mock_aws_service):
12621281
mock_aws_service.create_data_catalog()
12631282
mock_aws_service.create_database()
@@ -1272,6 +1291,7 @@ def test__get_relation_type_table(self, dbt_debug_caplog, mock_aws_service):
12721291
@mock_glue
12731292
@mock_s3
12741293
@mock_athena
1294+
@mock_sts
12751295
def test__get_relation_type_with_no_type(self, dbt_debug_caplog, mock_aws_service):
12761296
mock_aws_service.create_data_catalog()
12771297
mock_aws_service.create_database()
@@ -1286,6 +1306,7 @@ def test__get_relation_type_with_no_type(self, dbt_debug_caplog, mock_aws_servic
12861306
@mock_glue
12871307
@mock_s3
12881308
@mock_athena
1309+
@mock_sts
12891310
def test__get_relation_type_view(self, dbt_debug_caplog, mock_aws_service):
12901311
mock_aws_service.create_data_catalog()
12911312
mock_aws_service.create_database()
@@ -1300,6 +1321,7 @@ def test__get_relation_type_view(self, dbt_debug_caplog, mock_aws_service):
13001321
@mock_glue
13011322
@mock_s3
13021323
@mock_athena
1324+
@mock_sts
13031325
def test__get_relation_type_iceberg(self, dbt_debug_caplog, mock_aws_service):
13041326
mock_aws_service.create_data_catalog()
13051327
mock_aws_service.create_database()

0 commit comments

Comments
 (0)