diff --git a/src/codegate/cli.py b/src/codegate/cli.py index be5096f6..455d9001 100644 --- a/src/codegate/cli.py +++ b/src/codegate/cli.py @@ -16,7 +16,7 @@ from codegate.config import Config, ConfigurationError from codegate.db.connection import init_db_sync, init_session_if_not_exists from codegate.pipeline.factory import PipelineFactory -from codegate.pipeline.secrets.manager import SecretsManager +from codegate.pipeline.sensitive_data.manager import SensitiveDataManager from codegate.providers import crud as provendcrud from codegate.providers.copilot.provider import CopilotProvider from codegate.server import init_app @@ -331,8 +331,8 @@ def serve( # noqa: C901 click.echo("Existing Certificates are already present.") # Initialize secrets manager and pipeline factory - secrets_manager = SecretsManager() - pipeline_factory = PipelineFactory(secrets_manager) + sensitive_data_manager = SensitiveDataManager() + pipeline_factory = PipelineFactory(sensitive_data_manager) app = init_app(pipeline_factory) diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index 0baa322a..ddcd5a61 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -12,34 +12,23 @@ from codegate.clients.clients import ClientType from codegate.db.models import Alert, AlertSeverity, Output, Prompt from codegate.extract_snippets.message_extractor import CodeSnippet -from codegate.pipeline.secrets.manager import SecretsManager +from codegate.pipeline.sensitive_data.manager import SensitiveDataManager logger = structlog.get_logger("codegate") @dataclass class PipelineSensitiveData: - manager: SecretsManager + manager: SensitiveDataManager session_id: str - api_key: Optional[str] = None model: Optional[str] = None - provider: Optional[str] = None - api_base: Optional[str] = None def secure_cleanup(self): """Securely cleanup sensitive data for this session""" if self.manager is None or self.session_id == "": return - self.manager.cleanup_session(self.session_id) self.session_id = "" - - # Securely wipe the API key using the same method as secrets manager - if self.api_key is not None: - api_key_bytes = bytearray(self.api_key.encode()) - self.manager.crypto.wipe_bytearray(api_key_bytes) - self.api_key = None - self.model = None @@ -274,19 +263,19 @@ class InputPipelineInstance: def __init__( self, pipeline_steps: List[PipelineStep], - secret_manager: SecretsManager, + sensitive_data_manager: SensitiveDataManager, is_fim: bool, client: ClientType = ClientType.GENERIC, ): self.pipeline_steps = pipeline_steps - self.secret_manager = secret_manager + self.sensitive_data_manager = sensitive_data_manager self.is_fim = is_fim self.context = PipelineContext(client=client) # we create the sesitive context here so that it is not shared between individual requests # TODO: could we get away with just generating the session ID for an instance? self.context.sensitive = PipelineSensitiveData( - manager=self.secret_manager, + manager=self.sensitive_data_manager, session_id=str(uuid.uuid4()), ) self.context.metadata["is_fim"] = is_fim @@ -343,12 +332,12 @@ class SequentialPipelineProcessor: def __init__( self, pipeline_steps: List[PipelineStep], - secret_manager: SecretsManager, + sensitive_data_manager: SensitiveDataManager, client_type: ClientType, is_fim: bool, ): self.pipeline_steps = pipeline_steps - self.secret_manager = secret_manager + self.sensitive_data_manager = sensitive_data_manager self.is_fim = is_fim self.instance = self._create_instance(client_type) @@ -356,7 +345,7 @@ def _create_instance(self, client_type: ClientType) -> InputPipelineInstance: """Create a new pipeline instance for processing a request""" return InputPipelineInstance( self.pipeline_steps, - self.secret_manager, + self.sensitive_data_manager, self.is_fim, client_type, ) diff --git a/src/codegate/pipeline/factory.py b/src/codegate/pipeline/factory.py index acde51b4..813459d5 100644 --- a/src/codegate/pipeline/factory.py +++ b/src/codegate/pipeline/factory.py @@ -12,18 +12,18 @@ PiiRedactionNotifier, PiiUnRedactionStep, ) -from codegate.pipeline.secrets.manager import SecretsManager from codegate.pipeline.secrets.secrets import ( CodegateSecrets, SecretRedactionNotifier, SecretUnredactionStep, ) +from codegate.pipeline.sensitive_data.manager import SensitiveDataManager from codegate.pipeline.system_prompt.codegate import SystemPrompt class PipelineFactory: - def __init__(self, secrets_manager: SecretsManager): - self.secrets_manager = secrets_manager + def __init__(self, sensitive_data_manager: SensitiveDataManager): + self.sensitive_data_manager = sensitive_data_manager def create_input_pipeline(self, client_type: ClientType) -> SequentialPipelineProcessor: input_steps: List[PipelineStep] = [ @@ -32,7 +32,7 @@ def create_input_pipeline(self, client_type: ClientType) -> SequentialPipelinePr # and without obfuscating the secrets, we'd leak the secrets during those # later steps CodegateSecrets(), - CodegatePii(), + CodegatePii(self.sensitive_data_manager), CodegateCli(), CodegateContextRetriever(), SystemPrompt( @@ -41,7 +41,7 @@ def create_input_pipeline(self, client_type: ClientType) -> SequentialPipelinePr ] return SequentialPipelineProcessor( input_steps, - self.secrets_manager, + self.sensitive_data_manager, client_type, is_fim=False, ) @@ -49,11 +49,11 @@ def create_input_pipeline(self, client_type: ClientType) -> SequentialPipelinePr def create_fim_pipeline(self, client_type: ClientType) -> SequentialPipelineProcessor: fim_steps: List[PipelineStep] = [ CodegateSecrets(), - CodegatePii(), + CodegatePii(self.sensitive_data_manager), ] return SequentialPipelineProcessor( fim_steps, - self.secrets_manager, + self.sensitive_data_manager, client_type, is_fim=True, ) diff --git a/src/codegate/pipeline/pii/analyzer.py b/src/codegate/pipeline/pii/analyzer.py index a1ed5bed..96442824 100644 --- a/src/codegate/pipeline/pii/analyzer.py +++ b/src/codegate/pipeline/pii/analyzer.py @@ -1,5 +1,4 @@ -import uuid -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, List, Optional import structlog from presidio_analyzer import AnalyzerEngine @@ -7,41 +6,11 @@ from codegate.db.models import AlertSeverity from codegate.pipeline.base import PipelineContext +from codegate.pipeline.sensitive_data.session_store import SessionStore logger = structlog.get_logger("codegate.pii.analyzer") -class PiiSessionStore: - """ - A class to manage PII (Personally Identifiable Information) session storage. - - Attributes: - session_id (str): The unique identifier for the session. If not provided, a new UUID - is generated. mappings (Dict[str, str]): A dictionary to store mappings between UUID - placeholders and PII. - - Methods: - add_mapping(pii: str) -> str: - Adds a PII string to the session store and returns a UUID placeholder for it. - - get_pii(uuid_placeholder: str) -> str: - Retrieves the PII string associated with the given UUID placeholder. If the placeholder - is not found, returns the placeholder itself. - """ - - def __init__(self, session_id: str = None): - self.session_id = session_id or str(uuid.uuid4()) - self.mappings: Dict[str, str] = {} - - def add_mapping(self, pii: str) -> str: - uuid_placeholder = f"<{str(uuid.uuid4())}>" - self.mappings[uuid_placeholder] = pii - return uuid_placeholder - - def get_pii(self, uuid_placeholder: str) -> str: - return self.mappings.get(uuid_placeholder, uuid_placeholder) - - class PiiAnalyzer: """ PiiAnalyzer class for analyzing and anonymizing text containing PII. @@ -52,12 +21,12 @@ class PiiAnalyzer: Get or create the singleton instance of PiiAnalyzer. analyze: text (str): The text to analyze for PII. - Tuple[str, List[Dict[str, Any]], PiiSessionStore]: The anonymized text, a list of + Tuple[str, List[Dict[str, Any]], SessionStore]: The anonymized text, a list of found PII details, and the session store. entities (List[str]): The PII entities to analyze for. restore_pii: anonymized_text (str): The text with anonymized PII. - session_store (PiiSessionStore): The PiiSessionStore used for anonymization. + session_store (SessionStore): The SessionStore used for anonymization. str: The text with original PII restored. """ @@ -95,13 +64,11 @@ def __init__(self): # Create analyzer with custom NLP engine self.analyzer = AnalyzerEngine(nlp_engine=nlp_engine) self.anonymizer = AnonymizerEngine() - self.session_store = PiiSessionStore() + self.session_store = SessionStore() PiiAnalyzer._instance = self - def analyze( - self, text: str, context: Optional[PipelineContext] = None - ) -> Tuple[str, List[Dict[str, Any]], PiiSessionStore]: + def analyze(self, text: str, context: Optional[PipelineContext] = None) -> List: # Prioritize credit card detection first entities = [ "PHONE_NUMBER", @@ -125,81 +92,30 @@ def analyze( language="en", score_threshold=0.3, # Lower threshold to catch more potential matches ) + return analyzer_results - # Track found PII - found_pii = [] - - # Only anonymize if PII was found - if analyzer_results: - # Log each found PII instance and anonymize - anonymized_text = text - for result in analyzer_results: - pii_value = text[result.start : result.end] - uuid_placeholder = self.session_store.add_mapping(pii_value) - pii_info = { - "type": result.entity_type, - "value": pii_value, - "score": result.score, - "start": result.start, - "end": result.end, - "uuid_placeholder": uuid_placeholder, - } - found_pii.append(pii_info) - anonymized_text = anonymized_text.replace(pii_value, uuid_placeholder) - - # Log each PII detection with its UUID mapping - logger.info( - "PII detected and mapped", - pii_type=result.entity_type, - score=f"{result.score:.2f}", - uuid=uuid_placeholder, - # Don't log the actual PII value for security - value_length=len(pii_value), - session_id=self.session_store.session_id, - ) - - # Log summary of all PII found in this analysis - if found_pii and context: - # Create notification string for alert - notify_string = ( - f"**PII Detected** 🔒\n" - f"- Total PII Found: {len(found_pii)}\n" - f"- Types Found: {', '.join(set(p['type'] for p in found_pii))}\n" - ) - context.add_alert( - self._name, - trigger_string=notify_string, - severity_category=AlertSeverity.CRITICAL, - ) - - logger.info( - "PII analysis complete", - total_pii_found=len(found_pii), - pii_types=[p["type"] for p in found_pii], - session_id=self.session_store.session_id, - ) - - # Return the anonymized text, PII details, and session store - return anonymized_text, found_pii, self.session_store - - # If no PII found, return original text, empty list, and session store - return text, [], self.session_store - - def restore_pii(self, anonymized_text: str, session_store: PiiSessionStore) -> str: + def restore_pii(self, session_id: str, anonymized_text: str) -> str: """ Restore the original PII (Personally Identifiable Information) in the given anonymized text. This method replaces placeholders in the anonymized text with their corresponding original - PII values using the mappings stored in the provided PiiSessionStore. + PII values using the mappings stored in the provided SessionStore. Args: anonymized_text (str): The text containing placeholders for PII. - session_store (PiiSessionStore): The session store containing mappings of placeholders + session_id (str): The session id containing mappings of placeholders to original PII. Returns: str: The text with the original PII restored. """ - for uuid_placeholder, original_pii in session_store.mappings.items(): + session_data = self.session_store.get_by_session_id(session_id) + if not session_data: + logger.warning( + "No active PII session found for given session ID. Unable to restore PII." + ) + return anonymized_text + + for uuid_placeholder, original_pii in session_data.items(): anonymized_text = anonymized_text.replace(uuid_placeholder, original_pii) return anonymized_text diff --git a/src/codegate/pipeline/pii/manager.py b/src/codegate/pipeline/pii/manager.py deleted file mode 100644 index 54112713..00000000 --- a/src/codegate/pipeline/pii/manager.py +++ /dev/null @@ -1,84 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple - -import structlog - -from codegate.pipeline.base import PipelineContext -from codegate.pipeline.pii.analyzer import PiiAnalyzer, PiiSessionStore - -logger = structlog.get_logger("codegate") - - -class PiiManager: - """ - Manages the analysis and restoration of Personally Identifiable Information - (PII) in text. - - Attributes: - analyzer (PiiAnalyzer): The singleton instance of PiiAnalyzer used for - PII detection and restoration. - session_store (PiiSessionStore): The session store for the current PII session. - - Methods: - __init__(): - Initializes the PiiManager with the singleton PiiAnalyzer instance and sets the - session store. - - analyze(text: str) -> Tuple[str, List[Dict[str, Any]]]: - Analyzes the given text for PII, anonymizes it, and logs the detected PII details. - Args: - text (str): The text to be analyzed for PII. - Returns: - Tuple[str, List[Dict[str, Any]]]: A tuple containing the anonymized text and - a list of found PII details. - - restore_pii(anonymized_text: str) -> str: - Restores the PII in the given anonymized text using the current session. - Args: - anonymized_text (str): The text with anonymized PII to be restored. - Returns: - str: The text with restored PII. - """ - - def __init__(self): - """ - Initialize the PiiManager with the singleton PiiAnalyzer instance. - """ - self.analyzer = PiiAnalyzer.get_instance() - # Always use the analyzer's session store - self._session_store = self.analyzer.session_store - - @property - def session_store(self) -> PiiSessionStore: - """Get the current session store.""" - # Always return the analyzer's current session store - return self.analyzer.session_store - - def analyze( - self, text: str, context: Optional[PipelineContext] = None - ) -> Tuple[str, List[Dict[str, Any]]]: - # Call analyzer and get results - anonymized_text, found_pii, _ = self.analyzer.analyze(text, context=context) - - # Log found PII details (without modifying the found_pii list) - if found_pii: - for pii in found_pii: - logger.info( - "PII detected", - pii_type=pii["type"], - value="*" * len(pii["value"]), # Don't log actual value - score=f"{pii['score']:.2f}", - ) - - # Return the exact same objects we got from the analyzer - return anonymized_text, found_pii - - def restore_pii(self, anonymized_text: str) -> str: - """ - Restore PII in the given anonymized text using the current session. - """ - if self.session_store is None: - logger.warning("No active PII session found. Unable to restore PII.") - return anonymized_text - - # Use the analyzer's restore_pii method with the current session store - return self.analyzer.restore_pii(anonymized_text, self.session_store) diff --git a/src/codegate/pipeline/pii/pii.py b/src/codegate/pipeline/pii/pii.py index f0b9f271..fde89428 100644 --- a/src/codegate/pipeline/pii/pii.py +++ b/src/codegate/pipeline/pii/pii.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple +import uuid import regex as re import structlog @@ -6,13 +7,15 @@ from litellm.types.utils import Delta, StreamingChoices from codegate.config import Config +from codegate.db.models import AlertSeverity from codegate.pipeline.base import ( PipelineContext, PipelineResult, PipelineStep, ) from codegate.pipeline.output import OutputPipelineContext, OutputPipelineStep -from codegate.pipeline.pii.manager import PiiManager +from codegate.pipeline.pii.analyzer import PiiAnalyzer +from codegate.pipeline.sensitive_data.manager import SensitiveData, SensitiveDataManager from codegate.pipeline.systemmsg import add_or_update_system_message logger = structlog.get_logger("codegate") @@ -25,7 +28,7 @@ class CodegatePii(PipelineStep): Methods: __init__: - Initializes the CodegatePii pipeline step and sets up the PiiManager. + Initializes the CodegatePii pipeline step and sets up the SensitiveDataManager. name: Returns the name of the pipeline step. @@ -37,14 +40,15 @@ class CodegatePii(PipelineStep): Processes the chat completion request to detect and redact PII. Updates the request with anonymized text and stores PII details in the context metadata. - restore_pii(anonymized_text: str) -> str: - Restores the original PII from the anonymized text using the PiiManager. + restore_pii(session_id: str, anonymized_text: str) -> str: + Restores the original PII from the anonymized text using the SensitiveDataManager. """ - def __init__(self): + def __init__(self, sensitive_data_manager: SensitiveDataManager): """Initialize the CodegatePii pipeline step.""" super().__init__() - self.pii_manager = PiiManager() + self.sensitive_data_manager = sensitive_data_manager + self.analyzer = PiiAnalyzer.get_instance() @property def name(self) -> str: @@ -65,6 +69,68 @@ def _get_redacted_snippet(self, message: str, pii_details: List[Dict[str, Any]]) return message[start:end] + def process_results( + self, session_id: str, text: str, results: List, context: PipelineContext + ) -> Tuple[List, str]: + # Track found PII + found_pii = [] + + # Log each found PII instance and anonymize + anonymized_text = text + for result in results: + pii_value = text[result.start : result.end] + + # add to session store + obj = SensitiveData(original=pii_value, service="pii", type=result.entity_type) + uuid_placeholder = self.sensitive_data_manager.store(session_id, obj) + anonymized_text = anonymized_text.replace(pii_value, uuid_placeholder) + + # Add to found PII list + pii_info = { + "type": result.entity_type, + "value": pii_value, + "score": result.score, + "start": result.start, + "end": result.end, + "uuid_placeholder": uuid_placeholder, + } + found_pii.append(pii_info) + + # Log each PII detection with its UUID mapping + logger.info( + "PII detected and mapped", + pii_type=result.entity_type, + score=f"{result.score:.2f}", + uuid=uuid_placeholder, + # Don't log the actual PII value for security + value_length=len(pii_value), + session_id=session_id, + ) + + # Log summary of all PII found in this analysis + if found_pii and context: + # Create notification string for alert + notify_string = ( + f"**PII Detected** 🔒\n" + f"- Total PII Found: {len(found_pii)}\n" + f"- Types Found: {', '.join(set(p['type'] for p in found_pii))}\n" + ) + context.add_alert( + self.name, + trigger_string=notify_string, + severity_category=AlertSeverity.CRITICAL, + ) + + logger.info( + "PII analysis complete", + total_pii_found=len(found_pii), + pii_types=[p["type"] for p in found_pii], + session_id=session_id, + ) + + # Return the anonymized text, PII details, and session store + return found_pii, anonymized_text + async def process( self, request: ChatCompletionRequest, context: PipelineContext ) -> PipelineResult: @@ -75,23 +141,28 @@ async def process( total_pii_found = 0 all_pii_details: List[Dict[str, Any]] = [] last_redacted_text = "" + session_id = context.sensitive.session_id for i, message in enumerate(new_request["messages"]): if "content" in message and message["content"]: # This is where analyze and anonymize the text original_text = str(message["content"]) - anonymized_text, pii_details = self.pii_manager.analyze(original_text, context) - - if pii_details: - total_pii_found += len(pii_details) - all_pii_details.extend(pii_details) - new_request["messages"][i]["content"] = anonymized_text - - # If this is a user message, grab the redacted snippet! - if message.get("role") == "user": - last_redacted_text = self._get_redacted_snippet( - anonymized_text, pii_details - ) + results = self.analyzer.analyze(original_text, context) + if results: + pii_details, anonymized_text = self.process_results( + session_id, original_text, results, context + ) + + if pii_details: + total_pii_found += len(pii_details) + all_pii_details.extend(pii_details) + new_request["messages"][i]["content"] = anonymized_text + + # If this is a user message, grab the redacted snippet! + if message.get("role") == "user": + last_redacted_text = self._get_redacted_snippet( + anonymized_text, pii_details + ) logger.info(f"Total PII instances redacted: {total_pii_found}") @@ -99,9 +170,10 @@ async def process( context.metadata["redacted_pii_count"] = total_pii_found context.metadata["redacted_pii_details"] = all_pii_details context.metadata["redacted_text"] = last_redacted_text + context.metadata["session_id"] = session_id if total_pii_found > 0: - context.metadata["pii_manager"] = self.pii_manager + context.metadata["sensitive_data_manager"] = self.sensitive_data_manager system_message = ChatCompletionSystemMessage( content=Config.get_config().prompts.pii_redacted, @@ -113,8 +185,31 @@ async def process( return PipelineResult(request=new_request, context=context) - def restore_pii(self, anonymized_text: str) -> str: - return self.pii_manager.restore_pii(anonymized_text) + def restore_pii(self, session_id: str, anonymized_text: str) -> str: + """ + Restore the original PII (Personally Identifiable Information) in the given anonymized text. + + This method replaces placeholders in the anonymized text with their corresponding original + PII values using the mappings stored in the provided SessionStore. + + Args: + anonymized_text (str): The text containing placeholders for PII. + session_id (str): The session id containing mappings of placeholders + to original PII. + + Returns: + str: The text with the original PII restored. + """ + session_data = self.sensitive_data_manager.get_by_session_id(session_id) + if not session_data: + logger.warning( + "No active PII session found for given session ID. Unable to restore PII." + ) + return anonymized_text + + for uuid_placeholder, original_pii in session_data.items(): + anonymized_text = anonymized_text.replace(uuid_placeholder, original_pii) + return anonymized_text class PiiUnRedactionStep(OutputPipelineStep): @@ -136,12 +231,12 @@ class PiiUnRedactionStep(OutputPipelineStep): """ def __init__(self): - self.redacted_pattern = re.compile(r"<([0-9a-f-]{0,36})>") + self.redacted_pattern = re.compile(r"#([0-9a-f-]{0,36})#") self.complete_uuid_pattern = re.compile( r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" ) # noqa: E501 - self.marker_start = "<" - self.marker_end = ">" + self.marker_start = "#" + self.marker_end = "#" @property def name(self) -> str: @@ -151,7 +246,7 @@ def _is_complete_uuid(self, uuid_str: str) -> bool: """Check if the string is a complete UUID""" return bool(self.complete_uuid_pattern.match(uuid_str)) - async def process_chunk( + async def process_chunk( # noqa: C901 self, chunk: ModelResponse, context: OutputPipelineContext, @@ -162,6 +257,10 @@ async def process_chunk( return [chunk] content = chunk.choices[0].delta.content + session_id = input_context.sensitive.session_id + if not session_id: + logger.error("Could not get any session id, cannot process pii") + return [chunk] # Add current chunk to buffer if context.prefix_buffer: @@ -172,13 +271,13 @@ async def process_chunk( current_pos = 0 result = [] while current_pos < len(content): - start_idx = content.find("<", current_pos) + start_idx = content.find(self.marker_start, current_pos) if start_idx == -1: # No more markers!, add remaining content result.append(content[current_pos:]) break - end_idx = content.find(">", start_idx) + end_idx = content.find(self.marker_end, start_idx + 1) if end_idx == -1: # Incomplete marker, buffer the rest context.prefix_buffer = content[current_pos:] @@ -190,16 +289,18 @@ async def process_chunk( # Extract potential UUID if it's a valid format! uuid_marker = content[start_idx : end_idx + 1] - uuid_value = uuid_marker[1:-1] # Remove < > + uuid_value = uuid_marker[1:-1] # Remove # # if self._is_complete_uuid(uuid_value): # Get the PII manager from context metadata logger.debug(f"Valid UUID found: {uuid_value}") - pii_manager = input_context.metadata.get("pii_manager") if input_context else None - if pii_manager and pii_manager.session_store: + sensitive_data_manager = ( + input_context.metadata.get("sensitive_data_manager") if input_context else None + ) + if sensitive_data_manager and sensitive_data_manager.session_store: # Restore original value from PII manager logger.debug("Attempting to restore PII from UUID marker") - original = pii_manager.session_store.get_pii(uuid_marker) + original = sensitive_data_manager.get_original_value(session_id, uuid_marker) logger.debug(f"Restored PII: {original}") result.append(original) else: diff --git a/src/codegate/pipeline/secrets/gatecrypto.py b/src/codegate/pipeline/secrets/gatecrypto.py deleted file mode 100644 index 859b025d..00000000 --- a/src/codegate/pipeline/secrets/gatecrypto.py +++ /dev/null @@ -1,111 +0,0 @@ -import os -import time -from base64 import b64decode, b64encode - -import structlog -from cryptography.hazmat.primitives.ciphers.aead import AESGCM - -logger = structlog.get_logger("codegate") - - -class CodeGateCrypto: - """ - Manage session keys and provide encryption / decryption of tokens with replay protection. - Attributes: - session_keys (dict): A dictionary to store session keys with their associated timestamps. - SESSION_KEY_LIFETIME (int): The lifetime of a session key in seconds. - NONCE_SIZE (int): The size of the nonce used in AES GCM mode. - Methods: - generate_session_key(session_id): - Generates a session key with an associated timestamp. - get_session_key(session_id): - Retrieves a session key if it is still valid. - cleanup_expired_keys(): - Removes expired session keys from memory. - encrypt_token(token, session_id): - Encrypts a token with a session key and adds a timestamp for replay protection. - decrypt_token(encrypted_token, session_id): - Decrypts a token and validates its timestamp to prevent replay attacks. - wipe_bytearray(data): - Securely wipes a bytearray in-place. - """ - - def __init__(self): - self.session_keys = {} - self.SESSION_KEY_LIFETIME = 600 # 10 minutes - self.NONCE_SIZE = 12 # AES GCM recommended nonce size - - def generate_session_key(self, session_id): - """Generates a session key with an associated timestamp.""" - key = os.urandom(32) # Generate a 256-bit key - self.session_keys[session_id] = (key, time.time()) - return key - - def get_session_key(self, session_id): - """Retrieves a session key if it is still valid.""" - key_data = self.session_keys.get(session_id) - if key_data: - key, timestamp = key_data - if time.time() - timestamp < self.SESSION_KEY_LIFETIME: - return key - else: - # Key has expired - del self.session_keys[session_id] - return None - - def cleanup_expired_keys(self): - """Removes expired session keys from memory.""" - now = time.time() - expired_keys = [ - session_id - for session_id, (key, timestamp) in self.session_keys.items() - if now - timestamp >= self.SESSION_KEY_LIFETIME - ] - for session_id in expired_keys: - del self.session_keys[session_id] - - def encrypt_token(self, token, session_id): - """Encrypts a token with a session key and adds a timestamp for replay protection.""" - key = self.generate_session_key(session_id) - nonce = os.urandom(self.NONCE_SIZE) - timestamp = int(time.time()) - data = f"{token}:{timestamp}".encode() # Append timestamp to token - - aesgcm = AESGCM(key) - ciphertext = aesgcm.encrypt(nonce, data, None) # None for no associated data - - # Combine nonce and ciphertext (which includes the authentication tag) - encrypted_token = b64encode(nonce + ciphertext).decode() - return encrypted_token - - def decrypt_token(self, encrypted_token, session_id): - """Decrypts a token and validates its timestamp to prevent replay attacks.""" - key = self.get_session_key(session_id) - if not key: - raise ValueError("Session key expired or invalid.") - - encrypted_data = b64decode(encrypted_token) - nonce = encrypted_data[: self.NONCE_SIZE] - ciphertext = encrypted_data[self.NONCE_SIZE :] # Includes authentication tag - - aesgcm = AESGCM(key) - try: - decrypted_data = aesgcm.decrypt( - nonce, ciphertext, None - ).decode() # None for no associated data - except Exception as e: - raise ValueError("Decryption failed: Invalid token or tampering detected.") from e - - token, timestamp = decrypted_data.rsplit(":", 1) - if time.time() - int(timestamp) > self.SESSION_KEY_LIFETIME: - raise ValueError("Token has expired.") - - return token - - def wipe_bytearray(self, data): - """Securely wipes a bytearray in-place.""" - if not isinstance(data, bytearray): - raise ValueError("Only bytearray objects can be securely wiped.") - for i in range(len(data)): - data[i] = 0 # Overwrite each byte with 0 - logger.info("Sensitive data securely wiped from memory.") diff --git a/src/codegate/pipeline/secrets/manager.py b/src/codegate/pipeline/secrets/manager.py deleted file mode 100644 index bef07c75..00000000 --- a/src/codegate/pipeline/secrets/manager.py +++ /dev/null @@ -1,117 +0,0 @@ -from typing import NamedTuple, Optional - -import structlog - -from codegate.pipeline.secrets.gatecrypto import CodeGateCrypto - -logger = structlog.get_logger("codegate") - - -class SecretEntry(NamedTuple): - """Represents a stored secret""" - - original: str - encrypted: str - service: str - secret_type: str - - -class SecretsManager: - """Manages encryption, storage and retrieval of secrets""" - - def __init__(self): - self.crypto = CodeGateCrypto() - self._session_store: dict[str, dict[str, SecretEntry]] = {} - self._encrypted_to_session: dict[str, str] = {} # Reverse lookup index - - def store_secret(self, value: str, service: str, secret_type: str, session_id: str) -> str: - """ - Encrypts and stores a secret value. - Returns the encrypted value. - """ - if not value: - raise ValueError("Value must be provided") - if not service: - raise ValueError("Service must be provided") - if not secret_type: - raise ValueError("Secret type must be provided") - if not session_id: - raise ValueError("Session ID must be provided") - - encrypted_value = self.crypto.encrypt_token(value, session_id) - - # Store mappings - session_secrets = self._session_store.get(session_id, {}) - session_secrets[encrypted_value] = SecretEntry( - original=value, - encrypted=encrypted_value, - service=service, - secret_type=secret_type, - ) - self._session_store[session_id] = session_secrets - self._encrypted_to_session[encrypted_value] = session_id - - logger.debug("Stored secret", service=service, type=secret_type, encrypted=encrypted_value) - - return encrypted_value - - def get_original_value(self, encrypted_value: str, session_id: str) -> Optional[str]: - """Retrieve original value for an encrypted value""" - try: - stored_session_id = self._encrypted_to_session.get(encrypted_value) - if stored_session_id == session_id: - session_secrets = self._session_store[session_id].get(encrypted_value) - if session_secrets: - return session_secrets.original - except Exception as e: - logger.error("Error retrieving secret", error=str(e)) - return None - - def get_by_session_id(self, session_id: str) -> Optional[SecretEntry]: - """Get stored data by session ID""" - return self._session_store.get(session_id) - - def cleanup(self): - """Securely wipe sensitive data""" - try: - # Convert and wipe original values - for secrets in self._session_store.values(): - for entry in secrets.values(): - original_bytes = bytearray(entry.original.encode()) - self.crypto.wipe_bytearray(original_bytes) - - # Clear the dictionaries - self._session_store.clear() - self._encrypted_to_session.clear() - - logger.info("Secrets manager data securely wiped") - except Exception as e: - logger.error("Error during secure cleanup", error=str(e)) - - def cleanup_session(self, session_id: str): - """ - Remove a specific session's secrets and perform secure cleanup. - - Args: - session_id (str): The session identifier to remove - """ - try: - # Get the secret entry for the session - secrets = self._session_store.get(session_id, {}) - - for entry in secrets.values(): - # Securely wipe the original value - original_bytes = bytearray(entry.original.encode()) - self.crypto.wipe_bytearray(original_bytes) - - # Remove the encrypted value from the reverse lookup index - self._encrypted_to_session.pop(entry.encrypted, None) - - # Remove the session from the store - self._session_store.pop(session_id, None) - - logger.debug("Session secrets securely removed", session_id=session_id) - else: - logger.debug("No secrets found for session", session_id=session_id) - except Exception as e: - logger.error("Error during session cleanup", session_id=session_id, error=str(e)) diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py index 184c3ba3..527c817f 100644 --- a/src/codegate/pipeline/secrets/secrets.py +++ b/src/codegate/pipeline/secrets/secrets.py @@ -16,8 +16,8 @@ PipelineStep, ) from codegate.pipeline.output import OutputPipelineContext, OutputPipelineStep -from codegate.pipeline.secrets.manager import SecretsManager from codegate.pipeline.secrets.signatures import CodegateSignatures, Match +from codegate.pipeline.sensitive_data.manager import SensitiveData, SensitiveDataManager from codegate.pipeline.systemmsg import add_or_update_system_message logger = structlog.get_logger("codegate") @@ -171,25 +171,35 @@ def obfuscate(self, text: str, snippet: Optional[CodeSnippet]) -> tuple[str, Lis class SecretsEncryptor(SecretsModifier): def __init__( self, - secrets_manager: SecretsManager, + sensitive_data_manager: SensitiveDataManager, context: PipelineContext, session_id: str, ): - self._secrets_manager = secrets_manager + self._sensitive_data_manager = sensitive_data_manager self._session_id = session_id self._context = context self._name = "codegate-secrets" + super().__init__() def _hide_secret(self, match: Match) -> str: # Encrypt and store the value - encrypted_value = self._secrets_manager.store_secret( - match.value, - match.service, - match.type, - self._session_id, + if not self._session_id: + raise ValueError("Session id must be provided") + + if not match.value: + raise ValueError("Value must be provided") + if not match.service: + raise ValueError("Service must be provided") + if not match.type: + raise ValueError("Secret type must be provided") + + obj = SensitiveData(original=match.value, service=match.service, type=match.type) + uuid_placeholder = self._sensitive_data_manager.store(self._session_id, obj) + logger.debug( + "Stored secret", service=match.service, type=match.type, placeholder=uuid_placeholder ) - return f"REDACTED<${encrypted_value}>" + return f"REDACTED<{uuid_placeholder}>" def _notify_secret( self, match: Match, code_snippet: Optional[CodeSnippet], protected_text: List[str] @@ -251,7 +261,7 @@ def _redact_text( self, text: str, snippet: Optional[CodeSnippet], - secrets_manager: SecretsManager, + sensitive_data_manager: SensitiveDataManager, session_id: str, context: PipelineContext, ) -> tuple[str, List[Match]]: @@ -260,14 +270,14 @@ def _redact_text( Args: text: The text to protect - secrets_manager: .. + sensitive_data_manager: .. session_id: .. context: The pipeline context to be able to log alerts Returns: Tuple containing protected text with encrypted values and the count of redacted secrets """ # Find secrets in the text - text_encryptor = SecretsEncryptor(secrets_manager, context, session_id) + text_encryptor = SecretsEncryptor(sensitive_data_manager, context, session_id) return text_encryptor.obfuscate(text, snippet) async def process( @@ -287,8 +297,10 @@ async def process( if "messages" not in request: return PipelineResult(request=request, context=context) - secrets_manager = context.sensitive.manager - if not secrets_manager or not isinstance(secrets_manager, SecretsManager): + sensitive_data_manager = context.sensitive.manager + if not sensitive_data_manager or not isinstance( + sensitive_data_manager, SensitiveDataManager + ): raise ValueError("Secrets manager not found in context") session_id = context.sensitive.session_id if not session_id: @@ -305,7 +317,7 @@ async def process( for i, message in enumerate(new_request["messages"]): if "content" in message and message["content"]: redacted_content, secrets_matched = self._redact_message_content( - message["content"], secrets_manager, session_id, context + message["content"], sensitive_data_manager, session_id, context ) new_request["messages"][i]["content"] = redacted_content if i > last_assistant_idx: @@ -313,7 +325,7 @@ async def process( new_request = self._finalize_redaction(context, total_matches, new_request) return PipelineResult(request=new_request, context=context) - def _redact_message_content(self, message_content, secrets_manager, session_id, context): + def _redact_message_content(self, message_content, sensitive_data_manager, session_id, context): # Extract any code snippets extractor = MessageCodeExtractorFactory.create_snippet_extractor(context.client) snippets = extractor.extract_snippets(message_content) @@ -322,7 +334,7 @@ def _redact_message_content(self, message_content, secrets_manager, session_id, for snippet in snippets: redacted_snippet, secrets_matched = self._redact_text( - snippet, snippet, secrets_manager, session_id, context + snippet, snippet, sensitive_data_manager, session_id, context ) redacted_snippets[snippet.code] = redacted_snippet total_matches.extend(secrets_matched) @@ -336,7 +348,7 @@ def _redact_message_content(self, message_content, secrets_manager, session_id, if start_index > last_end: non_snippet_part = message_content[last_end:start_index] redacted_part, secrets_matched = self._redact_text( - non_snippet_part, "", secrets_manager, session_id, context + non_snippet_part, "", sensitive_data_manager, session_id, context ) non_snippet_parts.append(redacted_part) total_matches.extend(secrets_matched) @@ -347,7 +359,7 @@ def _redact_message_content(self, message_content, secrets_manager, session_id, if last_end < len(message_content): remaining_text = message_content[last_end:] redacted_remaining, secrets_matched = self._redact_text( - remaining_text, "", secrets_manager, session_id, context + remaining_text, "", sensitive_data_manager, session_id, context ) non_snippet_parts.append(redacted_remaining) total_matches.extend(secrets_matched) @@ -428,9 +440,14 @@ async def process_chunk( encrypted_value = match.group(1) if encrypted_value.startswith("$"): encrypted_value = encrypted_value[1:] + + session_id = input_context.sensitive.session_id + if not session_id: + raise ValueError("Session ID not found in context") + original_value = input_context.sensitive.manager.get_original_value( + session_id, encrypted_value, - input_context.sensitive.session_id, ) if original_value is None: diff --git a/src/codegate/pipeline/sensitive_data/manager.py b/src/codegate/pipeline/sensitive_data/manager.py new file mode 100644 index 00000000..89506d15 --- /dev/null +++ b/src/codegate/pipeline/sensitive_data/manager.py @@ -0,0 +1,50 @@ +import json +from typing import Dict, Optional +import pydantic +import structlog +from codegate.pipeline.sensitive_data.session_store import SessionStore + +logger = structlog.get_logger("codegate") + + +class SensitiveData(pydantic.BaseModel): + """Represents sensitive data with additional metadata.""" + + original: str + service: Optional[str] = None + type: Optional[str] = None + + +class SensitiveDataManager: + """Manages encryption, storage, and retrieval of secrets""" + + def __init__(self): + self.session_store = SessionStore() + + def store(self, session_id: str, value: SensitiveData) -> Optional[str]: + if not session_id or not value.original: + return None + return self.session_store.add_mapping(session_id, value.model_dump_json()) + + def get_by_session_id(self, session_id: str) -> Optional[Dict]: + if not session_id: + return None + data = self.session_store.get_by_session_id(session_id) + return SensitiveData.model_validate_json(data) if data else None + + def get_original_value(self, session_id: str, uuid_placeholder: str) -> Optional[str]: + if not session_id: + return None + secret_entry_json = self.session_store.get_mapping(session_id, uuid_placeholder) + return ( + SensitiveData.model_validate_json(secret_entry_json).original + if secret_entry_json + else None + ) + + def cleanup_session(self, session_id: str): + if session_id: + self.session_store.cleanup_session(session_id) + + def cleanup(self): + self.session_store.cleanup() diff --git a/src/codegate/pipeline/sensitive_data/session_store.py b/src/codegate/pipeline/sensitive_data/session_store.py new file mode 100644 index 00000000..5e508847 --- /dev/null +++ b/src/codegate/pipeline/sensitive_data/session_store.py @@ -0,0 +1,33 @@ +from typing import Dict, Optional +import uuid + + +class SessionStore: + """ + A generic session store for managing data protection. + """ + + def __init__(self): + self.sessions: Dict[str, Dict[str, str]] = {} + + def add_mapping(self, session_id: str, data: str) -> str: + uuid_placeholder = f"#{str(uuid.uuid4())}#" + if session_id not in self.sessions: + self.sessions[session_id] = {} + self.sessions[session_id][uuid_placeholder] = data + return uuid_placeholder + + def get_by_session_id(self, session_id: str) -> Optional[Dict]: + return self.sessions.get(session_id, None) + + def get_mapping(self, session_id: str, uuid_placeholder: str) -> Optional[str]: + return self.sessions.get(session_id, {}).get(uuid_placeholder) + + def cleanup_session(self, session_id: str): + """Clears all stored mappings for a specific session.""" + if session_id in self.sessions: + del self.sessions[session_id] + + def cleanup(self): + """Clears all stored mappings for all sessions.""" + self.sessions.clear() diff --git a/src/codegate/providers/copilot/provider.py b/src/codegate/providers/copilot/provider.py index b17e98a8..182f2731 100644 --- a/src/codegate/providers/copilot/provider.py +++ b/src/codegate/providers/copilot/provider.py @@ -17,7 +17,7 @@ from codegate.pipeline.base import PipelineContext from codegate.pipeline.factory import PipelineFactory from codegate.pipeline.output import OutputPipelineInstance -from codegate.pipeline.secrets.manager import SecretsManager +from codegate.pipeline.sensitive_data.manager import SensitiveDataManager from codegate.providers.copilot.mapping import PIPELINE_ROUTES, VALIDATED_ROUTES, PipelineType from codegate.providers.copilot.pipeline import ( CopilotChatPipeline, @@ -200,7 +200,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop): self.ca = CertificateAuthority.get_instance() self.cert_manager = TLSCertDomainManager(self.ca) self._closing = False - self.pipeline_factory = PipelineFactory(SecretsManager()) + self.pipeline_factory = PipelineFactory(SensitiveDataManager()) self.input_pipeline: Optional[CopilotPipeline] = None self.fim_pipeline: Optional[CopilotPipeline] = None # the context as provided by the pipeline diff --git a/tests/pipeline/pii/test_analyzer.py b/tests/pipeline/pii/test_analyzer.py index 8d5a7c6e..d626b8cf 100644 --- a/tests/pipeline/pii/test_analyzer.py +++ b/tests/pipeline/pii/test_analyzer.py @@ -3,44 +3,7 @@ import pytest from presidio_analyzer import RecognizerResult -from codegate.pipeline.pii.analyzer import PiiAnalyzer, PiiSessionStore - - -class TestPiiSessionStore: - def test_init_with_session_id(self): - session_id = "test-session" - store = PiiSessionStore(session_id) - assert store.session_id == session_id - assert store.mappings == {} - - def test_init_without_session_id(self): - store = PiiSessionStore() - assert isinstance(store.session_id, str) - assert len(store.session_id) > 0 - assert store.mappings == {} - - def test_add_mapping(self): - store = PiiSessionStore() - pii = "test@example.com" - placeholder = store.add_mapping(pii) - - assert placeholder.startswith("<") - assert placeholder.endswith(">") - assert store.mappings[placeholder] == pii - - def test_get_pii_existing(self): - store = PiiSessionStore() - pii = "test@example.com" - placeholder = store.add_mapping(pii) - - result = store.get_pii(placeholder) - assert result == pii - - def test_get_pii_nonexistent(self): - store = PiiSessionStore() - placeholder = "" - result = store.get_pii(placeholder) - assert result == placeholder +from codegate.pipeline.pii.analyzer import PiiAnalyzer class TestPiiAnalyzer: @@ -104,68 +67,31 @@ def test_singleton_pattern(self): with pytest.raises(RuntimeError, match="Use PiiAnalyzer.get_instance()"): PiiAnalyzer() - def test_analyze_no_pii(self, analyzer, mock_analyzer_engine): - text = "Hello world" - mock_analyzer_engine.analyze.return_value = [] - - result_text, found_pii, session_store = analyzer.analyze(text) - - assert result_text == text - assert found_pii == [] - assert isinstance(session_store, PiiSessionStore) - - def test_analyze_with_pii(self, analyzer, mock_analyzer_engine): - text = "My email is test@example.com" - email_pii = RecognizerResult( - entity_type="EMAIL_ADDRESS", - start=12, - end=28, - score=1.0, # EmailRecognizer returns a score of 1.0 - ) - mock_analyzer_engine.analyze.return_value = [email_pii] - - result_text, found_pii, session_store = analyzer.analyze(text) - - assert len(found_pii) == 1 - pii_info = found_pii[0] - assert pii_info["type"] == "EMAIL_ADDRESS" - assert pii_info["value"] == "test@example.com" - assert pii_info["score"] == 1.0 - assert pii_info["start"] == 12 - assert pii_info["end"] == 28 - assert "uuid_placeholder" in pii_info - # Verify the placeholder was used to replace the PII - placeholder = pii_info["uuid_placeholder"] - assert result_text == f"My email is {placeholder}" - # Verify the mapping was stored - assert session_store.get_pii(placeholder) == "test@example.com" - def test_restore_pii(self, analyzer): - session_store = PiiSessionStore() original_text = "test@example.com" - placeholder = session_store.add_mapping(original_text) - anonymized_text = f"My email is {placeholder}" + session_id = "session-id" - restored_text = analyzer.restore_pii(anonymized_text, session_store) + placeholder = analyzer.session_store.add_mapping(session_id, original_text) + anonymized_text = f"My email is {placeholder}" + restored_text = analyzer.restore_pii(session_id, anonymized_text) assert restored_text == f"My email is {original_text}" def test_restore_pii_multiple(self, analyzer): - session_store = PiiSessionStore() email = "test@example.com" phone = "123-456-7890" - email_placeholder = session_store.add_mapping(email) - phone_placeholder = session_store.add_mapping(phone) + session_id = "session-id" + email_placeholder = analyzer.session_store.add_mapping(session_id, email) + phone_placeholder = analyzer.session_store.add_mapping(session_id, phone) anonymized_text = f"Email: {email_placeholder}, Phone: {phone_placeholder}" - restored_text = analyzer.restore_pii(anonymized_text, session_store) + restored_text = analyzer.restore_pii(session_id, anonymized_text) assert restored_text == f"Email: {email}, Phone: {phone}" def test_restore_pii_no_placeholders(self, analyzer): - session_store = PiiSessionStore() text = "No PII here" - - restored_text = analyzer.restore_pii(text, session_store) + session_id = "session-id" + restored_text = analyzer.restore_pii(session_id, text) assert restored_text == text diff --git a/tests/pipeline/pii/test_pi.py b/tests/pipeline/pii/test_pi.py index 6578a7b6..06d2881f 100644 --- a/tests/pipeline/pii/test_pi.py +++ b/tests/pipeline/pii/test_pi.py @@ -4,9 +4,10 @@ from litellm import ChatCompletionRequest, ModelResponse from litellm.types.utils import Delta, StreamingChoices -from codegate.pipeline.base import PipelineContext +from codegate.pipeline.base import PipelineContext, PipelineSensitiveData from codegate.pipeline.output import OutputPipelineContext from codegate.pipeline.pii.pii import CodegatePii, PiiRedactionNotifier, PiiUnRedactionStep +from codegate.pipeline.sensitive_data.manager import SensitiveDataManager class TestCodegatePii: @@ -19,8 +20,9 @@ def mock_config(self): yield mock_config @pytest.fixture - def pii_step(self, mock_config): - return CodegatePii() + def pii_step(self): + mock_sensitive_data_manager = MagicMock() + return CodegatePii(mock_sensitive_data_manager) def test_name(self, pii_step): assert pii_step.name == "codegate-pii" @@ -51,57 +53,6 @@ async def test_process_no_messages(self, pii_step): assert result.request == request assert result.context == context - @pytest.mark.asyncio - async def test_process_with_pii(self, pii_step): - original_text = "My email is test@example.com" - request = ChatCompletionRequest( - model="test-model", messages=[{"role": "user", "content": original_text}] - ) - context = PipelineContext() - - # Mock the PII manager's analyze method - placeholder = "" - pii_details = [ - { - "type": "EMAIL_ADDRESS", - "value": "test@example.com", - "score": 1.0, - "start": 12, - "end": 27, - "uuid_placeholder": placeholder, - } - ] - anonymized_text = f"My email is {placeholder}" - pii_step.pii_manager.analyze = MagicMock(return_value=(anonymized_text, pii_details)) - - result = await pii_step.process(request, context) - - # Verify the user message was anonymized - user_messages = [m for m in result.request["messages"] if m["role"] == "user"] - assert len(user_messages) == 1 - assert user_messages[0]["content"] == anonymized_text - - # Verify metadata was updated - assert result.context.metadata["redacted_pii_count"] == 1 - assert len(result.context.metadata["redacted_pii_details"]) == 1 - # The redacted text should be just the placeholder since that's what _get_redacted_snippet returns # noqa: E501 - assert result.context.metadata["redacted_text"] == placeholder - assert "pii_manager" in result.context.metadata - - # Verify system message was added - system_messages = [m for m in result.request["messages"] if m["role"] == "system"] - assert len(system_messages) == 1 - assert system_messages[0]["content"] == "PII has been redacted" - - def test_restore_pii(self, pii_step): - anonymized_text = "My email is " - original_text = "My email is test@example.com" - pii_step.pii_manager.restore_pii = MagicMock(return_value=original_text) - - restored = pii_step.restore_pii(anonymized_text) - - assert restored == original_text - class TestPiiUnRedactionStep: @pytest.fixture @@ -148,7 +99,7 @@ async def test_process_chunk_with_uuid(self, unredaction_step): StreamingChoices( finish_reason=None, index=0, - delta=Delta(content=f"Text with <{uuid}>"), + delta=Delta(content=f"Text with #{uuid}#"), logprobs=None, ) ], @@ -157,17 +108,16 @@ async def test_process_chunk_with_uuid(self, unredaction_step): object="chat.completion.chunk", ) context = OutputPipelineContext() - input_context = PipelineContext() + manager = SensitiveDataManager() + sensitive = PipelineSensitiveData(manager=manager, session_id="session-id") + input_context = PipelineContext(sensitive=sensitive) # Mock PII manager in input context - mock_pii_manager = MagicMock() - mock_session = MagicMock() - mock_session.get_pii = MagicMock(return_value="test@example.com") - mock_pii_manager.session_store = mock_session - input_context.metadata["pii_manager"] = mock_pii_manager + mock_sensitive_data_manager = MagicMock() + mock_sensitive_data_manager.get_original_value = MagicMock(return_value="test@example.com") + input_context.metadata["sensitive_data_manager"] = mock_sensitive_data_manager result = await unredaction_step.process_chunk(chunk, context, input_context) - assert result[0].choices[0].delta.content == "Text with test@example.com" diff --git a/tests/pipeline/pii/test_pii_manager.py b/tests/pipeline/pii/test_pii_manager.py deleted file mode 100644 index 229b7314..00000000 --- a/tests/pipeline/pii/test_pii_manager.py +++ /dev/null @@ -1,106 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from codegate.pipeline.pii.analyzer import PiiSessionStore -from codegate.pipeline.pii.manager import PiiManager - - -class TestPiiManager: - @pytest.fixture - def session_store(self): - """Create a session store that will be shared between the mock and manager""" - return PiiSessionStore() - - @pytest.fixture - def mock_analyzer(self, session_store): - """Create a mock analyzer with the shared session store""" - mock_instance = MagicMock() - mock_instance.analyze = MagicMock() - mock_instance.restore_pii = MagicMock() - mock_instance.session_store = session_store - return mock_instance - - @pytest.fixture - def manager(self, mock_analyzer): - """Create a PiiManager instance with the mocked analyzer""" - with patch("codegate.pipeline.pii.manager.PiiAnalyzer") as mock_analyzer_class: - # Set up the mock class to return our mock instance - mock_analyzer_class.get_instance.return_value = mock_analyzer - # Create the manager which will use our mock - return PiiManager() - - def test_init(self, manager, mock_analyzer): - assert manager.session_store is mock_analyzer.session_store - assert manager.analyzer is mock_analyzer - - def test_analyze_no_pii(self, manager, mock_analyzer): - text = "Hello CodeGate" - session_store = mock_analyzer.session_store - mock_analyzer.analyze.return_value = (text, [], session_store) - - anonymized_text, found_pii = manager.analyze(text) - - assert anonymized_text == text - assert found_pii == [] - assert manager.session_store is session_store - mock_analyzer.analyze.assert_called_once_with(text, context=None) - - def test_analyze_with_pii(self, manager, mock_analyzer): - text = "My email is test@example.com" - session_store = mock_analyzer.session_store - placeholder = "" - pii_details = [ - { - "type": "EMAIL_ADDRESS", - "value": "test@example.com", - "score": 0.85, - "start": 12, - "end": 28, # Fixed end position - "uuid_placeholder": placeholder, - } - ] - anonymized_text = f"My email is {placeholder}" - session_store.mappings[placeholder] = "test@example.com" - mock_analyzer.analyze.return_value = (anonymized_text, pii_details, session_store) - - result_text, found_pii = manager.analyze(text) - - assert "My email is <" in result_text - assert ">" in result_text - assert found_pii == pii_details - assert manager.session_store is session_store - assert manager.session_store.mappings[placeholder] == "test@example.com" - mock_analyzer.analyze.assert_called_once_with(text, context=None) - - def test_restore_pii_no_session(self, manager, mock_analyzer): - text = "Anonymized text" - # Create a new session store that's None - mock_analyzer.session_store = None - - restored_text = manager.restore_pii(text) - - assert restored_text == text - - def test_restore_pii_with_session(self, manager, mock_analyzer): - anonymized_text = "My email is " - original_text = "My email is test@example.com" - manager.session_store.mappings[""] = "test@example.com" - mock_analyzer.restore_pii.return_value = original_text - - restored_text = manager.restore_pii(anonymized_text) - - assert restored_text == original_text - mock_analyzer.restore_pii.assert_called_once_with(anonymized_text, manager.session_store) - - def test_restore_pii_multiple_placeholders(self, manager, mock_analyzer): - anonymized_text = "Email: , Phone: " - original_text = "Email: test@example.com, Phone: 123-456-7890" - manager.session_store.mappings[""] = "test@example.com" - manager.session_store.mappings[""] = "123-456-7890" - mock_analyzer.restore_pii.return_value = original_text - - restored_text = manager.restore_pii(anonymized_text) - - assert restored_text == original_text - mock_analyzer.restore_pii.assert_called_once_with(anonymized_text, manager.session_store) diff --git a/tests/pipeline/secrets/test_gatecrypto.py b/tests/pipeline/secrets/test_gatecrypto.py deleted file mode 100644 index b7de4b19..00000000 --- a/tests/pipeline/secrets/test_gatecrypto.py +++ /dev/null @@ -1,157 +0,0 @@ -import time - -import pytest - -from codegate.pipeline.secrets.gatecrypto import CodeGateCrypto - - -@pytest.fixture -def crypto(): - return CodeGateCrypto() - - -def test_generate_session_key(crypto): - session_id = "test_session" - key = crypto.generate_session_key(session_id) - - assert len(key) == 32 # AES-256 key size - assert session_id in crypto.session_keys - assert isinstance(crypto.session_keys[session_id], tuple) - assert len(crypto.session_keys[session_id]) == 2 - - -def test_get_session_key(crypto): - session_id = "test_session" - original_key = crypto.generate_session_key(session_id) - retrieved_key = crypto.get_session_key(session_id) - - assert original_key == retrieved_key - - -def test_get_expired_session_key(crypto): - session_id = "test_session" - crypto.generate_session_key(session_id) - - # Manually expire the key by modifying its timestamp - key, _ = crypto.session_keys[session_id] - crypto.session_keys[session_id] = (key, time.time() - (crypto.SESSION_KEY_LIFETIME + 10)) - - retrieved_key = crypto.get_session_key(session_id) - assert retrieved_key is None - assert session_id not in crypto.session_keys - - -def test_cleanup_expired_keys(crypto): - # Generate multiple session keys - session_ids = ["session1", "session2", "session3"] - for session_id in session_ids: - crypto.generate_session_key(session_id) - - # Manually expire some keys - key, _ = crypto.session_keys["session1"] - crypto.session_keys["session1"] = (key, time.time() - (crypto.SESSION_KEY_LIFETIME + 10)) - key, _ = crypto.session_keys["session2"] - crypto.session_keys["session2"] = (key, time.time() - (crypto.SESSION_KEY_LIFETIME + 10)) - - crypto.cleanup_expired_keys() - - assert "session1" not in crypto.session_keys - assert "session2" not in crypto.session_keys - assert "session3" in crypto.session_keys - - -def test_encrypt_decrypt_token(crypto): - session_id = "test_session" - original_token = "sensitive_data_123" - - encrypted_token = crypto.encrypt_token(original_token, session_id) - decrypted_token = crypto.decrypt_token(encrypted_token, session_id) - - assert decrypted_token == original_token - - -def test_decrypt_with_expired_session(crypto): - session_id = "test_session" - token = "sensitive_data_123" - - encrypted_token = crypto.encrypt_token(token, session_id) - - # Manually expire the session key - key, _ = crypto.session_keys[session_id] - crypto.session_keys[session_id] = (key, time.time() - (crypto.SESSION_KEY_LIFETIME + 10)) - - with pytest.raises(ValueError, match="Session key expired or invalid."): - crypto.decrypt_token(encrypted_token, session_id) - - -def test_decrypt_with_invalid_session(crypto): - session_id = "test_session" - token = "sensitive_data_123" - - encrypted_token = crypto.encrypt_token(token, session_id) - - with pytest.raises(ValueError, match="Session key expired or invalid."): - crypto.decrypt_token(encrypted_token, "invalid_session") - - -def test_decrypt_with_expired_token(crypto, monkeypatch): - session_id = "test_session" - token = "sensitive_data_123" - current_time = time.time() - - # Mock time.time() for token encryption - monkeypatch.setattr(time, "time", lambda: current_time) - - # Generate token with current timestamp - encrypted_token = crypto.encrypt_token(token, session_id) - - # Mock time.time() to return a future timestamp for decryption - future_time = current_time + crypto.SESSION_KEY_LIFETIME + 10 - monkeypatch.setattr(time, "time", lambda: future_time) - - # Keep the original key but update its timestamp to keep it valid - key, _ = crypto.session_keys[session_id] - crypto.session_keys[session_id] = (key, future_time) - - with pytest.raises(ValueError, match="Token has expired."): - crypto.decrypt_token(encrypted_token, session_id) - - -def test_wipe_bytearray(crypto): - # Create a bytearray with sensitive data - sensitive_data = bytearray(b"sensitive_information") - original_content = sensitive_data.copy() - - # Wipe the data - crypto.wipe_bytearray(sensitive_data) - - # Verify all bytes are zeroed - assert all(byte == 0 for byte in sensitive_data) - assert sensitive_data != original_content - - -def test_wipe_bytearray_invalid_input(crypto): - # Try to wipe a string instead of bytearray - with pytest.raises(ValueError, match="Only bytearray objects can be securely wiped."): - crypto.wipe_bytearray("not a bytearray") - - -def test_encrypt_decrypt_with_special_characters(crypto): - session_id = "test_session" - special_chars_token = "!@#$%^&*()_+-=[]{}|;:,.<>?" - - encrypted_token = crypto.encrypt_token(special_chars_token, session_id) - decrypted_token = crypto.decrypt_token(encrypted_token, session_id) - - assert decrypted_token == special_chars_token - - -def test_encrypt_decrypt_multiple_tokens(crypto): - session_id = "test_session" - tokens = ["token1", "token2", "token3"] - - # Encrypt and immediately decrypt each token - for token in tokens: - encrypted = crypto.encrypt_token(token, session_id) - decrypted = crypto.decrypt_token(encrypted, session_id) - assert decrypted == token diff --git a/tests/pipeline/secrets/test_manager.py b/tests/pipeline/secrets/test_manager.py deleted file mode 100644 index 177e8f3f..00000000 --- a/tests/pipeline/secrets/test_manager.py +++ /dev/null @@ -1,149 +0,0 @@ -import pytest - -from codegate.pipeline.secrets.manager import SecretsManager - - -class TestSecretsManager: - def setup_method(self): - """Setup a fresh SecretsManager for each test""" - self.manager = SecretsManager() - self.test_session = "test_session_id" - self.test_value = "super_secret_value" - self.test_service = "test_service" - self.test_type = "api_key" - - def test_store_secret(self): - """Test basic secret storage and retrieval""" - # Store a secret - encrypted = self.manager.store_secret( - self.test_value, self.test_service, self.test_type, self.test_session - ) - - # Verify the secret was stored - stored = self.manager.get_by_session_id(self.test_session) - assert isinstance(stored, dict) - assert stored[encrypted].original == self.test_value - - # Verify encrypted value can be retrieved - retrieved = self.manager.get_original_value(encrypted, self.test_session) - assert retrieved == self.test_value - - def test_get_original_value_wrong_session(self): - """Test that secrets can't be accessed with wrong session ID""" - encrypted = self.manager.store_secret( - self.test_value, self.test_service, self.test_type, self.test_session - ) - - # Try to retrieve with wrong session ID - wrong_session = "wrong_session_id" - retrieved = self.manager.get_original_value(encrypted, wrong_session) - assert retrieved is None - - def test_get_original_value_nonexistent(self): - """Test handling of non-existent encrypted values""" - retrieved = self.manager.get_original_value("nonexistent", self.test_session) - assert retrieved is None - - def test_cleanup_session(self): - """Test that session cleanup properly removes secrets""" - # Store multiple secrets in different sessions - session1 = "session1" - session2 = "session2" - - encrypted1 = self.manager.store_secret("secret1", "service1", "type1", session1) - encrypted2 = self.manager.store_secret("secret2", "service2", "type2", session2) - - # Clean up session1 - self.manager.cleanup_session(session1) - - # Verify session1 secrets are gone - assert self.manager.get_by_session_id(session1) is None - assert self.manager.get_original_value(encrypted1, session1) is None - - # Verify session2 secrets remain - assert self.manager.get_by_session_id(session2) is not None - assert self.manager.get_original_value(encrypted2, session2) == "secret2" - - def test_cleanup(self): - """Test that cleanup properly wipes all data""" - # Store multiple secrets - self.manager.store_secret("secret1", "service1", "type1", "session1") - self.manager.store_secret("secret2", "service2", "type2", "session2") - - # Perform cleanup - self.manager.cleanup() - - # Verify all data is wiped - assert len(self.manager._session_store) == 0 - assert len(self.manager._encrypted_to_session) == 0 - - def test_multiple_secrets_same_session(self): - """Test storing multiple secrets in the same session""" - # Store multiple secrets in same session - encrypted1 = self.manager.store_secret("secret1", "service1", "type1", self.test_session) - encrypted2 = self.manager.store_secret("secret2", "service2", "type2", self.test_session) - - # Latest secret should be retrievable in the session - stored = self.manager.get_by_session_id(self.test_session) - assert isinstance(stored, dict) - assert stored[encrypted1].original == "secret1" - assert stored[encrypted2].original == "secret2" - - # Both secrets should be retrievable directly - assert self.manager.get_original_value(encrypted1, self.test_session) == "secret1" - assert self.manager.get_original_value(encrypted2, self.test_session) == "secret2" - - # Both encrypted values should map to the session - assert self.manager._encrypted_to_session[encrypted1] == self.test_session - assert self.manager._encrypted_to_session[encrypted2] == self.test_session - - def test_error_handling(self): - """Test error handling in secret operations""" - # Test with None values - with pytest.raises(ValueError): - self.manager.store_secret(None, self.test_service, self.test_type, self.test_session) - - with pytest.raises(ValueError): - self.manager.store_secret(self.test_value, None, self.test_type, self.test_session) - - with pytest.raises(ValueError): - self.manager.store_secret(self.test_value, self.test_service, None, self.test_session) - - with pytest.raises(ValueError): - self.manager.store_secret(self.test_value, self.test_service, self.test_type, None) - - def test_secure_cleanup(self): - """Test that cleanup securely wipes sensitive data""" - # Store a secret - self.manager.store_secret( - self.test_value, self.test_service, self.test_type, self.test_session - ) - - # Get reference to stored data before cleanup - stored = self.manager.get_by_session_id(self.test_session) - assert len(stored) == 1 - - # Perform cleanup - self.manager.cleanup() - - # Verify the original string was overwritten, not just removed - # This test is a bit tricky since Python strings are immutable, - # but we can at least verify the data is no longer accessible - assert self.test_value not in str(self.manager._session_store) - - def test_session_isolation(self): - """Test that sessions are properly isolated""" - session1 = "session1" - session2 = "session2" - - # Store secrets in different sessions - encrypted1 = self.manager.store_secret("secret1", "service1", "type1", session1) - encrypted2 = self.manager.store_secret("secret2", "service2", "type2", session2) - - # Verify cross-session access is not possible - assert self.manager.get_original_value(encrypted1, session2) is None - assert self.manager.get_original_value(encrypted2, session1) is None - - # Verify correct session access works - assert self.manager.get_original_value(encrypted1, session1) == "secret1" - assert self.manager.get_original_value(encrypted2, session2) == "secret2" diff --git a/tests/pipeline/secrets/test_secrets.py b/tests/pipeline/secrets/test_secrets.py index 759b94b0..3f272b5b 100644 --- a/tests/pipeline/secrets/test_secrets.py +++ b/tests/pipeline/secrets/test_secrets.py @@ -7,13 +7,13 @@ from codegate.pipeline.base import PipelineContext, PipelineSensitiveData from codegate.pipeline.output import OutputPipelineContext -from codegate.pipeline.secrets.manager import SecretsManager from codegate.pipeline.secrets.secrets import ( SecretsEncryptor, SecretsObfuscator, SecretUnredactionStep, ) from codegate.pipeline.secrets.signatures import CodegateSignatures, Match +from codegate.pipeline.sensitive_data.manager import SensitiveData, SensitiveDataManager class TestSecretsModifier: @@ -69,9 +69,11 @@ class TestSecretsEncryptor: def setup(self, temp_yaml_file): CodegateSignatures.initialize(temp_yaml_file) self.context = PipelineContext() - self.secrets_manager = SecretsManager() + self.sensitive_data_manager = SensitiveDataManager() self.session_id = "test_session" - self.encryptor = SecretsEncryptor(self.secrets_manager, self.context, self.session_id) + self.encryptor = SecretsEncryptor( + self.sensitive_data_manager, self.context, self.session_id + ) def test_hide_secret(self): # Create a test match @@ -87,12 +89,12 @@ def test_hide_secret(self): # Test secret hiding hidden = self.encryptor._hide_secret(match) - assert hidden.startswith("REDACTED<$") + assert hidden.startswith("REDACTED<") assert hidden.endswith(">") # Verify the secret was stored - encrypted_value = hidden[len("REDACTED<$") : -1] - original = self.secrets_manager.get_original_value(encrypted_value, self.session_id) + encrypted_value = hidden[len("REDACTED<") : -1] + original = self.sensitive_data_manager.get_original_value(self.session_id, encrypted_value) assert original == "AKIAIOSFODNN7EXAMPLE" def test_obfuscate(self): @@ -101,7 +103,7 @@ def test_obfuscate(self): protected, matched_secrets = self.encryptor.obfuscate(text, None) assert len(matched_secrets) == 1 - assert "REDACTED<$" in protected + assert "REDACTED<" in protected assert "AKIAIOSFODNN7EXAMPLE" not in protected assert "Other text" in protected @@ -171,25 +173,24 @@ def setup_method(self): """Setup fresh instances for each test""" self.step = SecretUnredactionStep() self.context = OutputPipelineContext() - self.secrets_manager = SecretsManager() + self.sensitive_data_manager = SensitiveDataManager() self.session_id = "test_session" # Setup input context with secrets manager self.input_context = PipelineContext() self.input_context.sensitive = PipelineSensitiveData( - manager=self.secrets_manager, session_id=self.session_id + manager=self.sensitive_data_manager, session_id=self.session_id ) @pytest.mark.asyncio async def test_complete_marker_processing(self): """Test processing of a complete REDACTED marker""" # Store a secret - encrypted = self.secrets_manager.store_secret( - "secret_value", "test_service", "api_key", self.session_id - ) + obj = SensitiveData(original="secret_value", service="test_service", type="api_key") + encrypted = self.sensitive_data_manager.store(self.session_id, obj) # Add content with REDACTED marker to buffer - self.context.buffer.append(f"Here is the REDACTED<${encrypted}> in text") + self.context.buffer.append(f"Here is the REDACTED<{encrypted}> in text") # Process a chunk result = await self.step.process_chunk( @@ -204,7 +205,7 @@ async def test_complete_marker_processing(self): async def test_partial_marker_buffering(self): """Test handling of partial REDACTED markers""" # Add partial marker to buffer - self.context.buffer.append("Here is REDACTED<$") + self.context.buffer.append("Here is REDACTED<") # Process a chunk result = await self.step.process_chunk( @@ -218,7 +219,7 @@ async def test_partial_marker_buffering(self): async def test_invalid_encrypted_value(self): """Test handling of invalid encrypted values""" # Add content with invalid encrypted value - self.context.buffer.append("Here is REDACTED<$invalid_value> in text") + self.context.buffer.append("Here is REDACTED in text") # Process chunk result = await self.step.process_chunk( @@ -227,7 +228,7 @@ async def test_invalid_encrypted_value(self): # Should keep the REDACTED marker for invalid values assert len(result) == 1 - assert result[0].choices[0].delta.content == "Here is REDACTED<$invalid_value> in text" + assert result[0].choices[0].delta.content == "Here is REDACTED in text" @pytest.mark.asyncio async def test_missing_context(self): @@ -271,12 +272,11 @@ async def test_no_markers(self): async def test_wrong_session(self): """Test unredaction with wrong session ID""" # Store secret with one session - encrypted = self.secrets_manager.store_secret( - "secret_value", "test_service", "api_key", "different_session" - ) + obj = SensitiveData(original="test_service", service="api_key", type="different_session") + encrypted = self.sensitive_data_manager.store("different_session", obj) # Try to unredact with different session - self.context.buffer.append(f"Here is the REDACTED<${encrypted}> in text") + self.context.buffer.append(f"Here is the REDACTED<{encrypted}> in text") result = await self.step.process_chunk( create_model_response("text"), self.context, self.input_context @@ -284,4 +284,4 @@ async def test_wrong_session(self): # Should keep REDACTED marker when session doesn't match assert len(result) == 1 - assert result[0].choices[0].delta.content == f"Here is the REDACTED<${encrypted}> in text" + assert result[0].choices[0].delta.content == f"Here is the REDACTED<{encrypted}> in text" diff --git a/tests/pipeline/sensitive_data/test_manager.py b/tests/pipeline/sensitive_data/test_manager.py new file mode 100644 index 00000000..6115ad14 --- /dev/null +++ b/tests/pipeline/sensitive_data/test_manager.py @@ -0,0 +1,48 @@ +import json +from unittest.mock import MagicMock, patch +import pytest +from codegate.pipeline.sensitive_data.manager import SensitiveData, SensitiveDataManager +from codegate.pipeline.sensitive_data.session_store import SessionStore + + +class TestSensitiveDataManager: + @pytest.fixture + def mock_session_store(self): + """Mock the SessionStore instance used within SensitiveDataManager.""" + return MagicMock(spec=SessionStore) + + @pytest.fixture + def manager(self, mock_session_store): + """Patch SensitiveDataManager to use the mocked SessionStore.""" + with patch.object(SensitiveDataManager, "__init__", lambda self: None): + manager = SensitiveDataManager() + manager.session_store = mock_session_store # Manually inject the mock + return manager + + def test_store_success(self, manager, mock_session_store): + """Test storing a SensitiveData object successfully.""" + session_id = "session-123" + sensitive_data = SensitiveData(original="secret_value", service="AWS", type="API_KEY") + + # Mock session store behavior + mock_session_store.add_mapping.return_value = "uuid-123" + + result = manager.store(session_id, sensitive_data) + + # Verify correct function calls + mock_session_store.add_mapping.assert_called_once_with( + session_id, sensitive_data.model_dump_json() + ) + assert result == "uuid-123" + + def test_store_invalid_session_id(self, manager): + """Test storing data with an invalid session ID (should return None).""" + sensitive_data = SensitiveData(original="secret_value", service="AWS", type="API_KEY") + result = manager.store("", sensitive_data) # Empty session ID + assert result is None + + def test_store_missing_original_value(self, manager): + """Test storing data without an original value (should return None).""" + sensitive_data = SensitiveData(original="", service="AWS", type="API_KEY") # Empty original + result = manager.store("session-123", sensitive_data) + assert result is None diff --git a/tests/pipeline/sensitive_data/test_session_store.py b/tests/pipeline/sensitive_data/test_session_store.py new file mode 100644 index 00000000..b9ab64fe --- /dev/null +++ b/tests/pipeline/sensitive_data/test_session_store.py @@ -0,0 +1,114 @@ +import uuid +import pytest +from codegate.pipeline.sensitive_data.session_store import SessionStore + + +class TestSessionStore: + @pytest.fixture + def session_store(self): + """Fixture to create a fresh SessionStore instance before each test.""" + return SessionStore() + + def test_add_mapping_creates_uuid(self, session_store): + """Test that add_mapping correctly stores data and returns a UUID.""" + session_id = "session-123" + data = "test-data" + + uuid_placeholder = session_store.add_mapping(session_id, data) + + # Ensure the returned placeholder follows the expected format + assert uuid_placeholder.startswith("#") and uuid_placeholder.endswith("#") + assert len(uuid_placeholder) > 2 # Should have a UUID inside + + # Verify data is correctly stored + stored_data = session_store.get_mapping(session_id, uuid_placeholder) + assert stored_data == data + + def test_add_mapping_creates_unique_uuids(self, session_store): + """Ensure multiple calls to add_mapping generate unique UUIDs.""" + session_id = "session-123" + data1 = "data1" + data2 = "data2" + + uuid_placeholder1 = session_store.add_mapping(session_id, data1) + uuid_placeholder2 = session_store.add_mapping(session_id, data2) + + assert uuid_placeholder1 != uuid_placeholder2 # UUIDs must be unique + + # Ensure data is correctly stored + assert session_store.get_mapping(session_id, uuid_placeholder1) == data1 + assert session_store.get_mapping(session_id, uuid_placeholder2) == data2 + + def test_get_by_session_id(self, session_store): + """Test retrieving all stored mappings for a session ID.""" + session_id = "session-123" + data1 = "data1" + data2 = "data2" + + uuid1 = session_store.add_mapping(session_id, data1) + uuid2 = session_store.add_mapping(session_id, data2) + + stored_session_data = session_store.get_by_session_id(session_id) + + assert uuid1 in stored_session_data + assert uuid2 in stored_session_data + assert stored_session_data[uuid1] == data1 + assert stored_session_data[uuid2] == data2 + + def test_get_by_session_id_not_found(self, session_store): + """Test get_by_session_id when session does not exist (should return None).""" + session_id = "non-existent-session" + assert session_store.get_by_session_id(session_id) is None + + def test_get_mapping_success(self, session_store): + """Test retrieving a specific mapping.""" + session_id = "session-123" + data = "test-data" + + uuid_placeholder = session_store.add_mapping(session_id, data) + + assert session_store.get_mapping(session_id, uuid_placeholder) == data + + def test_get_mapping_not_found(self, session_store): + """Test retrieving a mapping that does not exist (should return None).""" + session_id = "session-123" + uuid_placeholder = "#non-existent-uuid#" + + assert session_store.get_mapping(session_id, uuid_placeholder) is None + + def test_cleanup_session(self, session_store): + """Test that cleanup_session removes all data for a session ID.""" + session_id = "session-123" + session_store.add_mapping(session_id, "test-data") + + # Ensure session exists before cleanup + assert session_store.get_by_session_id(session_id) is not None + + session_store.cleanup_session(session_id) + + # Ensure session is removed after cleanup + assert session_store.get_by_session_id(session_id) is None + + def test_cleanup_session_non_existent(self, session_store): + """Test cleanup_session on a non-existent session (should not raise errors).""" + session_id = "non-existent-session" + session_store.cleanup_session(session_id) # Should not fail + assert session_store.get_by_session_id(session_id) is None + + def test_cleanup(self, session_store): + """Test global cleanup removes all stored sessions.""" + session_id1 = "session-1" + session_id2 = "session-2" + + session_store.add_mapping(session_id1, "data1") + session_store.add_mapping(session_id2, "data2") + + # Ensure sessions exist before cleanup + assert session_store.get_by_session_id(session_id1) is not None + assert session_store.get_by_session_id(session_id2) is not None + + session_store.cleanup() + + # Ensure all sessions are removed after cleanup + assert session_store.get_by_session_id(session_id1) is None + assert session_store.get_by_session_id(session_id2) is None diff --git a/tests/test_server.py b/tests/test_server.py index 1e06c096..aa549810 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -14,19 +14,13 @@ from codegate import __version__ from codegate.pipeline.factory import PipelineFactory -from codegate.pipeline.secrets.manager import SecretsManager +from codegate.pipeline.sensitive_data.manager import SensitiveDataManager from codegate.providers.registry import ProviderRegistry from codegate.server import init_app from src.codegate.cli import UvicornServer, cli from src.codegate.codegate_logging import LogFormat, LogLevel -@pytest.fixture -def mock_secrets_manager(): - """Create a mock secrets manager.""" - return MagicMock(spec=SecretsManager) - - @pytest.fixture def mock_provider_registry(): """Create a mock provider registry.""" @@ -96,9 +90,9 @@ def test_version_endpoint(mock_fetch_latest_version, test_client: TestClient) -> assert response_data["is_latest"] is False -@patch("codegate.pipeline.secrets.manager.SecretsManager") +@patch("codegate.pipeline.sensitive_data.manager.SensitiveDataManager") @patch("codegate.server.get_provider_registry") -def test_provider_registration(mock_registry, mock_secrets_mgr, mock_pipeline_factory) -> None: +def test_provider_registration(mock_registry, mock_pipeline_factory) -> None: """Test that all providers are registered correctly.""" init_app(mock_pipeline_factory)