diff --git a/migrations/versions/a692c8b52308_add_workspace_system_prompt.py b/migrations/versions/a692c8b52308_add_workspace_system_prompt.py new file mode 100644 index 00000000..24af5b3c --- /dev/null +++ b/migrations/versions/a692c8b52308_add_workspace_system_prompt.py @@ -0,0 +1,26 @@ +"""add_workspace_system_prompt + +Revision ID: a692c8b52308 +Revises: 5c2f3eee5f90 +Create Date: 2025-01-17 16:33:58.464223 + +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a692c8b52308" +down_revision: Union[str, None] = "5c2f3eee5f90" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add column to workspaces table + op.execute("ALTER TABLE workspaces ADD COLUMN system_prompt TEXT DEFAULT NULL;") + + +def downgrade() -> None: + op.execute("ALTER TABLE workspaces DROP COLUMN system_prompt;") diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index c2a32436..cca4e691 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -248,7 +248,7 @@ async def record_context(self, context: Optional[PipelineContext]) -> None: except Exception as e: logger.error(f"Failed to record context: {context}.", error=str(e)) - async def add_workspace(self, workspace_name: str) -> Optional[Workspace]: + async def add_workspace(self, workspace_name: str) -> Workspace: """Add a new workspace to the DB. This handles validation and insertion of a new workspace. @@ -256,8 +256,7 @@ async def add_workspace(self, workspace_name: str) -> Optional[Workspace]: It may raise a ValidationError if the workspace name is invalid. or a AlreadyExistsError if the workspace already exists. """ - workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name) - + workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name, system_prompt=None) sql = text( """ INSERT INTO workspaces (id, name) @@ -275,6 +274,21 @@ async def add_workspace(self, workspace_name: str) -> Optional[Workspace]: raise AlreadyExistsError(f"Workspace {workspace_name} already exists.") return added_workspace + async def update_workspace(self, workspace: Workspace) -> Workspace: + sql = text( + """ + UPDATE workspaces SET + name = :name, + system_prompt = :system_prompt + WHERE id = :id + RETURNING * + """ + ) + updated_workspace = await self._execute_update_pydantic_model( + workspace, sql, should_raise=True + ) + return updated_workspace + async def update_session(self, session: Session) -> Optional[Session]: sql = text( """ @@ -392,11 +406,11 @@ async def get_workspaces(self) -> List[WorkspaceActive]: workspaces = await self._execute_select_pydantic_model(WorkspaceActive, sql) return workspaces - async def get_workspace_by_name(self, name: str) -> List[Workspace]: + async def get_workspace_by_name(self, name: str) -> Optional[Workspace]: sql = text( """ SELECT - id, name + id, name, system_prompt FROM workspaces WHERE name = :name """ @@ -422,7 +436,7 @@ async def get_active_workspace(self) -> Optional[ActiveWorkspace]: sql = text( """ SELECT - w.id, w.name, s.id as session_id, s.last_update + w.id, w.name, w.system_prompt, s.id as session_id, s.last_update FROM sessions s INNER JOIN workspaces w ON w.id = s.active_workspace_id """ diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index fe5dbb68..6120ea1f 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -43,6 +43,7 @@ class Setting(BaseModel): class Workspace(BaseModel): id: str name: str + system_prompt: Optional[str] @field_validator("name", mode="plain") @classmethod @@ -98,5 +99,6 @@ class WorkspaceActive(BaseModel): class ActiveWorkspace(BaseModel): id: str name: str + system_prompt: Optional[str] session_id: str last_update: datetime.datetime diff --git a/src/codegate/pipeline/cli/cli.py b/src/codegate/pipeline/cli/cli.py index 333de7c9..bfe2bfda 100644 --- a/src/codegate/pipeline/cli/cli.py +++ b/src/codegate/pipeline/cli/cli.py @@ -8,7 +8,7 @@ PipelineResult, PipelineStep, ) -from codegate.pipeline.cli.commands import Version, Workspace +from codegate.pipeline.cli.commands import SystemPrompt, Version, Workspace HELP_TEXT = """ ## CodeGate CLI\n @@ -32,6 +32,7 @@ async def codegate_cli(command): available_commands = { "version": Version().exec, "workspace": Workspace().exec, + "system-prompt": SystemPrompt().exec, } out_func = available_commands.get(command[0]) if out_func is None: diff --git a/src/codegate/pipeline/cli/commands.py b/src/codegate/pipeline/cli/commands.py index f5a5d694..8902f409 100644 --- a/src/codegate/pipeline/cli/commands.py +++ b/src/codegate/pipeline/cli/commands.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List +from typing import Awaitable, Callable, Dict, List, Tuple from pydantic import ValidationError @@ -8,11 +8,24 @@ from codegate.workspaces import crud +class NoFlagValueError(Exception): + pass + + +class NoSubcommandError(Exception): + pass + + class CodegateCommand(ABC): @abstractmethod async def run(self, args: List[str]) -> str: pass + @property + @abstractmethod + def command_name(self) -> str: + pass + @property @abstractmethod def help(self) -> str: @@ -28,6 +41,10 @@ class Version(CodegateCommand): async def run(self, args: List[str]) -> str: return f"CodeGate version: {__version__}" + @property + def command_name(self) -> str: + return "version" + @property def help(self) -> str: return ( @@ -38,17 +55,107 @@ def help(self) -> str: ) -class Workspace(CodegateCommand): +class CodegateCommandSubcommand(CodegateCommand): + + @property + @abstractmethod + def subcommands(self) -> Dict[str, Callable[[List[str]], Awaitable[str]]]: + """ + List of subcommands that the command accepts. + """ + pass + + @property + @abstractmethod + def flags(self) -> List[str]: + """ + List of flags that the command accepts. + Example: ["-w", "-f"] + """ + pass + + def _parse_flags_and_subocomand(self, args: List[str]) -> Tuple[Dict[str, str], List[str], str]: + """ + Reads the flags and subcommand from the args + The flags are expected to be at the start of the args and are optional. + """ + i = 0 + read_flags = {} + # Parse all recognized flags at the start + while i < len(args): + if args[i] in self.flags: + flag_name = args[i] + if i + 1 >= len(args): + raise NoFlagValueError(f"Flag {flag_name} needs a value, but none provided.") + read_flags[flag_name] = args[i + 1] + i += 2 + else: + # Once we encounter something that's not a recognized flag, + # we assume it's the subcommand + break + + if i >= len(args): + raise NoSubcommandError("No subcommand found after optional flags.") + + subcommand = args[i] + i += 1 + + # The rest of the arguments after the subcommand + rest = args[i:] + return read_flags, rest, subcommand + + async def run(self, args: List[str]) -> str: + """ + Try to parse the flags and subcommand and execute the subcommand + """ + try: + flags, rest, subcommand = self._parse_flags_and_subocomand(args) + except NoFlagValueError: + return ( + f"Error reading the command. Flag without value found. " + f"Use `codegate {self.command_name} -h` to see available subcommands" + ) + except NoSubcommandError: + return ( + f"Submmand not found " + f"Use `codegate {self.command_name} -h` to see available subcommands" + ) + + command_to_execute = self.subcommands.get(subcommand) + if command_to_execute is None: + return ( + f"Submmand not found " + f"Use `codegate {self.command_name} -h` to see available subcommands" + ) + + return await command_to_execute(flags, rest) + + +class Workspace(CodegateCommandSubcommand): def __init__(self): self.workspace_crud = crud.WorkspaceCrud() - self.commands = { + + @property + def command_name(self) -> str: + return "workspace" + + @property + def flags(self) -> List[str]: + """ + No flags for the workspace command + """ + return [] + + @property + def subcommands(self) -> Dict[str, Callable[[List[str]], Awaitable[str]]]: + return { "list": self._list_workspaces, "add": self._add_workspace, "activate": self._activate_workspace, } - async def _list_workspaces(self, *args: List[str]) -> str: + async def _list_workspaces(self, flags: Dict[str, str], args: List[str]) -> str: """ List all workspaces """ @@ -61,7 +168,7 @@ async def _list_workspaces(self, *args: List[str]) -> str: respond_str += "\n" return respond_str - async def _add_workspace(self, args: List[str]) -> str: + async def _add_workspace(self, flags: Dict[str, str], args: List[str]) -> str: """ Add a workspace """ @@ -83,7 +190,7 @@ async def _add_workspace(self, args: List[str]) -> str: return f"Workspace **{new_workspace_name}** has been added" - async def _activate_workspace(self, args: List[str]) -> str: + async def _activate_workspace(self, flags: Dict[str, str], args: List[str]) -> str: """ Activate a workspace """ @@ -104,16 +211,6 @@ async def _activate_workspace(self, args: List[str]) -> str: return "An error occurred while activating the workspace" return f"Workspace **{workspace_name}** has been activated" - async def run(self, args: List[str]) -> str: - if not args: - return "Please provide a command. Use `codegate workspace -h` to see available commands" - command = args[0] - command_to_execute = self.commands.get(command) - if command_to_execute is not None: - return await command_to_execute(args[1:]) - else: - return "Command not found. Use `codegate workspace -h` to see available commands" - @property def help(self) -> str: return ( @@ -130,3 +227,90 @@ def help(self) -> str: " - *args*:\n\n" " - `workspace_name`" ) + + +class SystemPrompt(CodegateCommandSubcommand): + + def __init__(self): + self.workspace_crud = crud.WorkspaceCrud() + + @property + def command_name(self) -> str: + return "system-prompt" + + @property + def flags(self) -> List[str]: + """ + Flags for the system-prompt command. + -w: Workspace name + """ + return ["-w"] + + @property + def subcommands(self) -> Dict[str, Callable[[List[str]], Awaitable[str]]]: + return { + "set": self._set_system_prompt, + "show": self._show_system_prompt, + } + + async def _set_system_prompt(self, flags: Dict[str, str], args: List[str]) -> str: + """ + Set the system prompt of a workspace + If a workspace name is not provided, the active workspace is used + """ + if len(args) == 0: + return ( + "Please provide a workspace name and a system prompt. " + "Use `codegate workspace system-prompt -w `" + ) + + workspace_name = flags.get("-w") + if not workspace_name: + active_workspace = await self.workspace_crud.get_active_workspace() + workspace_name = active_workspace.name + + try: + updated_worksapce = await self.workspace_crud.update_workspace_system_prompt( + workspace_name, args + ) + except crud.WorkspaceDoesNotExistError: + return ( + f"Workspace system prompt not updated. Workspace `{workspace_name}` doesn't exist" + ) + + return f"Workspace `{updated_worksapce.name}` system prompt updated." + + async def _show_system_prompt(self, flags: Dict[str, str], args: List[str]) -> str: + """ + Show the system prompt of a workspace + If a workspace name is not provided, the active workspace is used + """ + workspace_name = flags.get("-w") + if not workspace_name: + active_workspace = await self.workspace_crud.get_active_workspace() + workspace_name = active_workspace.name + + try: + workspace = await self.workspace_crud.get_workspace_by_name(workspace_name) + except crud.WorkspaceDoesNotExistError: + return f"Workspace `{workspace_name}` doesn't exist" + + return f"Workspace **{workspace.name}** system prompt:\n\n{workspace.system_prompt}." + + @property + def help(self) -> str: + return ( + "### CodeGate System Prompt\n" + "Manage the system prompts of workspaces.\n\n" + "*Note*: If you want to update the system prompt using files please go to the " + "[dashboard](http://localhost:9090).\n\n" + "**Usage**: `codegate system-prompt -w `\n\n" + "*args*:\n" + "- `workspace_name`: Optional workspace name. If not specified will use the " + "active workspace\n\n" + "Available commands:\n" + "- `set`: Set the system prompt of the workspace\n" + " - *args*:\n" + " - `system_prompt`: The system prompt to set\n" + " - **Usage**: `codegate system-prompt -w set `\n" + ) diff --git a/src/codegate/pipeline/system_prompt/codegate.py b/src/codegate/pipeline/system_prompt/codegate.py index ee7310da..00efaa0c 100644 --- a/src/codegate/pipeline/system_prompt/codegate.py +++ b/src/codegate/pipeline/system_prompt/codegate.py @@ -1,4 +1,4 @@ -import json +from typing import Optional from litellm import ChatCompletionRequest, ChatCompletionSystemMessage @@ -7,6 +7,7 @@ PipelineResult, PipelineStep, ) +from codegate.workspaces.crud import WorkspaceCrud class SystemPrompt(PipelineStep): @@ -16,7 +17,7 @@ class SystemPrompt(PipelineStep): """ def __init__(self, system_prompt: str): - self._system_message = ChatCompletionSystemMessage(content=system_prompt, role="system") + self.codegate_system_prompt = system_prompt @property def name(self) -> str: @@ -25,6 +26,44 @@ def name(self) -> str: """ return "system-prompt" + async def _get_workspace_system_prompt(self) -> str: + wksp_crud = WorkspaceCrud() + workspace = await wksp_crud.get_active_workspace() + if not workspace: + return "" + + return workspace.system_prompt + + async def _construct_system_prompt( + self, + wrksp_sys_prompt: str, + req_sys_prompt: Optional[str], + should_add_codegate_sys_prompt: bool, + ) -> ChatCompletionSystemMessage: + + def _start_or_append(existing_prompt: str, new_prompt: str) -> str: + if existing_prompt: + return existing_prompt + "\n\nHere are additional instructions:\n\n" + new_prompt + return new_prompt + + system_prompt = "" + # Add codegate system prompt if secrets or bad packages are found at the beginning + if should_add_codegate_sys_prompt: + system_prompt = _start_or_append(system_prompt, self.codegate_system_prompt) + + # Add workspace system prompt if present + if wrksp_sys_prompt: + system_prompt = _start_or_append(system_prompt, wrksp_sys_prompt) + + # Add request system prompt if present + if req_sys_prompt and "codegate" not in req_sys_prompt.lower(): + system_prompt = _start_or_append(system_prompt, req_sys_prompt) + + return system_prompt + + async def _should_add_codegate_system_prompt(self, context: PipelineContext) -> bool: + return context.secrets_found or context.bad_packages_found + async def process( self, request: ChatCompletionRequest, context: PipelineContext ) -> PipelineResult: @@ -33,8 +72,12 @@ async def process( to the existing system prompt """ - # Nothing to do if no secrets or bad_packages are found - if not (context.secrets_found or context.bad_packages_found): + wrksp_sys_prompt = await self._get_workspace_system_prompt() + should_add_codegate_sys_prompt = await self._should_add_codegate_system_prompt(context) + + # Nothing to do if no secrets or bad_packages are found and we don't have a workspace + # system prompt + if not should_add_codegate_sys_prompt and not wrksp_sys_prompt: return PipelineResult(request=request, context=context) new_request = request.copy() @@ -42,23 +85,22 @@ async def process( if "messages" not in new_request: new_request["messages"] = [] - request_system_message = None + request_system_message = {} for message in new_request["messages"]: if message["role"] == "system": request_system_message = message + req_sys_prompt = request_system_message.get("content") - if request_system_message is None: - # Add system message - context.add_alert(self.name, trigger_string=json.dumps(self._system_message)) - new_request["messages"].insert(0, self._system_message) - elif "codegate" not in request_system_message["content"].lower(): - # Prepend to the system message - prepended_message = ( - self._system_message["content"] - + "\n Here are additional instructions. \n " - + request_system_message["content"] - ) - context.add_alert(self.name, trigger_string=prepended_message) - request_system_message["content"] = prepended_message + system_prompt = await self._construct_system_prompt( + wrksp_sys_prompt, req_sys_prompt, should_add_codegate_sys_prompt + ) + context.add_alert(self.name, trigger_string=system_prompt) + if not request_system_message: + # Insert the system prompt at the beginning of the messages + sytem_message = ChatCompletionSystemMessage(content=system_prompt, role="system") + new_request["messages"].insert(0, sytem_message) + else: + # Update the existing system prompt + request_system_message["content"] = system_prompt return PipelineResult(request=new_request, context=context) diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py index d14f24ab..2b44466d 100644 --- a/src/codegate/workspaces/crud.py +++ b/src/codegate/workspaces/crud.py @@ -60,7 +60,7 @@ async def _is_workspace_active( sessions = await self._db_reader.get_sessions() # The current implementation expects only one active session if len(sessions) != 1: - raise RuntimeError("Something went wrong. No active session found.") + raise WorkspaceCrudError("Something went wrong. More than one session found.") session = sessions[0] return (session.active_workspace_id == selected_workspace.id, session, selected_workspace) @@ -82,3 +82,26 @@ async def activate_workspace(self, workspace_name: str): db_recorder = DbRecorder() await db_recorder.update_session(session) return + + async def update_workspace_system_prompt( + self, workspace_name: str, sys_prompt_lst: List[str] + ) -> Workspace: + selected_workspace = await self._db_reader.get_workspace_by_name(workspace_name) + if not selected_workspace: + raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.") + + system_prompt = " ".join(sys_prompt_lst) + workspace_update = Workspace( + id=selected_workspace.id, + name=selected_workspace.name, + system_prompt=system_prompt, + ) + db_recorder = DbRecorder() + updated_workspace = await db_recorder.update_workspace(workspace_update) + return updated_workspace + + async def get_workspace_by_name(self, workspace_name: str) -> Workspace: + workspace = await self._db_reader.get_workspace_by_name(workspace_name) + if not workspace: + raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.") + return workspace diff --git a/tests/pipeline/system_prompt/test_system_prompt.py b/tests/pipeline/system_prompt/test_system_prompt.py index 06f92733..f17735e6 100644 --- a/tests/pipeline/system_prompt/test_system_prompt.py +++ b/tests/pipeline/system_prompt/test_system_prompt.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import pytest from litellm.types.llms.openai import ChatCompletionRequest @@ -14,7 +14,7 @@ def test_init_with_system_message(self): """ test_message = "Test system prompt" step = SystemPrompt(system_prompt=test_message) - assert step._system_message["content"] == test_message + assert step.codegate_system_prompt == test_message @pytest.mark.asyncio async def test_process_system_prompt_insertion(self): @@ -29,6 +29,7 @@ async def test_process_system_prompt_insertion(self): # Create system prompt step system_prompt = "Security analysis system prompt" step = SystemPrompt(system_prompt=system_prompt) + step._get_workspace_system_prompt = AsyncMock(return_value="") # Mock the get_last_user_message method step.get_last_user_message = Mock(return_value=(user_message, 0)) @@ -62,6 +63,7 @@ async def test_process_system_prompt_update(self): # Create system prompt step system_prompt = "Security analysis system prompt" step = SystemPrompt(system_prompt=system_prompt) + step._get_workspace_system_prompt = AsyncMock(return_value="") # Mock the get_last_user_message method step.get_last_user_message = Mock(return_value=(user_message, 0)) @@ -74,7 +76,7 @@ async def test_process_system_prompt_update(self): assert result.request["messages"][0]["role"] == "system" assert ( result.request["messages"][0]["content"] - == system_prompt + "\n Here are additional instructions. \n " + request_system_message + == system_prompt + "\n\nHere are additional instructions:\n\n" + request_system_message ) assert result.request["messages"][1]["role"] == "user" assert result.request["messages"][1]["content"] == user_message @@ -96,6 +98,7 @@ async def test_edge_cases(self, edge_case): system_prompt = "Security edge case prompt" step = SystemPrompt(system_prompt=system_prompt) + step._get_workspace_system_prompt = AsyncMock(return_value="") # Mock get_last_user_message to return None step.get_last_user_message = Mock(return_value=None) diff --git a/tests/pipeline/workspace/test_workspace.py b/tests/pipeline/workspace/test_workspace.py index e45376fa..27db0519 100644 --- a/tests/pipeline/workspace/test_workspace.py +++ b/tests/pipeline/workspace/test_workspace.py @@ -42,7 +42,7 @@ async def test_list_workspaces(mock_workspaces, expected_output): workspace_commands.workspace_crud.get_workspaces = mock_get_workspaces # Call the method - result = await workspace_commands._list_workspaces() + result = await workspace_commands._list_workspaces(None, None) # Check the result assert result == expected_output @@ -83,7 +83,7 @@ async def test_add_workspaces(args, existing_workspaces, expected_message): mock_recorder.add_workspace = AsyncMock() # Call the method - result = await workspace_commands._add_workspace(args) + result = await workspace_commands._add_workspace(None, args) # Assertions assert result == expected_message