Skip to content

fix: docs generate support federated query catalogs #324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 42 additions & 22 deletions dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,15 @@
get_table_type,
)
from dbt.adapters.athena.s3 import S3DataNaming
from dbt.adapters.athena.utils import clean_sql_comment, get_catalog_id, get_chunks
from dbt.adapters.athena.utils import (
AthenaCatalogType,
clean_sql_comment,
get_catalog_id,
get_catalog_type,
get_chunks,
)
from dbt.adapters.base import ConstraintSupport, available
from dbt.adapters.base.impl import GET_CATALOG_MACRO_NAME
from dbt.adapters.base.relation import BaseRelation, InformationSchema
from dbt.adapters.sql import SQLAdapter
from dbt.contracts.graph.manifest import Manifest
Expand Down Expand Up @@ -415,29 +422,42 @@ def _get_one_catalog(
manifest: Manifest,
) -> agate.Table:
data_catalog = self._get_data_catalog(information_schema.path.database)
catalog_id = get_catalog_id(data_catalog)
conn = self.connections.get_thread_connection()
client = conn.handle
with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())

catalog = []
paginator = glue_client.get_paginator("get_tables")
for schema, relations in schemas.items():
kwargs = {
"DatabaseName": schema,
"MaxResults": 100,
}
# If the catalog is `awsdatacatalog` we don't need to pass CatalogId as boto3 infers it from the account Id.
if catalog_id:
kwargs["CatalogId"] = catalog_id
data_catalog_type = get_catalog_type(data_catalog)

for page in paginator.paginate(**kwargs):
for table in page["TableList"]:
if relations and table["Name"] in relations:
catalog.extend(self._get_one_table_for_catalog(table, information_schema.path.database))
if data_catalog_type == AthenaCatalogType.GLUE:
conn = self.connections.get_thread_connection()
client = conn.handle
with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())

catalog = []
paginator = glue_client.get_paginator("get_tables")
for schema, relations in schemas.items():
kwargs = {
"DatabaseName": schema,
"MaxResults": 100,
}
# If the catalog is `awsdatacatalog` we don't need to pass CatalogId as boto3
# infers it from the account Id.
catalog_id = get_catalog_id(data_catalog)
if catalog_id:
kwargs["CatalogId"] = catalog_id

for page in paginator.paginate(**kwargs):
for table in page["TableList"]:
if relations and table["Name"] in relations:
catalog.extend(self._get_one_table_for_catalog(table, information_schema.path.database))
table = agate.Table.from_object(catalog)
elif data_catalog_type == AthenaCatalogType.LAMBDA:
kwargs = {"information_schema": information_schema, "schemas": schemas}
table = self.execute_macro(
GET_CATALOG_MACRO_NAME,
kwargs=kwargs,
manifest=manifest,
)
else:
raise NotImplementedError(f"Type of catalog {data_catalog_type} not supported.")

table = agate.Table.from_object(catalog)
filtered_table = self._catalog_filter_table(table, manifest)
return self._join_catalog_table_owners(filtered_table, manifest)

Expand Down
13 changes: 12 additions & 1 deletion dbt/adapters/athena/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from typing import Generator, List, Optional, TypeVar

from mypy_boto3_athena.type_defs import DataCatalogTypeDef
Expand All @@ -9,7 +10,17 @@ def clean_sql_comment(comment: str) -> str:


def get_catalog_id(catalog: Optional[DataCatalogTypeDef]) -> Optional[str]:
return catalog["Parameters"]["catalog-id"] if catalog else None
return catalog["Parameters"]["catalog-id"] if catalog and catalog["Type"] == AthenaCatalogType.GLUE.value else None


class AthenaCatalogType(Enum):
GLUE = "GLUE"
LAMBDA = "LAMBDA"
HIVE = "HIVE"


def get_catalog_type(catalog: Optional[DataCatalogTypeDef]) -> Optional[AthenaCatalogType]:
return AthenaCatalogType(catalog["Type"]) if catalog else None


