2
2
from unittest import mock
3
3
4
4
from absl .testing import absltest
5
+ from absl .testing import parameterized
5
6
from lit_nlp .examples .gcp import model_server
6
7
from lit_nlp .examples .prompt_debugging import utils as pd_utils
7
8
import webtest
8
9
9
10
10
- class TestWSGIApp (absltest .TestCase ):
11
+ class TestWSGIApp (parameterized .TestCase ):
11
12
12
- @mock . patch ( 'lit_nlp.examples.prompt_debugging.models.get_models' )
13
- def test_predict_endpoint ( self , mock_get_models ):
13
+ @classmethod
14
+ def setUpClass ( cls ):
14
15
test_model_name = 'lit_on_gcp_test_model'
16
+ sal_name , tok_name = pd_utils .generate_model_group_names (test_model_name )
15
17
test_model_config = f'{ test_model_name } :test_model_path'
16
18
os .environ ['MODEL_CONFIG' ] = test_model_config
17
19
18
- mock_model = mock .MagicMock ()
19
- mock_model .predict .side_effect = [[{'response' : 'test output text' }]]
20
+ generation_model = mock .MagicMock ()
21
+ generation_model .predict .side_effect = [[{'response' : 'test output text' }]]
20
22
21
23
salience_model = mock .MagicMock ()
22
24
salience_model .predict .side_effect = [[{
@@ -30,33 +32,42 @@ def test_predict_endpoint(self, mock_get_models):
30
32
[{'tokens' : ['test' , 'output' , 'text' ]}]
31
33
]
32
34
33
- sal_name , tok_name = pd_utils .generate_model_group_names (test_model_name )
34
-
35
- mock_get_models .return_value = {
36
- test_model_name : mock_model ,
35
+ cls .mock_models = {
36
+ test_model_name : generation_model ,
37
37
sal_name : salience_model ,
38
38
tok_name : tokenize_model ,
39
39
}
40
- app = webtest .TestApp (model_server .get_wsgi_app ())
41
40
42
- response = app .post_json ('/predict' , {'inputs' : 'test_input' })
43
- self .assertEqual (response .status_code , 200 )
44
- self .assertEqual (response .json , [{'response' : 'test output text' }])
45
41
46
- response = app .post_json ('/salience' , {'inputs' : 'test_input' })
47
- self .assertEqual (response .status_code , 200 )
48
- self .assertEqual (
49
- response .json ,
50
- [{
42
+ @parameterized .named_parameters (
43
+ dict (
44
+ testcase_name = 'predict' ,
45
+ endpoint = '/predict' ,
46
+ expected = [{'response' : 'test output text' }],
47
+ ),
48
+ dict (
49
+ testcase_name = 'salience' ,
50
+ endpoint = '/salience' ,
51
+ expected = [{
51
52
'tokens' : ['test' , 'output' , 'text' ],
52
53
'grad_l2' : [0.1234 , 0.3456 , 0.5678 ],
53
54
'grad_dot_input' : [0.1234 , - 0.3456 , 0.5678 ],
54
55
}],
55
- )
56
+ ),
57
+ dict (
58
+ testcase_name = 'tokenize' ,
59
+ endpoint = '/tokenize' ,
60
+ expected = [{'tokens' : ['test' , 'output' , 'text' ]}],
61
+ ),
62
+ )
63
+ @mock .patch ('lit_nlp.examples.prompt_debugging.models.get_models' )
64
+ def test_endpoint (self , mock_get_models , endpoint , expected ):
65
+ mock_get_models .return_value = self .mock_models
66
+ app = webtest .TestApp (model_server .get_wsgi_app ())
56
67
57
- response = app .post_json ('/tokenize' , {'inputs' : 'test_input' })
68
+ response = app .post_json (endpoint , {'inputs' : [{ 'prompt' : 'test input' }] })
58
69
self .assertEqual (response .status_code , 200 )
59
- self .assertEqual (response .json , [{ 'tokens' : [ 'test' , 'output' , 'text' ]}] )
70
+ self .assertEqual (response .json , expected )
60
71
61
72
62
73
if __name__ == '__main__' :
0 commit comments