13
13
import threading
14
14
import time
15
15
import uuid
16
+ from tzlocal import get_localzone_name
16
17
from typing import Optional , Dict
17
18
from unittest import mock
18
19
from urllib .parse import urlparse
32
33
from trino .client import TrinoQuery , TrinoRequest , TrinoResult , ClientSession , _DelayExponential , _retry_with , \
33
34
_RetryWithExponentialBackoff
34
35
36
+ try :
37
+ from zoneinfo ._common import ZoneInfoNotFoundError # type: ignore
38
+
39
+ except ModuleNotFoundError :
40
+ from backports .zoneinfo ._common import ZoneInfoNotFoundError # type: ignore
41
+
35
42
36
43
@mock .patch ("trino.client.TrinoRequest.http" )
37
44
def test_trino_initial_request (mock_requests , sample_post_response_data ):
@@ -65,6 +72,7 @@ def test_request_headers(mock_get_and_post):
65
72
schema = "test_schema"
66
73
user = "test_user"
67
74
source = "test_source"
75
+ timezone = "Europe/Brussels"
68
76
accept_encoding_header = "accept-encoding"
69
77
accept_encoding_value = "identity,deflate,gzip"
70
78
client_info_header = constants .HEADER_CLIENT_INFO
@@ -78,6 +86,7 @@ def test_request_headers(mock_get_and_post):
78
86
source = source ,
79
87
catalog = catalog ,
80
88
schema = schema ,
89
+ timezone = timezone ,
81
90
headers = {
82
91
accept_encoding_header : accept_encoding_value ,
83
92
client_info_header : client_info_value ,
@@ -93,9 +102,10 @@ def assert_headers(headers):
93
102
assert headers [constants .HEADER_SOURCE ] == source
94
103
assert headers [constants .HEADER_USER ] == user
95
104
assert headers [constants .HEADER_SESSION ] == ""
105
+ assert headers [constants .HEADER_TIMEZONE ] == timezone
96
106
assert headers [accept_encoding_header ] == accept_encoding_value
97
107
assert headers [client_info_header ] == client_info_value
98
- assert len (headers .keys ()) == 8
108
+ assert len (headers .keys ()) == 9
99
109
100
110
req .post ("URL" )
101
111
_ , post_kwargs = post .call_args
@@ -1056,3 +1066,62 @@ def test_request_headers_role_empty(mock_get_and_post):
1056
1066
req .get ("URL" )
1057
1067
_ , get_kwargs = get .call_args
1058
1068
assert_headers_with_roles (post_kwargs ["headers" ], None )
1069
+
1070
+
1071
+ def assert_headers_timezone (headers : Dict [str , str ], timezone : str ):
1072
+ assert headers [constants .HEADER_TIMEZONE ] == timezone
1073
+
1074
+
1075
+ def test_request_headers_with_timezone (mock_get_and_post ):
1076
+ get , post = mock_get_and_post
1077
+
1078
+ req = TrinoRequest (
1079
+ host = "coordinator" ,
1080
+ port = 8080 ,
1081
+ client_session = ClientSession (
1082
+ user = "test_user" ,
1083
+ timezone = "Europe/Brussels"
1084
+ ),
1085
+ )
1086
+
1087
+ req .post ("URL" )
1088
+ _ , post_kwargs = post .call_args
1089
+ assert_headers_timezone (post_kwargs ["headers" ], "Europe/Brussels" )
1090
+
1091
+ req .get ("URL" )
1092
+ _ , get_kwargs = get .call_args
1093
+ assert_headers_timezone (post_kwargs ["headers" ], "Europe/Brussels" )
1094
+
1095
+
1096
+ def test_request_headers_without_timezone (mock_get_and_post ):
1097
+ get , post = mock_get_and_post
1098
+
1099
+ req = TrinoRequest (
1100
+ host = "coordinator" ,
1101
+ port = 8080 ,
1102
+ client_session = ClientSession (
1103
+ user = "test_user" ,
1104
+ ),
1105
+ )
1106
+ localzone = get_localzone_name ()
1107
+
1108
+ req .post ("URL" )
1109
+ _ , post_kwargs = post .call_args
1110
+ assert_headers_timezone (post_kwargs ["headers" ], localzone )
1111
+
1112
+ req .get ("URL" )
1113
+ _ , get_kwargs = get .call_args
1114
+ assert_headers_timezone (post_kwargs ["headers" ], localzone )
1115
+
1116
+
1117
+ def test_request_with_invalid_timezone (mock_get_and_post ):
1118
+ with pytest .raises (ZoneInfoNotFoundError ) as zinfo_error :
1119
+ TrinoRequest (
1120
+ host = "coordinator" ,
1121
+ port = 8080 ,
1122
+ client_session = ClientSession (
1123
+ user = "test_user" ,
1124
+ timezone = "INVALID_TIMEZONE"
1125
+ ),
1126
+ )
1127
+ assert str (zinfo_error .value ).startswith ("'No time zone found with key" )
0 commit comments