Skip to content

Commit f9173e3

Browse files
authored
Add OpenAI LLM provider (#15)
1 parent 0c30076 commit f9173e3

File tree

4 files changed

+51
-1
lines changed

4 files changed

+51
-1
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ name = "shelloracle"
77
dynamic = ["version"]
88
dependencies = [
99
"httpx",
10+
"openai",
1011
"prompt-toolkit",
1112
"tomlkit"
1213
]
@@ -30,6 +31,8 @@ classifiers = [
3031
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
3132
]
3233

34+
[project.optional-dependencies]
35+
3336
[tool.setuptools]
3437
packages = ["shelloracle", "shelloracle.config", "shelloracle.providers"]
3538

shelloracle/provider.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ def get_provider(name: str) -> type[Provider]:
3535
:param name: the provider name
3636
:return: the requested provider
3737
"""
38-
from .providers import Ollama
38+
from .providers import Ollama, OpenAI
3939
providers = {
4040
Ollama.name: Ollama,
41+
OpenAI.name: OpenAI
4142
}
4243
return providers[name]

shelloracle/providers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .ollama import Ollama
2+
from .openai import OpenAI

shelloracle/providers/openai.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from collections.abc import AsyncIterator
2+
3+
from openai import APIError
4+
from openai import AsyncOpenAI as OpenAIClient
5+
6+
from ..config import Setting
7+
from ..provider import Provider, ProviderError
8+
9+
10+
class OpenAI(Provider):
11+
name = "OpenAI"
12+
13+
api_key = Setting(default="")
14+
model = Setting(default="gpt-3.5-turbo")
15+
system_prompt = Setting(
16+
default=(
17+
"Based on the following user description, generate a corresponding Bash command. Focus solely "
18+
"on interpreting the requirements and translating them into a single, executable Bash command. "
19+
"Ensure accuracy and relevance to the user's description. The output should be a valid Bash "
20+
"command that directly aligns with the user's intent, ready for execution in a command-line "
21+
"environment. Output nothing except for the command. No code block, no English explanation, "
22+
"no start/end tags."
23+
)
24+
)
25+
26+
def __init__(self):
27+
if not self.api_key:
28+
raise ProviderError("No API key provided")
29+
self.client = OpenAIClient(api_key=self.api_key)
30+
31+
async def generate(self, prompt: str) -> AsyncIterator[str]:
32+
try:
33+
stream = await self.client.chat.completions.create(
34+
model=self.model,
35+
messages=[
36+
{"role": "system", "content": self.system_prompt},
37+
{"role": "user", "content": prompt}
38+
],
39+
stream=True,
40+
)
41+
async for chunk in stream:
42+
if chunk.choices[0].delta.content is not None:
43+
yield chunk.choices[0].delta.content
44+
except APIError as e:
45+
raise ProviderError(f"Something went wrong while querying OpenAI: {e}") from e

0 commit comments

Comments
 (0)