Skip to content

Commit bc6f82b

Browse files
nadah09LIT team
authored and
LIT team
committed
Removes predict_with_metadata() overrides in Model subclasses.
This change leaves lit_nlp.api.model.Model as the only implementer of predict_with_metadata(), which sets up the removal of this method in the future. PiperOrigin-RevId: 551569578
1 parent ad65fd9 commit bc6f82b

File tree

4 files changed

+8
-34
lines changed

4 files changed

+8
-34
lines changed

lit_nlp/api/model.py

-11
Original file line numberDiff line numberDiff line change
@@ -285,17 +285,6 @@ def predict(
285285
) -> Iterable[JsonDict]:
286286
return self.wrapped.predict(inputs, *args, **kw)
287287

288-
# NOTE: if a subclass modifies predict(), it should also override this to
289-
# call the custom predict() method - otherwise this will delegate to the
290-
# wrapped class and call /that class's/ predict() method, likely leading to
291-
# incorrect results.
292-
# b/171513556 will solve this problem by removing the need for any
293-
# *_with_metadata() methods.
294-
def predict_with_metadata(
295-
self, indexed_inputs: Iterable[JsonDict], **kw
296-
) -> Iterable[JsonDict]:
297-
return self.wrapped.predict_with_metadata(indexed_inputs, **kw)
298-
299288
def load(self, path: str):
300289
"""Load a new model and wrap it with this class."""
301290
new_model = self.wrapped.load(path)

lit_nlp/examples/models/t5.py

-8
Original file line numberDiff line numberDiff line change
@@ -436,10 +436,6 @@ def predict(self, inputs):
436436
outputs = self.wrapped.predict(model_inputs)
437437
return (utils.remap_dict(mo, self.FIELD_RENAMES) for mo in outputs)
438438

439-
def predict_with_metadata(self, indexed_inputs):
440-
"""As predict(), but inputs are IndexedInput."""
441-
return self.predict((ex["data"] for ex in indexed_inputs))
442-
443439
def input_spec(self):
444440
spec = lit_types.remap_spec(self.wrapped.input_spec(), self.FIELD_RENAMES)
445441
spec["source_language"] = lit_types.CategoryLabel()
@@ -505,10 +501,6 @@ def predict(self, inputs):
505501
mo["rougeL"] = float(score["rougeL"].fmeasure)
506502
yield mo
507503

508-
def predict_with_metadata(self, indexed_inputs):
509-
"""As predict(), but inputs are IndexedInput."""
510-
return self.predict((ex["data"] for ex in indexed_inputs))
511-
512504
def input_spec(self):
513505
return lit_types.remap_spec(self.wrapped.input_spec(), self.FIELD_RENAMES)
514506

lit_nlp/lib/caching.py

-7
Original file line numberDiff line numberDiff line change
@@ -285,13 +285,6 @@ def predict(self,
285285

286286
return cached_results
287287

288-
# TODO(b/171513556): remove this method once we no longer need to override
289-
# ModelWrapper.predict_with_metadata()
290-
def predict_with_metadata(self, indexed_inputs: Iterable[JsonDict], **kw):
291-
"""As predict(), but inputs are IndexedInput."""
292-
results = self.predict((ex["data"] for ex in indexed_inputs), **kw)
293-
return results
294-
295288
def _get_results_from_cache(self, input_keys: list[CacheKey]):
296289
with self._cache.lock:
297290
return [self._cache.get(input_key) for input_key in input_keys]

lit_nlp/lib/caching_test.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,21 @@ def test_caching_model_wrapper_no_dataset_skip_cache(self):
3737
model = testing_utils.IdentityRegressionModelForTesting()
3838
wrapper = caching.CachingModelWrapper(model, "test")
3939
examples = [{"data": {"val": 1}, "id": "my_id"}]
40-
results = wrapper.predict_with_metadata(examples)
40+
results = list(wrapper.predict_with_metadata(examples))
4141
self.assertEqual(1, model.count)
4242
self.assertEqual({"score": 1}, results[0])
43-
results = wrapper.predict_with_metadata(examples)
43+
results = list(wrapper.predict_with_metadata(examples))
4444
self.assertEqual(2, model.count)
4545
self.assertEqual({"score": 1}, results[0])
4646

4747
def test_caching_model_wrapper_use_cache(self):
4848
model = testing_utils.IdentityRegressionModelForTesting()
4949
wrapper = caching.CachingModelWrapper(model, "test")
5050
examples = [{"data": {"val": 1, "_id": "id_to_cache"}, "id": "id_to_cache"}]
51-
results = wrapper.predict_with_metadata(examples)
51+
results = list(wrapper.predict_with_metadata(examples))
5252
self.assertEqual(1, model.count)
5353
self.assertEqual({"score": 1}, results[0])
54-
results = wrapper.predict_with_metadata(examples)
54+
results = list(wrapper.predict_with_metadata(examples))
5555
self.assertEqual(1, model.count)
5656
self.assertEqual({"score": 1}, results[0])
5757
self.assertEmpty(wrapper._cache._pred_locks)
@@ -60,11 +60,11 @@ def test_caching_model_wrapper_not_cached(self):
6060
model = testing_utils.IdentityRegressionModelForTesting()
6161
wrapper = caching.CachingModelWrapper(model, "test")
6262
examples = [{"data": {"val": 1}, "id": "my_id"}]
63-
results = wrapper.predict_with_metadata(examples)
63+
results = list(wrapper.predict_with_metadata(examples))
6464
self.assertEqual(1, model.count)
6565
self.assertEqual({"score": 1}, results[0])
6666
examples = [{"data": {"val": 2}, "id": "other_id"}]
67-
results = wrapper.predict_with_metadata(examples)
67+
results = list(wrapper.predict_with_metadata(examples))
6868
self.assertEqual(2, model.count)
6969
self.assertEqual({"score": 2}, results[0])
7070

@@ -98,14 +98,14 @@ def test_caching_model_wrapper_mixed_list(self):
9898
subset = examples[:1]
9999

100100
# Run the CachingModelWrapper over a subset of examples
101-
results = wrapper.predict_with_metadata(subset)
101+
results = list(wrapper.predict_with_metadata(subset))
102102
self.assertEqual(1, model.count)
103103
self.assertEqual({"score": 0}, results[0])
104104

105105
# Now, run the CachingModelWrapper over all of the examples. This should
106106
# only pass the examples that were not in subset to the wrapped model, and
107107
# the total number of inputs processed by the wrapped model should be 3
108-
results = wrapper.predict_with_metadata(examples)
108+
results = list(wrapper.predict_with_metadata(examples))
109109
self.assertEqual(3, model.count)
110110
self.assertEqual({"score": 0}, results[0])
111111
self.assertEqual({"score": 1}, results[1])

0 commit comments

Comments
 (0)