Skip to content

Commit 6c0e6f4

Browse files
committed
server: tests: adding concurrent embedding in issue ggml-org#5655
allow to enable VERBOSE mode
1 parent 30f802d commit 6c0e6f4

File tree

7 files changed

+117
-50
lines changed

7 files changed

+117
-50
lines changed
File renamed without changes.

examples/server/tests/README.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ Server tests scenario using [BDD](https://en.wikipedia.org/wiki/Behavior-driven_
1212
3. Start the test: `./tests.sh`
1313

1414
It's possible to override some scenario steps values with environment variables:
15-
- `$PORT` -> `context.server_port` to set the listening port of the server during scenario, default: `8080`
16-
- `$LLAMA_SERVER_BIN_PATH` -> to change the server binary path, default: `../../../build/bin/server`
15+
- `PORT` -> `context.server_port` to set the listening port of the server during scenario, default: `8080`
16+
- `LLAMA_SERVER_BIN_PATH` -> to change the server binary path, default: `../../../build/bin/server`
17+
- `DEBUG` -> "ON" to enable server verbose mode `--verbose`
1718

1819
### Run @bug, @wip or @wrong_usage annotated scenario
1920

@@ -23,4 +24,4 @@ Feature or Scenario must be annotated with `@llama.cpp` to be included in the de
2324
- `@wip` to focus on a scenario working in progress
2425

2526
To run a scenario annotated with `@bug`, start:
26-
`./tests.sh --tags bug`
27+
`DEBUG=ON ./tests.sh --no-skipped --tags bug`

examples/server/tests/features/environment.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def after_scenario(context, scenario):
2424
for line in f:
2525
print(line)
2626
if not is_server_listening(context.server_fqdn, context.server_port):
27-
print("ERROR: Server has crashed")
27+
print("\x1b[33;101mERROR: Server stopped listening\x1b[0m")
2828

2929
if not pid_exists(context.server_process.pid):
3030
assert False, f"Server not running pid={context.server_process.pid} ..."
@@ -41,7 +41,7 @@ def after_scenario(context, scenario):
4141
time.sleep(0.1)
4242
attempts += 1
4343
if attempts > 5:
44-
print(f"Server dandling exits, killing all {context.server_path} ...")
44+
print(f"Server dangling exits, killing all {context.server_path} ...")
4545
process = subprocess.run(['killall', '-9', context.server_path],
4646
stderr=subprocess.PIPE,
4747
universal_newlines=True)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# List of ongoing issues
2+
@bug
3+
Feature: Issues
4+
# Issue #5655
5+
Scenario: Multi users embeddings
6+
Given a server listening on localhost:8080
7+
And a model file stories260K.gguf
8+
And a model alias tinyllama-2
9+
And 42 as server seed
10+
And 64 KV cache size
11+
And 2 slots
12+
And continuous batching
13+
And embeddings extraction
14+
Then the server is starting
15+
Then the server is healthy
16+
17+
Given a prompt:
18+
"""
19+
Write a very long story about AI.
20+
"""
21+
And a prompt:
22+
"""
23+
Write another very long music lyrics.
24+
"""
25+
And a prompt:
26+
"""
27+
Write a very long poem.
28+
"""
29+
And a prompt:
30+
"""
31+
Write a very long joke.
32+
"""
33+
Given concurrent embedding requests
34+
Then the server is busy
35+
Then the server is idle
36+
Then all embeddings are generated

examples/server/tests/features/steps/steps.py

+70-39
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def step_server_config(context, server_fqdn, server_port):
3535
context.server_seed = None
3636
context.user_api_key = None
3737

38-
context.completions = []
39-
context.concurrent_completion_tasks = []
38+
context.tasks_result = []
39+
context.concurrent_tasks = []
4040
context.prompts = []
4141

4242

@@ -149,20 +149,20 @@ async def step_request_completion(context, api_error):
149149
server_seed=context.server_seed,
150150
expect_api_error=expect_api_error,
151151
user_api_key=context.user_api_key)
152-
context.completions.append(completion)
152+
context.tasks_result.append(completion)
153153
print(f"Completion response: {completion}")
154154
if expect_api_error:
155155
assert completion == 401, f"completion must be an 401 status code: {completion}"
156156

157157

158158
@step(u'{predicted_n} tokens are predicted matching {re_content}')
159159
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
160-
assert_n_tokens_predicted(context.completions.pop(), int(predicted_n), re_content)
160+
assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n), re_content)
161161

162162

163163
@step(u'{predicted_n} tokens are predicted')
164164
def step_n_tokens_predicted(context, predicted_n):
165-
assert_n_tokens_predicted(context.completions.pop(), int(predicted_n))
165+
assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n))
166166

