Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 6fe4d93

Browse files
author
Luke Hinds
committed
Remove locking
1 parent c4d8573 commit 6fe4d93

File tree

2 files changed

+138
-154
lines changed

2 files changed

+138
-154
lines changed

src/codegate/ca/codegate_ca.py

Lines changed: 137 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
import ssl
3-
import threading
43
from dataclasses import dataclass
54
from datetime import datetime, timedelta, timezone
65
from typing import Dict, Optional, Tuple
@@ -10,7 +9,7 @@
109
from cryptography.exceptions import InvalidSignature
1110
from cryptography.hazmat.backends import default_backend
1211
from cryptography.hazmat.primitives import hashes, serialization
13-
from cryptography.hazmat.primitives.asymmetric import rsa, padding
12+
from cryptography.hazmat.primitives.asymmetric import padding, rsa
1413
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
1514

1615
from codegate.config import Config
@@ -41,20 +40,15 @@ def __init__(self, ca_provider: "CertificateAuthority"):
4140
self._cert_cache: Dict[str, CachedCertificate] = {}
4241
self._context_cache: Dict[str, ssl.SSLContext] = {}
4342

44-
# Add lock for thread-safe operations, save anything bad happening with
45-
# cache race conditions!!!
46-
self._cache_lock = threading.Lock()
47-
4843
def get_domain_context(self, server_name: str) -> ssl.SSLContext:
4944
logger.debug(f"Getting domain context for server_name: {server_name}")
50-
with self._cache_lock:
51-
if server_name not in self._context_cache:
52-
logger.debug(f"No cached SSL context for {server_name}, creating new one.")
53-
cert_path, key_path = self._ca.get_domain_certificate(server_name)
54-
context = self._create_domain_ssl_context(cert_path, key_path, server_name)
55-
self._context_cache[server_name] = context
56-
logger.debug(f"Created new SSL context for {server_name}")
57-
return self._context_cache[server_name]
45+
if server_name not in self._context_cache:
46+
logger.debug(f"No cached SSL context for {server_name}, creating new one.")
47+
cert_path, key_path = self._ca.get_domain_certificate(server_name)
48+
context = self._create_domain_ssl_context(cert_path, key_path, server_name)
49+
self._context_cache[server_name] = context
50+
logger.debug(f"Created new SSL context for {server_name}")
51+
return self._context_cache[server_name]
5852

