diff --git a/README.md b/README.md index f3ea42c3..299694b3 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,8 @@ $ pip install trino Use the DBAPI interface to query Trino: +if `host` is a valid url, the port and http schema will be automatically determined. For example `https://my-trino-server:9999` will assign the `http_schema` property to `https` and port to `9999`. + ```python from trino.dbapi import connect diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py index 7065dd4b..b56466a2 100644 --- a/tests/unit/test_dbapi.py +++ b/tests/unit/test_dbapi.py @@ -29,7 +29,7 @@ ) from trino import constants from trino.auth import OAuth2Authentication -from trino.dbapi import connect +from trino.dbapi import Connection, connect @patch("trino.dbapi.trino.client") @@ -272,3 +272,40 @@ def test_role_is_set_when_specified(mock_client): _, passed_role = mock_client.ClientSession.call_args assert passed_role["roles"] == roles + + +def test_hostname_parsing(): + https_server_with_port = Connection("https://mytrinoserver.domain:9999") + assert https_server_with_port.host == "mytrinoserver.domain" + assert https_server_with_port.port == 9999 + assert https_server_with_port.http_scheme == constants.HTTPS + + https_server_without_port = Connection("https://mytrinoserver.domain") + assert https_server_without_port.host == "mytrinoserver.domain" + assert https_server_without_port.port == 8080 + assert https_server_without_port.http_scheme == constants.HTTPS + + http_server_with_port = Connection("http://mytrinoserver.domain:9999") + assert http_server_with_port.host == "mytrinoserver.domain" + assert http_server_with_port.port == 9999 + assert http_server_with_port.http_scheme == constants.HTTP + + http_server_without_port = Connection("http://mytrinoserver.domain") + assert http_server_without_port.host == "mytrinoserver.domain" + assert http_server_without_port.port == 8080 + assert http_server_without_port.http_scheme == constants.HTTP + + http_server_with_path = Connection("http://mytrinoserver.domain/some_path") + assert http_server_with_path.host == "mytrinoserver.domain/some_path" + assert http_server_with_path.port == 8080 + assert http_server_with_path.http_scheme == constants.HTTP + + only_hostname = Connection("mytrinoserver.domain") + assert only_hostname.host == "mytrinoserver.domain" + assert only_hostname.port == 8080 + assert only_hostname.http_scheme == constants.HTTP + + only_hostname_with_path = Connection("mytrinoserver.domain/some_path") + assert only_hostname_with_path.host == "mytrinoserver.domain/some_path" + assert only_hostname_with_path.port == 8080 + assert only_hostname_with_path.http_scheme == constants.HTTP diff --git a/trino/dbapi.py b/trino/dbapi.py index ae28d088..0147017f 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -23,6 +23,7 @@ import uuid from decimal import Decimal from typing import Any, Dict, List, NamedTuple, Optional # NOQA for mypy types +from urllib.parse import urlparse import trino.client import trino.exceptions @@ -92,7 +93,7 @@ class Connection(object): def __init__( self, - host, + host: str, port=constants.DEFAULT_PORT, user=None, source=constants.DEFAULT_SOURCE, @@ -114,8 +115,11 @@ def __init__( roles=None, timezone=None, ): - self.host = host - self.port = port + # Automatically assign http_schema, port based on hostname + parsed_host = urlparse(host, allow_fragments=False) + + self.host = host if parsed_host.hostname is None else parsed_host.hostname + parsed_host.path + self.port = port if parsed_host.port is None else parsed_host.port self.user = user self.source = source self.catalog = catalog @@ -141,7 +145,7 @@ def __init__( else: self._http_session = http_session self.http_headers = http_headers - self.http_scheme = http_scheme + self.http_scheme = http_scheme if not parsed_host.scheme else parsed_host.scheme self.auth = auth self.extra_credential = extra_credential self.redirect_handler = redirect_handler