Skip to content

Commit fb79a6d

Browse files
committed
add deepseek provider
1 parent 52e11b3 commit fb79a6d

File tree

4 files changed

+81
-2
lines changed

4 files changed

+81
-2
lines changed

README.md

+6-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
</p>
1010

1111
ShellOracle is an innovative terminal utility designed for intelligent shell command generation, bringing a new level of
12-
efficiency to your command-line interactions. ShellOracle currently supports Ollama, OpenAI, LocalAI, and Grok!
12+
efficiency to your command-line interactions. ShellOracle currently supports Ollama, OpenAI, Deepseek, LocalAI, and Grok!
1313

1414
![ShellOracle](https://i.imgur.com/lqTW1lO.gif)
1515

@@ -98,6 +98,11 @@ Refer to the [Ollama docs](https://ollama.ai) for installation, available models
9898
To use ShellOracle with OpenAI's models, create an [API key](https://platform.openai.com/account/api-keys). Edit
9999
your `~/.shelloracle/config.toml` to change your provider and enter your API key.
100100
101+
### Deepseek
102+
103+
To use ShellOracle with Deepseek's models, create an [API key](https://platform.deepseek.com/api_keys). Edit
104+
your `~/.shelloracle/config.toml` to change your provider and enter your API key.
105+
101106
### LocalAI
102107

103108
Refer to the [LocalAI docs](https://localai.io/) for installation, available models, and usage.

src/shelloracle/providers/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,13 @@ def __get__(self, instance: Provider, owner: type[Provider]) -> T:
7474

7575

7676
def _providers() -> dict[str, type[Provider]]:
77+
from shelloracle.providers.deepseek import Deepseek
7778
from shelloracle.providers.localai import LocalAI
7879
from shelloracle.providers.ollama import Ollama
7980
from shelloracle.providers.openai import OpenAI
8081
from shelloracle.providers.xai import XAI
8182

82-
return {Ollama.name: Ollama, OpenAI.name: OpenAI, LocalAI.name: LocalAI, XAI.name: XAI}
83+
return {Ollama.name: Ollama, OpenAI.name: OpenAI, LocalAI.name: LocalAI, XAI.name: XAI, Deepseek.name: Deepseek}
8384

8485

8586
def get_provider(name: str) -> type[Provider]:

src/shelloracle/providers/deepseek.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from collections.abc import AsyncIterator
2+
3+
from openai import APIError, AsyncOpenAI
4+
5+
from shelloracle.providers import Provider, ProviderError, Setting, system_prompt
6+
7+
8+
class Deepseek(Provider):
9+
name = "Deepseek"
10+
11+
api_key = Setting(default="")
12+
model = Setting(default="deepseek-chat")
13+
14+
def __init__(self):
15+
if not self.api_key:
16+
msg = "No API key provided"
17+
raise ProviderError(msg)
18+
self.client = AsyncOpenAI(base_url="https://api.deepseek.com/v1", api_key=self.api_key)
19+
20+
async def generate(self, prompt: str) -> AsyncIterator[str]:
21+
try:
22+
stream = await self.client.chat.completions.create(
23+
model=self.model,
24+
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
25+
stream=True,
26+
)
27+
async for chunk in stream:
28+
if chunk.choices[0].delta.content is not None:
29+
yield chunk.choices[0].delta.content
30+
except APIError as e:
31+
msg = f"Something went wrong while querying Deepseek: {e}"
32+
raise ProviderError(msg) from e

tests/providers/test_deepseek.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import pytest
2+
3+
from shelloracle.providers.deepseek import Deepseek
4+
5+
6+
class TestOpenAI:
7+
@pytest.fixture
8+
def deepseek_config(self, set_config):
9+
config = {
10+
"shelloracle": {"provider": "Deepseek"},
11+
"provider": {
12+
"Deepseek": {
13+
"api_key": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
14+
"model": "grok-beta",
15+
}
16+
},
17+
}
18+
set_config(config)
19+
20+
@pytest.fixture
21+
def deepseek_instance(self, deepseek_config):
22+
return Deepseek()
23+
24+
def test_name(self):
25+
assert Deepseek.name == "Deepseek"
26+
27+
def test_api_key(self, deepseek_instance):
28+
assert (
29+
deepseek_instance.api_key
30+
== "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
31+
)
32+
33+
def test_model(self, deepseek_instance):
34+
assert deepseek_instance.model == "grok-beta"
35+
36+
@pytest.mark.asyncio
37+
async def test_generate(self, mock_asyncopenai, deepseek_instance):
38+
result = ""
39+
async for response in deepseek_instance.generate(""):
40+
result += response
41+
assert result == "head -c 100 /dev/urandom | hexdump -C"

0 commit comments

Comments
 (0)