167167

168168
@step(u'a user prompt {user_prompt}')
@@ -195,13 +195,13 @@ def step_user_api_key(context, user_api_key):
195195
context.user_api_key = user_api_key
196196

197197

198-
@step(u'a user api key ')
198+
@step(u'no user api key')
199199
def step_no_user_api_key(context):
200200
context.user_api_key = None
201201

202202

203-
@step(u'no user api key')
204-
def step_no_user_api_key(context):
203+
@step(u'a user api key ')
204+
def step_no_user_api_key_space(context):
205205
context.user_api_key = None
206206

207207

@@ -234,7 +234,7 @@ async def step_oai_chat_completions(context, api_error):
234234
if hasattr(context, 'user_api_key') else None,
235235

236236
expect_api_error=expect_api_error)
237-
context.completions.append(completion)
237+
context.tasks_result.append(completion)
238238
print(f"Completion response: {completion}")
239239
if expect_api_error:
240240
assert completion == 401, f"completion must be an 401 status code: {completion}"
@@ -285,47 +285,38 @@ async def step_oai_chat_completions(context):
285285
if hasattr(context, 'user_api_key') else None)
286286

287287

288-
@async_run_until_complete
289288
@step(u'all prompts are predicted')
290-
async def step_impl(context):
289+
@async_run_until_complete
290+
async def step_all_prompts_are_predicted(context):
291291
await all_prompts_are_predicted(context)
292292

293293

294294
@step(u'all prompts are predicted with {n_predict} tokens')
295295
@async_run_until_complete
296-
async def step_all_prompts_are_predicted(context, n_predict):
296+
async def step_all_prompts_are_predicted_with_n_tokens(context, n_predict):
297297
expected_predicted_n = int(n_predict)
298298
await all_prompts_are_predicted(context, expected_predicted_n)
299299

300300

301301
async def all_prompts_are_predicted(context, expected_predicted_n=None):
302-
n_completions = await gather_concurrent_completions_tasks(context)
302+
n_completions = await gather_tasks_results(context)
303303
assert n_completions > 0
304304
for i in range(n_completions):
305-
assert_n_tokens_predicted(context.completions.pop(), expected_predicted_n=expected_predicted_n)
305+
assert_n_tokens_predicted(context.tasks_result.pop(), expected_predicted_n=expected_predicted_n)
306+
assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
306307

307308

308309
@step(u'embeddings are computed for')
309310
@async_run_until_complete
310311
async def step_compute_embedding(context):
311-
async with aiohttp.ClientSession() as session:
312-
async with session.post(f'{context.base_url}/embedding',
313-
json={
314-
"content": context.text,
315-
}) as response:
316-
assert response.status == 200
317-
response_json = await response.json()
318-
context.embeddings = response_json['embedding']
312+
content = context.text
313+
base_url = context.base_url
314+
context.embeddings = await request_embedding(content, base_url)
319315

320316

321317
@step(u'embeddings are generated')
322-
def step_compute_embeddings(context):
323-
assert len(context.embeddings) > 0
324-
embeddings_computed = False
325-
for emb in context.embeddings:
326-
if emb != 0:
327-
embeddings_computed = True
328-
assert embeddings_computed, f"Embeddings: {context.embeddings}"
318+
def step_assert_embeddings(context):
319+
assert_embeddings(context.embeddings)
329320

330321

331322
@step(u'an OAI compatible embeddings computation request for')
@@ -341,6 +332,24 @@ def step_oai_compute_embedding(context):
341332
context.embeddings = embeddings
342333

343334

335+
@step(u'concurrent embedding requests')
336+
@async_run_until_complete()
337+
async def step_concurrent_embedding_requests(context):
338+
await concurrent_completion_requests(context,
339+
request_embedding,
340+
# prompt is inserted automatically
341+
context.base_url)
342+
343+
344+
@step(u'all embeddings are generated')
345+
@async_run_until_complete()
346+
async def all_embeddings_are_generated(context):
347+
n_embedding_requests = await gather_tasks_results(context)
348+
assert n_embedding_requests > 0
349+
for i in range(n_embedding_requests):
350+
assert_embeddings(context.tasks_result.pop())
351+
352+
344353
@step(u'tokenizing')
345354
@async_run_until_complete
346355
async def step_tokenize(context):
@@ -391,7 +400,7 @@ async def concurrent_completion_requests(context, f_completion, *args, **kwargs)
391400
assert n_prompts > 0
392401
for prompt_no in range(n_prompts):
393402
shifted_args = [context.prompts.pop(), *args]
394-
context.concurrent_completion_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
403+
context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
395404
await asyncio.sleep(0.1)
396405

