Skip to content

Commit 62ffa49

Browse files
mdesmethashhar
authored andcommitted
Encode roles as in JDBC driver
1 parent e4a3f0f commit 62ffa49

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

tests/unit/test_client.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import json
1313
import threading
1414
import time
15+
import urllib
1516
import uuid
1617
from typing import Dict, Optional
1718
from unittest import mock
@@ -105,6 +106,13 @@ def test_request_headers(mock_get_and_post):
105106
headers={
106107
accept_encoding_header: accept_encoding_value,
107108
client_info_header: client_info_value,
109+
},
110+
roles={
111+
"hive": "ALL",
112+
"system": "analyst",
113+
"catalog1": "NONE",
114+
# ensure backwards compatibility
115+
"catalog2": "ROLE{catalog2_role}",
108116
}
109117
),
110118
http_scheme="http",
@@ -121,7 +129,13 @@ def assert_headers(headers):
121129
assert headers[constants.HEADER_CLIENT_CAPABILITIES] == "PARAMETRIC_DATETIME"
122130
assert headers[accept_encoding_header] == accept_encoding_value
123131
assert headers[client_info_header] == client_info_value
124-
assert len(headers.keys()) == 10
132+
assert headers[constants.HEADER_ROLE] == (
133+
"hive=ALL,"
134+
"system=" + urllib.parse.quote("ROLE{analyst}") + ","
135+
"catalog1=NONE,"
136+
"catalog2=" + urllib.parse.quote("ROLE{catalog2_role}")
137+
)
138+
assert len(headers.keys()) == 11
125139

126140
req.post("URL")
127141
_, post_kwargs = post.call_args
@@ -1095,14 +1109,15 @@ def test_request_headers_role_admin(mock_get_and_post):
10951109
roles={"system": "admin"}
10961110
),
10971111
)
1112+
roles = "system=" + urllib.parse.quote("ROLE{admin}")
10981113

10991114
req.post("URL")
11001115
_, post_kwargs = post.call_args
1101-
assert_headers_with_roles(post_kwargs["headers"], "system=admin")
1116+
assert_headers_with_roles(post_kwargs["headers"], roles)
11021117

11031118
req.get("URL")
11041119
_, get_kwargs = get.call_args
1105-
assert_headers_with_roles(post_kwargs["headers"], "system=admin")
1120+
assert_headers_with_roles(post_kwargs["headers"], roles)
11061121

11071122

11081123
def test_request_headers_role_empty(mock_get_and_post):

trino/client.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@
8585
MAX_PYTHON_TEMPORAL_PRECISION_POWER = 6
8686
MAX_PYTHON_TEMPORAL_PRECISION = POWERS_OF_TEN[MAX_PYTHON_TEMPORAL_PRECISION_POWER]
8787

88+
ROLE_PATTERN = re.compile(r"^ROLE\{(.*)\}$")
89+
8890

8991
class ClientSession(object):
9092
"""
@@ -143,7 +145,7 @@ def __init__(
143145
self._transaction_id = transaction_id
144146
self._extra_credential = extra_credential
145147
self._client_tags = client_tags.copy() if client_tags is not None else list()
146-
self._roles = roles.copy() if roles is not None else {}
148+
self._roles = self._format_roles(roles) if roles is not None else {}
147149
self._prepared_statements: Dict[str, str] = {}
148150
self._object_lock = threading.Lock()
149151
self._timezone = timezone or get_localzone_name()
@@ -234,6 +236,15 @@ def timezone(self):
234236
with self._object_lock:
235237
return self._timezone
236238

239+
def _format_roles(self, roles):
240+
formatted_roles = {}
241+
for catalog, role in roles.items():
242+
if role in ("NONE", "ALL") or ROLE_PATTERN.match(role) is not None:
243+
formatted_roles[catalog] = role
244+
else:
245+
formatted_roles[catalog] = f"ROLE{{{role}}}"
246+
return formatted_roles
247+
237248
def __getstate__(self):
238249
state = self.__dict__.copy()
239250
del state["_object_lock"]

0 commit comments

Comments
 (0)