5953
def _create_domain_ssl_context(
6054
self, cert_path: str, key_path: str, domain: str
@@ -116,8 +110,6 @@ def __init__(self):
116110
# Use a separate cache for SSL contexts
117111
self._context_cache: Dict[str, Tuple[ssl.SSLContext, datetime]] = {}
118112

119-
# Add a lock for thread-safe cache operations
120-
self._cache_lock = threading.Lock()
121113
CertificateAuthority._instance = self
122114

123115
# Load existing certificates into cache
@@ -147,7 +139,10 @@ def _load_existing_certificates(self) -> None:
147139
expiry_date = current_time + timedelta(days=TLS_GRACE_PERIOD_DAYS)
148140

149141
for filename in os.listdir(certs_dir):
150-
if filename.endswith('.crt') and not filename in [Config.get_config().ca_cert, Config.get_config().server_cert]:
142+
if (
143+
filename.endswith('.crt') and
144+
filename not in [Config.get_config().ca_cert, Config.get_config().server_cert]
145+
):
151146
cert_path = os.path.join(certs_dir, filename)
152147
key_path = os.path.join(certs_dir, filename.replace('.crt', '.key'))
153148

@@ -179,12 +174,11 @@ def _load_existing_certificates(self) -> None:
179174
# Check if certificate is still valid
180175
if cert.not_valid_after_utc > expiry_date:
181176
logger.debug(f"Loading valid certificate for {common_name}")
182-
with self._cache_lock:
183-
self._cert_cache[common_name] = CachedCertificate(
184-
cert_path=cert_path,
185-
key_path=key_path,
186-
creation_time=datetime.utcnow(),
187-
)
177+
self._cert_cache[common_name] = CachedCertificate(
178+
cert_path=cert_path,
179+
key_path=key_path,
180+
creation_time=datetime.utcnow(),
181+
)
188182
else:
189183
logger.debug(f"Skipping expired certificate for {common_name}")
190184

@@ -196,45 +190,44 @@ def _load_existing_certificates(self) -> None:
196190

197191
def _get_cached_ca_certificates(self) -> Tuple[x509.Certificate, rsa.RSAPrivateKey]:
198192
"""Get CA certificates from cache or load them if needed."""
199-
with self._cache_lock:
200-
current_time = datetime.now(timezone.utc)
201-
202-
# Check if certificates are loaded and not expired
203-
if (
204-
self._ca_cert is not None
205-
and self._ca_key is not None
206-
and self._ca_cert_expiry is not None
207-
and current_time < self._ca_cert_expiry
208-
):
209-
return self._ca_cert, self._ca_key
193+
current_time = datetime.now(timezone.utc)
210194

211-
# Load certificates from disk
212-
logger.debug("Loading CA certificates from disk")
213-
ca_cert_path = self.get_cert_path(Config.get_config().ca_cert)
214-
ca_key_path = self.get_cert_path(Config.get_config().ca_key)
195+
# Check if certificates are loaded and not expired
196+
if (
197+
self._ca_cert is not None
198+
and self._ca_key is not None
199+
and self._ca_cert_expiry is not None
200+
and current_time < self._ca_cert_expiry
201+
):
202+
return self._ca_cert, self._ca_key
203+
204+
# Load certificates from disk
205+
logger.debug("Loading CA certificates from disk")
206+
ca_cert_path = self.get_cert_path(Config.get_config().ca_cert)
207+
ca_key_path = self.get_cert_path(Config.get_config().ca_key)
215208

216-
try:
217-
with open(ca_cert_path, "rb") as f:
218-
self._ca_cert = x509.load_pem_x509_certificate(f.read(), default_backend())
219-
self._ca_cert_expiry = self._ca_cert.not_valid_after_utc
209+
try:
210+
with open(ca_cert_path, "rb") as f:
211+
self._ca_cert = x509.load_pem_x509_certificate(f.read(), default_backend())
212+
self._ca_cert_expiry = self._ca_cert.not_valid_after_utc
220213

221-
with open(ca_key_path, "rb") as f:
222-
self._ca_key = serialization.load_pem_private_key(
223-
f.read(), password=None, backend=default_backend()
224-
)
214+
with open(ca_key_path, "rb") as f:
215+
self._ca_key = serialization.load_pem_private_key(
216+
f.read(), password=None, backend=default_backend()
217+
)
225218

226-
self._ca_last_load_time = current_time
227-
logger.debug("Successfully loaded and cached CA certificates")
228-
return self._ca_cert, self._ca_key
219+
self._ca_last_load_time = current_time
220+
logger.debug("Successfully loaded and cached CA certificates")
221+
return self._ca_cert, self._ca_key
229222

230-
except Exception as e:
231-
logger.error(f"Failed to load CA certificates: {e}")
232-
# Clear cached values on error
233-
self._ca_cert = None
234-
self._ca_key = None
235-
self._ca_cert_expiry = None
236-
self._ca_last_load_time = None
237-
raise
223+
except Exception as e:
224+
logger.error(f"Failed to load CA certificates: {e}")
225+
# Clear cached values on error
226+
self._ca_cert = None
227+
self._ca_key = None
228+
self._ca_cert_expiry = None
229+
self._ca_last_load_time = None
230+
raise
238231

239232
def remove_certificates(self) -> None:
240233
"""Remove all cached certificates and contexts"""
@@ -246,11 +239,10 @@ def remove_certificates(self) -> None:
246239
os.rmdir(self.certs_dir)
247240
os.makedirs(self.certs_dir)
248241
# Clear CA certificate cache
249-
with self._cache_lock:
250-
self._ca_cert = None
251-
self._ca_key = None
252-
self._ca_cert_expiry = None
253-
self._ca_last_load_time = None
242+
self._ca_cert = None
243+
self._ca_key = None
244+
self._ca_cert_expiry = None
245+
self._ca_last_load_time = None
254246
except OSError as e:
255247
logger.error(f"Failed to remove certs directory: {e}")
256248
raise
@@ -362,103 +354,95 @@ def get_domain_certificate(self, domain: str) -> Tuple[str, str]:
362354
# Use cached CA certificates
363355
ca_cert, ca_key = self._get_cached_ca_certificates()
364356

365-
with self._cache_lock:
366-
cached = self._cert_cache.get(domain)
367-
if cached:
368-
# Validate the cached certificate's expiry
369-
try:
370-
with open(cached.cert_path, "rb") as domain_cert_file:
371-
domain_cert = x509.load_pem_x509_certificate(
372-
domain_cert_file.read(), default_backend()
373-
)
374-
# Check if certificate is still valid beyond the grace period
375-
expiry_date = datetime.now(timezone.utc) + timedelta(days=TLS_GRACE_PERIOD_DAYS)
376-
logger.debug(f"Expiry date: {expiry_date}")
377-
logger.debug(f"Certificate expiry: {domain_cert.not_valid_after}")
378-
if domain_cert.not_valid_after_utc > expiry_date:
379-
logger.debug(
380-
f"Using cached certificate for {domain} from {cached.cert_path}"
381-
) # noqa: E501
382-
return cached.cert_path, cached.key_path
383-
else:
384-
logger.debug(f"Cached certificate for {domain} is expiring soon, renewing.")
385-
except Exception as e:
386-
logger.error(f"Failed to validate cached certificate for {domain}: {e}")
357+
cached = self._cert_cache.get(domain)
358+
if cached:
359+
# Validate the cached certificate's expiry
360+
try:
361+
with open(cached.cert_path, "rb") as domain_cert_file:
362+
domain_cert = x509.load_pem_x509_certificate(
363+
domain_cert_file.read(), default_backend()
364+
)
365+
# Check if certificate is still valid beyond the grace period
366+
expiry_date = datetime.now(timezone.utc) + timedelta(days=TLS_GRACE_PERIOD_DAYS)
367+
logger.debug(f"Expiry date: {expiry_date}")
368+
logger.debug(f"Certificate expiry: {domain_cert.not_valid_after}")
369+
if domain_cert.not_valid_after_utc > expiry_date:
370+
logger.debug(
371+
f"Using cached certificate for {domain} from {cached.cert_path}"
372+
) # noqa: E501
373+
return cached.cert_path, cached.key_path
374+
else:
375+
logger.debug(f"Cached certificate for {domain} is expiring soon, renewing.")
376+
except Exception as e:
377+
logger.error(f"Failed to validate cached certificate for {domain}: {e}")
387378

388-
logger.debug(f"Generating new certificate for domain: {domain}")
389-
key = rsa.generate_private_key(
390-
public_exponent=65537,
391-
key_size=2048,
392-
)
379+
logger.debug(f"Generating new certificate for domain: {domain}")
380+
key = rsa.generate_private_key(
381+
public_exponent=65537,
382+
key_size=2048,
383+
)
393384

394-
# Nothing is in the cache or its expired, generate a new one!
385+
# Nothing is in the cache or its expired, generate a new one!
395386

396-
name = x509.Name(
397-
[
398-
x509.NameAttribute(NameOID.COMMON_NAME, domain),
399-
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "CodeGate Generated"),
400-
]
401-
)
387+
name = x509.Name(
388+
[
389+
x509.NameAttribute(NameOID.COMMON_NAME, domain),
390+
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "CodeGate Generated"),
391+
]
392+
)
402393