397406

@@ -540,6 +549,17 @@ async def oai_chat_completions(user_prompt,
540549
return completion_response
541550

542551

552+
async def request_embedding(content, base_url):
553+
async with aiohttp.ClientSession() as session:
554+
async with session.post(f'{base_url}/embedding',
555+
json={
556+
"content": content,
557+
}) as response:
558+
assert response.status == 200
559+
response_json = await response.json()
560+
return response_json['embedding']
561+
562+
543563
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
544564
content = completion_response['content']
545565
n_predicted = completion_response['timings']['predicted_n']
@@ -554,12 +574,12 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
554574
f' ```\n{content}\n``` do not match /{re_content}/')
555575

556576

557-
async def gather_concurrent_completions_tasks(context):
558-
n_completion_tasks = len(context.concurrent_completion_tasks)
559-
print(f"Waiting for all {n_completion_tasks} completion responses...")
560-
for task_no in range(n_completion_tasks):
561-
context.completions.append(await context.concurrent_completion_tasks.pop())
562-
n_completions = len(context.completions)
577+
async def gather_tasks_results(context):
578+
n_tasks = len(context.concurrent_tasks)
579+
print(f"Waiting for all {n_tasks} tasks results...")
580+
for task_no in range(n_tasks):
581+
context.tasks_result.append(await context.concurrent_tasks.pop())
582+
n_completions = len(context.tasks_result)
563583
return n_completions
564584

565585

@@ -602,16 +622,25 @@ async def wait_for_health_status(context,
602622
if counter >= timeout:
603623
# Sometimes health requests are triggered after completions are predicted
604624
if expected_http_status_code == 503:
605-
if len(context.completions) == 0:
606-
print("\x1b[5;37;43mWARNING: forcing concurrents completions tasks,"
625+
if len(context.tasks_result) == 0:
626+
print("\x1b[5;37;43mWARNING: forcing concurrent tasks,"
607627
" busy health check missed, probably too fast inference\x1b[0m")
608-
n_completions = await gather_concurrent_completions_tasks(context)
628+
n_completions = await gather_tasks_results(context)
609629
if n_completions > 0:
610630
return
611631

612632
assert False, 'timeout exceeded'
613633

614634

635+
def assert_embeddings(embeddings):
636+
assert len(embeddings) > 0
637+
embeddings_computed = False
638+
for emb in embeddings:
639+
if emb != 0:
640+
embeddings_computed = True
641+
assert embeddings_computed, f"Embeddings: {embeddings}"
642+
643+
615644
async def request_slots_status(context, expected_slots):
616645
async with aiohttp.ClientSession() as session:
617646
async with await session.get(f'{context.base_url}/slots') as slots_response:
@@ -652,6 +681,8 @@ def start_server_background(context):
652681
server_args.extend(['--n-predict', context.n_server_predict])
653682
if context.server_api_key is not None:
654683
server_args.extend(['--api-key', context.server_api_key])
684+
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
685+
server_args.append('--verbose')
655686
print(f"starting server with: {context.server_path}", *server_args)
656687
context.server_process = subprocess.Popen(
657688
[str(arg) for arg in [context.server_path, *server_args]],

examples/server/tests/features/wrong_usage.feature renamed to examples/server/tests/features/wrong_usages.feature

+4-5
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,18 @@ Feature: Wrong usage of llama.cpp server
44

55
#3969 The user must always set --n-predict option
66
# to cap the number of tokens any completion request can generate
7-
# or pass n_predict or max_tokens in the request.
7+
# or pass n_predict/max_tokens in the request.
88
Scenario: Infinite loop
99
Given a server listening on localhost:8080
1010
And a model file stories260K.gguf
11-
And 1 slots
12-
And 32 KV cache size
1311
# Uncomment below to fix the issue
1412
#And 64 server max tokens to predict
1513
Then the server is starting
1614
Given a prompt:
1715
"""
1816
Go to: infinite loop
1917
"""
18+
# Uncomment below to fix the issue
19+
#And 128 max tokens to predict
2020
Given concurrent completion requests
21-
22-
Then all prompts are predicted
21+
Then all prompts are predicted

examples/server/tests/tests.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ set -eu
55
if [ $# -lt 1 ]
66
then
77
# Start @llama.cpp scenario
8-
behave --summary --stop --no-capture --tags llama.cpp
8+
behave --summary --stop --no-capture --exclude 'issues|wrong_usages' --tags llama.cpp
99
else
1010
behave "$@"
1111
fi

0 commit comments

Comments
 (0)