Skip to content

Commit 5731329

Browse files
committed
Reorder value mappers
This makes it easier to identify missing types for example.
1 parent f670a53 commit 5731329

File tree

1 file changed

+73
-68
lines changed

1 file changed

+73
-68
lines changed

trino/mapper.py

+73-68
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,6 @@ def map(self, value: Any) -> Optional[T]:
3131
pass
3232

3333

34-
class NoOpValueMapper(ValueMapper[Any]):
35-
def map(self, value) -> Optional[Any]:
36-
return value
37-
38-
39-
class DecimalValueMapper(ValueMapper[Decimal]):
40-
def map(self, value) -> Optional[Decimal]:
41-
if value is None:
42-
return None
43-
return Decimal(value)
44-
45-
4634
class DoubleValueMapper(ValueMapper[float]):
4735
def map(self, value) -> Optional[float]:
4836
if value is None:
@@ -56,19 +44,25 @@ def map(self, value) -> Optional[float]:
5644
return float(value)
5745

5846

59-
def _create_tzinfo(timezone_str: str) -> tzinfo:
60-
if timezone_str.startswith("+") or timezone_str.startswith("-"):
61-
hours = timezone_str[1:3]
62-
minutes = timezone_str[4:6]
63-
if timezone_str.startswith("-"):
64-
return timezone(-timedelta(hours=int(hours), minutes=int(minutes)))
65-
return timezone(timedelta(hours=int(hours), minutes=int(minutes)))
66-
else:
67-
return ZoneInfo(timezone_str)
47+
class DecimalValueMapper(ValueMapper[Decimal]):
48+
def map(self, value) -> Optional[Decimal]:
49+
if value is None:
50+
return None
51+
return Decimal(value)
6852

6953

70-
def _fraction_to_decimal(fractional_str: str) -> Decimal:
71-
return Decimal(fractional_str or 0) / POWERS_OF_TEN[len(fractional_str)]
54+
class BinaryValueMapper(ValueMapper[bytes]):
55+
def map(self, value) -> Optional[bytes]:
56+
if value is None:
57+
return None
58+
return base64.b64decode(value.encode("utf8"))
59+
60+
61+
class DateValueMapper(ValueMapper[date]):
62+
def map(self, value) -> Optional[date]:
63+
if value is None:
64+
return None
65+
return date.fromisoformat(value)
7266

7367

7468
class TimeValueMapper(ValueMapper[time]):
@@ -103,13 +97,6 @@ def map(self, value) -> Optional[time]:
10397
).round_to(self.precision).to_python_type()
10498

10599

106-
class DateValueMapper(ValueMapper[date]):
107-
def map(self, value) -> Optional[date]:
108-
if value is None:
109-
return None
110-
return date.fromisoformat(value)
111-
112-
113100
class TimestampValueMapper(ValueMapper[datetime]):
114101
def __init__(self, precision):
115102
self.datetime_default_size = 19 # size of 'YYYY-MM-DD HH:MM:SS' (the datetime string up to the seconds)
@@ -139,11 +126,19 @@ def map(self, value) -> Optional[datetime]:
139126
).round_to(self.precision).to_python_type()
140127

141128

142-
class BinaryValueMapper(ValueMapper[bytes]):
143-
def map(self, value) -> Optional[bytes]:
144-
if value is None:
145-
return None
146-
return base64.b64decode(value.encode("utf8"))
129+
def _create_tzinfo(timezone_str: str) -> tzinfo:
130+
if timezone_str.startswith("+") or timezone_str.startswith("-"):
131+
hours = timezone_str[1:3]
132+
minutes = timezone_str[4:6]
133+
if timezone_str.startswith("-"):
134+
return timezone(-timedelta(hours=int(hours), minutes=int(minutes)))
135+
return timezone(timedelta(hours=int(hours), minutes=int(minutes)))
136+
else:
137+
return ZoneInfo(timezone_str)
138+
139+
140+
def _fraction_to_decimal(fractional_str: str) -> Decimal:
141+
return Decimal(fractional_str or 0) / POWERS_OF_TEN[len(fractional_str)]
147142

148143

149144
class ArrayValueMapper(ValueMapper[List[Optional[Any]]]):
@@ -156,6 +151,19 @@ def map(self, values: List[Any]) -> Optional[List[Any]]:
156151
return [self.mapper.map(value) for value in values]
157152

158153

