Skip to content

Commit 706e5d3

Browse files
Address comments
1 parent af0ad7d commit 706e5d3

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

dbt/adapters/athena/impl.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import posixpath as path
44
import tempfile
5+
from enum import Enum
56
from itertools import chain
67
from textwrap import dedent
78
from threading import Lock
@@ -53,6 +54,12 @@
5354
boto3_client_lock = Lock()
5455

5556

57+
class AthenaCatalogType(Enum):
58+
GLUE = "GLUE"
59+
LAMBDA = "LAMBDA"
60+
HIVE = "HIVE"
61+
62+
5663
class AthenaAdapter(SQLAdapter):
5764
BATCH_CREATE_PARTITION_API_LIMIT = 100
5865
BATCH_DELETE_PARTITION_API_LIMIT = 25
@@ -416,9 +423,8 @@ def _get_one_catalog(
416423
manifest: Manifest,
417424
) -> agate.Table:
418425
data_catalog = self._get_data_catalog(information_schema.path.database)
419-
catalog_id = get_catalog_id(data_catalog)
420426

421-
if data_catalog["Type"] == "GLUE":
427+
if data_catalog["Type"] == AthenaCatalogType.GLUE.value:
422428
conn = self.connections.get_thread_connection()
423429
client = conn.handle
424430
with boto3_client_lock:
@@ -433,6 +439,7 @@ def _get_one_catalog(
433439
}
434440
# If the catalog is `awsdatacatalog` we don't need to pass CatalogId as boto3
435441
# infers it from the account Id.
442+
catalog_id = get_catalog_id(data_catalog)
436443
if catalog_id:
437444
kwargs["CatalogId"] = catalog_id
438445

@@ -441,13 +448,15 @@ def _get_one_catalog(
441448
if relations and table["Name"] in relations:
442449
catalog.extend(self._get_one_table_for_catalog(table, information_schema.path.database))
443450
table = agate.Table.from_object(catalog)
444-
else:
451+
elif data_catalog["Type"] == AthenaCatalogType.LAMBDA.value:
445452
kwargs = {"information_schema": information_schema, "schemas": schemas}
446453
table = self.execute_macro(
447454
GET_CATALOG_MACRO_NAME,
448455
kwargs=kwargs,
449456
manifest=manifest,
450457
)
458+
else:
459+
raise NotImplementedError(f"Type of catalog {data_catalog['Type']} not supported: {data_catalog['Name']}")
451460

452461
filtered_table = self._catalog_filter_table(table, manifest)
453462
return self._join_catalog_table_owners(filtered_table, manifest)

tests/unit/test_adapter.py

+18
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,24 @@ def test__get_one_catalog_federated_query_catalog(self, mock_execute, mock_aws_s
706706
for row in actual.rows.values():
707707
assert row.values() in expected_rows
708708

709+
@mock_athena
710+
def test__get_one_catalog_unsupported_type(self, mock_aws_service):
711+
catalog_name = "example_hive_catalog"
712+
catalog_type = "HIVE"
713+
mock_aws_service.create_data_catalog(catalog_name=catalog_name, catalog_type=catalog_type)
714+
mock_information_schema = mock.MagicMock()
715+
mock_information_schema.path.database = catalog_name
716+
717+
self.adapter.acquire_connection("dummy")
718+
719+
with pytest.raises(NotImplementedError) as exc:
720+
self.adapter._get_one_catalog(
721+
mock_information_schema,
722+
mock.MagicMock(),
723+
self.mock_manifest,
724+
)
725+
assert exc.message == f"Type of catalog {catalog_type} not supported: {catalog_name}"
726+
709727
def test__get_catalog_schemas(self):
710728
res = self.adapter._get_catalog_schemas(self.mock_manifest)
711729
assert len(res.keys()) == 3

0 commit comments

Comments
 (0)