@@ -35,8 +35,8 @@ def step_server_config(context, server_fqdn, server_port):
35
35
context .server_seed = None
36
36
context .user_api_key = None
37
37
38
- context .completions = []
39
- context .concurrent_completion_tasks = []
38
+ context .tasks_result = []
39
+ context .concurrent_tasks = []
40
40
context .prompts = []
41
41
42
42
@@ -149,20 +149,20 @@ async def step_request_completion(context, api_error):
149
149
server_seed = context .server_seed ,
150
150
expect_api_error = expect_api_error ,
151
151
user_api_key = context .user_api_key )
152
- context .completions .append (completion )
152
+ context .tasks_result .append (completion )
153
153
print (f"Completion response: { completion } " )
154
154
if expect_api_error :
155
155
assert completion == 401 , f"completion must be an 401 status code: { completion } "
156
156
157
157
158
158
@step (u'{predicted_n} tokens are predicted matching {re_content}' )
159
159
def step_n_tokens_predicted_with_content (context , predicted_n , re_content ):
160
- assert_n_tokens_predicted (context .completions .pop (), int (predicted_n ), re_content )
160
+ assert_n_tokens_predicted (context .tasks_result .pop (), int (predicted_n ), re_content )
161
161
162
162
163
163
@step (u'{predicted_n} tokens are predicted' )
164
164
def step_n_tokens_predicted (context , predicted_n ):
165
- assert_n_tokens_predicted (context .completions .pop (), int (predicted_n ))
165
+ assert_n_tokens_predicted (context .tasks_result .pop (), int (predicted_n ))
166
166
167
167
168
168
@step (u'a user prompt {user_prompt}' )
@@ -195,13 +195,13 @@ def step_user_api_key(context, user_api_key):
195
195
context .user_api_key = user_api_key
196
196
197
197
198
- @step (u'a user api key ' )
198
+ @step (u'no user api key' )
199
199
def step_no_user_api_key (context ):
200
200
context .user_api_key = None
201
201
202
202
203
- @step (u'no user api key' )
204
- def step_no_user_api_key (context ):
203
+ @step (u'a user api key ' )
204
+ def step_no_user_api_key_space (context ):
205
205
context .user_api_key = None
206
206
207
207
@@ -234,7 +234,7 @@ async def step_oai_chat_completions(context, api_error):
234
234
if hasattr (context , 'user_api_key' ) else None ,
235
235
236
236
expect_api_error = expect_api_error )
237
- context .completions .append (completion )
237
+ context .tasks_result .append (completion )
238
238
print (f"Completion response: { completion } " )
239
239
if expect_api_error :
240
240
assert completion == 401 , f"completion must be an 401 status code: { completion } "
@@ -285,47 +285,38 @@ async def step_oai_chat_completions(context):
285
285
if hasattr (context , 'user_api_key' ) else None )
286
286
287
287
288
- @async_run_until_complete
289
288
@step (u'all prompts are predicted' )
290
- async def step_impl (context ):
289
+ @async_run_until_complete
290
+ async def step_all_prompts_are_predicted (context ):
291
291
await all_prompts_are_predicted (context )
292
292
293
293
294
294
@step (u'all prompts are predicted with {n_predict} tokens' )
295
295
@async_run_until_complete
296
- async def step_all_prompts_are_predicted (context , n_predict ):
296
+ async def step_all_prompts_are_predicted_with_n_tokens (context , n_predict ):
297
297
expected_predicted_n = int (n_predict )
298
298
await all_prompts_are_predicted (context , expected_predicted_n )
299
299
300
300
301
301
async def all_prompts_are_predicted (context , expected_predicted_n = None ):
302
- n_completions = await gather_concurrent_completions_tasks (context )
302
+ n_completions = await gather_tasks_results (context )
303
303
assert n_completions > 0
304
304
for i in range (n_completions ):
305
- assert_n_tokens_predicted (context .completions .pop (), expected_predicted_n = expected_predicted_n )
305
+ assert_n_tokens_predicted (context .tasks_result .pop (), expected_predicted_n = expected_predicted_n )
306
+ assert len (context .concurrent_tasks ) == 0 , f"{ len (context .concurrent_tasks )} pending requests"
306
307
307
308
308
309
@step (u'embeddings are computed for' )
309
310
@async_run_until_complete
310
311
async def step_compute_embedding (context ):
311
- async with aiohttp .ClientSession () as session :
312
- async with session .post (f'{ context .base_url } /embedding' ,
313
- json = {
314
- "content" : context .text ,
315
- }) as response :
316
- assert response .status == 200
317
- response_json = await response .json ()
318
- context .embeddings = response_json ['embedding' ]
312
+ content = context .text
313
+ base_url = context .base_url
314
+ context .embeddings = await request_embedding (content , base_url )
319
315
320
316
321
317
@step (u'embeddings are generated' )
322
- def step_compute_embeddings (context ):
323
- assert len (context .embeddings ) > 0
324
- embeddings_computed = False
325
- for emb in context .embeddings :
326
- if emb != 0 :
327
- embeddings_computed = True
328
- assert embeddings_computed , f"Embeddings: { context .embeddings } "
318
+ def step_assert_embeddings (context ):
319
+ assert_embeddings (context .embeddings )
329
320
330
321
331
322
@step (u'an OAI compatible embeddings computation request for' )
@@ -341,6 +332,24 @@ def step_oai_compute_embedding(context):
341
332
context .embeddings = embeddings
342
333
343
334
335
+ @step (u'concurrent embedding requests' )
336
+ @async_run_until_complete ()
337
+ async def step_concurrent_embedding_requests (context ):
338
+ await concurrent_completion_requests (context ,
339
+ request_embedding ,
340
+ # prompt is inserted automatically
341
+ context .base_url )
342
+
343
+
344
+ @step (u'all embeddings are generated' )
345
+ @async_run_until_complete ()
346
+ async def all_embeddings_are_generated (context ):
347
+ n_embedding_requests = await gather_tasks_results (context )
348
+ assert n_embedding_requests > 0
349
+ for i in range (n_embedding_requests ):
350
+ assert_embeddings (context .tasks_result .pop ())
351
+
352
+
344
353
@step (u'tokenizing' )
345
354
@async_run_until_complete
346
355
async def step_tokenize (context ):
@@ -391,7 +400,7 @@ async def concurrent_completion_requests(context, f_completion, *args, **kwargs)
391
400
assert n_prompts > 0
392
401
for prompt_no in range (n_prompts ):
393
402
shifted_args = [context .prompts .pop (), * args ]
394
- context .concurrent_completion_tasks .append (asyncio .create_task (f_completion (* shifted_args , ** kwargs )))
403
+ context .concurrent_tasks .append (asyncio .create_task (f_completion (* shifted_args , ** kwargs )))
395
404
await asyncio .sleep (0.1 )
396
405
397
406
@@ -540,6 +549,17 @@ async def oai_chat_completions(user_prompt,
540
549
return completion_response
541
550
542
551
552
+ async def request_embedding (content , base_url ):
553
+ async with aiohttp .ClientSession () as session :
554
+ async with session .post (f'{ base_url } /embedding' ,
555
+ json = {
556
+ "content" : content ,
557
+ }) as response :
558
+ assert response .status == 200
559
+ response_json = await response .json ()
560
+ return response_json ['embedding' ]
561
+
562
+
543
563
def assert_n_tokens_predicted (completion_response , expected_predicted_n = None , re_content = None ):
544
564
content = completion_response ['content' ]
545
565
n_predicted = completion_response ['timings' ]['predicted_n' ]
@@ -554,12 +574,12 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
554
574
f' ```\n { content } \n ``` do not match /{ re_content } /' )
555
575
556
576
557
- async def gather_concurrent_completions_tasks (context ):
558
- n_completion_tasks = len (context .concurrent_completion_tasks )
559
- print (f"Waiting for all { n_completion_tasks } completion responses ..." )
560
- for task_no in range (n_completion_tasks ):
561
- context .completions .append (await context .concurrent_completion_tasks .pop ())
562
- n_completions = len (context .completions )
577
+ async def gather_tasks_results (context ):
578
+ n_tasks = len (context .concurrent_tasks )
579
+ print (f"Waiting for all { n_tasks } tasks results ..." )
580
+ for task_no in range (n_tasks ):
581
+ context .tasks_result .append (await context .concurrent_tasks .pop ())
582
+ n_completions = len (context .tasks_result )
563
583
return n_completions
564
584
565
585
@@ -602,16 +622,25 @@ async def wait_for_health_status(context,
602
622
if counter >= timeout :
603
623
# Sometimes health requests are triggered after completions are predicted
604
624
if expected_http_status_code == 503 :
605
- if len (context .completions ) == 0 :
606
- print ("\x1b [5;37;43mWARNING: forcing concurrents completions tasks,"
625
+ if len (context .tasks_result ) == 0 :
626
+ print ("\x1b [5;37;43mWARNING: forcing concurrent tasks,"
607
627
" busy health check missed, probably too fast inference\x1b [0m" )
608
- n_completions = await gather_concurrent_completions_tasks (context )
628
+ n_completions = await gather_tasks_results (context )
609
629
if n_completions > 0 :
610
630
return
611
631
612
632
assert False , 'timeout exceeded'
613
633
614
634
635
+ def assert_embeddings (embeddings ):
636
+ assert len (embeddings ) > 0
637
+ embeddings_computed = False
638
+ for emb in embeddings :
639
+ if emb != 0 :
640
+ embeddings_computed = True
641
+ assert embeddings_computed , f"Embeddings: { embeddings } "
642
+
643
+
615
644
async def request_slots_status (context , expected_slots ):
616
645
async with aiohttp .ClientSession () as session :
617
646
async with await session .get (f'{ context .base_url } /slots' ) as slots_response :
@@ -652,6 +681,8 @@ def start_server_background(context):
652
681
server_args .extend (['--n-predict' , context .n_server_predict ])
653
682
if context .server_api_key is not None :
654
683
server_args .extend (['--api-key' , context .server_api_key ])
684
+ if 'DEBUG' in os .environ and os .environ ['DEBUG' ] == 'ON' :
685
+ server_args .append ('--verbose' )
655
686
print (f"starting server with: { context .server_path } " , * server_args )
656
687
context .server_process = subprocess .Popen (
657
688
[str (arg ) for arg in [context .server_path , * server_args ]],
0 commit comments