|
17 | 17 | import functools
|
18 | 18 | import glob
|
19 | 19 | import math
|
20 |
| -import os |
21 | 20 | import random
|
22 | 21 | import threading
|
23 | 22 | import time
|
|
56 | 55 | ModelLoadersMap = dict[str, ModelLoader]
|
57 | 56 |
|
58 | 57 |
|
| 58 | +# TODO(b/277249726): Move this function to utils.py, add tests, and expand usage |
| 59 | +# across HTTP and Interpreter APIs. |
| 60 | +def _validate_config_against_spec(config: JsonDict, spec: types.Spec): |
| 61 | + """Validates that the provided config is compatible with the Spec. |
| 62 | +
|
| 63 | + Args: |
| 64 | + config: The configuration parameters, typically extracted from the data of |
| 65 | + an HTTP Request, that are to be used in a function call. |
| 66 | + spec: A Spec defining the shape of allowed configuration parameters for the |
| 67 | + associated LIT component. |
| 68 | +
|
| 69 | + Raises: |
| 70 | + KeyError: Under two conditions: 1) the `config` is missing one or more |
| 71 | + required fields defined in the `spec`, or 2) the `config` contains fields |
| 72 | + not defined in the `spec`. Either of these conditions would likely result |
| 73 | + in a TypeError (for missing or unexpected arguments) if the `config` was |
| 74 | + used in a call. |
| 75 | + """ |
| 76 | + missing_required_keys = [ |
| 77 | + param_name for param_name, param_type in spec.items() |
| 78 | + if param_type.required and param_name not in config |
| 79 | + ] |
| 80 | + if missing_required_keys: |
| 81 | + raise KeyError(f'Missing required parameters: {missing_required_keys}') |
| 82 | + |
| 83 | + unsupported_keys = [ |
| 84 | + param_name for param_name in config |
| 85 | + if param_name not in spec |
| 86 | + ] |
| 87 | + if unsupported_keys: |
| 88 | + raise KeyError(f'Received unsupported parameters: {unsupported_keys}') |
| 89 | + |
| 90 | + |
59 | 91 | class LitApp(object):
|
60 | 92 | """LIT WSGI application."""
|
61 | 93 |
|
@@ -300,44 +332,68 @@ def _get_dataset(self,
|
300 | 332 | return list(self._datasets[dataset_name].indexed_examples)
|
301 | 333 |
|
302 | 334 | def _create_dataset(self,
|
303 |
| - unused_data, |
| 335 | + data: JsonDict, |
304 | 336 | dataset_name: Optional[str] = None,
|
305 |
| - dataset_path: Optional[str] = None, |
306 | 337 | **unused_kw):
|
307 |
| - """Create dataset from a path, updating and returning the metadata.""" |
| 338 | + """Create a dataset, updating and returning the metadata.""" |
| 339 | + if dataset_name is None: |
| 340 | + raise ValueError('No base dataset specified.') |
308 | 341 |
|
309 |
| - assert dataset_name is not None, 'No dataset specified.' |
310 |
| - assert dataset_path is not None, 'No dataset path specified.' |
311 |
| - |
312 |
| - new_dataset = self._datasets[dataset_name].load(dataset_path) |
313 |
| - if new_dataset is not None: |
314 |
| - new_dataset_name = dataset_name + '-' + os.path.basename(dataset_path) |
315 |
| - self._datasets[new_dataset_name] = new_dataset |
316 |
| - self._info = self._build_metadata() |
317 |
| - return (self._info, new_dataset_name) |
318 |
| - else: |
319 |
| - logging.error('Not able to load: %s', dataset_name) |
320 |
| - return None |
| 342 | + config: Optional[JsonDict] = data.get('config') |
| 343 | + if config is None: |
| 344 | + raise ValueError('No config specified.') |
| 345 | + |
| 346 | + new_name: Optional[str] = config.pop('new_name', None) |
| 347 | + if new_name is None: |
| 348 | + raise ValueError('No name provided for the new dataset.') |
| 349 | + |
| 350 | + if (loader_info := self._dataset_loaders.get(dataset_name)) is None: |
| 351 | + raise ValueError( |
| 352 | + f'No loader information (Cls + init_spec) found for {dataset_name}') |
| 353 | + |
| 354 | + dataset_cls, dataset_init_spec = loader_info |
| 355 | + |
| 356 | + if dataset_init_spec is not None: |
| 357 | + _validate_config_against_spec(config, dataset_init_spec) |
| 358 | + |
| 359 | + new_dataset = dataset_cls(**config) |
| 360 | + annotated_dataset = self._run_annotators(new_dataset) |
| 361 | + self._datasets[new_name] = lit_dataset.IndexedDataset( |
| 362 | + base=annotated_dataset, id_fn=caching.input_hash |
| 363 | + ) |
| 364 | + self._info = self._build_metadata() |
| 365 | + return (self._info, new_name) |
321 | 366 |
|
322 | 367 | def _create_model(self,
|
323 |
| - unused_data, |
| 368 | + data: JsonDict, |
324 | 369 | model_name: Optional[str] = None,
|
325 |
| - model_path: Optional[str] = None, |
326 | 370 | **unused_kw):
|
327 |
| - """Create model from a path, updating and returning the metadata.""" |
328 |
| - |
329 |
| - assert model_name is not None, 'No model specified.' |
330 |
| - assert model_path is not None, 'No model path specified.' |
331 |
| - # Load using the underlying model class, then wrap explicitly in a cache. |
332 |
| - new_model = self._models[model_name].wrapped.load(model_path) |
333 |
| - if new_model is not None: |
334 |
| - new_model_name = model_name + ':' + os.path.basename(model_path) |
335 |
| - self._models[new_model_name] = caching.CachingModelWrapper( |
336 |
| - new_model, new_model_name, cache_dir=self._data_dir) |
337 |
| - self._info = self._build_metadata() |
338 |
| - return (self._info, new_model_name) |
339 |
| - else: |
340 |
| - return None |
| 371 | + """Create a model, updating and returning the metadata.""" |
| 372 | + if model_name is None: |
| 373 | + raise ValueError('No base model specified.') |
| 374 | + |
| 375 | + config: Optional[JsonDict] = data.get('config') |
| 376 | + if config is None: |
| 377 | + raise ValueError('No config specified.') |
| 378 | + |
| 379 | + new_name: Optional[str] = config.pop('new_name', None) |
| 380 | + if new_name is None: |
| 381 | + raise ValueError('No name provided for the new model.') |
| 382 | + |
| 383 | + if (loader_info := self._model_loaders.get(model_name)) is None: |
| 384 | + raise ValueError( |
| 385 | + f'No loader information (Cls + init_spec) found for {model_name}') |
| 386 | + |
| 387 | + model_cls, model_init_spec = loader_info |
| 388 | + |
| 389 | + if model_init_spec is not None: |
| 390 | + _validate_config_against_spec(config, model_init_spec) |
| 391 | + |
| 392 | + new_model = model_cls(**config) |
| 393 | + self._models[new_name] = caching.CachingModelWrapper( |
| 394 | + new_model, new_name, cache_dir=self._data_dir) |
| 395 | + self._info = self._build_metadata() |
| 396 | + return (self._info, new_name) |
341 | 397 |
|
342 | 398 | def _get_generated(self, data, model: str, dataset_name: str, generator: str,
|
343 | 399 | **unused_kw):
|
@@ -505,12 +561,19 @@ def _handler(app: wsgi_app.App, request, environ):
|
505 | 561 | # but for requests from Python we may want to use the invertible encoding
|
506 | 562 | # so that datatypes from remote models are the same as local ones.
|
507 | 563 | response_simple_json = utils.coerce_bool(
|
508 |
| - kw.pop('response_simple_json', True)) |
| 564 | + kw.pop('response_simple_json', True) |
| 565 | + ) |
509 | 566 | data = serialize.from_json(request.data) if len(request.data) else None
|
510 | 567 | # Special handling to dereference IDs.
|
511 |
| - if data and 'inputs' in data.keys() and 'dataset_name' in kw: |
512 |
| - data['inputs'] = self._reconstitute_inputs(data['inputs'], |
513 |
| - kw['dataset_name']) |
| 568 | + if ( |
| 569 | + data |
| 570 | + and 'inputs' in data.keys() |
| 571 | + and len(data.get('inputs')) |
| 572 | + and 'dataset_name' in kw |
| 573 | + ): |
| 574 | + data['inputs'] = self._reconstitute_inputs( |
| 575 | + data['inputs'], kw['dataset_name'] |
| 576 | + ) |
514 | 577 |
|
515 | 578 | outputs = fn(data, **kw)
|
516 | 579 | response_body = serialize.to_json(outputs, simple=response_simple_json)
|
|
0 commit comments