Skip to content

Commit 3e930ce

Browse files
authored
Add TLS support for TCP sockets (#211)
TLS tests requires a TLS enabled memcached server. In order to get one, you must compile memcached with `--enable-tls`
2 parents d2a51bd + f086033 commit 3e930ce

File tree

7 files changed

+95
-13
lines changed

7 files changed

+95
-13
lines changed

bmemcached/client/distributed.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ class DistributedClient(ClientMixin):
1212
It tries to distribute keys over the specified servers using `HashRing` consistent hash.
1313
"""
1414
def __init__(self, servers=('127.0.0.1:11211',), username=None, password=None, compression=None,
15-
socket_timeout=SOCKET_TIMEOUT, pickle_protocol=0, pickler=pickle.Pickler, unpickler=pickle.Unpickler):
15+
socket_timeout=SOCKET_TIMEOUT, pickle_protocol=0, pickler=pickle.Pickler, unpickler=pickle.Unpickler,
16+
tls_context=None):
1617
super(DistributedClient, self).__init__(servers, username, password, compression, socket_timeout,
17-
pickle_protocol, pickler, unpickler)
18+
pickle_protocol, pickler, unpickler, tls_context)
1819
self._ring = HashRing(self._servers)
1920

2021
def _get_server(self, key):

bmemcached/client/mixin.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ class ClientMixin(object):
2828
:type pickler: function
2929
:param unpickler: Use this to replace the object deserialization mechanism.
3030
:type unpickler: function
31+
:param tls_context: A TLS context in order to connect to TLS enabled
32+
memcached servers.
33+
:type tls_context: ssl.SSLContext
3134
"""
3235
def __init__(self, servers=('127.0.0.1:11211',),
3336
username=None,
@@ -36,14 +39,16 @@ def __init__(self, servers=('127.0.0.1:11211',),
3639
socket_timeout=SOCKET_TIMEOUT,
3740
pickle_protocol=PICKLE_PROTOCOL,
3841
pickler=pickle.Pickler,
39-
unpickler=pickle.Unpickler):
42+
unpickler=pickle.Unpickler,
43+
tls_context=None):
4044
self.username = username
4145
self.password = password
4246
self.compression = compression
4347
self.socket_timeout = socket_timeout
4448
self.pickle_protocol = pickle_protocol
4549
self.pickler = pickler
4650
self.unpickler = unpickler
51+
self.tls_context = tls_context
4752
self.set_servers(servers)
4853

4954
@property
@@ -73,6 +78,7 @@ def set_servers(self, servers):
7378
pickle_protocol=self.pickle_protocol,
7479
pickler=self.pickler,
7580
unpickler=self.unpickler,
81+
tls_context=self.tls_context,
7682
) for server in servers]
7783

7884
def flush_all(self, time=0):

bmemcached/protocol.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class Protocol(threading.local):
9999
COMPRESSION_THRESHOLD = 128
100100