T = TypeVar("T")
Expand Down
87 changes: 85 additions & 2 deletions dbt/include/athena/macros/adapters/metadata.sql
Original file line number Diff line number Diff line change
@@ -1,6 +1,89 @@
{% macro athena__get_catalog(information_schema, schemas) -%}
{{ return(adapter.get_catalog()) }}
{%- endmacro %}
{%- set query -%}
select * from (
(
with tables as (

select
tables.table_catalog as table_database,
tables.table_schema as table_schema,
tables.table_name as table_name,

case
when views.table_name is not null
then 'view'
when table_type = 'BASE TABLE'
then 'table'
else table_type
end as table_type,

null as table_comment

from {{ information_schema }}.tables
left join {{ information_schema }}.views
on tables.table_catalog = views.table_catalog
and tables.table_schema = views.table_schema
and tables.table_name = views.table_name

),

columns as (

select
table_catalog as table_database,
table_schema as table_schema,
table_name as table_name,
column_name as column_name,
ordinal_position as column_index,
data_type as column_type,
comment as column_comment

from {{ information_schema }}.columns

),

catalog as (

select
tables.table_database,
tables.table_schema,
tables.table_name,
tables.table_type,
tables.table_comment,
columns.column_name,
columns.column_index,
columns.column_type,
columns.column_comment

from tables
join columns
on tables."table_database" = columns."table_database"
and tables."table_schema" = columns."table_schema"
and tables."table_name" = columns."table_name"

)

{%- for schema, relations in schemas.items() -%}
{%- for relation_batch in relations|batch(100) %}
select * from catalog
where "table_schema" = lower('{{ schema }}')
and (
{%- for relation in relation_batch -%}
"table_name" = lower('{{ relation }}')
{%- if not loop.last %} or {% endif -%}
{%- endfor -%}
)

{%- if not loop.last %} union all {% endif -%}
{%- endfor -%}

{%- if not loop.last %} union all {% endif -%}
{%- endfor -%}
)
)
{%- endset -%}
{{ return(run_query(query)) }}
{% endmacro -%}


{% macro athena__list_schemas(database) -%}
Expand Down
1 change: 1 addition & 0 deletions tests/unit/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
CATALOG_ID = "12345678910"
DATA_CATALOG_NAME = "awsdatacatalog"
SHARED_DATA_CATALOG_NAME = "9876543210"
FEDERATED_QUERY_CATALOG_NAME = "federated_query_data_source"
DATABASE_NAME = "test_dbt_athena"
BUCKET = "test-dbt-athena"
AWS_REGION = "eu-west-1"
Expand Down
127 changes: 125 additions & 2 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dbt.adapters.athena.connections import AthenaCursor, AthenaParameterFormatter
from dbt.adapters.athena.exceptions import S3LocationException
from dbt.adapters.athena.relation import AthenaRelation, TableType
from dbt.adapters.athena.utils import AthenaCatalogType
from dbt.clients import agate_helper
from dbt.contracts.connection import ConnectionState
from dbt.contracts.files import FileHash
Expand All @@ -28,6 +29,7 @@
BUCKET,
DATA_CATALOG_NAME,
DATABASE_NAME,
FEDERATED_QUERY_CATALOG_NAME,
S3_STAGING_DIR,
SHARED_DATA_CATALOG_NAME,
)
Expand Down Expand Up @@ -66,6 +68,7 @@ def setup_method(self, _):
("awsdatacatalog", "quux"),
("awsdatacatalog", "baz"),
(SHARED_DATA_CATALOG_NAME, "foo"),
(FEDERATED_QUERY_CATALOG_NAME, "foo"),
}
self.mock_manifest.nodes = {
"model.root.model1": CompiledNode(
Expand Down Expand Up @@ -212,6 +215,42 @@ def setup_method(self, _):
raw_code="select * from source_table",
language="",
),
"model.root.model5": CompiledNode(
name="model5",
database=FEDERATED_QUERY_CATALOG_NAME,
schema="foo",
resource_type=NodeType.Model,
unique_id="model.root.model5",
alias="bar",
fqn=["root", "model5"],
package_name="root",
refs=[],
sources=[],
depends_on=DependsOn(),
config=NodeConfig.from_dict(
{
"enabled": True,
"materialized": "table",
"persist_docs": {},
"post-hook": [],
"pre-hook": [],
"vars": {},
"meta": {"owner": "data-engineers"},
"quoting": {},
"column_types": {},
"tags": [],
}
),
tags=[],
path="model5.sql",
original_file_path="model5.sql",
compiled=True,
extra_ctes_injected=False,
extra_ctes=[],
checksum=FileHash.from_contents(""),
raw_code="select * from source_table",
language="",
),
}

