@@ -60,6 +60,7 @@ def _build_metadata(self):
60
60
info = {
61
61
'description' : model .description (),
62
62
'spec' : {
63
+ 'initSpec' : self ._model_init_specs [name ],
63
64
'input' : model .input_spec (),
64
65
'output' : model .output_spec (),
65
66
}
@@ -96,6 +97,7 @@ def _build_metadata(self):
96
97
dataset_info = {}
97
98
for name , ds in self ._datasets .items ():
98
99
dataset_info [name ] = {
100
+ 'init' : self ._dataset_init_specs [name ],
99
101
'spec' : ds .spec (),
100
102
'description' : ds .description (),
101
103
'size' : len (ds ),
@@ -285,7 +287,7 @@ def _get_dataset(self,
285
287
dataset_name : Optional [str ] = None ,
286
288
** unused_kw ) -> list [IndexedInput ]:
287
289
"""Attempt to get dataset, or override with a specific path."""
288
- return self ._datasets [dataset_name ].indexed_examples
290
+ return list ( self ._datasets [dataset_name ].indexed_examples )
289
291
290
292
def _create_dataset (self ,
291
293
unused_data ,
@@ -549,30 +551,38 @@ def __init__(
549
551
# client code to manually merge when this is the desired behavior.
550
552
self ._layouts = dict (layout .DEFAULT_LAYOUTS , ** (layouts or {}))
551
553
552
- # Wrap models in caching wrapper
553
- self ._models = {
554
- name : caching .CachingModelWrapper (model , name , cache_dir = data_dir )
555
- for name , model in models .items ()
556
- }
557
-
558
- self ._datasets : dict [str , lit_dataset .Dataset ] = dict (datasets )
559
- # TODO(b/202210900): get rid of this, just dynamically create the empty
560
- # dataset on the frontend.
561
- self ._datasets ['_union_empty' ] = lit_dataset .NoneDataset (self ._models )
562
-
563
- self ._annotators = annotators or []
564
-
554
+ self ._model_init_specs : dict [str , Optional [types .Spec ]] = {}
555
+ self ._models : dict [str , caching .CachingModelWrapper ] = {}
556
+ for name , model in models .items ():
557
+ # We need to extract and store the results of the original
558
+ # model.init_spec() here so that we don't lose access to those fields
559
+ # after LIT wraps the model in a CachingModelWrapper.
560
+ self ._model_init_specs [name ] = model .init_spec ()
561
+ # Wrap model in caching wrapper and add it to the app
562
+ self ._models [name ] = caching .CachingModelWrapper (model , name ,
563
+ cache_dir = data_dir )
564
+
565
+ self ._annotators : list [lit_components .Annotator ] = annotators or []
565
566
self ._saved_datapoints = {}
566
567
self ._saved_datapoints_lock = threading .Lock ()
567
568
568
- # Run annotation on each dataset, creating an annotated dataset and
569
- # replace the datasets with the annotated versions.
570
- for ds_key , ds in self ._datasets .items ():
571
- self ._datasets [ds_key ] = self ._run_annotators (ds )
572
-
573
- # Index all datasets
574
- self ._datasets = lit_dataset .IndexedDataset .index_all (
575
- self ._datasets , caching .input_hash )
569
+ tmp_datasets : dict [str , lit_dataset .Dataset ] = dict (datasets )
570
+ # TODO(b/202210900): get rid of this, just dynamically create the empty
571
+ # dataset on the frontend.
572
+ tmp_datasets ['_union_empty' ] = lit_dataset .NoneDataset (self ._models )
573
+
574
+ self ._dataset_init_specs : dict [str , Optional [types .Spec ]] = {}
575
+ self ._datasets : dict [str , lit_dataset .IndexedDataset ] = {}
576
+ for name , ds in tmp_datasets .items ():
577
+ # We need to extract and store the results of the original
578
+ # dataset.init_spec() here so that we don't lose access to those fields
579
+ # after LIT goes through the dataset annotation and indexing process.
580
+ self ._dataset_init_specs [name ] = ds .init_spec ()
581
+ # Anotate the dataset
582
+ annotated_ds = self ._run_annotators (ds )
583
+ # Index the annotated dataset and add it to the app
584
+ self ._datasets [name ] = lit_dataset .IndexedDataset (
585
+ base = annotated_ds , id_fn = caching .input_hash )
576
586
577
587
# Generator initialization
578
588
if generators is not None :
0 commit comments