Skip to content

Commit f74798a

Browse files
RyanMullinsLIT team
authored and
LIT team
committed
Adds init_spec() info to LitMetadata for Models and Datasets.
Removes superfluous log content from init_spec() impls. PiperOrigin-RevId: 502646185
1 parent c7fa619 commit f74798a

File tree

7 files changed

+65
-31
lines changed

7 files changed

+65
-31
lines changed

lit_nlp/api/dataset.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def description(self) -> str:
9898
"""
9999
return self._description or inspect.getdoc(self) or '' # pytype: disable=bad-return-type
100100

101-
def init_spec(self) -> Optional[Spec]:
101+
def init_spec(self) -> Optional[types.Spec]:
102102
"""Attempts to infer a Spec describing a Dataset's constructor parameters.
103103
104104
The Dataset base class attempts to infer a Spec for the constructor using
@@ -119,8 +119,8 @@ def init_spec(self) -> Optional[Spec]:
119119
spec = types.infer_spec_for_func(self.__init__)
120120
except TypeError as e:
121121
spec = None
122-
logging.warning("Unable to infer init spec for model '%s'. %s",
123-
self.__class__.__name__, str(e), exc_info=True)
122+
logging.warning("Unable to infer init spec for dataset '%s'. %s",
123+
self.__class__.__name__, str(e))
124124
return spec
125125

126126
def load(self, path: str):

lit_nlp/api/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def init_spec(self) -> Optional[Spec]:
115115
except TypeError as e:
116116
spec = None
117117
logging.warning("Unable to infer init spec for model '%s'. %s",
118-
self.__class__.__name__, str(e), exc_info=True)
118+
self.__class__.__name__, str(e))
119119
return spec
120120

121121
def is_compatible_with_dataset(self, dataset: lit_dataset.Dataset) -> bool:

lit_nlp/app.py

