Skip to content

Commit 9ea4ab2

Browse files
RyanMullinsLIT team
authored and
LIT team
committed
Adds components.core library w/ functions for default interpreters + generators
PiperOrigin-RevId: 469241495
1 parent ab057b5 commit 9ea4ab2

File tree

2 files changed

+112
-73
lines changed

2 files changed

+112
-73
lines changed

lit_nlp/app.py

+5-73
Original file line numberDiff line numberDiff line change
@@ -28,27 +28,7 @@
2828
from lit_nlp.api import layout
2929
from lit_nlp.api import model as lit_model
3030
from lit_nlp.api import types
31-
from lit_nlp.components import ablation_flip
32-
from lit_nlp.components import classification_results
33-
from lit_nlp.components import curves
34-
from lit_nlp.components import gradient_maps
35-
from lit_nlp.components import hotflip
36-
from lit_nlp.components import lemon_explainer
37-
from lit_nlp.components import lime_explainer
38-
from lit_nlp.components import metrics
39-
from lit_nlp.components import model_salience
40-
from lit_nlp.components import nearest_neighbors
41-
from lit_nlp.components import pca
42-
from lit_nlp.components import pdp
43-
from lit_nlp.components import projection
44-
from lit_nlp.components import regression_results
45-
from lit_nlp.components import salience_clustering
46-
from lit_nlp.components import scrambler
47-
from lit_nlp.components import shap_explainer
48-
from lit_nlp.components import tcav
49-
from lit_nlp.components import thresholder
50-
from lit_nlp.components import umap
51-
from lit_nlp.components import word_replacer
31+
from lit_nlp.components import core
5232
from lit_nlp.lib import caching
5333
from lit_nlp.lib import serialize
5434
from lit_nlp.lib import ui_state
@@ -492,65 +472,17 @@ def __init__(
492472
self._datasets = lit_dataset.IndexedDataset.index_all(
493473
self._datasets, caching.input_hash)
494474

475+
# Generator initialization
495476
if generators is not None:
496477
self._generators = generators
497478
else:
498-
self._generators = {
499-
'Ablation Flip': ablation_flip.AblationFlip(),
500-
'Hotflip': hotflip.HotFlip(),
501-
'Scrambler': scrambler.Scrambler(),
502-
'Word Replacer': word_replacer.WordReplacer(),
503-
}
479+
self._generators = core.default_generators()
504480

481+
# Interpreter initialization
505482
if interpreters is not None:
506483
self._interpreters = interpreters
507-
508484
else:
509-
metrics_group = lit_components.ComponentGroup({
510-
'regression': metrics.RegressionMetrics(),
511-
'multiclass': metrics.MulticlassMetrics(),
512-
'paired': metrics.MulticlassPairedMetrics(),
513-
'bleu': metrics.CorpusBLEU(),
514-
'rouge': metrics.RougeL(),
515-
})
516-
gradient_map_interpreters = {
517-
'Grad L2 Norm': gradient_maps.GradientNorm(),
518-
'Grad ⋅ Input': gradient_maps.GradientDotInput(),
519-
'Integrated Gradients': gradient_maps.IntegratedGradients(),
520-
'LIME': lime_explainer.LIME(),
521-
}
522-
# pyformat: disable
523-
self._interpreters: dict[str, lit_components.Interpreter] = {
524-
'Model-provided salience': model_salience.ModelSalience(self._models),
525-
'counterfactual explainer': lemon_explainer.LEMON(),
526-
'tcav': tcav.TCAV(),
527-
'curves': curves.CurvesInterpreter(),
528-
'thresholder': thresholder.Thresholder(),
529-
'metrics': metrics_group,
530-
'pdp': pdp.PdpInterpreter(),
531-
'Salience Clustering': salience_clustering.SalienceClustering(
532-
gradient_map_interpreters),
533-
'Tabular SHAP': shap_explainer.TabularShapExplainer(),
534-
}
535-
# pyformat: enable
536-
self._interpreters.update(gradient_map_interpreters)
537-
538-
# Ensure the prediction analysis interpreters are included.
539-
prediction_analysis_interpreters = {
540-
'classification': classification_results.ClassificationInterpreter(),
541-
'regression': regression_results.RegressionInterpreter(),
542-
}
543-
# Ensure the embedding-based interpreters are included.
544-
embedding_based_interpreters = {
545-
'nearest neighbors': nearest_neighbors.NearestNeighbors(),
546-
# Embedding projectors expose a standard interface, but get special
547-
# handling so we can precompute the projections if requested.
548-
'pca': projection.ProjectionManager(pca.PCAModel),
549-
'umap': projection.ProjectionManager(umap.UmapModel),
550-
}
551-
self._interpreters = dict(**self._interpreters,
552-
**prediction_analysis_interpreters,
553-
**embedding_based_interpreters)
485+
self._interpreters = core.default_interpreters(self._models)
554486

555487
# Component to sync state from TS -> Python. Used in notebooks.
556488
if sync_state:

lit_nlp/components/core.py

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Helpers for getting default values for LitApp configurations."""
16+
from typing import Union
17+
from lit_nlp.api import components as lit_components
18+
from lit_nlp.api import model as lit_model
19+
from lit_nlp.components import ablation_flip
20+
from lit_nlp.components import classification_results
21+
from lit_nlp.components import curves
22+
from lit_nlp.components import gradient_maps
23+
from lit_nlp.components import hotflip
24+
from lit_nlp.components import lemon_explainer
25+
from lit_nlp.components import lime_explainer
26+
from lit_nlp.components import metrics
27+
from lit_nlp.components import model_salience
28+
from lit_nlp.components import nearest_neighbors
29+
from lit_nlp.components import pca
30+
from lit_nlp.components import pdp
31+
from lit_nlp.components import projection
32+
from lit_nlp.components import regression_results
33+
from lit_nlp.components import salience_clustering
34+
from lit_nlp.components import scrambler
35+
from lit_nlp.components import shap_explainer
36+
from lit_nlp.components import tcav
37+
from lit_nlp.components import thresholder
38+
from lit_nlp.components import umap
39+
from lit_nlp.components import word_replacer
40+
41+
ComponentGroup = lit_components.ComponentGroup
42+
Generator = lit_components.Generator
43+
Interpreter = lit_components.Interpreter
44+
Model = lit_model.Model
45+
46+
47+
def default_generators() -> dict[str, Generator]:
48+
"""Returns a dict of the default generators used in a LitApp."""
49+
return {
50+
'Ablation Flip': ablation_flip.AblationFlip(),
51+
'Hotflip': hotflip.HotFlip(),
52+
'Scrambler': scrambler.Scrambler(),
53+
'Word Replacer': word_replacer.WordReplacer(),
54+
}
55+
56+
57+
def default_interpreters(models: dict[str, Model]) -> dict[str, Interpreter]:
58+
"""Returns a dict of the default interpreters (and metrics) used in a LitApp.
59+
60+
Args:
61+
models: A dictionary of models that included in the LitApp that may provide
62+
thier own salience information.
63+
"""
64+
# Ensure the embedding-based interpreters are included.
65+
embedding_based_interpreters: dict[str, Interpreter] = {
66+
'nearest neighbors': nearest_neighbors.NearestNeighbors(),
67+
# Embedding projectors expose a standard interface, but get special
68+
# handling so we can precompute the projections if requested.
69+
'pca': projection.ProjectionManager(pca.PCAModel),
70+
'umap': projection.ProjectionManager(umap.UmapModel),
71+
}
72+
gradient_map_interpreters: dict[str, Interpreter] = {
73+
'Grad L2 Norm': gradient_maps.GradientNorm(),
74+
'Grad ⋅ Input': gradient_maps.GradientDotInput(),
75+
'Integrated Gradients': gradient_maps.IntegratedGradients(),
76+
'LIME': lime_explainer.LIME(),
77+
}
78+
metrics_group: ComponentGroup = ComponentGroup({
79+
'regression': metrics.RegressionMetrics(),
80+
'multiclass': metrics.MulticlassMetrics(),
81+
'paired': metrics.MulticlassPairedMetrics(),
82+
'bleu': metrics.CorpusBLEU(),
83+
'rouge': metrics.RougeL(),
84+
})
85+
# Ensure the prediction analysis interpreters are included.
86+
prediction_analysis_interpreters: dict[str, Interpreter] = {
87+
'classification': classification_results.ClassificationInterpreter(),
88+
'regression': regression_results.RegressionInterpreter(),
89+
}
90+
# pyformat: disable
91+
interpreters: dict[str, Union[ComponentGroup, Interpreter]] = {
92+
'Model-provided salience': model_salience.ModelSalience(models),
93+
'counterfactual explainer': lemon_explainer.LEMON(),
94+
'tcav': tcav.TCAV(),
95+
'curves': curves.CurvesInterpreter(),
96+
'thresholder': thresholder.Thresholder(),
97+
'metrics': metrics_group,
98+
'pdp': pdp.PdpInterpreter(),
99+
'Salience Clustering': salience_clustering.SalienceClustering(
100+
gradient_map_interpreters),
101+
'Tabular SHAP': shap_explainer.TabularShapExplainer(),
102+
}
103+
# pyformat: enable
104+
interpreters.update(**gradient_map_interpreters,
105+
**prediction_analysis_interpreters,
106+
**embedding_based_interpreters)
107+
return interpreters

0 commit comments

Comments
 (0)