diff --git a/tests/integration/test_types_integration.py b/tests/integration/test_types_integration.py index 9d97ead3..09611d30 100644 --- a/tests/integration/test_types_integration.py +++ b/tests/integration/test_types_integration.py @@ -794,14 +794,91 @@ def test_array(trino_connection): def test_map(trino_connection): + # primitive types + ( + SqlTest(trino_connection) + .add_field(sql="CAST(null AS MAP(VARCHAR, INTEGER))", python=None) + .add_field(sql="MAP()", python={}) + .add_field(sql="MAP(ARRAY[true, false], ARRAY[false, true])", python={True: False, False: True}) + .add_field(sql="MAP(ARRAY[true, false], ARRAY[true, null])", python={True: True, False: None}) + .add_field(sql="MAP(ARRAY[1, 2], ARRAY[1, null])", python={1: 1, 2: None}) + .add_field(sql="MAP(" + "ARRAY[CAST('NaN' AS REAL), CAST('-Infinity' AS REAL), CAST(3.4028235E38 AS REAL), CAST(1.4E-45 AS REAL), CAST('Infinity' AS REAL), CAST(1 AS REAL)], " # noqa: E501 + "ARRAY[CAST('NaN' AS REAL), CAST('-Infinity' AS REAL), CAST(3.4028235E38 AS REAL), CAST(1.4E-45 AS REAL), CAST('Infinity' AS REAL), null])", # noqa: E501 + python={math.nan: math.nan, + -math.inf: -math.inf, + 3.4028235e+38: 3.4028235e+38, + 1.4e-45: 1.4e-45, + math.inf: math.inf, + 1: None}, + has_nan=True) + .add_field(sql="MAP(" + "ARRAY[CAST('NaN' AS DOUBLE), CAST('-Infinity' AS DOUBLE), CAST(1.7976931348623157E308 AS DOUBLE), CAST(4.9E-324 AS DOUBLE), CAST('Infinity' AS DOUBLE), CAST(1 AS DOUBLE)], " # noqa: E501 + "ARRAY[CAST('NaN' AS DOUBLE), CAST('-Infinity' AS DOUBLE), CAST(1.7976931348623157E308 AS DOUBLE), CAST(4.9E-324 AS DOUBLE), CAST('Infinity' AS DOUBLE), null])", # noqa: E501 + python={math.nan: math.nan, + -math.inf: -math.inf, + 1.7976931348623157e+308: 1.7976931348623157e+308, + 5e-324: 5e-324, + math.inf: math.inf, + 1: None}, + has_nan=True) + .add_field(sql="MAP(ARRAY[CAST('NaN' AS DOUBLE)], ARRAY[CAST('NaN' AS DOUBLE)])", + python={math.nan: math.nan}, + has_nan=True) + .add_field(sql="MAP(ARRAY[1.2, 2.4, 4.8], ARRAY[1.2, 2.4, null])", + python={Decimal("1.2"): Decimal("1.2"), Decimal("2.4"): Decimal("2.4"), Decimal("4.8"): None}) + .add_field(sql="MAP(" + "ARRAY[CAST('hello' AS VARCHAR), CAST('null' AS VARCHAR)], " + "ARRAY[CAST('hello' AS VARCHAR), null])", + python={'hello': 'hello', 'null': None}) + .add_field(sql="MAP(ARRAY[CAST('a' AS CHAR(4)), CAST('null' AS CHAR(4))], ARRAY[CAST('a' AS CHAR), null])", + python={'a ': 'a', 'null': None}) + .add_field(sql="MAP(ARRAY[X'', X'65683F', X'00'], ARRAY[X'', X'65683F', null])", + python={b'': b'', b'eh?': b'eh?', b'\x00': None}) + .add_field(sql="MAP(ARRAY[JSON '1', JSON '{}', JSON 'null'], ARRAY[JSON '1', JSON '{}', null])", + python={'1': '1', '{}': '{}', 'null': None}) + ).execute() + + # temporal types + tz_india = create_timezone("+05:30") + tz_new_york = create_timezone("-05:00") + tz_los_angeles = create_timezone("America/Los_Angeles") + time_1 = time(1, 1, 1) + time_2 = time(23, 59, 59) + datetime_1 = datetime(1970, 1, 1, 1, 1, 1) + datetime_2 = datetime(2023, 1, 1, 23, 59, 59) SqlTest(trino_connection) \ - .add_field(sql="CAST(null AS MAP(VARCHAR, INTEGER))", python=None) \ - .add_field(sql="MAP(ARRAY['a', 'b'], ARRAY[1, null])", python={'a': 1, 'b': None}) \ - .add_field(sql="MAP(ARRAY['a', 'b'], ARRAY[2.4, null])", python={'a': Decimal("2.4"), 'b': None}) \ - .add_field(sql="MAP(ARRAY[2.4, 4.8], ARRAY[CAST(4.9E-324 AS DOUBLE), null])", - python={Decimal("2.4"): 5e-324, Decimal("4.8"): None}) \ + .add_field(sql="MAP(ARRAY[DATE '1970-01-01', DATE '2023-01-01'], ARRAY[DATE '1970-01-01', null])", + python={date(1970, 1, 1): date(1970, 1, 1), date(2023, 1, 1): None}) \ + .add_field(sql="MAP(ARRAY[TIME '01:01:01', TIME '23:59:59'], ARRAY[TIME '01:01:01', null])", + python={time_1: time_1, time_2: None}) \ + .add_field(sql="MAP(" + "ARRAY[TIME '01:01:01 +05:30', TIME '23:59:59 -05:00'], " + "ARRAY[TIME '01:01:01 +05:30', null])", + python={time_1.replace(tzinfo=tz_india): time_1.replace(tzinfo=tz_india), + time_2.replace(tzinfo=tz_new_york): None}) \ + .add_field(sql="MAP(" + "ARRAY[TIMESTAMP '1970-01-01 01:01:01', TIMESTAMP '2023-01-01 23:59:59'], " + "ARRAY[TIMESTAMP '1970-01-01 01:01:01', null])", + python={datetime_1: datetime_1, datetime_2: None}) \ + .add_field(sql="MAP(" + "ARRAY[TIMESTAMP '1970-01-01 01:01:01 +05:30', TIMESTAMP '2023-01-01 23:59:59 America/Los_Angeles'], " # noqa: E501 + "ARRAY[TIMESTAMP '1970-01-01 01:01:01 +05:30', null])", + python={datetime_1.replace(tzinfo=tz_india): datetime_1.replace(tzinfo=tz_india), + datetime_2.replace(tzinfo=tz_los_angeles): None}) \ .execute() + # structural types - note that none of these below tests work in the Trino JDBC Driver either. + # TODO: https://github.com/trinodb/trino-python-client/issues/442 + # Unhashable types like lists and dicts cannot be used as keys so these values cannot be represented as Python + # objects at all. + # .add_field(sql="MAP(ARRAY[ARRAY[1, 2]], ARRAY[null])", python={[1, 2]: None}) + # .add_field(sql="MAP(ARRAY[MAP(ARRAY[1], ARRAY[2])], ARRAY[null])", python={{1: 2}: None}) + + # TODO: fails because server sends [[{"[1, 2]":null}]] as response whereas it sends [[[1,2]]] as response for ROW + # types that are not enclosed in a MAP while the RowValueMapper expects values to be lists. + # .add_field(sql="MAP(ARRAY[ROW(1, 2)], ARRAY[CAST(null AS VARCHAR)])", python={(1, 2): None}) + def test_row(trino_connection): SqlTest(trino_connection) \ diff --git a/trino/mapper.py b/trino/mapper.py index 0794a646..f29dd666 100644 --- a/trino/mapper.py +++ b/trino/mapper.py @@ -31,6 +31,29 @@ def map(self, value: Any) -> Optional[T]: pass +class BooleanValueMapper(ValueMapper[bool]): + def map(self, value: Any) -> Optional[bool]: + if value is None: + return None + if isinstance(value, bool): + return value + if str(value).lower() == 'true': + return True + if str(value).lower() == 'false': + return False + raise ValueError(f"Server sent unexpected value {value} of type {type(value)} for boolean") + + +class IntegerValueMapper(ValueMapper[int]): + def map(self, value: Any) -> Optional[int]: + if value is None: + return None + if isinstance(value, int): + return value + # int(3.1) == 3 but server won't send such values for integer types + return int(value) + + class DoubleValueMapper(ValueMapper[float]): def map(self, value) -> Optional[float]: if value is None: @@ -51,6 +74,13 @@ def map(self, value) -> Optional[Decimal]: return Decimal(value) +class StringValueMapper(ValueMapper[str]): + def map(self, value: Any) -> Optional[str]: + if value is None: + return None + return str(value) + + class BinaryValueMapper(ValueMapper[bytes]): def map(self, value) -> Optional[bytes]: if value is None: @@ -221,12 +251,20 @@ def _create_value_mapper(self, column) -> ValueMapper: col_type = column['rawType'] # primitive types + if col_type == 'boolean': + return BooleanValueMapper() + if col_type in {'tinyint', 'smallint', 'integer', 'bigint'}: + return IntegerValueMapper() if col_type in {'double', 'real'}: return DoubleValueMapper() if col_type == 'decimal': return DecimalValueMapper() + if col_type in {'varchar', 'char'}: + return StringValueMapper() if col_type == 'varbinary': return BinaryValueMapper() + if col_type == 'json': + return StringValueMapper() if col_type == 'date': return DateValueMapper() if col_type == 'time':