20
20
import random
21
21
import threading
22
22
import time
23
- from typing import Optional , Mapping , Sequence , Union , Callable , Iterable
23
+ from typing import Callable , Iterable , Optional , Mapping , Sequence , TypedDict , Union
24
24
25
25
from absl import logging
26
26
55
55
ModelLoadersMap = dict [str , ModelLoader ]
56
56
57
57
58
+ # LINT.IfChange
59
+ class ComponentInfo (TypedDict ):
60
+ configSpec : types .Spec # pylint: disable=invalid-name # Named for JSON struct
61
+ metaSpec : types .Spec # pylint: disable=invalid-name # Named for JSON struct
62
+ description : str
63
+ # LINT.ThenChange(./client/lib/types.ts)
64
+
65
+
66
+ def _get_component_info (
67
+ obj : lit_components .Interpreter ,
68
+ ) -> ComponentInfo :
69
+ """Returns the ComponentInfo for an Interpreter, Generator, Metric, etc."""
70
+ return ComponentInfo (
71
+ configSpec = obj .config_spec (),
72
+ metaSpec = obj .meta_spec (),
73
+ description = obj .description (),
74
+ )
75
+
76
+
77
+ def _get_compatible_names (
78
+ candidates : Mapping [str , lit_components .Interpreter ],
79
+ model : lit_model .Model ,
80
+ dataset : lit_dataset .Dataset ,
81
+ ) -> Sequence [str ]:
82
+ """Returns the names of the candidates compatible with the model/dataset."""
83
+ return [
84
+ name
85
+ for name , candidate in candidates .items ()
86
+ if candidate .is_compatible (model = model , dataset = dataset )
87
+ ]
88
+
89
+
58
90
class LitApp (object ):
59
91
"""LIT WSGI application."""
60
92
@@ -80,22 +112,29 @@ def _build_metadata(self):
80
112
81
113
compat_gens : set [str ] = set ()
82
114
compat_interps : set [str ] = set ()
115
+ compat_metrics : set [str ] = set ()
83
116
84
117
for d in info ['datasets' ]:
85
118
dataset : lit_dataset .Dataset = self ._datasets [d ]
86
- compat_gens .update ([
87
- name for name , gen in self ._generators .items ()
88
- if gen .is_compatible (model = model , dataset = dataset )
89
- ])
90
- compat_interps .update ([
91
- name for name , interp in self ._interpreters .items ()
92
- if interp .is_compatible (model = model , dataset = dataset )
93
- ])
94
-
95
- info ['generators' ] = [name for name in self ._generators .keys ()
96
- if name in compat_gens ]
97
- info ['interpreters' ] = [name for name in self ._interpreters .keys ()
98
- if name in compat_interps ]
119
+ compat_gens .update (
120
+ _get_compatible_names (self ._generators , model , dataset )
121
+ )
122
+ compat_interps .update (
123
+ _get_compatible_names (self ._interpreters , model , dataset )
124
+ )
125
+ compat_metrics .update (
126
+ _get_compatible_names (self ._metrics , model , dataset )
127
+ )
128
+
129
+ info ['generators' ] = [
130
+ name for name in self ._generators .keys () if name in compat_gens
131
+ ]
132
+ info ['interpreters' ] = [
133
+ name for name in self ._interpreters .keys () if name in compat_interps
134
+ ]
135
+ info ['metrics' ] = [
136
+ name for name in self ._metrics .keys () if name in compat_metrics
137
+ ]
99
138
model_info [name ] = info
100
139
101
140
dataset_info = {}
@@ -106,21 +145,19 @@ def _build_metadata(self):
106
145
'size' : len (ds ),
107
146
}
108
147
109
- generator_info = {}
110
- for name , gen in self ._generators .items ():
111
- generator_info [name ] = {
112
- 'configSpec' : gen .config_spec (),
113
- 'metaSpec' : gen .meta_spec (),
114
- 'description' : gen .description ()
115
- }
148
+ generator_info : Mapping [str , ComponentInfo ] = {
149
+ name : _get_component_info (gen ) for name , gen in self ._generators .items ()
150
+ }
116
151
117
- interpreter_info = {}
118
- for name , interpreter in self ._interpreters .items ():
119
- interpreter_info [name ] = {
120
- 'configSpec' : interpreter .config_spec (),
121
- 'metaSpec' : interpreter .meta_spec (),
122
- 'description' : interpreter .description ()
123
- }
152
+ interpreter_info : Mapping [str , ComponentInfo ] = {
153
+ name : _get_component_info (interp )
154
+ for name , interp in self ._interpreters .items ()
155
+ }
156
+
157
+ metrics_info : Mapping [str , ComponentInfo ] = {
158
+ name : _get_component_info (metric )
159
+ for name , metric in self ._metrics .items ()
160
+ }
124
161
125
162
init_specs = {
126
163
'datasets' : {n : s for n , (_ , s ) in self ._dataset_loaders .items ()},
@@ -133,6 +170,7 @@ def _build_metadata(self):
133
170
'datasets' : dataset_info ,
134
171
'generators' : generator_info ,
135
172
'interpreters' : interpreter_info ,
173
+ 'metrics' : metrics_info ,
136
174
'layouts' : self ._layouts ,
137
175
# Global configuration
138
176
'demoMode' : self ._demo_mode ,
@@ -569,6 +607,7 @@ def __init__(
569
607
datasets : Mapping [str , lit_dataset .Dataset ],
570
608
generators : Optional [Mapping [str , lit_components .Generator ]] = None ,
571
609
interpreters : Optional [Mapping [str , lit_components .Interpreter ]] = None ,
610
+ metrics : Optional [Mapping [str , lit_components .Metrics ]] = None ,
572
611
annotators : Optional [list [lit_components .Annotator ]] = None ,
573
612
layouts : Optional [layout .LitComponentLayouts ] = None ,
574
613
dataset_loaders : Optional [DatasetLoadersMap ] = None ,
@@ -657,6 +696,11 @@ def __init__(
657
696
else :
658
697
self ._interpreters = core .default_interpreters (self ._models )
659
698
699
+ if metrics is not None :
700
+ self ._metrics = metrics
701
+ else :
702
+ self ._metrics = core .default_metrics ()
703
+
660
704
# Component to sync state from TS -> Python. Used in notebooks.
661
705
if sync_state :
662
706
self .ui_state_tracker = ui_state .UIStateTracker ()
0 commit comments