Skip to content

Commit e2ac3aa

Browse files
committed
feat: start mgo option to use with CA cert
1 parent 9351137 commit e2ac3aa

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

testing/mgo.go

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ type MgoSuite struct {
157157
SkipTestCleanup bool
158158
}
159159

160-
// generatePEM receives server certificate and the server private key
161-
// and creates a PEM file in the given path.
160+
// generatePEM receives server certificate and optionally the
161+
// server private key and creates a PEM file in the given path.
162162
func generatePEM(path string, serverCert *x509.Certificate, serverKey *rsa.PrivateKey) error {
163163
pemFile, err := os.Create(path)
164164
if err != nil {
@@ -172,6 +172,9 @@ func generatePEM(path string, serverCert *x509.Certificate, serverKey *rsa.Priva
172172
if err != nil {
173173
return fmt.Errorf("failed to write cert to %q: %v", path, err)
174174
}
175+
if serverKey == nil {
176+
return nil
177+
}
175178
err = pem.Encode(pemFile, &pem.Block{
176179
Type: "RSA PRIVATE KEY",
177180
Bytes: x509.MarshalPKCS1PrivateKey(serverKey),
@@ -257,6 +260,10 @@ func (inst *MgoInstance) Start(certs *Certs) error {
257260
if err = generatePEM(pemPath, certs.ServerCert, certs.ServerKey); err != nil {
258261
return fmt.Errorf("cannot write cert/key PEM: %v", err)
259262
}
263+
caPath := filepath.Join(dbdir, "ca.crt")
264+
if err = generatePEM(caPath, certs.CACert, nil); err != nil {
265+
return fmt.Errorf("cannot write ca cert PEM: %v", err)
266+
}
260267
inst.certs = certs
261268
}
262269

@@ -324,9 +331,11 @@ func (inst *MgoInstance) run(vers version.Number) error {
324331
}
325332
if inst.certs != nil {
326333
mgoargs = append(mgoargs,
327-
"--sslMode", "requireSSL",
328-
"--sslPEMKeyFile", filepath.Join(inst.dir, "server.pem"),
329-
"--sslPEMKeyPassword=ignored")
334+
"--tlsMode", "requireTLS",
335+
"--tlsCertificateKeyFile", filepath.Join(inst.dir, "server.pem"),
336+
"--tlsCAFile", filepath.Join(inst.dir, "ca.crt"),
337+
"--tlsCertificateKeyFilePassword=ignored",
338+
"--tlsAllowInvalidHostnames=true")
330339
}
331340

332341
mongopath, version, err := installedMongod.Get()
@@ -713,9 +722,15 @@ func MgoDialInfo(certs *Certs, addrs ...string) *mgo.DialInfo {
713722
if certs != nil {
714723
pool := x509.NewCertPool()
715724
pool.AddCert(certs.CACert)
725+
// For testing, we'll just use the server cert.
726+
clientCert := tls.Certificate{
727+
Certificate: [][]byte{certs.ServerCert.Raw},
728+
PrivateKey: certs.ServerKey,
729+
}
716730
tlsConfig := &tls.Config{
717731
RootCAs: pool,
718732
ServerName: "anything",
733+
Certificates: []tls.Certificate{clientCert},
719734
}
720735
dial = func(addr net.Addr) (net.Conn, error) {
721736
conn, err := tls.Dial("tcp", addr.String(), tlsConfig)

0 commit comments

Comments
 (0)