Skip to content

Commit 9baac29

Browse files
committed
Code health update on model server tests
1 parent 2488aa7 commit 9baac29

File tree

1 file changed

+32
-21
lines changed

1 file changed

+32
-21
lines changed

lit_nlp/examples/gcp/model_server_test.py

+32-21
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,23 @@
22
from unittest import mock
33

44
from absl.testing import absltest
5+
from absl.testing import parameterized
56
from lit_nlp.examples.gcp import model_server
67
from lit_nlp.examples.prompt_debugging import utils as pd_utils
78
import webtest
89

910

10-
class TestWSGIApp(absltest.TestCase):
11+
class TestWSGIApp(parameterized.TestCase):
1112

12-
@mock.patch('lit_nlp.examples.prompt_debugging.models.get_models')
13-
def test_predict_endpoint(self, mock_get_models):
13+
@classmethod
14+
def setUpClass(cls):
1415
test_model_name = 'lit_on_gcp_test_model'
16+
sal_name, tok_name = pd_utils.generate_model_group_names(test_model_name)
1517
test_model_config = f'{test_model_name}:test_model_path'
1618
os.environ['MODEL_CONFIG'] = test_model_config
1719

18-
mock_model = mock.MagicMock()
19-
mock_model.predict.side_effect = [[{'response': 'test output text'}]]
20+
generation_model = mock.MagicMock()
21+
generation_model.predict.side_effect = [[{'response': 'test output text'}]]
2022

2123
salience_model = mock.MagicMock()
2224
salience_model.predict.side_effect = [[{
@@ -30,33 +32,42 @@ def test_predict_endpoint(self, mock_get_models):
3032
[{'tokens': ['test', 'output', 'text']}]
3133
]
3234

33-
sal_name, tok_name = pd_utils.generate_model_group_names(test_model_name)
34-
35-
mock_get_models.return_value = {
36-
test_model_name: mock_model,
35+
cls.mock_models = {
36+
test_model_name: generation_model,
3737
sal_name: salience_model,
3838
tok_name: tokenize_model,
3939
}
40-
app = webtest.TestApp(model_server.get_wsgi_app())
4140

42-
response = app.post_json('/predict', {'inputs': 'test_input'})
43-
self.assertEqual(response.status_code, 200)
44-
self.assertEqual(response.json, [{'response': 'test output text'}])
4541

46-
response = app.post_json('/salience', {'inputs': 'test_input'})
47-
self.assertEqual(response.status_code, 200)
48-
self.assertEqual(
49-
response.json,
50-
[{
42+
@parameterized.named_parameters(
43+
dict(
44+
testcase_name='predict',
45+
endpoint='/predict',
46+
expected=[{'response': 'test output text'}],
47+
),
48+
dict(
49+
testcase_name='salience',
50+
endpoint='/salience',
51+
expected=[{
5152
'tokens': ['test', 'output', 'text'],
5253
'grad_l2': [0.1234, 0.3456, 0.5678],
5354
'grad_dot_input': [0.1234, -0.3456, 0.5678],
5455
}],
55-
)
56+
),
57+
dict(
58+
testcase_name='tokenize',
59+
endpoint='/tokenize',
60+
expected=[{'tokens': ['test', 'output', 'text']}],
61+
),
62+
)
63+
@mock.patch('lit_nlp.examples.prompt_debugging.models.get_models')
64+
def test_endpoint(self, mock_get_models, endpoint, expected):
65+
mock_get_models.return_value = self.mock_models
66+
app = webtest.TestApp(model_server.get_wsgi_app())
5667

57-
response = app.post_json('/tokenize', {'inputs': 'test_input'})
68+
response = app.post_json(endpoint, {'inputs': [{'prompt': 'test input'}]})
5869
self.assertEqual(response.status_code, 200)
59-
self.assertEqual(response.json, [{'tokens': ['test', 'output', 'text']}])
70+
self.assertEqual(response.json, expected)
6071

6172

6273
if __name__ == '__main__':

0 commit comments

Comments
 (0)