|
3 | 3 | import logging
|
4 | 4 | import time
|
5 | 5 |
|
6 |
| -from typing import Iterable, List, Optional |
| 6 | +from typing import Iterable, List, Optional, Any |
7 | 7 | from contextlib import contextmanager
|
8 | 8 | import datetime as dt
|
9 | 9 | from datetime import datetime, timedelta
|
|
32 | 32 | add_allowed_ip,
|
33 | 33 | delete_allowed_ip,
|
34 | 34 | )
|
| 35 | +from exasol.saas.client.openapi.types import UNSET |
35 | 36 |
|
36 | 37 |
|
37 | 38 | LOG = logging.getLogger(__name__)
|
@@ -76,6 +77,75 @@ def create_saas_client(
|
76 | 77 | )
|
77 | 78 |
|
78 | 79 |
|
| 80 | +def _get_database_id( |
| 81 | + account_id: str, |
| 82 | + client: openapi.AuthenticatedClient, |
| 83 | + database_name: str, |
| 84 | +) -> str: |
| 85 | + """ |
| 86 | + Finds the database id, given an optional database name. If the name is not |
| 87 | + provided returns an id of any non-deleted database. The latter option may be |
| 88 | + useful for testing. |
| 89 | + """ |
| 90 | + dbs = list_databases.sync(account_id, client=client) |
| 91 | + dbs = list(filter(lambda db: (db.name == database_name) and # type: ignore |
| 92 | + (db.deleted_at is UNSET) and # type: ignore |
| 93 | + (db.deleted_by is UNSET), dbs)) # type: ignore |
| 94 | + if not dbs: |
| 95 | + raise RuntimeError(f'SaaS database {database_name} was not found.') |
| 96 | + return dbs[0].id |
| 97 | + |
| 98 | + |
| 99 | +def get_connection_params( |
| 100 | + host: str, |
| 101 | + account_id: str, |
| 102 | + pat: str, |
| 103 | + database_id: str | None = None, |
| 104 | + database_name: str | None = None, |
| 105 | +) -> dict[str, Any]: |
| 106 | + """ |
| 107 | + Gets the database connection parameters, such as those required by pyexasol: |
| 108 | + - dns |
| 109 | + - user |
| 110 | + - password. |
| 111 | + Returns the parameters in a dictionary that can be used as kwargs when |
| 112 | + creating a connection, like in the code below: |
| 113 | +
|
| 114 | + connection_params = get_connection_params(...) |
| 115 | + connection = pyexasol.connect(**connection_params) |
| 116 | +
|
| 117 | + Args: |
| 118 | + host: SaaS service URL. |
| 119 | + account_id: User account ID |
| 120 | + pat: Personal Access Token. |
| 121 | + database_id: Database ID, id known. |
| 122 | + database_name: Database name, in case the id is unknown. |
| 123 | + """ |
| 124 | + |
| 125 | + with create_saas_client(host, pat) as client: |
| 126 | + if not database_id: |
| 127 | + if not database_name: |
| 128 | + raise ValueError(('To get SaaS connection parameters, ' |
| 129 | + 'either database name or database id must be provided.')) |
| 130 | + database_id = _get_database_id(account_id, client, database_name=database_name) |
| 131 | + clusters = list_clusters.sync(account_id, |
| 132 | + database_id, |
| 133 | + client=client) |
| 134 | + cluster_id = next(filter(lambda cl: cl.main_cluster, clusters)).id # type: ignore |
| 135 | + connections = get_cluster_connection.sync(account_id, |
| 136 | + database_id, |
| 137 | + cluster_id, |
| 138 | + client=client) |
| 139 | + if connections is None: |
| 140 | + raise RuntimeError('Failed to get the SaaS connection data.') |
| 141 | + connection_params = { |
| 142 | + 'dsn': f'{connections.dns}:{connections.port}', |
| 143 | + 'user': connections.db_username, |
| 144 | + 'password': pat |
| 145 | + } |
| 146 | + return connection_params |
| 147 | + |
| 148 | + |
79 | 149 | class OpenApiAccess:
|
80 | 150 | """
|
81 | 151 | This class is meant to be used only in the context of the API
|
@@ -190,7 +260,6 @@ def poll_status():
|
190 | 260 | if poll_status() not in success:
|
191 | 261 | raise DatabaseStartupFailure()
|
192 | 262 |
|
193 |
| - |
194 | 263 | def clusters(
|
195 | 264 | self,
|
196 | 265 | database_id: str,
|
|
0 commit comments