diff --git a/src/codegate/ca/codegate_ca.py b/src/codegate/ca/codegate_ca.py index f534e82e..a95a2716 100644 --- a/src/codegate/ca/codegate_ca.py +++ b/src/codegate/ca/codegate_ca.py @@ -1,18 +1,75 @@ -import datetime import os import ssl +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone from typing import Dict, Optional, Tuple import structlog from cryptography import x509 +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric import padding, rsa from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID from codegate.config import Config logger = structlog.get_logger("codegate") +# Add a buffer to renew certificates slightly before expiry. +TLS_GRACE_PERIOD_DAYS = 2 + + +@dataclass +class CachedCertificate: + """Hold certificate data with metadata""" + + cert_path: str + key_path: str + creation_time: datetime + + +class TLSCertDomainManager: + """ + This class manages SSL contexts for domain certificates with SNI + """ + + def __init__(self, ca_provider: "CertificateAuthority"): + self._ca = ca_provider + # Use strong references for caching + self._cert_cache: Dict[str, CachedCertificate] = {} + self._context_cache: Dict[str, ssl.SSLContext] = {} + + def get_domain_context(self, server_name: str) -> ssl.SSLContext: + cert_path, key_path = self._ca.get_domain_certificate(server_name) + context = self._create_domain_ssl_context(cert_path, key_path, server_name) + return context + + def _create_domain_ssl_context( + self, cert_path: str, key_path: str, domain: str + ) -> ssl.SSLContext: + """ + Domain SNI Context Setting + """ + + logger.debug(f"Loading cert chain from {cert_path} for domain {domain}") + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + try: + context.load_cert_chain(cert_path, key_path) + except ssl.SSLError as e: + logger.error(f"Failed to load cert chain for {domain}: {e}") + raise + + context.minimum_version = ssl.TLSVersion.TLSv1_2 + context.set_ciphers("ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20") + context.options |= ( + ssl.OP_NO_SSLv2 + | ssl.OP_NO_SSLv3 + | ssl.OP_NO_COMPRESSION + | ssl.OP_CIPHER_SERVER_PREFERENCE + ) + return context + class CertificateAuthority: """ @@ -41,194 +98,177 @@ def __init__(self): logger.debug("Initializing Certificate Authority class: CertificateAuthority") self._ca_cert = None self._ca_key = None - self._cert_cache: Dict[str, Tuple[str, str]] = {} - self._load_or_generate_ca() - CertificateAuthority._instance = self + self._ca_cert_expiry = None + self._ca_last_load_time = None - def _load_or_generate_ca(self): - """Load existing CA certificate and key or generate new ones""" - logger.debug("Loading or generating CA certificate and key: fn: _load_or_generate_ca") - ca_cert = os.path.join(Config.get_config().certs_dir, Config.get_config().ca_cert) - ca_key = os.path.join(Config.get_config().certs_dir, Config.get_config().ca_key) - - if os.path.exists(ca_cert) and os.path.exists(ca_key): - # Load existing CA certificate and key - with open(ca_cert, "rb") as f: - logger.debug(f"Loading CA certificate from {ca_cert}") - self._ca_cert = x509.load_pem_x509_certificate(f.read()) - with open(ca_key, "rb") as f: - logger.debug(f"Loading CA key from {ca_key}") - self._ca_key = serialization.load_pem_private_key(f.read(), password=None) - else: - # Generate new CA certificate and key - logger.debug("Generating new CA certificate and key") - self._ca_key = rsa.generate_private_key( - public_exponent=65537, - key_size=4096, - ) + # Use strong references for caching certificates + self._cert_cache: Dict[str, CachedCertificate] = {} + # Use a separate cache for SSL contexts + self._context_cache: Dict[str, Tuple[ssl.SSLContext, datetime]] = {} - name = x509.Name( - [ - x509.NameAttribute(NameOID.COMMON_NAME, "CodeGate CA"), - x509.NameAttribute(NameOID.ORGANIZATION_NAME, "CodeGate"), - x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "CodeGate"), - x509.NameAttribute(NameOID.COUNTRY_NAME, "UK"), - ] - ) - - builder = x509.CertificateBuilder() - builder = builder.subject_name(name) - builder = builder.issuer_name(name) - builder = builder.public_key(self._ca_key.public_key()) - builder = builder.serial_number(x509.random_serial_number()) - builder = builder.not_valid_before(datetime.datetime.utcnow()) - builder = builder.not_valid_after( - datetime.datetime.utcnow() + datetime.timedelta(days=3650) # 10 years - ) - - builder = builder.add_extension( - x509.BasicConstraints(ca=True, path_length=None), - critical=True, - ) - - builder = builder.add_extension( - x509.KeyUsage( - digital_signature=True, - content_commitment=False, - key_encipherment=True, - data_encipherment=False, - key_agreement=False, - key_cert_sign=True, # This is a CA - crl_sign=True, - encipher_only=False, - decipher_only=False, - ), - critical=True, - ) - - self._ca_cert = builder.sign( - private_key=self._ca_key, - algorithm=hashes.SHA256(), - ) - - # Save CA certificate and key - if not os.path.exists(Config.get_config().certs_dir): - logger.debug(f"Creating directory: {Config.get_config().certs_dir}") - os.makedirs(Config.get_config().certs_dir) - - with open(ca_cert, "wb") as f: - logger.debug(f"Saving CA certificate to {ca_cert}") - f.write(self._ca_cert.public_bytes(serialization.Encoding.PEM)) + CertificateAuthority._instance = self - with open(ca_key, "wb") as f: - logger.debug(f"Saving CA key to {ca_key}") - f.write( - self._ca_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) + # Load existing certificates into cache + self._load_existing_certificates() + + def _load_existing_certificates(self) -> None: + """Load existing certificates from disk into the cache.""" + logger.debug("Loading existing certificates from disk into cache") + certs_dir = Config.get_config().certs_dir + + if not os.path.exists(certs_dir): + logger.debug(f"Certificates directory {certs_dir} does not exist") + return + + # First load the CA certificate to verify signatures + try: + ca_cert_path = self.get_cert_path(Config.get_config().ca_cert) + logger.debug(f"Loading CA certificate for verification: {ca_cert_path}") + with open(ca_cert_path, "rb") as f: + ca_cert = x509.load_pem_x509_certificate(f.read(), default_backend()) + except Exception as e: + logger.error(f"Failed to load CA certificate for verification: {e}") + return + + # Get current time for expiry checks + current_time = datetime.now(timezone.utc) + expiry_date = current_time + timedelta(days=TLS_GRACE_PERIOD_DAYS) + + for filename in os.listdir(certs_dir): + if ( + filename.endswith('.crt') and + filename not in [Config.get_config().ca_cert, Config.get_config().server_cert] + ): + cert_path = os.path.join(certs_dir, filename) + key_path = os.path.join(certs_dir, filename.replace('.crt', '.key')) + + # Skip if key file doesn't exist + if not os.path.exists(key_path): + logger.debug(f"Skipping {filename} as key file does not exist") + continue + + try: + # Load and validate certificate + with open(cert_path, "rb") as cert_file: + cert = x509.load_pem_x509_certificate(cert_file.read(), default_backend()) + + # Extract domain from common name + common_name = cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value + + # Verify certificate is signed by our CA + try: + ca_cert.public_key().verify( + cert.signature, + cert.tbs_certificate_bytes, + padding.PKCS1v15(), + cert.signature_hash_algorithm, + ) + except InvalidSignature: + logger.debug(f"Skipping {filename} as it's not signed by our CA") + continue + + # Check if certificate is still valid + if cert.not_valid_after_utc > expiry_date: + logger.debug(f"Loading valid certificate for {common_name}") + self._cert_cache[common_name] = CachedCertificate( + cert_path=cert_path, + key_path=key_path, + creation_time=datetime.utcnow(), + ) + else: + logger.debug(f"Skipping expired certificate for {common_name}") + + except Exception as e: + logger.error(f"Failed to load certificate {filename}: {e}") + continue + + logger.debug(f"Loaded {len(self._cert_cache)} certificates into cache") + + def _get_cached_ca_certificates(self) -> Tuple[x509.Certificate, rsa.RSAPrivateKey]: + """Get CA certificates from cache or load them if needed.""" + current_time = datetime.now(timezone.utc) + + # Check if certificates are loaded and not expired + if ( + self._ca_cert is not None + and self._ca_key is not None + and self._ca_cert_expiry is not None + and current_time < self._ca_cert_expiry + ): + return self._ca_cert, self._ca_key + + # Load certificates from disk + logger.debug("Loading CA certificates from disk") + ca_cert_path = self.get_cert_path(Config.get_config().ca_cert) + ca_key_path = self.get_cert_path(Config.get_config().ca_key) + + try: + with open(ca_cert_path, "rb") as f: + self._ca_cert = x509.load_pem_x509_certificate(f.read(), default_backend()) + self._ca_cert_expiry = self._ca_cert.not_valid_after_utc + + with open(ca_key_path, "rb") as f: + self._ca_key = serialization.load_pem_private_key( + f.read(), password=None, backend=default_backend() ) - def get_domain_certificate(self, domain: str) -> Tuple[str, str]: - """Get or generate a certificate for a specific domain""" - logger.debug(f"Getting domain certificate for domain: {domain} fn: get_domain_certificate") - if domain in self._cert_cache: - return self._cert_cache[domain] - - # Generate new certificate for domain - logger.debug(f"Generating private key for domain: {domain}") - key = rsa.generate_private_key( - public_exponent=65537, - key_size=2048, # 2048 bits is sufficient for domain certs - ) - - name = x509.Name( - [ - x509.NameAttribute(NameOID.COMMON_NAME, domain), - x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Proxy Pilot Generated"), - ] - ) - - builder = x509.CertificateBuilder() - builder = builder.subject_name(name) - builder = builder.issuer_name(self._ca_cert.subject) - builder = builder.public_key(key.public_key()) - builder = builder.serial_number(x509.random_serial_number()) - builder = builder.not_valid_before(datetime.datetime.utcnow()) - builder = builder.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365)) - - # Add domain to SAN - builder = builder.add_extension( - x509.SubjectAlternativeName([x509.DNSName(domain)]), - critical=False, - ) - - # Add extended key usage - builder = builder.add_extension( - x509.ExtendedKeyUsage( - [ - ExtendedKeyUsageOID.SERVER_AUTH, - ExtendedKeyUsageOID.CLIENT_AUTH, - ] - ), - critical=False, - ) - - # Basic constraints (not a CA) - builder = builder.add_extension( - x509.BasicConstraints(ca=False, path_length=None), - critical=True, - ) - - logger.debug(f"Signing certificate for domain: {domain}") - certificate = builder.sign( - private_key=self._ca_key, - algorithm=hashes.SHA256(), - ) - - # Save domain certificate and key - logger.debug(f"Saving certificate and key for domain: {domain}") - domain_cert_path = os.path.join(Config.get_config().certs_dir, f"{domain}.crt") - domain_key_path = os.path.join(Config.get_config().certs_dir, f"{domain}.key") - - with open(domain_cert_path, "wb") as f: - logger.debug(f"Saving certificate to {domain_cert_path}") - f.write(certificate.public_bytes(serialization.Encoding.PEM)) - - with open(domain_key_path, "wb") as f: - logger.debug(f"Saving key to {domain_key_path}") - f.write( - key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) - ) + self._ca_last_load_time = current_time + logger.debug("Successfully loaded and cached CA certificates") + return self._ca_cert, self._ca_key + + except Exception as e: + logger.error(f"Failed to load CA certificates: {e}") + # Clear cached values on error + self._ca_cert = None + self._ca_key = None + self._ca_cert_expiry = None + self._ca_last_load_time = None + raise + + def remove_certificates(self) -> None: + """Remove all cached certificates and contexts""" + logger.debug("Removing all cached certificates and contexts") + self.certs_dir = Config.get_config().certs_dir + # remove and recreate certs directory + try: + logger.debug(f"Removing certs directory: {self.certs_dir}") + os.rmdir(self.certs_dir) + os.makedirs(self.certs_dir) + # Clear CA certificate cache + self._ca_cert = None + self._ca_key = None + self._ca_cert_expiry = None + self._ca_last_load_time = None + except OSError as e: + logger.error(f"Failed to remove certs directory: {e}") + raise + + def generate_ca_certificates(self) -> None: + """ + Generate self-signed CA certificates - self._cert_cache[domain] = (domain_cert_path, domain_key_path) - return domain_cert_path, domain_key_path + Key Attributes are: + X509v3 + Key Usage + Digital Signature, Key Encipherment, Key Cert Sign, CRL Sign - def generate_certificates(self) -> Tuple[str, str]: - """Generate self-signed certificates with proper extensions for HTTPS proxy""" - logger.debug("Generating certificates fn: generate_certificates") + Expirtation: + 1 year from now - if not os.path.exists(Config.get_config().certs_dir): - logger.debug(f"Creating directory: {Config.get_config().certs_dir}") - os.makedirs(Config.get_config().certs_dir) + """ + logger.debug("Generating CA certificates fn: generate_ca_certificates") # Generate private key logger.debug("Generating private key for CA") - ca_private_key = rsa.generate_private_key( + self._ca_key = rsa.generate_private_key( public_exponent=65537, key_size=4096, ) - # Generate public key logger.debug("Generating public key for CA") - ca_public_key = ca_private_key.public_key() + self._ca_public_key = self._ca_key.public_key() - # Add name attributes + # Define certificate subject name = x509.Name( [ x509.NameAttribute(NameOID.COMMON_NAME, "CodeGate CA"), @@ -240,21 +280,20 @@ def generate_certificates(self) -> Tuple[str, str]: # Create certificate builder builder = x509.CertificateBuilder() - - # Basic certificate information builder = builder.subject_name(name) builder = builder.issuer_name(name) - builder = builder.public_key(ca_public_key) + builder = builder.public_key(self._ca_public_key) builder = builder.serial_number(x509.random_serial_number()) - builder = builder.not_valid_before(datetime.datetime.utcnow()) - builder = builder.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365)) + builder = builder.not_valid_before(datetime.now(timezone.utc)) + builder = builder.not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + # Add basic constraints builder = builder.add_extension( x509.BasicConstraints(ca=True, path_length=None), critical=True, ) - # Add key usage extensions + # Add key usage builder = builder.add_extension( x509.KeyUsage( digital_signature=True, @@ -262,7 +301,7 @@ def generate_certificates(self) -> Tuple[str, str]: key_encipherment=True, data_encipherment=False, key_agreement=False, - key_cert_sign=True, # This is a CA + key_cert_sign=True, crl_sign=True, encipher_only=False, decipher_only=False, @@ -270,38 +309,154 @@ def generate_certificates(self) -> Tuple[str, str]: critical=True, ) + # Sign the certificate logger.debug("Signing CA certificate") - ca_cert = builder.sign( - private_key=ca_private_key, + self._ca_cert = builder.sign( + private_key=self._ca_key, algorithm=hashes.SHA256(), ) - # Save CA certificate and key - with open( - os.path.join(Config.get_config().certs_dir, Config.get_config().ca_cert), "wb" - ) as f: - logger.debug(f"Saving CA certificate to {Config.get_config().ca_cert}") - f.write(ca_cert.public_bytes(serialization.Encoding.PEM)) - - with open( - os.path.join(Config.get_config().certs_dir, Config.get_config().ca_key), "wb" - ) as f: - logger.debug(f"Saving CA key to {Config.get_config().ca_key}") + # Set expiry time for cache + self._ca_cert_expiry = self._ca_cert.not_valid_after_utc + self._ca_last_load_time = datetime.now(timezone.utc) + + + # Define file paths for certificate and key + ca_cert_path = self.get_cert_path(Config.get_config().ca_cert) + ca_key_path = self.get_cert_path(Config.get_config().ca_key) + + if not os.path.exists(Config.get_config().certs_dir): + logger.debug(f"Creating directory: {Config.get_config().certs_dir}") + os.makedirs(Config.get_config().certs_dir) + + with open(ca_cert_path, "wb") as f: + logger.debug(f"Saving CA certificate to {ca_cert_path}") + f.write(self._ca_cert.public_bytes(serialization.Encoding.PEM)) + + with open(ca_key_path, "wb") as f: + logger.debug(f"Saving CA key to {ca_key_path}") f.write( - ca_private_key.private_bytes( + self._ca_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), ) ) - # CA generated, now generate server certificate + def get_domain_certificate(self, domain: str) -> Tuple[str, str]: + """Generate or retrieve a cached certificate for a domain.""" + logger.debug(f"Getting domain certificate for domain: {domain}") + + # Use cached CA certificates + ca_cert, ca_key = self._get_cached_ca_certificates() + + cached = self._cert_cache.get(domain) + if cached: + # Validate the cached certificate's expiry + try: + with open(cached.cert_path, "rb") as domain_cert_file: + domain_cert = x509.load_pem_x509_certificate( + domain_cert_file.read(), default_backend() + ) + # Check if certificate is still valid beyond the grace period + expiry_date = datetime.now(timezone.utc) + timedelta(days=TLS_GRACE_PERIOD_DAYS) + logger.debug(f"Certificate expiry: {domain_cert.not_valid_after_utc}") + if domain_cert.not_valid_after_utc > expiry_date: + logger.debug( + f"Using cached certificate for {domain} from {cached.cert_path}" + ) # noqa: E501 + return cached.cert_path, cached.key_path + else: + logger.debug(f"Cached certificate for {domain} is expiring soon, renewing.") + except Exception as e: + logger.error(f"Failed to validate cached certificate for {domain}: {e}") + + logger.debug(f"Generating new certificate for domain: {domain}") + key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + + # Nothing is in the cache or its expired, generate a new one! + + name = x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, domain), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "CodeGate Generated"), + ] + ) + + builder = x509.CertificateBuilder() + builder = builder.subject_name(name) + builder = builder.issuer_name(ca_cert.subject) + builder = builder.public_key(key.public_key()) + builder = builder.serial_number(x509.random_serial_number()) + builder = builder.not_valid_before(datetime.now(timezone.utc)) + builder = builder.not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + builder = builder.add_extension( + x509.SubjectAlternativeName([x509.DNSName(domain)]), critical=False + ) + builder = builder.add_extension( + x509.ExtendedKeyUsage( + [ExtendedKeyUsageOID.SERVER_AUTH, ExtendedKeyUsageOID.CLIENT_AUTH] + ), + critical=False, + ) + builder = builder.add_extension( + x509.BasicConstraints(ca=False, path_length=None), critical=False + ) + + certificate = builder.sign(private_key=ca_key, algorithm=hashes.SHA256()) + + cert_dir = Config.get_config().certs_dir + domain_cert_path = os.path.join(cert_dir, f"{domain}.crt") + domain_key_path = os.path.join(cert_dir, f"{domain}.key") + + try: + os.makedirs(cert_dir, exist_ok=True) + except OSError as e: + logger.error(f"Failed to create directory {cert_dir} for {domain}: {e}") + raise + + try: + logger.debug(f"Saving certificate to {domain_cert_path} for domain {domain}") + with open(domain_cert_path, "wb") as f: + f.write(certificate.public_bytes(serialization.Encoding.PEM)) + + logger.debug(f"Saving key to {domain_key_path} for domain {domain}") + with open(domain_key_path, "wb") as f: + f.write( + key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + except OSError as e: + logger.error(f"Failed to save certificate or key for {domain}: {e}") + raise + + logger.debug(f"Generated and cached new certificate for {domain}") + return domain_cert_path, domain_key_path + + def load_ca_certificates(self) -> Tuple[x509.Certificate, rsa.RSAPrivateKey]: + """Load CA certificates for HTTPS proxy""" + logger.debug("Loading CA certificates fn: load_ca_certificates") + return self._get_cached_ca_certificates() + + def generate_server_certificates(self) -> Tuple[x509.Certificate, rsa.RSAPrivateKey]: + """Generate self-signed server certificates for HTTPS proxy""" + logger.debug("Generating server certificates fn: generate_server_certificates") + try: + ca_cert, ca_key = self._get_cached_ca_certificates() + except Exception as e: + logger.error(f"Failed to load CA certificates: {e}") + raise - ## Generate new certificate for domain logger.debug("Generating private key for server") server_key = rsa.generate_private_key( public_exponent=65537, - key_size=2048, # 2048 bits is sufficient for domain certs + key_size=2048, ) name = x509.Name( @@ -311,22 +466,19 @@ def generate_certificates(self) -> Tuple[str, str]: ] ) - # Add extended key usage extension builder = x509.CertificateBuilder() builder = builder.subject_name(name) builder = builder.issuer_name(ca_cert.subject) builder = builder.public_key(server_key.public_key()) builder = builder.serial_number(x509.random_serial_number()) - builder = builder.not_valid_before(datetime.datetime.utcnow()) - builder = builder.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365)) + builder = builder.not_valid_before(datetime.now(timezone.utc)) + builder = builder.not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) - # Add domain to SAN builder = builder.add_extension( x509.SubjectAlternativeName([x509.DNSName("localhost")]), critical=False, ) - # Add extended key usage builder = builder.add_extension( x509.ExtendedKeyUsage( [ @@ -337,51 +489,63 @@ def generate_certificates(self) -> Tuple[str, str]: critical=False, ) - # Basic constraints (not a CA) builder = builder.add_extension( x509.BasicConstraints(ca=False, path_length=None), - critical=True, + critical=False, ) - logger.debug("Signing server certificate") + logger.debug("Signing server certificate with CA") server_cert = builder.sign( - private_key=ca_private_key, + private_key=ca_key, algorithm=hashes.SHA256(), ) - with open( - os.path.join(Config.get_config().certs_dir, Config.get_config().server_cert), "wb" - ) as f: - logger.debug(f"Saving server certificate to {Config.get_config().server_cert}") - f.write(server_cert.public_bytes(serialization.Encoding.PEM)) + server_cert_path = os.path.join( + Config.get_config().certs_dir, Config.get_config().server_cert + ) + server_key_path = os.path.join( + Config.get_config().certs_dir, Config.get_config().server_key + ) - with open( - os.path.join(Config.get_config().certs_dir, Config.get_config().server_key), "wb" - ) as f: - logger.debug(f"Saving server key to {Config.get_config().server_key}") - f.write( - server_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), + try: + logger.debug(f"Saving server certificate to {server_cert_path}") + with open(server_cert_path, "wb") as f: + f.write(server_cert.public_bytes(serialization.Encoding.PEM)) + + logger.debug(f"Saving server key to {server_key_path}") + with open(server_key_path, "wb") as f: + f.write( + server_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) ) - ) + except OSError as e: + logger.error(f"Failed to write server certificate or key: {e}") + raise - # Print instructions for trusting the certificates - logger.debug("Certificates generated successfully") + logger.debug("Server certificates generated successfully") return server_cert, server_key - def create_ssl_context(self) -> ssl.SSLContext: + def create_server_ssl_context(self) -> ssl.SSLContext: """Create SSL context with secure configuration""" + server_cert_path = self.get_cert_path(Config.get_config().server_cert) + server_key_path = self.get_cert_path(Config.get_config().server_key) + logger.debug("Creating SSL context fn: create_ssl_context") + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - logger.debug( - f"Loading server certificate for ssl_context from: {Config.get_config().server_cert}" - ) - ssl_context.load_cert_chain( - os.path.join(Config.get_config().certs_dir, Config.get_config().server_cert), - os.path.join(Config.get_config().certs_dir, Config.get_config().server_key), - ) + + logger.debug(f"Using server cert: {server_cert_path}") + + logger.debug(f"Using server cert: {server_key_path}") + try: + ssl_context.load_cert_chain(server_cert_path, server_key_path) + except ssl.SSLError as e: + logger.error(f"Failed to load cert chain: {e}") + raise + ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 ssl_context.options |= ( ssl.OP_NO_SSLv2 @@ -393,29 +557,67 @@ def create_ssl_context(self) -> ssl.SSLContext: logger.debug("SSL context created successfully") return ssl_context - def check_certificates_exist(self) -> bool: - """Check if SSL certificates exist""" - logger.debug("Checking if certificates exist fn: check_certificates_exist") - return os.path.exists( - os.path.join(Config.get_config().certs_dir, Config.get_config().server_cert) - ) and os.path.exists( - os.path.join(Config.get_config().certs_dir, Config.get_config().server_key) - ) + def check_and_ensure_certificates(self) -> bool: + """Check if SSL certificates exist, ensure their presence, and validate them.""" + + logger.debug("Checking and ensuring SSL certificates exist: check_and_ensure_certificates") + + def is_certificate_valid(cert_path: str) -> bool: + """Check if a certificate is valid (not expired) using cryptography.""" + try: + with open(cert_path, "rb") as cert_file: + cert_data = cert_file.read() # Read the certificate file + cert = x509.load_pem_x509_certificate(cert_data, default_backend()) + + # Use timezone-aware expiration date + expiration_date = cert.not_valid_after_utc + current_time = datetime.now(timezone.utc) + return expiration_date > current_time + except Exception as e: + logger.error(f"Failed to validate certificate {cert_path}: {e}") + return False + + server_cert_path = self.get_cert_path(Config.get_config().server_cert) + server_key_path = self.get_cert_path(Config.get_config().server_key) + ca_cert_path = self.get_cert_path(Config.get_config().ca_cert) + ca_key_path = self.get_cert_path(Config.get_config().ca_key) + + cert_status = { + "server_cert": os.path.exists(server_cert_path) + and is_certificate_valid(server_cert_path), + "server_key": os.path.exists(server_key_path), + "ca_cert": os.path.exists(ca_cert_path) and is_certificate_valid(ca_cert_path), + "ca_key": os.path.exists(ca_key_path), + } + + for cert_name, exists in cert_status.items(): + logger.debug(f"{cert_name} exists: {exists}") + + if all(cert_status.values()): + return True + + if not cert_status["ca_cert"] or not cert_status["ca_key"]: + logger.info( + "CA certificates missing or invalid, generating new CA and server certificates." + ) + # Clear the CA certificate cache before regenerating + # with self._cache_lock: + self._ca_cert = None + self._ca_key = None + self._ca_cert_expiry = None + self._ca_last_load_time = None + + self.generate_ca_certificates() + self.generate_server_certificates() + + elif not cert_status["server_cert"] or not cert_status["server_key"]: + logger.info( + "Server certificates missing or invalid, generating new server certificates." + ) + self.generate_server_certificates() + + return False - def ensure_certificates_exist(self) -> bool: - """Ensure SSL certificates exist, generate if they don't""" - logger.debug("Ensuring certificates exist. fn ensure_certificates_exist") - if not self.check_certificates_exist(): - logger.info("Certificates not found. Generating new certificates.") - self.generate_certificates() - - def get_ssl_context(self) -> ssl.SSLContext: - """Get SSL context with certificates""" - logger.debug("Getting SSL context fn: get_ssl_context") - self.ensure_certificates_exist() - return self.create_ssl_context() - - def get_cert_files(self) -> Tuple[str, str]: - """Get certificate and key file paths""" - logger.debug("Getting certificate and key file paths fn: get_cert_files") - return Config.get_config().server_cert, Config.get_config().server_key + def get_cert_path(self, cert_name: str) -> str: + logger.debug(f"Using path: {Config.get_config().certs_dir, cert_name}") + return os.path.join(Config.get_config().certs_dir, cert_name) diff --git a/src/codegate/cli.py b/src/codegate/cli.py index e1e30634..983d9a60 100644 --- a/src/codegate/cli.py +++ b/src/codegate/cli.py @@ -303,7 +303,12 @@ def serve( # Check certificates and create CA if necessary logger.info("Checking certificates and creating CA if needed") ca = CertificateAuthority.get_instance() - ca.ensure_certificates_exist() + + certs_check = ca.check_and_ensure_certificates() + if certs_check: + click.echo("New Certificates generated successfully.") + else: + click.echo("Existing Certificates are already present.") # Initialize secrets manager and pipeline factory secrets_manager = SecretsManager() @@ -452,7 +457,10 @@ def restore_backup(backup_path: Path, backup_name: str) -> None: "--force-certs", is_flag=True, default=False, - help="Force the generation of certificates even if they already exist.", + help=( + "Force the generation of certificates even if they already exist. " + "Warning: this will overwrite existing certificates." + ), ) @click.option( "--log-level", @@ -488,17 +496,20 @@ def generate_certs( cli_log_format=log_format, ) setup_logging(cfg.log_level, cfg.log_format) + logger = structlog.get_logger("codegate").bind(origin="cli") ca = CertificateAuthority.get_instance() - should_generate = force_certs or not ca.check_certificates_exist() - if should_generate: - ca.generate_certificates() - click.echo("Certificates generated successfully.") - click.echo(f"Certificates saved to: {cfg.certs_dir}") - click.echo("Make sure to add the new CA certificate to the operating system trust store.") + # Remove and regenerate certificates if forced; otherwise, just ensure they exist + logger.info("Checking certificates and creating certs if needed") + if force_certs: + ca.remove_certificates() + + certs_check = ca.check_and_ensure_certificates() + if certs_check: + logger.info("New Certificates generated successfully.") else: - click.echo("Certificates already exist. Skipping generation...") + logger.info("Existing Certificates are already present.") def main() -> None: diff --git a/src/codegate/codegate_logging.py b/src/codegate/codegate_logging.py index c6bb7c37..7bf77fc0 100644 --- a/src/codegate/codegate_logging.py +++ b/src/codegate/codegate_logging.py @@ -146,7 +146,7 @@ def setup_logging( root_logger.addHandler(stderr_handler) # Set explicitly the log level for other modules - logging.getLogger("sqlalchemy").setLevel(logging.WARNING) + logging.getLogger("sqlalchemy").disabled = True logging.getLogger("uvicorn.error").disabled = True # Create a logger for our package diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index f3c57c1a..bdf75259 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -113,7 +113,8 @@ async def record_request(self, prompt_params: Optional[Prompt] = None) -> Option """ ) recorded_request = await self._insert_pydantic_model(prompt_params, sql) - logger.debug(f"Recorded request: {recorded_request}") + # Uncomment to debug the recorded request + # logger.debug(f"Recorded request: {recorded_request}") return recorded_request async def record_outputs(self, outputs: List[Output]) -> Optional[Output]: diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index 80ac119a..2bf82176 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -133,7 +133,8 @@ def add_input_request( type="fim" if is_fim_request else "chat", request=request_str, ) - logger.debug(f"Added input request to context: {self.input_request}") + # Uncomment the below to debug the input + # logger.debug(f"Added input request to context: {self.input_request}") except Exception as e: logger.warning(f"Failed to serialize input request: {normalized_request}", error=str(e)) diff --git a/src/codegate/providers/copilot/pipeline.py b/src/codegate/providers/copilot/pipeline.py index 76ef0f48..5268aeaa 100644 --- a/src/codegate/providers/copilot/pipeline.py +++ b/src/codegate/providers/copilot/pipeline.py @@ -119,7 +119,8 @@ async def process_body(self, headers: list[str], body: bytes) -> Tuple[bytes, Pi # the pipeline did modify the request, return to the user # in the original LLM format body = self.normalizer.denormalize(result.request) - logger.debug(f"Pipeline processed request: {body}") + # Uncomment the below to debug the request + # logger.debug(f"Pipeline processed request: {body}") return body, result.context except Exception as e: diff --git a/src/codegate/providers/copilot/provider.py b/src/codegate/providers/copilot/provider.py index de8fa98e..baa31ad9 100644 --- a/src/codegate/providers/copilot/provider.py +++ b/src/codegate/providers/copilot/provider.py @@ -8,7 +8,7 @@ import structlog from litellm.types.utils import Delta, ModelResponse, StreamingChoices -from codegate.ca.codegate_ca import CertificateAuthority +from codegate.ca.codegate_ca import CertificateAuthority, TLSCertDomainManager from codegate.codegate_logging import setup_logging from codegate.config import Config from codegate.pipeline.base import PipelineContext @@ -147,6 +147,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop): self.ssl_context: Optional[ssl.SSLContext] = None self.proxy_ep: Optional[str] = None self.ca = CertificateAuthority.get_instance() + self.cert_manager = TLSCertDomainManager(self.ca) self._closing = False self.pipeline_factory = PipelineFactory(SecretsManager()) self.context_tracking: Optional[PipelineContext] = None @@ -494,8 +495,8 @@ def handle_connect(self) -> None: self.target_host, port = path.split(":") self.target_port = int(port) - cert_path, key_path = self.ca.get_domain_certificate(self.target_host) - self.ssl_context = self._create_ssl_context(cert_path, key_path) + # Get SSL context through the TLS handler + self.ssl_context = self.cert_manager.get_domain_context(self.target_host) self.is_connect = True asyncio.create_task(self.connect_to_target()) @@ -505,13 +506,6 @@ def handle_connect(self) -> None: logger.error(f"Error handling CONNECT: {e}") self.send_error_response(502, str(e).encode()) - def _create_ssl_context(self, cert_path: str, key_path: str) -> ssl.SSLContext: - """Create SSL context for CONNECT tunneling""" - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - ssl_context.load_cert_chain(cert_path, key_path) - ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 - return ssl_context - async def connect_to_target(self) -> None: """Establish connection to target for CONNECT requests""" try: @@ -616,7 +610,7 @@ async def run_proxy_server(cls) -> None: """Run the proxy server""" try: ca = CertificateAuthority.get_instance() - ssl_context = ca.create_ssl_context() + ssl_context = ca.create_server_ssl_context() config = Config.get_config() server = await cls.create_proxy_server(config.host, config.proxy_port, ssl_context)