-
Notifications
You must be signed in to change notification settings - Fork 189
Add support for TIMEZONE #252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
import pytest | ||
import pytz | ||
import requests | ||
from tzlocal import get_localzone_name # type: ignore | ||
hashhar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
import trino | ||
from tests.integration.conftest import trino_version | ||
|
@@ -1107,3 +1108,31 @@ def test_prepared_statements(run_trino): | |
cur.execute('DEALLOCATE PREPARE test_prepared_statements') | ||
cur.fetchall() | ||
assert cur._request._client_session.prepared_statements == {} | ||
|
||
|
||
def test_set_timezone_in_connection(run_trino): | ||
_, host, port = run_trino | ||
|
||
trino_connection = trino.dbapi.Connection( | ||
host=host, port=port, user="test", catalog="tpch", timezone="Europe/Brussels" | ||
) | ||
cur = trino_connection.cursor() | ||
cur.execute('SELECT current_timezone()') | ||
res = cur.fetchall() | ||
assert res[0][0] == "Europe/Brussels" | ||
|
||
|
||
def test_connection_without_timezone(run_trino): | ||
_, host, port = run_trino | ||
|
||
trino_connection = trino.dbapi.Connection( | ||
host=host, port=port, user="test", catalog="tpch" | ||
) | ||
cur = trino_connection.cursor() | ||
cur.execute('SELECT current_timezone()') | ||
res = cur.fetchall() | ||
session_tz = res[0][0] | ||
localzone = get_localzone_name() | ||
assert session_tz == localzone or \ | ||
(session_tz == "UTC" and localzone == "Etc/UTC") \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I conclude from this test if we set the timezone header to "Etc/UTC", Trino actually understands this but returns UTC as the Is my understanding correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Exactly |
||
# Workaround for difference between Trino timezone and tzlocal for UTC |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,10 +50,18 @@ | |
import pytz | ||
import requests | ||
from pytz.tzinfo import BaseTzInfo | ||
from tzlocal import get_localzone_name # type: ignore | ||
|
||
import trino.logging | ||
from trino import constants, exceptions | ||
|
||
try: | ||
from zoneinfo import ZoneInfo # type: ignore | ||
|
||
except ModuleNotFoundError: | ||
from backports.zoneinfo import ZoneInfo # type: ignore | ||
|
||
|
||
__all__ = ["ClientSession", "TrinoQuery", "TrinoRequest", "PROXIES"] | ||
|
||
logger = trino.logging.get_logger(__name__) | ||
|
@@ -107,6 +115,7 @@ class ClientSession(object): | |
:param client_tags: Client tags as list of strings. | ||
:param roles: roles for the current session. Some connectors do not | ||
support role management. See connector documentation for more details. | ||
:param timezone: The timezone for query processing. Defaults to the system's local timezone. | ||
""" | ||
|
||
def __init__( | ||
|
@@ -121,6 +130,7 @@ def __init__( | |
extra_credential: List[Tuple[str, str]] = None, | ||
client_tags: List[str] = None, | ||
roles: Dict[str, str] = None, | ||
timezone: str = None, | ||
): | ||
self._user = user | ||
self._catalog = catalog | ||
|
@@ -134,6 +144,9 @@ def __init__( | |
self._roles = roles.copy() if roles is not None else {} | ||
self._prepared_statements: Dict[str, str] = {} | ||
self._object_lock = threading.Lock() | ||
self._timezone = timezone or get_localzone_name() | ||
if timezone: # Check timezone validity | ||
ZoneInfo(timezone) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just for other reviewers this throws on invalid timezones (invalidity is determined by the OS's tzdb files or the tzdata package shipped by Python itself):
|
||
|
||
@property | ||
def user(self): | ||
|
@@ -214,6 +227,11 @@ def prepared_statements(self, prepared_statements): | |
with self._object_lock: | ||
self._prepared_statements = prepared_statements | ||
|
||
@property | ||
def timezone(self): | ||
with self._object_lock: | ||
return self._timezone | ||
|
||
def __getstate__(self): | ||
state = self.__dict__.copy() | ||
del state["_object_lock"] | ||
|
@@ -415,6 +433,7 @@ def http_headers(self) -> Dict[str, str]: | |
headers[constants.HEADER_SCHEMA] = self._client_session.schema | ||
headers[constants.HEADER_SOURCE] = self._client_session.source | ||
headers[constants.HEADER_USER] = self._client_session.user | ||
headers[constants.HEADER_TIMEZONE] = self._client_session.timezone | ||
if len(self._client_session.roles.values()): | ||
headers[constants.HEADER_ROLE] = ",".join( | ||
# ``name`` must not contain ``=`` | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please revert the unrelated change on python_requires