Skip to content

Commit 169226e

Browse files
baohe-zhanghashhar
authored andcommitted
Support SET SESSION AUTHORIZATION on trino-python-client
1 parent 856d8e9 commit 169226e

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

tests/unit/test_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def test_request_headers(mock_get_and_post):
9090
catalog = "test_catalog"
9191
schema = "test_schema"
9292
user = "test_user"
93+
authorization_user = "test_authorization_user"
9394
source = "test_source"
9495
timezone = "Europe/Brussels"
9596
accept_encoding_header = "accept-encoding"
@@ -103,6 +104,7 @@ def test_request_headers(mock_get_and_post):
103104
port=8080,
104105
client_session=ClientSession(
105106
user=user,
107+
authorization_user=authorization_user,
106108
source=source,
107109
catalog=catalog,
108110
schema=schema,
@@ -127,6 +129,7 @@ def assert_headers(headers):
127129
assert headers[constants.HEADER_SCHEMA] == schema
128130
assert headers[constants.HEADER_SOURCE] == source
129131
assert headers[constants.HEADER_USER] == user
132+
assert headers[constants.HEADER_AUTHORIZATION_USER] == authorization_user
130133
assert headers[constants.HEADER_SESSION] == ""
131134
assert headers[constants.HEADER_TRANSACTION] is None
132135
assert headers[constants.HEADER_TIMEZONE] == timezone
@@ -140,7 +143,7 @@ def assert_headers(headers):
140143
"catalog2=" + urllib.parse.quote("ROLE{catalog2_role}")
141144
)
142145
assert headers["User-Agent"] == f"{constants.CLIENT_NAME}/{__version__}"
143-
assert len(headers.keys()) == 12
146+
assert len(headers.keys()) == 13
144147

145148
req.post("URL")
146149
_, post_kwargs = post.call_args

trino/client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ class ClientSession(object):
8282
8383
:param user: associated with the query. It is useful for access control
8484
and query scheduling.
85+
:param authorization_user: associated with the query. It is useful for access control
86+
and query scheduling.
8587
:param source: associated with the query. It is useful for access
8688
control and query scheduling.
8789
:param catalog: to query. The *catalog* is associated with a Trino
@@ -113,6 +115,7 @@ class ClientSession(object):
113115
def __init__(
114116
self,
115117
user: str,
118+
authorization_user: str = None,
116119
catalog: str = None,
117120
schema: str = None,
118121
source: str = None,
@@ -125,6 +128,7 @@ def __init__(
125128
timezone: str = None,
126129
):
127130
self._user = user
131+
self._authorization_user = authorization_user
128132
self._catalog = catalog
129133
self._schema = schema
130134
self._source = source
@@ -144,6 +148,16 @@ def __init__(
144148
def user(self):
145149
return self._user
146150

151+
@property
152+
def authorization_user(self):
153+
with self._object_lock:
154+
return self._authorization_user
155+
156+
@authorization_user.setter
157+
def authorization_user(self, authorization_user):
158+
with self._object_lock:
159+
self._authorization_user = authorization_user
160+
147161
@property
148162
def catalog(self):
149163
with self._object_lock:
@@ -441,6 +455,7 @@ def http_headers(self) -> Dict[str, str]:
441455
headers[constants.HEADER_SCHEMA] = self._client_session.schema
442456
headers[constants.HEADER_SOURCE] = self._client_session.source
443457
headers[constants.HEADER_USER] = self._client_session.user
458+
headers[constants.HEADER_AUTHORIZATION_USER] = self._client_session.authorization_user
444459
headers[constants.HEADER_TIMEZONE] = self._client_session.timezone
445460
headers[constants.HEADER_CLIENT_CAPABILITIES] = 'PARAMETRIC_DATETIME'
446461
headers["user-agent"] = f"{constants.CLIENT_NAME}/{__version__}"
@@ -631,6 +646,12 @@ def process(self, http_response) -> TrinoStatus:
631646
):
632647
self._client_session.prepared_statements.pop(name, None)
633648

649+
if constants.HEADER_SET_AUTHORIZATION_USER in http_response.headers:
650+
self._client_session.authorization_user = http_response.headers[constants.HEADER_SET_AUTHORIZATION_USER]
651+
652+
if constants.HEADER_RESET_AUTHORIZATION_USER in http_response.headers:
653+
self._client_session.authorization_user = None
654+
634655
self._next_uri = response.get("nextUri")
635656

636657
data = response.get("data") if response.get("data") else []

trino/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@
5656

5757
HEADER_CLIENT_CAPABILITIES = "X-Trino-Client-Capabilities"
5858

59+
HEADER_AUTHORIZATION_USER = "X-Trino-Authorization-User"
60+
HEADER_SET_AUTHORIZATION_USER = "X-Trino-Set-Authorization-User"
61+
HEADER_RESET_AUTHORIZATION_USER = "X-Trino-Reset-Authorization-User"
62+
5963
LENGTH_TYPES = ["char", "varchar"]
6064
PRECISION_TYPES = ["time", "time with time zone", "timestamp", "timestamp with time zone", "decimal"]
6165
SCALE_TYPES = ["decimal"]

0 commit comments

Comments
 (0)