Skip to content

Commit f019279

Browse files
RyanMullinsLIT team
authored and
LIT team
committed
Promotes metrics info to top-level property of LitMetada.
PiperOrigin-RevId: 530374619
1 parent 0caf827 commit f019279

File tree

7 files changed

+110
-50
lines changed

7 files changed

+110
-50
lines changed

lit_nlp/app.py

+72-28
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import random
2121
import threading
2222
import time
23-
from typing import Optional, Mapping, Sequence, Union, Callable, Iterable
23+
from typing import Callable, Iterable, Optional, Mapping, Sequence, TypedDict, Union
2424

2525
from absl import logging
2626

@@ -55,6 +55,38 @@
5555
ModelLoadersMap = dict[str, ModelLoader]
5656

5757

58+
# LINT.IfChange
59+
class ComponentInfo(TypedDict):
60+
configSpec: types.Spec # pylint: disable=invalid-name # Named for JSON struct
61+
metaSpec: types.Spec # pylint: disable=invalid-name # Named for JSON struct
62+
description: str
63+
# LINT.ThenChange(./client/lib/types.ts)
64+
65+
66+
def _get_component_info(
67+
obj: lit_components.Interpreter,
68+
) -> ComponentInfo:
69+
"""Returns the ComponentInfo for an Interpreter, Generator, Metric, etc."""
70+
return ComponentInfo(
71+
configSpec=obj.config_spec(),
72+
metaSpec=obj.meta_spec(),
73+
description=obj.description(),
74+
)
75+
76+
77+
def _get_compatible_names(
78+
candidates: Mapping[str, lit_components.Interpreter],
79+
model: lit_model.Model,
80+
dataset: lit_dataset.Dataset,
81+
) -> Sequence[str]:
82+
"""Returns the names of the candidates compatible with the model/dataset."""
83+
return [
84+
name
85+
for name, candidate in candidates.items()
86+
if candidate.is_compatible(model=model, dataset=dataset)
87+
]
88+
89+
5890
class LitApp(object):
5991
"""LIT WSGI application."""
6092

@@ -80,22 +112,29 @@ def _build_metadata(self):
80112

81113
compat_gens: set[str] = set()
82114
compat_interps: set[str] = set()
115+
compat_metrics: set[str] = set()
83116

84117
for d in info['datasets']:
85118
dataset: lit_dataset.Dataset = self._datasets[d]
86-
compat_gens.update([
87-
name for name, gen in self._generators.items()
88-
if gen.is_compatible(model=model, dataset=dataset)
89-
])
90-
compat_interps.update([
91-
name for name, interp in self._interpreters.items()
92-
if interp.is_compatible(model=model, dataset=dataset)
93-
])
94-
95-
info['generators'] = [name for name in self._generators.keys()
96-
if name in compat_gens]
97-
info['interpreters'] = [name for name in self._interpreters.keys()
98-
if name in compat_interps]
119+
compat_gens.update(
120+
_get_compatible_names(self._generators, model, dataset)
121+
)
122+
compat_interps.update(
123+
_get_compatible_names(self._interpreters, model, dataset)
124+
)
125+
compat_metrics.update(
126+
_get_compatible_names(self._metrics, model, dataset)
127+
)
128+
129+
info['generators'] = [
130+
name for name in self._generators.keys() if name in compat_gens
131+
]
132+
info['interpreters'] = [
133+
name for name in self._interpreters.keys() if name in compat_interps
134+
]
135+
info['metrics'] = [
136+
name for name in self._metrics.keys() if name in compat_metrics
137+
]
99138
model_info[name] = info
100139

101140
dataset_info = {}
@@ -106,21 +145,19 @@ def _build_metadata(self):
106145
'size': len(ds),
107146
}
108147

109-
generator_info = {}
110-
for name, gen in self._generators.items():
111-
generator_info[name] = {
112-
'configSpec': gen.config_spec(),
113-
'metaSpec': gen.meta_spec(),
114-
'description': gen.description()
115-
}
148+
generator_info: Mapping[str, ComponentInfo] = {
149+
name: _get_component_info(gen) for name, gen in self._generators.items()
150+
}
116151