154+
class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]):
155+
def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]):
156+
self.key_mapper = key_mapper
157+
self.value_mapper = value_mapper
158+
159+
def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]:
160+
if values is None:
161+
return None
162+
return {
163+
self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items()
164+
}
165+
166+
159167
class RowValueMapper(ValueMapper[Tuple[Optional[Any], ...]]):
160168
def __init__(self, mappers: List[ValueMapper[Any]], names: List[str], types: List[str]):
161169
self.mappers = mappers
@@ -172,26 +180,18 @@ def map(self, values: List[Any]) -> Optional[Tuple[Optional[Any], ...]]:
172180
)
173181

174182

175-
class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]):
176-
def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]):
177-
self.key_mapper = key_mapper
178-
self.value_mapper = value_mapper
179-
180-
def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]:
181-
if values is None:
182-
return None
183-
return {
184-
self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items()
185-
}
186-
187-
188183
class UuidValueMapper(ValueMapper[uuid.UUID]):
189184
def map(self, value: Any) -> Optional[uuid.UUID]:
190185
if value is None:
191186
return None
192187
return uuid.UUID(value)
193188

194189

190+
class NoOpValueMapper(ValueMapper[Any]):
191+
def map(self, value) -> Optional[Any]:
192+
return value
193+
194+
195195
class NoOpRowMapper:
196196
"""
197197
No-op RowMapper which does not perform any transformation
@@ -220,9 +220,32 @@ def create(self, columns, legacy_primitive_types):
220220
def _create_value_mapper(self, column) -> ValueMapper:
221221
col_type = column['rawType']
222222

223+
# primitive types
224+
if col_type in {'double', 'real'}:
225+
return DoubleValueMapper()
226+
if col_type == 'decimal':
227+
return DecimalValueMapper()
228+
if col_type == 'varbinary':
229+
return BinaryValueMapper()
230+
if col_type == 'date':
231+
return DateValueMapper()
232+
if col_type == 'time':
233+
return TimeValueMapper(self._get_precision(column))
234+
if col_type == 'time with time zone':
235+
return TimeWithTimeZoneValueMapper(self._get_precision(column))
236+
if col_type == 'timestamp':
237+
return TimestampValueMapper(self._get_precision(column))
238+
if col_type == 'timestamp with time zone':
239+
return TimestampWithTimeZoneValueMapper(self._get_precision(column))
240+
241+
# structural types
223242
if col_type == 'array':
224243
value_mapper = self._create_value_mapper(column['arguments'][0]['value'])
225244
return ArrayValueMapper(value_mapper)
245+
if col_type == 'map':
246+
key_mapper = self._create_value_mapper(column['arguments'][0]['value'])
247+
value_mapper = self._create_value_mapper(column['arguments'][1]['value'])
248+
return MapValueMapper(key_mapper, value_mapper)
226249
if col_type == 'row':
227250
mappers = []
228251
names = []
@@ -232,26 +255,8 @@ def _create_value_mapper(self, column) -> ValueMapper:
232255
names.append(arg['value']['fieldName']['name'] if "fieldName" in arg['value'] else None)
233256
types.append(arg['value']['typeSignature']['rawType'])
234257
return RowValueMapper(mappers, names, types)
235-
if col_type == 'map':
236-
key_mapper = self._create_value_mapper(column['arguments'][0]['value'])
237-
value_mapper = self._create_value_mapper(column['arguments'][1]['value'])
238-
return MapValueMapper(key_mapper, value_mapper)
239-
if col_type == 'decimal':
240-
return DecimalValueMapper()
241-
if col_type in {'double', 'real'}:
242-
return DoubleValueMapper()
243-
if col_type == 'timestamp with time zone':
244-
return TimestampWithTimeZoneValueMapper(self._get_precision(column))
245-
if col_type == 'timestamp':
246-
return TimestampValueMapper(self._get_precision(column))
247-
if col_type == 'time with time zone':
248-
return TimeWithTimeZoneValueMapper(self._get_precision(column))
249-
if col_type == 'time':
250-
return TimeValueMapper(self._get_precision(column))
251-
if col_type == 'date':
252-
return DateValueMapper()
253-
if col_type == 'varbinary':
254-
return BinaryValueMapper()
258+
259+
# others
255260
if col_type == 'uuid':
256261
return UuidValueMapper()
257262
return NoOpValueMapper()

0 commit comments

Comments
 (0)