101101
def __init__(self, server, username=None, password=None, compression=None, socket_timeout=None,
102-
pickle_protocol=None, pickler=None, unpickler=None):
102+
pickle_protocol=None, pickler=None, unpickler=None, tls_context=None):
103103
super(Protocol, self).__init__()
104104
self.server = server
105105
self._username = username
@@ -112,6 +112,7 @@ def __init__(self, server, username=None, password=None, compression=None, socke
112112
self.pickle_protocol = pickle_protocol
113113
self.pickler = pickler
114114
self.unpickler = unpickler
115+
self.tls_context = tls_context
115116

116117
self.reconnects_deferred_until = None
117118

@@ -144,6 +145,12 @@ def _open_connection(self):
144145
self.connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
145146
self.connection.settimeout(self.socket_timeout)
146147
self.connection.connect((self.host, self.port))
148+
149+
if self.tls_context:
150+
self.connection = self.tls_context.wrap_socket(
151+
self.connection,
152+
server_hostname=self.host,
153+
)
147154
else:
148155
self.connection = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
149156
self.connection.connect(self.server)

requirements_test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ pytest-cov==2.7.1
33
mock==2.0.0
44
flake8==3.7.7
55
bumpversion==0.5.3
6+
trustme==0.6.0

test/conftest.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,38 @@
55
import pytest
66

77

8-
os.environ.setdefault('MEMCACHED_HOST', '127.0.0.1')
8+
os.environ.setdefault("MEMCACHED_HOST", "localhost")
99

1010

11-
@pytest.yield_fixture(scope='session', autouse=True)
11+
@pytest.yield_fixture(scope="session", autouse=True)
1212
def memcached_standard_port():
13-
p = subprocess.Popen(['memcached'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
13+
p = subprocess.Popen(
14+
["memcached"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
15+
)
1416
time.sleep(0.1)
1517
yield p
1618
p.kill()
1719
p.wait()
1820

1921

20-
@pytest.yield_fixture(scope='session', autouse=True)
22+
@pytest.yield_fixture(scope="session", autouse=True)
2123
def memcached_other_port():
22-
p = subprocess.Popen(['memcached', '-p5000'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
24+
p = subprocess.Popen(
25+
["memcached", "-p5000"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
26+
)
2327
time.sleep(0.1)
2428
yield p
2529
p.kill()
2630
p.wait()
2731

2832

29-
@pytest.yield_fixture(scope='session', autouse=True)
33+
@pytest.yield_fixture(scope="session", autouse=True)
3034
def memcached_socket():
31-
p = subprocess.Popen(['memcached', '-s/tmp/memcached.sock'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
35+
p = subprocess.Popen(
36+
["memcached", "-s/tmp/memcached.sock"],
37+
stdout=subprocess.PIPE,
38+
stderr=subprocess.PIPE,
39+
)
3240
time.sleep(0.1)
3341
yield p
3442
p.kill()

test/test_server_parsing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ def testNoPortGiven(self):
2727
self.assertEqual(server.port, 11211)
2828

2929
def testInvalidPort(self):
30-
server = bmemcached.protocol.Protocol('127.0.0.1:blah')
30+
server = bmemcached.protocol.Protocol('{}:blah'.format(os.environ['MEMCACHED_HOST']))
3131
self.assertEqual(server.host, os.environ['MEMCACHED_HOST'])
3232
self.assertEqual(server.port, 11211)
3333

3434
def testNonStandardPort(self):
35-
server = bmemcached.protocol.Protocol('127.0.0.1:5000')
35+
server = bmemcached.protocol.Protocol('{}:5000'.format(os.environ['MEMCACHED_HOST']))
3636
self.assertEqual(server.host, os.environ['MEMCACHED_HOST'])
3737
self.assertEqual(server.port, 5000)
3838

test/test_tls.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import os
2+
import pytest
3+
import subprocess
4+
import ssl
5+
import time
6+
import trustme
7+
8+
import bmemcached
9+
import test_simple_functions
10+
11+
12+
ca = trustme.CA()
13+
server_cert = ca.issue_cert(os.environ["MEMCACHED_HOST"] + u"")
14+
15+
16+
@pytest.yield_fixture(scope="module", autouse=True)
17+
def memcached_tls():
18+
key = server_cert.private_key_pem
19+
cert = server_cert.cert_chain_pems[0]
20+
21+
with cert.tempfile() as c, key.tempfile() as k:
22+
p = subprocess.Popen(
23+
[
24+
"memcached",
25+
"-p5001",
26+
"-Z",
27+
"-o",
28+
"ssl_key={}".format(k),
29+
"-o",
30+
"ssl_chain_cert={}".format(c),
31+
"-o",
32+
"ssl_verify_mode=1",
33+
],
34+
stdout=subprocess.PIPE,
35+
stderr=subprocess.PIPE,
36+
)
37+
time.sleep(0.1)
38+
39+
if p.poll() is not None:
40+
pytest.skip("Memcached server is not built with TLS support.")
41+
42+
yield p
43+
p.kill()
44+
p.wait()
45+
46+
47+
class TLSMemcachedTests(test_simple_functions.MemcachedTests):
48+
"""
49+
Same tests as above, just make sure it works with TLS.
50+
"""
51+
52+
def setUp(self):
53+
ctx = ssl.create_default_context()
54+
55+
ca.configure_trust(ctx)
56+
57+
self.server = "{}:5001".format(os.environ["MEMCACHED_HOST"])
58+
self.client = bmemcached.Client(self.server, tls_context=ctx)
59+
self.reset()

0 commit comments

Comments
 (0)