117-
interpreter_info = {}
118-
for name, interpreter in self._interpreters.items():
119-
interpreter_info[name] = {
120-
'configSpec': interpreter.config_spec(),
121-
'metaSpec': interpreter.meta_spec(),
122-
'description': interpreter.description()
123-
}
152+
interpreter_info: Mapping[str, ComponentInfo] = {
153+
name: _get_component_info(interp)
154+
for name, interp in self._interpreters.items()
155+
}
156+
157+
metrics_info: Mapping[str, ComponentInfo] = {
158+
name: _get_component_info(metric)
159+
for name, metric in self._metrics.items()
160+
}
124161

125162
init_specs = {
126163
'datasets': {n: s for n, (_, s) in self._dataset_loaders.items()},
@@ -133,6 +170,7 @@ def _build_metadata(self):
133170
'datasets': dataset_info,
134171
'generators': generator_info,
135172
'interpreters': interpreter_info,
173+
'metrics': metrics_info,
136174
'layouts': self._layouts,
137175
# Global configuration
138176
'demoMode': self._demo_mode,
@@ -569,6 +607,7 @@ def __init__(
569607
datasets: Mapping[str, lit_dataset.Dataset],
570608
generators: Optional[Mapping[str, lit_components.Generator]] = None,
571609
interpreters: Optional[Mapping[str, lit_components.Interpreter]] = None,
610+
metrics: Optional[Mapping[str, lit_components.Metrics]] = None,
572611
annotators: Optional[list[lit_components.Annotator]] = None,
573612
layouts: Optional[layout.LitComponentLayouts] = None,
574613
dataset_loaders: Optional[DatasetLoadersMap] = None,
@@ -657,6 +696,11 @@ def __init__(
657696
else:
658697
self._interpreters = core.default_interpreters(self._models)
659698

699+
if metrics is not None:
700+
self._metrics = metrics
701+
else:
702+
self._metrics = core.default_metrics()
703+
660704
# Component to sync state from TS -> Python. Used in notebooks.
661705
if sync_state:
662706
self.ui_state_tracker = ui_state.UIStateTracker()

lit_nlp/client/lib/testing_utils.ts

+10-4
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ export const mockMetadata: LitMetadata = {
8585
'generators':
8686
['word_replacer', 'scrambler', 'backtranslation', 'hotflip'],
8787
'interpreters':
88-
['grad_norm', 'grad_sum', 'lime', 'metrics', 'pca', 'umap']
88+
['grad_norm', 'grad_sum', 'lime', 'metrics', 'pca', 'umap'],
89+
'metrics': []
8990
},
9091
'sst_1_micro': {
9192
'spec': {
@@ -118,7 +119,8 @@ export const mockMetadata: LitMetadata = {
118119
'generators':
119120
['word_replacer', 'scrambler', 'backtranslation', 'hotflip'],
120121
'interpreters':
121-
['grad_norm', 'grad_sum', 'lime', 'metrics', 'pca', 'umap']
122+
['grad_norm', 'grad_sum', 'lime', 'metrics', 'pca', 'umap'],
123+
'metrics': []
122124
}
123125
},
124126
'datasets': {
@@ -182,6 +184,7 @@ export const mockMetadata: LitMetadata = {
182184
'pca': emptySpec(),
183185
'umap': emptySpec(),
184186
},
187+
'metrics': {},
185188
'initSpecs': {
186189
'datasets': {
187190
'sst_dev': {'split': createLitType(StringLitType)},
@@ -250,7 +253,8 @@ export const mockSerializedMetadata: SerializedLitMetadata = {
250253
'generators':
251254
['word_replacer', 'scrambler', 'backtranslation', 'hotflip'],
252255
'interpreters':
253-
['grad_norm', 'grad_sum', 'lime', 'metrics', 'pca', 'umap']
256+
['grad_norm', 'grad_sum', 'lime', 'metrics', 'pca', 'umap'],
257+
'metrics': []
254258
},
255259
'sst_1_micro': {
256260
'spec': {
@@ -296,7 +300,8 @@ export const mockSerializedMetadata: SerializedLitMetadata = {
296300
'generators':
297301
['word_replacer', 'scrambler', 'backtranslation', 'hotflip'],
298302
'interpreters':
299-
['grad_norm', 'grad_sum', 'lime', 'metrics', 'pca', 'umap']
303+
['grad_norm', 'grad_sum', 'lime', 'metrics', 'pca', 'umap'],
304+
'metrics': []
300305
}
301306
},
302307
'datasets': {
@@ -375,6 +380,7 @@ export const mockSerializedMetadata: SerializedLitMetadata = {
375380
'pca': emptySpec(),
376381
'umap': emptySpec(),
377382
},
383+
'metrics': {},
378384
'initSpecs': {
379385
'datasets': {
380386
'sst_dev': {

lit_nlp/client/lib/types.ts

+4
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,13 @@ interface InitSpecMap {
4242
[name: string]: Spec|null; // using null here because None ==> null in Python
4343
}
4444

45+
// LINT.IfChange
4546
export interface ComponentInfo {
4647
configSpec: Spec;
4748
metaSpec: Spec;
4849
description?: string;
4950
}
51+
// LINT.ThenChange(../../app.py)
5052

5153
export interface DatasetInfo {
5254
size: number;
@@ -76,6 +78,7 @@ export interface ModelInfo {
7678
datasets: string[];
7779
generators: string[];
7880
interpreters: string[];
81+
metrics: string[];
7982
spec: ModelSpec;
8083
description?: string;
8184
}
@@ -89,6 +92,7 @@ export interface LitMetadata {
8992
datasets: DatasetInfoMap;
9093
generators: ComponentInfoMap;
9194
interpreters: ComponentInfoMap;
95+
metrics: ComponentInfoMap;
9296
layouts: LitComponentLayouts;
9397
demoMode: boolean;
9498
defaultLayout: string;

lit_nlp/client/modules/metrics_module.ts

+1
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ export class MetricsModule extends LitModule {
299299
/** Convert the metricsMap information into table data for display. */
300300
@computed
301301
get tableData(): TableHeaderAndData {
302+
// TODO(b/254832560): Use this.appState.metadata.metrics here.
302303
const {metaSpec} = this.appState.metadata.interpreters['metrics'];
303304
if (metaSpec == null) return {'header': [], 'data': []};
304305

lit_nlp/client/services/classification_service_test.ts

+4-2
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ describe('classification service test', () => {
8585
spec,
8686
datasets: [],
8787
generators: [],
88-
interpreters: []
88+
interpreters: [],
89+
metrics: []
8990
}}
9091
};
9192

@@ -165,7 +166,8 @@ describe('classification service test', () => {
165166
spec,
166167
datasets: [],
167168
generators: [],
168-
interpreters: []
169+
interpreters: [],
170+
metrics: []
169171
}}
170172
};
171173

lit_nlp/components/core.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,8 @@
5151
_UMAP_AVAILABLE = False
5252
# pytype: enable=import-error # pylint: enable=g-import-not-at-top
5353

54-
ComponentGroup = lit_components.ComponentGroup
55-
Generator = lit_components.Generator
56-
Interpreter = lit_components.Interpreter
57-
Model = lit_model.Model
5854

59-
60-
def default_generators() -> dict[str, Generator]:
55+
def default_generators() -> dict[str, lit_components.Generator]:
6156
"""Returns a dict of the default generators used in a LitApp."""
6257
return {
6358
'Ablation Flip': ablation_flip.AblationFlip(),
@@ -67,21 +62,23 @@ def default_generators() -> dict[str, Generator]:
6762
}
6863

6964

70-
def required_interpreters() -> dict[str, Interpreter]:
65+
def required_interpreters() -> dict[str, lit_components.Interpreter]:
7166
"""Returns a dict of required interpreters.
7267
7368
These are used by multiple core modules, and without them the frontend will
7469
likely throw errors.
7570
"""
7671
# Ensure the prediction analysis interpreters are included.
77-
prediction_analysis_interpreters: dict[str, Interpreter] = {
72+
prediction_analysis_interpreters: dict[str, lit_components.Interpreter] = {
7873
'classification': classification_results.ClassificationInterpreter(),
7974
'regression': regression_results.RegressionInterpreter(),
8075
}
8176
return prediction_analysis_interpreters
8277

8378

84-
def default_interpreters(models: dict[str, Model]) -> dict[str, Interpreter]:
79+
def default_interpreters(
80+
models: dict[str, lit_model.Model]
81+
) -> dict[str, lit_components.Interpreter]:
8582
"""Returns a dict of the default interpreters (and metrics) used in a LitApp.
8683
8784
Args:
@@ -91,7 +88,7 @@ def default_interpreters(models: dict[str, Model]) -> dict[str, Interpreter]:
9188
interpreters = required_interpreters()
9289

9390
# Ensure the embedding-based interpreters are included.
94-
embedding_interpreters: dict[str, Interpreter] = {
91+
embedding_interpreters: dict[str, lit_components.Interpreter] = {
9592
'nearest neighbors': nearest_neighbors.NearestNeighbors(),
9693
# Embedding projectors expose a standard interface, but get special
9794
# handling so we can precompute the projections if requested.
@@ -103,21 +100,23 @@ def default_interpreters(models: dict[str, Model]) -> dict[str, Interpreter]:
103100
umap.UmapModel
104101
)
105102

106-
gradient_map_interpreters: dict[str, Interpreter] = {
103+
gradient_map_interpreters: dict[str, lit_components.Interpreter] = {
107104
'Grad L2 Norm': gradient_maps.GradientNorm(),
108105
'Grad ⋅ Input': gradient_maps.GradientDotInput(),
109106
'Integrated Gradients': gradient_maps.IntegratedGradients(),
110107
'LIME': lime_explainer.LIME(),
111108
}
112109

113110
# pyformat: disable
114-
core_interpreters: dict[str, Interpreter] = {
111+
core_interpreters: dict[str, lit_components.Interpreter] = {
115112
'Model-provided salience': model_salience.ModelSalience(models),
116113
'counterfactual explainer': lemon_explainer.LEMON(),
117114
'tcav': tcav.TCAV(),
118115
'curves': curves.CurvesInterpreter(),
119116
'thresholder': thresholder.Thresholder(),
120-
'metrics': default_metrics(),
117+
# TODO(b/254832560): Remove this "metrics" record from the core
118+
# interpreters once the front-end Metrics module has been updated.
119+
'metrics': lit_components.ComponentGroup(default_metrics()),
121120
'pdp': pdp.PdpInterpreter(),
122121
'Salience Clustering': salience_clustering.SalienceClustering(
123122
dict(gradient_map_interpreters)),
@@ -133,13 +132,15 @@ def default_interpreters(models: dict[str, Model]) -> dict[str, Interpreter]:
133132
return interpreters
134133

135134

136-
def default_metrics() -> ComponentGroup:
137-
return ComponentGroup({
135+
# TODO(b/254833485): Update typing to be a dict[str, lit_components.Metrics]
136+
# once the Wrapper classes in metrics.py inherit from lit_components.Metrics.
137+
def default_metrics() -> dict[str, lit_components.Interpreter]:
138+
return {
138139
'regression': metrics.RegressionMetrics(),
139140
'multiclass': metrics.MulticlassMetrics(),
140141
'multilabel': metrics.MultilabelMetrics(),
141142
'paired': metrics.MulticlassPairedMetrics(),
142143
'bleu': metrics.CorpusBLEU(),
143144
'rouge': metrics.RougeL(),
144145
'exactmatch': metrics.ExactMatchMetrics(),
145-
})
146+
}

lit_nlp/components/metrics.py

+2
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def run_with_metadata(self,
152152
return ret
153153

154154

155+
# TODO(b/254833485): Convert to inherit from lit_components.Metrics so that
156+
# promotion of Metrics to a top-level class more direct.
155157
class ClassificationMetricsWrapper(lit_components.Interpreter):
156158
"""Wrapper for classification metrics interpreters.
157159

0 commit comments

Comments
 (0)