9
9
10
10
11
11
DEBUG = False
12
+ tf_graph = "graph.pb"
13
+ torch_graph = "pt-minimal.pt"
12
14
13
15
14
16
class Capturing (list ):
@@ -114,7 +116,7 @@ def test_numpy_tensor(self):
114
116
con .tensorset ("trying" , stringarr )
115
117
116
118
def test_modelset_errors (self ):
117
- model_path = os .path .join (MODEL_DIR , "graph.pb" )
119
+ model_path = os .path .join (MODEL_DIR , tf_graph )
118
120
model_pb = load_model (model_path )
119
121
con = self .get_client ()
120
122
with self .assertRaises (ValueError ):
@@ -139,7 +141,7 @@ def test_modelset_errors(self):
139
141
)
140
142
141
143
def test_modelget_meta (self ):
142
- model_path = os .path .join (MODEL_DIR , "graph.pb" )
144
+ model_path = os .path .join (MODEL_DIR , tf_graph )
143
145
model_pb = load_model (model_path )
144
146
con = self .get_client ()
145
147
con .modelset (
@@ -160,7 +162,7 @@ def test_modelget_meta(self):
160
162
)
161
163
162
164
def test_modelrun_non_list_input_output (self ):
163
- model_path = os .path .join (MODEL_DIR , "graph.pb" )
165
+ model_path = os .path .join (MODEL_DIR , tf_graph )
164
166
model_pb = load_model (model_path )
165
167
con = self .get_client ()
166
168
con .modelset (
@@ -173,7 +175,7 @@ def test_modelrun_non_list_input_output(self):
173
175
174
176
def test_nonasciichar (self ):
175
177
nonascii = "ĉ"
176
- model_path = os .path .join (MODEL_DIR , "graph.pb" )
178
+ model_path = os .path .join (MODEL_DIR , tf_graph )
177
179
model_pb = load_model (model_path )
178
180
con = self .get_client ()
179
181
con .modelset (
@@ -192,8 +194,8 @@ def test_nonasciichar(self):
192
194
self .assertTrue ((np .allclose (tensor , [4.0 , 9.0 ])))
193
195
194
196
def test_run_tf_model (self ):
195
- model_path = os .path .join (MODEL_DIR , "graph.pb" )
196
- bad_model_path = os .path .join (MODEL_DIR , "pt-minimal.pt" )
197
+ model_path = os .path .join (MODEL_DIR , tf_graph )
198
+ bad_model_path = os .path .join (MODEL_DIR , torch_graph )
197
199
198
200
model_pb = load_model (model_path )
199
201
wrong_model_pb = load_model (bad_model_path )
@@ -295,7 +297,7 @@ def test_run_onnxdl_model(self):
295
297
self .assertTrue (np .allclose (outtensor , [4.0 ]))
296
298
297
299
def test_run_pytorch_model (self ):
298
- model_path = os .path .join (MODEL_DIR , "pt-minimal.pt" )
300
+ model_path = os .path .join (MODEL_DIR , torch_graph )
299
301
ptmodel = load_model (model_path )
300
302
con = self .get_client ()
301
303
con .modelset ("pt_model" , "torch" , "cpu" , ptmodel , tag = "v1.0" )
@@ -317,7 +319,7 @@ def test_run_tflite_model(self):
317
319
self .assertTrue (np .allclose (output , [8 ]))
318
320
319
321
def test_info (self ):
320
- model_path = os .path .join (MODEL_DIR , "graph.pb" )
322
+ model_path = os .path .join (MODEL_DIR , tf_graph )
321
323
model_pb = load_model (model_path )
322
324
con = self .get_client ()
323
325
con .modelset ("m" , "tf" , "cpu" , model_pb , inputs = ["a" , "b" ], outputs = ["mul" ])
@@ -345,13 +347,13 @@ def test_info(self):
345
347
self .assertEqual (first_info , third_info ) # before modelrun and after reset
346
348
347
349
def test_model_scan (self ):
348
- model_path = os .path .join (MODEL_DIR , "graph.pb" )
350
+ model_path = os .path .join (MODEL_DIR , tf_graph )
349
351
model_pb = load_model (model_path )
350
352
con = self .get_client ()
351
353
con .modelset (
352
354
"m" , "tf" , "cpu" , model_pb , inputs = ["a" , "b" ], outputs = ["mul" ], tag = "v1.2"
353
355
)
354
- model_path = os .path .join (MODEL_DIR , "pt-minimal.pt" )
356
+ model_path = os .path .join (MODEL_DIR , torch_graph )
355
357
ptmodel = load_model (model_path )
356
358
con = self .get_client ()
357
359
# TODO: RedisAI modelscan issue
@@ -377,7 +379,7 @@ class DagTestCase(RedisAITestBase):
377
379
def setUp (self ):
378
380
super ().setUp ()
379
381
con = self .get_client ()
380
- model_path = os .path .join (MODEL_DIR , "pt-minimal.pt" )
382
+ model_path = os .path .join (MODEL_DIR , torch_graph )
381
383
ptmodel = load_model (model_path )
382
384
con .modelset ("pt_model" , "torch" , "cpu" , ptmodel , tag = "v7.0" )
383
385
0 commit comments