403-
builder = x509.CertificateBuilder()
404-
builder = builder.subject_name(name)
405-
builder = builder.issuer_name(ca_cert.subject)
406-
builder = builder.public_key(key.public_key())
407-
builder = builder.serial_number(x509.random_serial_number())
408-
builder = builder.not_valid_before(datetime.now(timezone.utc))
409-
builder = builder.not_valid_after(datetime.now(timezone.utc) + timedelta(days=365))
410-
builder = builder.add_extension(
411-
x509.SubjectAlternativeName([x509.DNSName(domain)]), critical=False
412-
)
413-
builder = builder.add_extension(
414-
x509.ExtendedKeyUsage(
415-
[ExtendedKeyUsageOID.SERVER_AUTH, ExtendedKeyUsageOID.CLIENT_AUTH]
416-
),
417-
critical=False,
418-
)
419-
builder = builder.add_extension(
420-
x509.BasicConstraints(ca=False, path_length=None), critical=False
421-
)
394+
builder = x509.CertificateBuilder()
395+
builder = builder.subject_name(name)
396+
builder = builder.issuer_name(ca_cert.subject)
397+
builder = builder.public_key(key.public_key())
398+
builder = builder.serial_number(x509.random_serial_number())
399+
builder = builder.not_valid_before(datetime.now(timezone.utc))
400+
builder = builder.not_valid_after(datetime.now(timezone.utc) + timedelta(days=365))
401+
builder = builder.add_extension(
402+
x509.SubjectAlternativeName([x509.DNSName(domain)]), critical=False
403+
)
404+
builder = builder.add_extension(
405+
x509.ExtendedKeyUsage(
406+
[ExtendedKeyUsageOID.SERVER_AUTH, ExtendedKeyUsageOID.CLIENT_AUTH]
407+
),
408+
critical=False,
409+
)
410+
builder = builder.add_extension(
411+
x509.BasicConstraints(ca=False, path_length=None), critical=False
412+
)
422413