+32-22
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def _build_metadata(self):
6060
info = {
6161
'description': model.description(),
6262
'spec': {
63+
'initSpec': self._model_init_specs[name],
6364
'input': model.input_spec(),
6465
'output': model.output_spec(),
6566
}
@@ -96,6 +97,7 @@ def _build_metadata(self):
9697
dataset_info = {}
9798
for name, ds in self._datasets.items():
9899
dataset_info[name] = {
100+
'init': self._dataset_init_specs[name],
99101
'spec': ds.spec(),
100102
'description': ds.description(),
101103
'size': len(ds),
@@ -285,7 +287,7 @@ def _get_dataset(self,
285287
dataset_name: Optional[str] = None,
286288
**unused_kw) -> list[IndexedInput]:
287289
"""Attempt to get dataset, or override with a specific path."""
288-
return self._datasets[dataset_name].indexed_examples
290+
return list(self._datasets[dataset_name].indexed_examples)
289291

290292
def _create_dataset(self,
291293
unused_data,
@@ -549,30 +551,38 @@ def __init__(
549551
# client code to manually merge when this is the desired behavior.
550552
self._layouts = dict(layout.DEFAULT_LAYOUTS, **(layouts or {}))
551553

552-
# Wrap models in caching wrapper
553-
self._models = {
554-
name: caching.CachingModelWrapper(model, name, cache_dir=data_dir)
555-
for name, model in models.items()
556-
}
557-
558-
self._datasets: dict[str, lit_dataset.Dataset] = dict(datasets)
559-
# TODO(b/202210900): get rid of this, just dynamically create the empty
560-
# dataset on the frontend.
561-
self._datasets['_union_empty'] = lit_dataset.NoneDataset(self._models)
562-
563-
self._annotators = annotators or []
564-
554+
self._model_init_specs: dict[str, Optional[types.Spec]] = {}
555+
self._models: dict[str, caching.CachingModelWrapper] = {}
556+
for name, model in models.items():
557+
# We need to extract and store the results of the original
558+
# model.init_spec() here so that we don't lose access to those fields
559+
# after LIT wraps the model in a CachingModelWrapper.
560+
self._model_init_specs[name] = model.init_spec()
561+
# Wrap model in caching wrapper and add it to the app
562+
self._models[name] = caching.CachingModelWrapper(model, name,
563+
cache_dir=data_dir)
564+
565+
self._annotators: list[lit_components.Annotator] = annotators or []
565566
self._saved_datapoints = {}
566567
self._saved_datapoints_lock = threading.Lock()
567568

568-
# Run annotation on each dataset, creating an annotated dataset and
569-
# replace the datasets with the annotated versions.
570-
for ds_key, ds in self._datasets.items():
571-
self._datasets[ds_key] = self._run_annotators(ds)
572-
573-
# Index all datasets
574-
self._datasets = lit_dataset.IndexedDataset.index_all(
575-
self._datasets, caching.input_hash)
569+
tmp_datasets: dict[str, lit_dataset.Dataset] = dict(datasets)
570+
# TODO(b/202210900): get rid of this, just dynamically create the empty
571+
# dataset on the frontend.
572+
tmp_datasets['_union_empty'] = lit_dataset.NoneDataset(self._models)
573+
574+
self._dataset_init_specs: dict[str, Optional[types.Spec]] = {}
575+
self._datasets: dict[str, lit_dataset.IndexedDataset] = {}
576+
for name, ds in tmp_datasets.items():
577+
# We need to extract and store the results of the original
578+
# dataset.init_spec() here so that we don't lose access to those fields
579+
# after LIT goes through the dataset annotation and indexing process.
580+
self._dataset_init_specs[name] = ds.init_spec()
581+
# Anotate the dataset
582+
annotated_ds = self._run_annotators(ds)
583+
# Index the annotated dataset and add it to the app
584+
self._datasets[name] = lit_dataset.IndexedDataset(
585+
base=annotated_ds, id_fn=caching.input_hash)
576586

577587
# Generator initialization
578588
if generators is not None:

lit_nlp/client/lib/testing_utils.ts

+15-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import 'jasmine';
1919

20-
import {AttentionHeads, BooleanLitType, CategoryLabel, Embeddings, MulticlassPreds, Scalar, TextSegment, TokenGradients, Tokens} from './lit_types';
20+
import {AttentionHeads, BooleanLitType, CategoryLabel, Embeddings, MulticlassPreds, Scalar, StringLitType, TextSegment, TokenGradients, Tokens} from './lit_types';
2121
import {LitMetadata, SerializedLitMetadata} from './types';
2222
import {createLitType} from './utils';
2323

@@ -56,6 +56,7 @@ export const mockMetadata: LitMetadata = {
5656
'models': {
5757
'sst_0_micro': {
5858
'spec': {
59+
'init': {},
5960
'input': {
6061
'passage': createLitType(TextSegment),
6162
'passage_tokens':
@@ -89,6 +90,7 @@ export const mockMetadata: LitMetadata = {
8990
},
9091
'sst_1_micro': {
9192
'spec': {
93+
'init': {},
9294
'input': {
9395
'passage': createLitType(TextSegment),
9496
'passage_tokens':
@@ -123,13 +125,17 @@ export const mockMetadata: LitMetadata = {
123125
},
124126
'datasets': {
125127
'sst_dev': {
128+
'initSpec': {
129+
'split': createLitType(StringLitType),
130+
},
126131
'size': 872,
127132
'spec': {
128133
'passage': createLitType(TextSegment),
129134
'label': createLitType(CategoryLabel, {'vocab': ['0', '1']}),
130135
}
131136
},
132137
'color_test': {
138+
'initSpec': null,
133139
'size': 2,
134140
'spec': {
135141
'testNumFeat0': createLitType(Scalar),
@@ -139,6 +145,7 @@ export const mockMetadata: LitMetadata = {
139145
}
140146
},
141147
'penguin_dev': {
148+
'initSpec': {},
142149
'size': 10,
143150
'spec': {
144151
'body_mass_g': createLitType(Scalar, {
@@ -197,6 +204,7 @@ export const mockSerializedMetadata: SerializedLitMetadata = {
197204
'models': {
198205
'sst_0_micro': {
199206
'spec': {
207+
'init': {},
200208
'input': {
201209
'passage': {'__name__': 'TextSegment', 'required': true},
202210
'passage_tokens':
@@ -243,6 +251,7 @@ export const mockSerializedMetadata: SerializedLitMetadata = {
243251
},
244252
'sst_1_micro': {
245253
'spec': {
254+
'init': {},
246255
'input': {
247256
'passage': {'__name__': 'TextSegment', 'required': true},
248257
'passage_tokens':
@@ -290,6 +299,9 @@ export const mockSerializedMetadata: SerializedLitMetadata = {
290299
},
291300
'datasets': {
292301
'sst_dev': {
302+
'initSpec': {
303+
'split':{'__name__': 'StringLitType', 'required': true}
304+
},
293305
'size': 872,
294306
'spec': {
295307
'passage': {'__name__': 'TextSegment', 'required': true},
@@ -298,6 +310,7 @@ export const mockSerializedMetadata: SerializedLitMetadata = {
298310
}
299311
},
300312
'color_test': {
313+
'initSpec': null,
301314
'size': 2,
302315
'spec': {
303316
'testNumFeat0': {'__name__': 'Scalar', 'required': true},
@@ -315,6 +328,7 @@ export const mockSerializedMetadata: SerializedLitMetadata = {
315328
}
316329
},
317330
'penguin_dev': {
331+
'initSpec': {},
318332
'size': 10,
319333
'spec': {
320334
'body_mass_g': {'__name__': 'Scalar', 'step': 1, 'required': true},

lit_nlp/client/lib/types.ts

+2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ export interface ComponentInfo {
4646

4747
export interface DatasetInfo {
4848
size: number;
49+
initSpec: Spec | null; // using null here because None ==> null in Python
4950
spec: Spec;
5051
description?: string;
5152
}
@@ -64,6 +65,7 @@ export interface CallConfig {
6465
}
6566

6667
export interface ModelSpec {
68+
init: Spec | null; // using null here because None ==> null in Python
6769
input: Spec;
6870
output: Spec;
6971
}

lit_nlp/client/lib/utils.ts

+7-4
Original file line numberDiff line numberDiff line change
@@ -175,14 +175,17 @@ export function cloneSpec(spec: Spec): Spec {
175175
*/
176176
export function deserializeLitTypesInLitMetadata(
177177
metadata: SerializedLitMetadata): LitMetadata {
178+
178179
for (const model of Object.keys(metadata.models)) {
179-
metadata.models[model].spec.input =
180-
deserializeLitTypesInSpec(metadata.models[model].spec.input);
181-
metadata.models[model].spec.output =
182-
deserializeLitTypesInSpec(metadata.models[model].spec.output);
180+
const {spec} = metadata.models[model];
181+
spec.init = spec.init ? deserializeLitTypesInSpec(spec.init) : null;
182+
spec.input = deserializeLitTypesInSpec(spec.input);
183+
spec.output = deserializeLitTypesInSpec(spec.output);
183184
}
184185

185186
for (const dataset of Object.keys(metadata.datasets)) {
187+
metadata.datasets[dataset].initSpec = metadata.datasets[dataset].initSpec ?
188+
deserializeLitTypesInSpec(metadata.datasets[dataset].initSpec) : null;
186189
metadata.datasets[dataset].spec =
187190
deserializeLitTypesInSpec(metadata.datasets[dataset].spec);
188191
}

lit_nlp/client/services/classification_service_test.ts

+5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ MULTICLASS_PRED_WITH_THRESHOLD.null_idx = 0;
1313
MULTICLASS_PRED_WITH_THRESHOLD.vocab = ['0', '1'];
1414
MULTICLASS_PRED_WITH_THRESHOLD.threshold = 0.3;
1515
const MULTICLASS_SPEC_WITH_THRESHOLD: ModelSpec = {
16+
init: null,
1617
input: {},
1718
output: {[FIELD_NAME]: MULTICLASS_PRED_WITH_THRESHOLD}
1819
};
@@ -21,25 +22,29 @@ const MULTICLASS_PRED_WITHOUT_THRESHOLD = new MulticlassPreds();
2122
MULTICLASS_PRED_WITHOUT_THRESHOLD.null_idx = 0;
2223
MULTICLASS_PRED_WITHOUT_THRESHOLD.vocab = ['0', '1'];
2324
const MULTICLASS_SPEC_WITHOUT_THRESHOLD: ModelSpec = {
25+
init: null,
2426
input: {},
2527
output: {[FIELD_NAME]: MULTICLASS_PRED_WITHOUT_THRESHOLD}
2628
};
2729

2830
const MULTICLASS_PRED_NO_VOCAB = new MulticlassPreds();
2931
MULTICLASS_PRED_NO_VOCAB.null_idx = 0;
3032
const INVALID_SPEC_NO_VOCAB: ModelSpec = {
33+
init: null,
3134
input: {},
3235
output: {[FIELD_NAME]: MULTICLASS_PRED_NO_VOCAB}
3336
};
3437

3538
const MULTICLASS_PRED_NO_NULL_IDX = new MulticlassPreds();
3639
MULTICLASS_PRED_NO_NULL_IDX.vocab = ['0', '1'];
3740
const INVALID_SPEC_NO_NULL_IDX: ModelSpec = {
41+
init: null,
3842
input: {},
3943
output: {[FIELD_NAME]: MULTICLASS_PRED_NO_NULL_IDX}
4044
};
4145

4246
const INVALID_SPEC_NO_MULTICLASS_PRED: ModelSpec = {
47+
init: null,
4348
input: {},
4449
output: {}
4550
};

0 commit comments

Comments
 (0)