22
22
import requests
23
23
from httpretty import httprettified
24
24
from requests_kerberos .exceptions import KerberosExchangeError
25
+ from tzlocal import get_localzone_name # type: ignore
25
26
26
27
import trino .exceptions
27
28
from tests .unit .oauth_test_utils import (
48
49
_RetryWithExponentialBackoff ,
49
50
)
50
51
52
+ try :
53
+ from zoneinfo import ZoneInfoNotFoundError # type: ignore
54
+ except ModuleNotFoundError :
55
+ from backports .zoneinfo ._common import ZoneInfoNotFoundError # type: ignore
56
+
51
57
52
58
@mock .patch ("trino.client.TrinoRequest.http" )
53
59
def test_trino_initial_request (mock_requests , sample_post_response_data ):
@@ -81,6 +87,7 @@ def test_request_headers(mock_get_and_post):
81
87
schema = "test_schema"
82
88
user = "test_user"
83
89
source = "test_source"
90
+ timezone = "Europe/Brussels"
84
91
accept_encoding_header = "accept-encoding"
85
92
accept_encoding_value = "identity,deflate,gzip"
86
93
client_info_header = constants .HEADER_CLIENT_INFO
@@ -94,6 +101,7 @@ def test_request_headers(mock_get_and_post):
94
101
source = source ,
95
102
catalog = catalog ,
96
103
schema = schema ,
104
+ timezone = timezone ,
97
105
headers = {
98
106
accept_encoding_header : accept_encoding_value ,
99
107
client_info_header : client_info_value ,
@@ -109,9 +117,10 @@ def assert_headers(headers):
109
117
assert headers [constants .HEADER_SOURCE ] == source
110
118
assert headers [constants .HEADER_USER ] == user
111
119
assert headers [constants .HEADER_SESSION ] == ""
120
+ assert headers [constants .HEADER_TIMEZONE ] == timezone
112
121
assert headers [accept_encoding_header ] == accept_encoding_value
113
122
assert headers [client_info_header ] == client_info_value
114
- assert len (headers .keys ()) == 8
123
+ assert len (headers .keys ()) == 9
115
124
116
125
req .post ("URL" )
117
126
_ , post_kwargs = post .call_args
@@ -1113,3 +1122,62 @@ def test_request_headers_role_empty(mock_get_and_post):
1113
1122
req .get ("URL" )
1114
1123
_ , get_kwargs = get .call_args
1115
1124
assert_headers_with_roles (post_kwargs ["headers" ], None )
1125
+
1126
+
1127
+ def assert_headers_timezone (headers : Dict [str , str ], timezone : str ):
1128
+ assert headers [constants .HEADER_TIMEZONE ] == timezone
1129
+
1130
+
1131
+ def test_request_headers_with_timezone (mock_get_and_post ):
1132
+ get , post = mock_get_and_post
1133
+
1134
+ req = TrinoRequest (
1135
+ host = "coordinator" ,
1136
+ port = 8080 ,
1137
+ client_session = ClientSession (
1138
+ user = "test_user" ,
1139
+ timezone = "Europe/Brussels"
1140
+ ),
1141
+ )
1142
+
1143
+ req .post ("URL" )
1144
+ _ , post_kwargs = post .call_args
1145
+ assert_headers_timezone (post_kwargs ["headers" ], "Europe/Brussels" )
1146
+
1147
+ req .get ("URL" )
1148
+ _ , get_kwargs = get .call_args
1149
+ assert_headers_timezone (post_kwargs ["headers" ], "Europe/Brussels" )
1150
+
1151
+
1152
+ def test_request_headers_without_timezone (mock_get_and_post ):
1153
+ get , post = mock_get_and_post
1154
+
1155
+ req = TrinoRequest (
1156
+ host = "coordinator" ,
1157
+ port = 8080 ,
1158
+ client_session = ClientSession (
1159
+ user = "test_user" ,
1160
+ ),
1161
+ )
1162
+ localzone = get_localzone_name ()
1163
+
1164
+ req .post ("URL" )
1165
+ _ , post_kwargs = post .call_args
1166
+ assert_headers_timezone (post_kwargs ["headers" ], localzone )
1167
+
1168
+ req .get ("URL" )
1169
+ _ , get_kwargs = get .call_args
1170
+ assert_headers_timezone (post_kwargs ["headers" ], localzone )
1171
+
1172
+
1173
+ def test_request_with_invalid_timezone (mock_get_and_post ):
1174
+ with pytest .raises (ZoneInfoNotFoundError ) as zinfo_error :
1175
+ TrinoRequest (
1176
+ host = "coordinator" ,
1177
+ port = 8080 ,
1178
+ client_session = ClientSession (
1179
+ user = "test_user" ,
1180
+ timezone = "INVALID_TIMEZONE"
1181
+ ),
1182
+ )
1183
+ assert str (zinfo_error .value ).startswith ("'No time zone found with key" )
0 commit comments