Skip to content

Commit 1bf9336

Browse files
committed
Add support for TIMEZONE
1 parent f97aea6 commit 1bf9336

File tree

8 files changed

+138
-2
lines changed

8 files changed

+138
-2
lines changed

.pre-commit-config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ repos:
1313
additional_dependencies:
1414
- "types-pytz"
1515
- "types-requests"
16+
- "types-tzlocal"

README.md

+14
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,20 @@ conn = trino.dbapi.connect(
357357
)
358358
```
359359

360+
## Timezone
361+
362+
The time zone for the session can be explicitly set using the official IANA time zone name. When not set the time zone defaults to the client side local timezone.
363+
364+
```python
365+
import trino
366+
conn = trino.dbapi.connect(
367+
host='localhost',
368+
port=443,
369+
user='username',
370+
timezone="Europe/Brussels",
371+
)
372+
```
373+
360374
## SSL
361375

362376
### SSL verification

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
"Topic :: Database :: Front-Ends",
7777
],
7878
python_requires='>=3.7',
79-
install_requires=["pytz", "requests"],
79+
install_requires=["pytz", "requests", "tzlocal"],
8080
extras_require={
8181
"all": all_require,
8282
"kerberos": kerberos_require,

tests/integration/test_dbapi_integration.py

+29
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import pytest
1717
import pytz
1818
import requests
19+
from tzlocal import get_localzone_name
1920

2021
import trino
2122
from tests.integration.conftest import trino_version
@@ -1120,3 +1121,31 @@ def test_prepared_statements(run_trino):
11201121
cur.execute('DEALLOCATE PREPARE test_prepared_statements')
11211122
cur.fetchall()
11221123
assert cur._request._client_session.prepared_statements == {}
1124+
1125+
1126+
def test_set_timezone_in_connection(run_trino):
1127+
_, host, port = run_trino
1128+
1129+
trino_connection = trino.dbapi.Connection(
1130+
host=host, port=port, user="test", catalog="tpch", timezone="Europe/Brussels"
1131+
)
1132+
cur = trino_connection.cursor()
1133+
cur.execute('SELECT current_timezone()')
1134+
res = cur.fetchall()
1135+
assert res[0][0] == "Europe/Brussels"
1136+
1137+
1138+
def test_connection_without_timezone(run_trino):
1139+
_, host, port = run_trino
1140+
1141+
trino_connection = trino.dbapi.Connection(
1142+
host=host, port=port, user="test", catalog="tpch"
1143+
)
1144+
cur = trino_connection.cursor()
1145+
cur.execute('SELECT current_timezone()')
1146+
res = cur.fetchall()
1147+
session_tz = res[0][0]
1148+
localzone = get_localzone_name()
1149+
assert session_tz == localzone or \
1150+
(session_tz == "UTC" and localzone == "Etc/UTC") \
1151+
# Workaround for difference between Trino timezone and tzlocal for UTC

tests/unit/test_client.py

+70-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import threading
1414
import time
1515
import uuid
16+
from tzlocal import get_localzone_name
1617
from typing import Optional, Dict
1718
from unittest import mock
1819
from urllib.parse import urlparse
@@ -32,6 +33,12 @@
3233
from trino.client import TrinoQuery, TrinoRequest, TrinoResult, ClientSession, _DelayExponential, _retry_with, \
3334
_RetryWithExponentialBackoff
3435

36+
try:
37+
from zoneinfo._common import ZoneInfoNotFoundError # type: ignore
38+
39+
except ModuleNotFoundError:
40+
from backports.zoneinfo._common import ZoneInfoNotFoundError # type: ignore
41+
3542

3643
@mock.patch("trino.client.TrinoRequest.http")
3744
def test_trino_initial_request(mock_requests, sample_post_response_data):
@@ -65,6 +72,7 @@ def test_request_headers(mock_get_and_post):
6572
schema = "test_schema"
6673
user = "test_user"
6774
source = "test_source"
75+
timezone = "Europe/Brussels"
6876
accept_encoding_header = "accept-encoding"
6977
accept_encoding_value = "identity,deflate,gzip"
7078
client_info_header = constants.HEADER_CLIENT_INFO
@@ -78,6 +86,7 @@ def test_request_headers(mock_get_and_post):
7886
source=source,
7987
catalog=catalog,
8088
schema=schema,
89+
timezone=timezone,
8190
headers={
8291
accept_encoding_header: accept_encoding_value,
8392
client_info_header: client_info_value,
@@ -93,9 +102,10 @@ def assert_headers(headers):
93102
assert headers[constants.HEADER_SOURCE] == source
94103
assert headers[constants.HEADER_USER] == user
95104
assert headers[constants.HEADER_SESSION] == ""
105+
assert headers[constants.HEADER_TIMEZONE] == timezone
96106
assert headers[accept_encoding_header] == accept_encoding_value
97107
assert headers[client_info_header] == client_info_value
98-
assert len(headers.keys()) == 8
108+
assert len(headers.keys()) == 9
99109

100110
req.post("URL")
101111
_, post_kwargs = post.call_args
@@ -1056,3 +1066,62 @@ def test_request_headers_role_empty(mock_get_and_post):
10561066
req.get("URL")
10571067
_, get_kwargs = get.call_args
10581068
assert_headers_with_roles(post_kwargs["headers"], None)
1069+
1070+
1071+
def assert_headers_timezone(headers: Dict[str, str], timezone: str):
1072+
assert headers[constants.HEADER_TIMEZONE] == timezone
1073+
1074+
1075+
def test_request_headers_with_timezone(mock_get_and_post):
1076+
get, post = mock_get_and_post
1077+
1078+
req = TrinoRequest(
1079+
host="coordinator",
1080+
port=8080,
1081+
client_session=ClientSession(
1082+
user="test_user",
1083+
timezone="Europe/Brussels"
1084+
),
1085+
)
1086+
1087+
req.post("URL")
1088+
_, post_kwargs = post.call_args
1089+
assert_headers_timezone(post_kwargs["headers"], "Europe/Brussels")
1090+
1091+
req.get("URL")
1092+
_, get_kwargs = get.call_args
1093+
assert_headers_timezone(post_kwargs["headers"], "Europe/Brussels")
1094+
1095+
1096+
def test_request_headers_without_timezone(mock_get_and_post):
1097+
get, post = mock_get_and_post
1098+
1099+
req = TrinoRequest(
1100+
host="coordinator",
1101+
port=8080,
1102+
client_session=ClientSession(
1103+
user="test_user",
1104+
),
1105+
)
1106+
localzone = get_localzone_name()
1107+
1108+
req.post("URL")
1109+
_, post_kwargs = post.call_args
1110+
assert_headers_timezone(post_kwargs["headers"], localzone)
1111+
1112+
req.get("URL")
1113+
_, get_kwargs = get.call_args
1114+
assert_headers_timezone(post_kwargs["headers"], localzone)
1115+
1116+
1117+
def test_request_with_invalid_timezone(mock_get_and_post):
1118+
with pytest.raises(ZoneInfoNotFoundError) as zinfo_error:
1119+
TrinoRequest(
1120+
host="coordinator",
1121+
port=8080,
1122+
client_session=ClientSession(
1123+
user="test_user",
1124+
timezone="INVALID_TIMEZONE"
1125+
),
1126+
)
1127+
assert str(zinfo_error.value).startswith("'No time zone found with key")

trino/client.py

+20
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import urllib.parse
4444
from datetime import datetime, timedelta, timezone
4545
from decimal import Decimal
46+
from tzlocal import get_localzone_name
4647
from typing import Any, Dict, List, Optional, Tuple, Union
4748

4849
import pytz
@@ -51,6 +52,13 @@
5152
import trino.logging
5253
from trino import constants, exceptions
5354

55+
try:
56+
from zoneinfo import ZoneInfo # type: ignore
57+
58+
except ModuleNotFoundError:
59+
from backports.zoneinfo import ZoneInfo # type: ignore
60+
61+
5462
__all__ = ["ClientSession", "TrinoQuery", "TrinoRequest", "PROXIES"]
5563

5664
logger = trino.logging.get_logger(__name__)
@@ -100,6 +108,8 @@ class ClientSession(object):
100108
:param client_tags: Client tags as list of strings.
101109
:param roles: roles for the current session. Some connectors do not
102110
support role management. See connector documentation for more details.
111+
:param timezone: The timezone for query processing. Defaults to the timezone
112+
of the Trino cluster, and not the timezone of the client.
103113
"""
104114

105115
def __init__(
@@ -114,6 +124,7 @@ def __init__(
114124
extra_credential: List[Tuple[str, str]] = None,
115125
client_tags: List[str] = None,
116126
roles: Dict[str, str] = None,
127+
timezone: str = None,
117128
):
118129
self._user = user
119130
self._catalog = catalog
@@ -127,6 +138,9 @@ def __init__(
127138
self._roles = roles.copy() if roles is not None else {}
128139
self._prepared_statements: Dict[str, str] = {}
129140
self._object_lock = threading.Lock()
141+
if timezone: # Check timezone validity
142+
ZoneInfo(timezone)
143+
self._timezone = timezone or get_localzone_name()
130144

131145
@property
132146
def user(self):
@@ -207,6 +221,11 @@ def prepared_statements(self, prepared_statements):
207221
with self._object_lock:
208222
self._prepared_statements = prepared_statements
209223

224+
@property
225+
def timezone(self):
226+
with self._object_lock:
227+
return self._timezone
228+
210229
def __getstate__(self):
211230
state = self.__dict__.copy()
212231
del state["_object_lock"]
@@ -408,6 +427,7 @@ def http_headers(self) -> Dict[str, str]:
408427
headers[constants.HEADER_SCHEMA] = self._client_session.schema
409428
headers[constants.HEADER_SOURCE] = self._client_session.source
410429
headers[constants.HEADER_USER] = self._client_session.user
430+
headers[constants.HEADER_TIMEZONE] = self._client_session.timezone
411431
if len(self._client_session.roles.values()):
412432
headers[constants.HEADER_ROLE] = ",".join(
413433
# ``name`` must not contain ``=``

trino/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
HEADER_CLIENT_INFO = "X-Trino-Client-Info"
3535
HEADER_CLIENT_TAGS = "X-Trino-Client-Tags"
3636
HEADER_EXTRA_CREDENTIAL = "X-Trino-Extra-Credential"
37+
HEADER_TIMEZONE = "X-Trino-Time-Zone"
3738

3839
HEADER_SESSION = "X-Trino-Session"
3940
HEADER_SET_SESSION = "X-Trino-Set-Session"

trino/dbapi.py

+2
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(
111111
client_tags=None,
112112
experimental_python_types=False,
113113
roles=None,
114+
timezone=None,
114115
):
115116
self.host = host
116117
self.port = port
@@ -130,6 +131,7 @@ def __init__(
130131
extra_credential=extra_credential,
131132
client_tags=client_tags,
132133
roles=roles,
134+
timezone=timezone,
133135
)
134136
# mypy cannot follow module import
135137
if http_session is None:

0 commit comments

Comments
 (0)