Skip to content

Add support for TIMEZONE #252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,24 @@ conn = trino.dbapi.connect(
)
```

## Timezone

The time zone for the session can be explicitly set using the IANA time zone
name. When not set the time zone defaults to the client side local timezone.

```python
import trino
conn = trino.dbapi.connect(
host='localhost',
port=443,
user='username',
timezone='Europe/Brussels',
)
```

> **NOTE: The behaviour till version 0.320.0 was the same as setting session timezone to UTC.**
> **To preserve that behaviour pass `timezone='UTC'` when creating the connection.**

## SSL

### SSL verification
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@
"Programming Language :: Python :: Implementation :: PyPy",
"Topic :: Database :: Front-Ends",
],
python_requires=">=3.7",
install_requires=["pytz", "requests"],
python_requires='>=3.7',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert the unrelated change on python_requires

install_requires=["pytz", "requests", "tzlocal"],
extras_require={
"all": all_require,
"kerberos": kerberos_require,
Expand Down
29 changes: 29 additions & 0 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pytest
import pytz
import requests
from tzlocal import get_localzone_name # type: ignore

import trino
from tests.integration.conftest import trino_version
Expand Down Expand Up @@ -1107,3 +1108,31 @@ def test_prepared_statements(run_trino):
cur.execute('DEALLOCATE PREPARE test_prepared_statements')
cur.fetchall()
assert cur._request._client_session.prepared_statements == {}


def test_set_timezone_in_connection(run_trino):
_, host, port = run_trino

trino_connection = trino.dbapi.Connection(
host=host, port=port, user="test", catalog="tpch", timezone="Europe/Brussels"
)
cur = trino_connection.cursor()
cur.execute('SELECT current_timezone()')
res = cur.fetchall()
assert res[0][0] == "Europe/Brussels"


def test_connection_without_timezone(run_trino):
_, host, port = run_trino

trino_connection = trino.dbapi.Connection(
host=host, port=port, user="test", catalog="tpch"
)
cur = trino_connection.cursor()
cur.execute('SELECT current_timezone()')
res = cur.fetchall()
session_tz = res[0][0]
localzone = get_localzone_name()
assert session_tz == localzone or \
(session_tz == "UTC" and localzone == "Etc/UTC") \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I conclude from this test if we set the timezone header to "Etc/UTC", Trino actually understands this but returns UTC as the current_timezone().

Is my understanding correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly

# Workaround for difference between Trino timezone and tzlocal for UTC
70 changes: 69 additions & 1 deletion tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import requests
from httpretty import httprettified
from requests_kerberos.exceptions import KerberosExchangeError
from tzlocal import get_localzone_name # type: ignore

import trino.exceptions
from tests.unit.oauth_test_utils import (
Expand All @@ -48,6 +49,11 @@
_RetryWithExponentialBackoff,
)

try:
from zoneinfo import ZoneInfoNotFoundError # type: ignore
except ModuleNotFoundError:
from backports.zoneinfo._common import ZoneInfoNotFoundError # type: ignore


@mock.patch("trino.client.TrinoRequest.http")
def test_trino_initial_request(mock_requests, sample_post_response_data):
Expand Down Expand Up @@ -81,6 +87,7 @@ def test_request_headers(mock_get_and_post):
schema = "test_schema"
user = "test_user"
source = "test_source"
timezone = "Europe/Brussels"
accept_encoding_header = "accept-encoding"
accept_encoding_value = "identity,deflate,gzip"
client_info_header = constants.HEADER_CLIENT_INFO
Expand All @@ -94,6 +101,7 @@ def test_request_headers(mock_get_and_post):
source=source,
catalog=catalog,
schema=schema,
timezone=timezone,
headers={
accept_encoding_header: accept_encoding_value,
client_info_header: client_info_value,
Expand All @@ -109,9 +117,10 @@ def assert_headers(headers):
assert headers[constants.HEADER_SOURCE] == source
assert headers[constants.HEADER_USER] == user
assert headers[constants.HEADER_SESSION] == ""
assert headers[constants.HEADER_TIMEZONE] == timezone
assert headers[accept_encoding_header] == accept_encoding_value
assert headers[client_info_header] == client_info_value
assert len(headers.keys()) == 8
assert len(headers.keys()) == 9

req.post("URL")
_, post_kwargs = post.call_args
Expand Down Expand Up @@ -1113,3 +1122,62 @@ def test_request_headers_role_empty(mock_get_and_post):
req.get("URL")
_, get_kwargs = get.call_args
assert_headers_with_roles(post_kwargs["headers"], None)


def assert_headers_timezone(headers: Dict[str, str], timezone: str):
assert headers[constants.HEADER_TIMEZONE] == timezone


def test_request_headers_with_timezone(mock_get_and_post):
get, post = mock_get_and_post

req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(
user="test_user",
timezone="Europe/Brussels"
),
)

req.post("URL")
_, post_kwargs = post.call_args
assert_headers_timezone(post_kwargs["headers"], "Europe/Brussels")

req.get("URL")
_, get_kwargs = get.call_args
assert_headers_timezone(post_kwargs["headers"], "Europe/Brussels")


def test_request_headers_without_timezone(mock_get_and_post):
get, post = mock_get_and_post

req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(
user="test_user",
),
)
localzone = get_localzone_name()

req.post("URL")
_, post_kwargs = post.call_args
assert_headers_timezone(post_kwargs["headers"], localzone)

req.get("URL")
_, get_kwargs = get.call_args
assert_headers_timezone(post_kwargs["headers"], localzone)


def test_request_with_invalid_timezone(mock_get_and_post):
with pytest.raises(ZoneInfoNotFoundError) as zinfo_error:
TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(
user="test_user",
timezone="INVALID_TIMEZONE"
),
)
assert str(zinfo_error.value).startswith("'No time zone found with key")
19 changes: 19 additions & 0 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,18 @@
import pytz
import requests
from pytz.tzinfo import BaseTzInfo
from tzlocal import get_localzone_name # type: ignore

import trino.logging
from trino import constants, exceptions

try:
from zoneinfo import ZoneInfo # type: ignore

except ModuleNotFoundError:
from backports.zoneinfo import ZoneInfo # type: ignore


__all__ = ["ClientSession", "TrinoQuery", "TrinoRequest", "PROXIES"]

logger = trino.logging.get_logger(__name__)
Expand Down Expand Up @@ -107,6 +115,7 @@ class ClientSession(object):
:param client_tags: Client tags as list of strings.
:param roles: roles for the current session. Some connectors do not
support role management. See connector documentation for more details.
:param timezone: The timezone for query processing. Defaults to the system's local timezone.
"""

def __init__(
Expand All @@ -121,6 +130,7 @@ def __init__(
extra_credential: List[Tuple[str, str]] = None,
client_tags: List[str] = None,
roles: Dict[str, str] = None,
timezone: str = None,
):
self._user = user
self._catalog = catalog
Expand All @@ -134,6 +144,9 @@ def __init__(
self._roles = roles.copy() if roles is not None else {}
self._prepared_statements: Dict[str, str] = {}
self._object_lock = threading.Lock()
self._timezone = timezone or get_localzone_name()
if timezone: # Check timezone validity
ZoneInfo(timezone)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for other reviewers this throws on invalid timezones (invalidity is determined by the OS's tzdb files or the tzdata package shipped by Python itself):

ZoneInfoNotFoundError: 'No time zone found with key Foobar'


@property
def user(self):
Expand Down Expand Up @@ -214,6 +227,11 @@ def prepared_statements(self, prepared_statements):
with self._object_lock:
self._prepared_statements = prepared_statements

@property
def timezone(self):
with self._object_lock:
return self._timezone

def __getstate__(self):
state = self.__dict__.copy()
del state["_object_lock"]
Expand Down Expand Up @@ -415,6 +433,7 @@ def http_headers(self) -> Dict[str, str]:
headers[constants.HEADER_SCHEMA] = self._client_session.schema
headers[constants.HEADER_SOURCE] = self._client_session.source
headers[constants.HEADER_USER] = self._client_session.user
headers[constants.HEADER_TIMEZONE] = self._client_session.timezone
if len(self._client_session.roles.values()):
headers[constants.HEADER_ROLE] = ",".join(
# ``name`` must not contain ``=``
Expand Down
1 change: 1 addition & 0 deletions trino/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
HEADER_CLIENT_INFO = "X-Trino-Client-Info"
HEADER_CLIENT_TAGS = "X-Trino-Client-Tags"
HEADER_EXTRA_CREDENTIAL = "X-Trino-Extra-Credential"
HEADER_TIMEZONE = "X-Trino-Time-Zone"

HEADER_SESSION = "X-Trino-Session"
HEADER_SET_SESSION = "X-Trino-Set-Session"
Expand Down
2 changes: 2 additions & 0 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
client_tags=None,
experimental_python_types=False,
roles=None,
timezone=None,
):
self.host = host
self.port = port
Expand All @@ -129,6 +130,7 @@ def __init__(
extra_credential=extra_credential,
client_tags=client_tags,
roles=roles,
timezone=timezone,
)
# mypy cannot follow module import
if http_session is None:
Expand Down