1
1
import os
2
2
import ssl
3
- import threading
4
3
from dataclasses import dataclass
5
4
from datetime import datetime , timedelta , timezone
6
5
from typing import Dict , Optional , Tuple
10
9
from cryptography .exceptions import InvalidSignature
11
10
from cryptography .hazmat .backends import default_backend
12
11
from cryptography .hazmat .primitives import hashes , serialization
13
- from cryptography .hazmat .primitives .asymmetric import rsa , padding
12
+ from cryptography .hazmat .primitives .asymmetric import padding , rsa
14
13
from cryptography .x509 .oid import ExtendedKeyUsageOID , NameOID
15
14
16
15
from codegate .config import Config
@@ -41,20 +40,15 @@ def __init__(self, ca_provider: "CertificateAuthority"):
41
40
self ._cert_cache : Dict [str , CachedCertificate ] = {}
42
41
self ._context_cache : Dict [str , ssl .SSLContext ] = {}
43
42
44
- # Add lock for thread-safe operations, save anything bad happening with
45
- # cache race conditions!!!
46
- self ._cache_lock = threading .Lock ()
47
-
48
43
def get_domain_context (self , server_name : str ) -> ssl .SSLContext :
49
44
logger .debug (f"Getting domain context for server_name: { server_name } " )
50
- with self ._cache_lock :
51
- if server_name not in self ._context_cache :
52
- logger .debug (f"No cached SSL context for { server_name } , creating new one." )
53
- cert_path , key_path = self ._ca .get_domain_certificate (server_name )
54
- context = self ._create_domain_ssl_context (cert_path , key_path , server_name )
55
- self ._context_cache [server_name ] = context
56
- logger .debug (f"Created new SSL context for { server_name } " )
57
- return self ._context_cache [server_name ]
45
+ if server_name not in self ._context_cache :
46
+ logger .debug (f"No cached SSL context for { server_name } , creating new one." )
47
+ cert_path , key_path = self ._ca .get_domain_certificate (server_name )
48
+ context = self ._create_domain_ssl_context (cert_path , key_path , server_name )
49
+ self ._context_cache [server_name ] = context
50
+ logger .debug (f"Created new SSL context for { server_name } " )
51
+ return self ._context_cache [server_name ]
58
52
59
53
def _create_domain_ssl_context (
60
54
self , cert_path : str , key_path : str , domain : str
@@ -116,8 +110,6 @@ def __init__(self):
116
110
# Use a separate cache for SSL contexts
117
111
self ._context_cache : Dict [str , Tuple [ssl .SSLContext , datetime ]] = {}
118
112
119
- # Add a lock for thread-safe cache operations
120
- self ._cache_lock = threading .Lock ()
121
113
CertificateAuthority ._instance = self
122
114
123
115
# Load existing certificates into cache
@@ -147,7 +139,10 @@ def _load_existing_certificates(self) -> None:
147
139
expiry_date = current_time + timedelta (days = TLS_GRACE_PERIOD_DAYS )
148
140
149
141
for filename in os .listdir (certs_dir ):
150
- if filename .endswith ('.crt' ) and not filename in [Config .get_config ().ca_cert , Config .get_config ().server_cert ]:
142
+ if (
143
+ filename .endswith ('.crt' ) and
144
+ filename not in [Config .get_config ().ca_cert , Config .get_config ().server_cert ]
145
+ ):
151
146
cert_path = os .path .join (certs_dir , filename )
152
147
key_path = os .path .join (certs_dir , filename .replace ('.crt' , '.key' ))
153
148
@@ -179,12 +174,11 @@ def _load_existing_certificates(self) -> None:
179
174
# Check if certificate is still valid
180
175
if cert .not_valid_after_utc > expiry_date :
181
176
logger .debug (f"Loading valid certificate for { common_name } " )
182
- with self ._cache_lock :
183
- self ._cert_cache [common_name ] = CachedCertificate (
184
- cert_path = cert_path ,
185
- key_path = key_path ,
186
- creation_time = datetime .utcnow (),
187
- )
177
+ self ._cert_cache [common_name ] = CachedCertificate (
178
+ cert_path = cert_path ,
179
+ key_path = key_path ,
180
+ creation_time = datetime .utcnow (),
181
+ )
188
182
else :
189
183
logger .debug (f"Skipping expired certificate for { common_name } " )
190
184
@@ -196,45 +190,44 @@ def _load_existing_certificates(self) -> None:
196
190
197
191
def _get_cached_ca_certificates (self ) -> Tuple [x509 .Certificate , rsa .RSAPrivateKey ]:
198
192
"""Get CA certificates from cache or load them if needed."""
199
- with self ._cache_lock :
200
- current_time = datetime .now (timezone .utc )
201
-
202
- # Check if certificates are loaded and not expired
203
- if (
204
- self ._ca_cert is not None
205
- and self ._ca_key is not None
206
- and self ._ca_cert_expiry is not None
207
- and current_time < self ._ca_cert_expiry
208
- ):
209
- return self ._ca_cert , self ._ca_key
193
+ current_time = datetime .now (timezone .utc )
210
194
211
- # Load certificates from disk
212
- logger .debug ("Loading CA certificates from disk" )
213
- ca_cert_path = self .get_cert_path (Config .get_config ().ca_cert )
214
- ca_key_path = self .get_cert_path (Config .get_config ().ca_key )
195
+ # Check if certificates are loaded and not expired
196
+ if (
197
+ self ._ca_cert is not None
198
+ and self ._ca_key is not None
199
+ and self ._ca_cert_expiry is not None
200
+ and current_time < self ._ca_cert_expiry
201
+ ):
202
+ return self ._ca_cert , self ._ca_key
203
+
204
+ # Load certificates from disk
205
+ logger .debug ("Loading CA certificates from disk" )
206
+ ca_cert_path = self .get_cert_path (Config .get_config ().ca_cert )
207
+ ca_key_path = self .get_cert_path (Config .get_config ().ca_key )
215
208
216
- try :
217
- with open (ca_cert_path , "rb" ) as f :
218
- self ._ca_cert = x509 .load_pem_x509_certificate (f .read (), default_backend ())
219
- self ._ca_cert_expiry = self ._ca_cert .not_valid_after_utc
209
+ try :
210
+ with open (ca_cert_path , "rb" ) as f :
211
+ self ._ca_cert = x509 .load_pem_x509_certificate (f .read (), default_backend ())
212
+ self ._ca_cert_expiry = self ._ca_cert .not_valid_after_utc
220
213
221
- with open (ca_key_path , "rb" ) as f :
222
- self ._ca_key = serialization .load_pem_private_key (
223
- f .read (), password = None , backend = default_backend ()
224
- )
214
+ with open (ca_key_path , "rb" ) as f :
215
+ self ._ca_key = serialization .load_pem_private_key (
216
+ f .read (), password = None , backend = default_backend ()
217
+ )
225
218
226
- self ._ca_last_load_time = current_time
227
- logger .debug ("Successfully loaded and cached CA certificates" )
228
- return self ._ca_cert , self ._ca_key
219
+ self ._ca_last_load_time = current_time
220
+ logger .debug ("Successfully loaded and cached CA certificates" )
221
+ return self ._ca_cert , self ._ca_key
229
222
230
- except Exception as e :
231
- logger .error (f"Failed to load CA certificates: { e } " )
232
- # Clear cached values on error
233
- self ._ca_cert = None
234
- self ._ca_key = None
235
- self ._ca_cert_expiry = None
236
- self ._ca_last_load_time = None
237
- raise
223
+ except Exception as e :
224
+ logger .error (f"Failed to load CA certificates: { e } " )
225
+ # Clear cached values on error
226
+ self ._ca_cert = None
227
+ self ._ca_key = None
228
+ self ._ca_cert_expiry = None
229
+ self ._ca_last_load_time = None
230
+ raise
238
231
239
232
def remove_certificates (self ) -> None :
240
233
"""Remove all cached certificates and contexts"""
@@ -246,11 +239,10 @@ def remove_certificates(self) -> None:
246
239
os .rmdir (self .certs_dir )
247
240
os .makedirs (self .certs_dir )
248
241
# Clear CA certificate cache
249
- with self ._cache_lock :
250
- self ._ca_cert = None
251
- self ._ca_key = None
252
- self ._ca_cert_expiry = None
253
- self ._ca_last_load_time = None
242
+ self ._ca_cert = None
243
+ self ._ca_key = None
244
+ self ._ca_cert_expiry = None
245
+ self ._ca_last_load_time = None
254
246
except OSError as e :
255
247
logger .error (f"Failed to remove certs directory: { e } " )
256
248
raise
@@ -362,103 +354,95 @@ def get_domain_certificate(self, domain: str) -> Tuple[str, str]:
362
354
# Use cached CA certificates
363
355
ca_cert , ca_key = self ._get_cached_ca_certificates ()
364
356
365
- with self ._cache_lock :
366
- cached = self ._cert_cache .get (domain )
367
- if cached :
368
- # Validate the cached certificate's expiry
369
- try :
370
- with open (cached .cert_path , "rb" ) as domain_cert_file :
371
- domain_cert = x509 .load_pem_x509_certificate (
372
- domain_cert_file .read (), default_backend ()
373
- )
374
- # Check if certificate is still valid beyond the grace period
375
- expiry_date = datetime .now (timezone .utc ) + timedelta (days = TLS_GRACE_PERIOD_DAYS )
376
- logger .debug (f"Expiry date: { expiry_date } " )
377
- logger .debug (f"Certificate expiry: { domain_cert .not_valid_after } " )
378
- if domain_cert .not_valid_after_utc > expiry_date :
379
- logger .debug (
380
- f"Using cached certificate for { domain } from { cached .cert_path } "
381
- ) # noqa: E501
382
- return cached .cert_path , cached .key_path
383
- else :
384
- logger .debug (f"Cached certificate for { domain } is expiring soon, renewing." )
385
- except Exception as e :
386
- logger .error (f"Failed to validate cached certificate for { domain } : { e } " )
357
+ cached = self ._cert_cache .get (domain )
358
+ if cached :
359
+ # Validate the cached certificate's expiry
360
+ try :
361
+ with open (cached .cert_path , "rb" ) as domain_cert_file :
362
+ domain_cert = x509 .load_pem_x509_certificate (
363
+ domain_cert_file .read (), default_backend ()
364
+ )
365
+ # Check if certificate is still valid beyond the grace period
366
+ expiry_date = datetime .now (timezone .utc ) + timedelta (days = TLS_GRACE_PERIOD_DAYS )
367
+ logger .debug (f"Expiry date: { expiry_date } " )
368
+ logger .debug (f"Certificate expiry: { domain_cert .not_valid_after } " )
369
+ if domain_cert .not_valid_after_utc > expiry_date :
370
+ logger .debug (
371
+ f"Using cached certificate for { domain } from { cached .cert_path } "
372
+ ) # noqa: E501
373
+ return cached .cert_path , cached .key_path
374
+ else :
375
+ logger .debug (f"Cached certificate for { domain } is expiring soon, renewing." )
376
+ except Exception as e :
377
+ logger .error (f"Failed to validate cached certificate for { domain } : { e } " )
387
378
388
- logger .debug (f"Generating new certificate for domain: { domain } " )
389
- key = rsa .generate_private_key (
390
- public_exponent = 65537 ,
391
- key_size = 2048 ,
392
- )
379
+ logger .debug (f"Generating new certificate for domain: { domain } " )
380
+ key = rsa .generate_private_key (
381
+ public_exponent = 65537 ,
382
+ key_size = 2048 ,
383
+ )
393
384
394
- # Nothing is in the cache or its expired, generate a new one!
385
+ # Nothing is in the cache or its expired, generate a new one!
395
386
396
- name = x509 .Name (
397
- [
398
- x509 .NameAttribute (NameOID .COMMON_NAME , domain ),
399
- x509 .NameAttribute (NameOID .ORGANIZATION_NAME , "CodeGate Generated" ),
400
- ]
401
- )
387
+ name = x509 .Name (
388
+ [
389
+ x509 .NameAttribute (NameOID .COMMON_NAME , domain ),
390
+ x509 .NameAttribute (NameOID .ORGANIZATION_NAME , "CodeGate Generated" ),
391
+ ]
392
+ )
402
393
403
- builder = x509 .CertificateBuilder ()
404
- builder = builder .subject_name (name )
405
- builder = builder .issuer_name (ca_cert .subject )
406
- builder = builder .public_key (key .public_key ())
407
- builder = builder .serial_number (x509 .random_serial_number ())
408
- builder = builder .not_valid_before (datetime .now (timezone .utc ))
409
- builder = builder .not_valid_after (datetime .now (timezone .utc ) + timedelta (days = 365 ))
410
- builder = builder .add_extension (
411
- x509 .SubjectAlternativeName ([x509 .DNSName (domain )]), critical = False
412
- )
413
- builder = builder .add_extension (
414
- x509 .ExtendedKeyUsage (
415
- [ExtendedKeyUsageOID .SERVER_AUTH , ExtendedKeyUsageOID .CLIENT_AUTH ]
416
- ),
417
- critical = False ,
418
- )
419
- builder = builder .add_extension (
420
- x509 .BasicConstraints (ca = False , path_length = None ), critical = False
421
- )
394
+ builder = x509 .CertificateBuilder ()
395
+ builder = builder .subject_name (name )
396
+ builder = builder .issuer_name (ca_cert .subject )
397
+ builder = builder .public_key (key .public_key ())
398
+ builder = builder .serial_number (x509 .random_serial_number ())
399
+ builder = builder .not_valid_before (datetime .now (timezone .utc ))
400
+ builder = builder .not_valid_after (datetime .now (timezone .utc ) + timedelta (days = 365 ))
401
+ builder = builder .add_extension (
402
+ x509 .SubjectAlternativeName ([x509 .DNSName (domain )]), critical = False
403
+ )
404
+ builder = builder .add_extension (
405
+ x509 .ExtendedKeyUsage (
406
+ [ExtendedKeyUsageOID .SERVER_AUTH , ExtendedKeyUsageOID .CLIENT_AUTH ]
407
+ ),
408
+ critical = False ,
409
+ )
410
+ builder = builder .add_extension (
411
+ x509 .BasicConstraints (ca = False , path_length = None ), critical = False
412
+ )
422
413
423
- certificate = builder .sign (private_key = ca_key , algorithm = hashes .SHA256 ())
414
+ certificate = builder .sign (private_key = ca_key , algorithm = hashes .SHA256 ())
424
415
425
- cert_dir = Config .get_config ().certs_dir
426
- domain_cert_path = os .path .join (cert_dir , f"{ domain } .crt" )
427
- domain_key_path = os .path .join (cert_dir , f"{ domain } .key" )
416
+ cert_dir = Config .get_config ().certs_dir
417
+ domain_cert_path = os .path .join (cert_dir , f"{ domain } .crt" )
418
+ domain_key_path = os .path .join (cert_dir , f"{ domain } .key" )
428
419
429
- try :
430
- os .makedirs (cert_dir , exist_ok = True )
431
- except OSError as e :
432
- logger .error (f"Failed to create directory { cert_dir } for { domain } : { e } " )
433
- raise
420
+ try :
421
+ os .makedirs (cert_dir , exist_ok = True )
422
+ except OSError as e :
423
+ logger .error (f"Failed to create directory { cert_dir } for { domain } : { e } " )
424
+ raise
434
425
435
- try :
436
- logger .debug (f"Saving certificate to { domain_cert_path } for domain { domain } " )
437
- with open (domain_cert_path , "wb" ) as f :
438
- f .write (certificate .public_bytes (serialization .Encoding .PEM ))
439
-
440
- logger .debug (f"Saving key to { domain_key_path } for domain { domain } " )
441
- with open (domain_key_path , "wb" ) as f :
442
- f .write (
443
- key .private_bytes (
444
- encoding = serialization .Encoding .PEM ,
445
- format = serialization .PrivateFormat .PKCS8 ,
446
- encryption_algorithm = serialization .NoEncryption (),
447
- )
426
+ try :
427
+ logger .debug (f"Saving certificate to { domain_cert_path } for domain { domain } " )
428
+ with open (domain_cert_path , "wb" ) as f :
429
+ f .write (certificate .public_bytes (serialization .Encoding .PEM ))
430
+
431
+ logger .debug (f"Saving key to { domain_key_path } for domain { domain } " )
432
+ with open (domain_key_path , "wb" ) as f :
433
+ f .write (
434
+ key .private_bytes (
435
+ encoding = serialization .Encoding .PEM ,
436
+ format = serialization .PrivateFormat .PKCS8 ,
437
+ encryption_algorithm = serialization .NoEncryption (),
448
438
)
449
- except OSError as e :
450
- logger .error (f"Failed to save certificate or key for { domain } : { e } " )
451
- raise
452
-
453
- with self ._cache_lock :
454
- self ._cert_cache [domain ] = CachedCertificate (
455
- cert_path = domain_cert_path ,
456
- key_path = domain_key_path ,
457
- creation_time = datetime .utcnow (),
458
439
)
440
+ except OSError as e :
441
+ logger .error (f"Failed to save certificate or key for { domain } : { e } " )
442
+ raise
459
443
460
- logger .debug (f"Generated and cached new certificate for { domain } " )
461
- return domain_cert_path , domain_key_path
444
+ logger .debug (f"Generated and cached new certificate for { domain } " )
445
+ return domain_cert_path , domain_key_path
462
446
463
447
def load_ca_certificates (self ) -> Tuple [x509 .Certificate , rsa .RSAPrivateKey ]:
464
448
"""Load CA certificates for HTTPS proxy"""
@@ -622,11 +606,11 @@ def is_certificate_valid(cert_path: str) -> bool:
622
606
"CA certificates missing or invalid, generating new CA and server certificates."
623
607
)
624
608
# Clear the CA certificate cache before regenerating
625
- with self ._cache_lock :
626
- self ._ca_cert = None
627
- self ._ca_key = None
628
- self ._ca_cert_expiry = None
629
- self ._ca_last_load_time = None
609
+ # with self._cache_lock:
610
+ self ._ca_cert = None
611
+ self ._ca_key = None
612
+ self ._ca_cert_expiry = None
613
+ self ._ca_last_load_time = None
630
614
631
615
self .generate_ca_certificates ()
632
616
self .generate_server_certificates ()
0 commit comments