@property
Expand Down Expand Up @@ -612,9 +651,85 @@ def test__get_one_catalog_shared_catalog(self, mock_aws_service):
for row in actual.rows.values():
assert row.values() in expected_rows

@mock_athena
@mock.patch.object(AthenaAdapter, "execute_macro")
def test__get_one_catalog_federated_query_catalog(self, mock_execute, mock_aws_service):
column_names = (
"table_database",
"table_schema",
"table_name",
"table_type",
"table_comment",
"column_name",
"column_index",
"column_type",
"column_comment",
)

rows = [
(FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "id", 0, "string", None),
(FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "country", 1, "string", None),
(FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "dt", 2, "date", None),
]

mock_execute.return_value = agate.Table(rows=rows, column_names=column_names)
mock_aws_service.create_data_catalog(
catalog_name=FEDERATED_QUERY_CATALOG_NAME, catalog_type=AthenaCatalogType.LAMBDA
)
mock_information_schema = mock.MagicMock()
mock_information_schema.path.database = FEDERATED_QUERY_CATALOG_NAME

self.adapter.acquire_connection("dummy")
actual = self.adapter._get_one_catalog(
mock_information_schema,
mock.MagicMock(),
self.mock_manifest,
)

expected_column_names = (
"table_database",
"table_schema",
"table_name",
"table_type",
"table_comment",
"column_name",
"column_index",
"column_type",
"column_comment",
"table_owner",
)
expected_rows = [
(FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "id", 0, "string", None, "data-engineers"),
(FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "country", 1, "string", None, "data-engineers"),
(FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "dt", 2, "date", None, "data-engineers"),
]

assert actual.column_names == expected_column_names
assert len(actual.rows) == len(expected_rows)
for row in actual.rows.values():
assert row.values() in expected_rows

@mock_athena
def test__get_one_catalog_unsupported_type(self, mock_aws_service):
catalog_name = "example_hive_catalog"
catalog_type = AthenaCatalogType.HIVE
mock_aws_service.create_data_catalog(catalog_name=catalog_name, catalog_type=catalog_type)
mock_information_schema = mock.MagicMock()
mock_information_schema.path.database = catalog_name

self.adapter.acquire_connection("dummy")

with pytest.raises(NotImplementedError) as exc:
self.adapter._get_one_catalog(
mock_information_schema,
mock.MagicMock(),
self.mock_manifest,
)
assert exc.message == f"Type of catalog {catalog_type.value} not supported: {catalog_name}"

def test__get_catalog_schemas(self):
res = self.adapter._get_catalog_schemas(self.mock_manifest)
assert len(res.keys()) == 2
assert len(res.keys()) == 3

information_schema_0 = list(res.keys())[0]
assert information_schema_0.name == "INFORMATION_SCHEMA"
Expand All @@ -632,6 +747,14 @@ def test__get_catalog_schemas(self):
assert set(relations.keys()) == {"foo"}
assert list(relations.values()) == [{"bar"}]

information_schema_1 = list(res.keys())[2]
assert information_schema_1.name == "INFORMATION_SCHEMA"
assert information_schema_1.schema is None
assert information_schema_1.database == FEDERATED_QUERY_CATALOG_NAME
relations = list(res.values())[1]
assert set(relations.keys()) == {"foo"}
assert list(relations.values()) == [{"bar"}]

@mock_athena
@mock_sts
def test__get_data_catalog(self, mock_aws_service):
Expand Down Expand Up @@ -696,7 +819,7 @@ def test_list_relations_without_caching_with_non_glue_data_catalog(
self, parent_list_relations_without_caching, mock_aws_service
):
data_catalog_name = "other_data_catalog"
mock_aws_service.create_data_catalog(data_catalog_name, "HIVE")
mock_aws_service.create_data_catalog(data_catalog_name, AthenaCatalogType.HIVE)
schema_relation = self.adapter.Relation.create(
database=data_catalog_name,
schema=DATABASE_NAME,
Expand Down
Loading