423-
certificate = builder.sign(private_key=ca_key, algorithm=hashes.SHA256())
414+
certificate = builder.sign(private_key=ca_key, algorithm=hashes.SHA256())
424415

425-
cert_dir = Config.get_config().certs_dir
426-
domain_cert_path = os.path.join(cert_dir, f"{domain}.crt")
427-
domain_key_path = os.path.join(cert_dir, f"{domain}.key")
416+
cert_dir = Config.get_config().certs_dir
417+
domain_cert_path = os.path.join(cert_dir, f"{domain}.crt")
418+
domain_key_path = os.path.join(cert_dir, f"{domain}.key")
428419

429-
try:
430-
os.makedirs(cert_dir, exist_ok=True)
431-
except OSError as e:
432-
logger.error(f"Failed to create directory {cert_dir} for {domain}: {e}")
433-
raise
420+
try:
421+
os.makedirs(cert_dir, exist_ok=True)
422+
except OSError as e:
423+
logger.error(f"Failed to create directory {cert_dir} for {domain}: {e}")
424+
raise
434425

435-
try:
436-
logger.debug(f"Saving certificate to {domain_cert_path} for domain {domain}")
437-
with open(domain_cert_path, "wb") as f:
438-
f.write(certificate.public_bytes(serialization.Encoding.PEM))
439-
440-
logger.debug(f"Saving key to {domain_key_path} for domain {domain}")
441-
with open(domain_key_path, "wb") as f:
442-
f.write(
443-
key.private_bytes(
444-
encoding=serialization.Encoding.PEM,
445-
format=serialization.PrivateFormat.PKCS8,
446-
encryption_algorithm=serialization.NoEncryption(),
447-
)
426+
try:
427+
logger.debug(f"Saving certificate to {domain_cert_path} for domain {domain}")
428+
with open(domain_cert_path, "wb") as f:
429+
f.write(certificate.public_bytes(serialization.Encoding.PEM))
430+
431+
logger.debug(f"Saving key to {domain_key_path} for domain {domain}")
432+
with open(domain_key_path, "wb") as f:
433+
f.write(
434+
key.private_bytes(
435+
encoding=serialization.Encoding.PEM,
436+
format=serialization.PrivateFormat.PKCS8,
437+
encryption_algorithm=serialization.NoEncryption(),
448438
)
449-
except OSError as e:
450-
logger.error(f"Failed to save certificate or key for {domain}: {e}")
451-
raise
452-
453-
with self._cache_lock:
454-
self._cert_cache[domain] = CachedCertificate(
455-
cert_path=domain_cert_path,
456-
key_path=domain_key_path,
457-
creation_time=datetime.utcnow(),
458439
)
440+
except OSError as e:
441+
logger.error(f"Failed to save certificate or key for {domain}: {e}")
442+
raise
459443

460-
logger.debug(f"Generated and cached new certificate for {domain}")
461-
return domain_cert_path, domain_key_path
444+
logger.debug(f"Generated and cached new certificate for {domain}")
445+
return domain_cert_path, domain_key_path
462446

463447
def load_ca_certificates(self) -> Tuple[x509.Certificate, rsa.RSAPrivateKey]:
464448
"""Load CA certificates for HTTPS proxy"""
@@ -622,11 +606,11 @@ def is_certificate_valid(cert_path: str) -> bool:
622606
"CA certificates missing or invalid, generating new CA and server certificates."
623607
)
624608
# Clear the CA certificate cache before regenerating
625-
with self._cache_lock:
626-
self._ca_cert = None
627-
self._ca_key = None
628-
self._ca_cert_expiry = None
629-
self._ca_last_load_time = None
609+
# with self._cache_lock:
610+
self._ca_cert = None
611+
self._ca_key = None
612+
self._ca_cert_expiry = None
613+
self._ca_last_load_time = None
630614

631615
self.generate_ca_certificates()
632616
self.generate_server_certificates()

src/codegate/providers/copilot/provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
CopilotFimPipeline,
2222
CopilotPipeline,
2323
)
24-
from codegate.providers.copilot.streaming import SSEProcessor4
24+
from codegate.providers.copilot.streaming import SSEProcessor
2525

2626
setup_logging()
2727
logger = structlog.get_logger("codegate").bind(origin="copilot_proxy")

0 commit comments

Comments
 (0)