diff --git a/src/main/java/org/jruby/ext/openssl/SSLContext.java b/src/main/java/org/jruby/ext/openssl/SSLContext.java index f57c2de1..cc61b17c 100644 --- a/src/main/java/org/jruby/ext/openssl/SSLContext.java +++ b/src/main/java/org/jruby/ext/openssl/SSLContext.java @@ -45,6 +45,7 @@ import javax.net.ssl.KeyManager; import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSessionContext; import javax.net.ssl.TrustManager; import javax.net.ssl.X509ExtendedKeyManager; @@ -53,6 +54,7 @@ import org.jruby.Ruby; import org.jruby.RubyArray; import org.jruby.RubyClass; +import org.jruby.RubyString; import org.jruby.RubyFixnum; import org.jruby.RubyHash; import org.jruby.RubyInteger; @@ -60,6 +62,7 @@ import org.jruby.RubyNumeric; import org.jruby.RubyObject; import org.jruby.RubySymbol; +import org.jruby.RubyProc; import org.jruby.anno.JRubyMethod; import org.jruby.common.IRubyWarnings.ID; import org.jruby.runtime.Arity; @@ -242,6 +245,8 @@ public static void createSSLContext(final Ruby runtime, final RubyModule SSL) { SSLContext.addReadWriteAttribute(context, "tmp_dh_callback"); SSLContext.addReadWriteAttribute(context, "servername_cb"); SSLContext.addReadWriteAttribute(context, "renegotiation_cb"); + SSLContext.addReadWriteAttribute(context, "alpn_protocols"); + SSLContext.addReadWriteAttribute(context, "alpn_select_cb"); SSLContext.defineAlias("ssl_timeout", "timeout"); SSLContext.defineAlias("ssl_timeout=", "timeout="); @@ -451,6 +456,29 @@ public IRubyObject setup(final ThreadContext context) { // SSL_CTX_set_tlsext_servername_callback(ctx, ssl_servername_cb); } + final String[] alpnProtocols; + + value = getInstanceVariable("@alpn_protocols"); + if ( value != null && ! value.isNil() ) { + IRubyObject[] alpn_protocols = ((RubyArray) value).toJavaArrayMaybeUnsafe(); + String[] protocols = new String[alpn_protocols.length]; + for(int i = 0; i < protocols.length; i++) { + protocols[i] = alpn_protocols[i].convertToString().asJavaString(); + } + alpnProtocols = protocols; + } else { + alpnProtocols = null; + } + + final RubyProc alpnSelectCb; + value = getInstanceVariable("@alpn_select_cb"); + if ( value != null && ! value.isNil() ) { + alpnSelectCb = (RubyProc) value; + } else { + alpnSelectCb = null; + } + + // NOTE: no API under javax.net to support session get/new/remove callbacks /* val = ossl_sslctx_get_sess_id_ctx(self); @@ -477,7 +505,8 @@ public IRubyObject setup(final ThreadContext context) { */ try { - internalContext = createInternalContext(context, cert, key, store, clientCert, extraChainCert, verifyMode, timeout); + internalContext = createInternalContext(context, cert, key, store, clientCert, extraChainCert, + verifyMode, timeout, alpnProtocols, alpnSelectCb); } catch (GeneralSecurityException e) { throw newSSLError(runtime, e); @@ -505,7 +534,7 @@ private RubyArray matchedCiphersWithCache(final ThreadContext context) { private RubyArray matchedCiphers(final ThreadContext context) { final Ruby runtime = context.runtime; try { - final String[] supported = getSupportedCipherSuites(protocol); + final String[] supported = getSupportedCipherSuites(runtime, protocol); final Collection cipherDefs = CipherStrings.matchingCiphers(this.ciphers, supported, false); @@ -688,14 +717,48 @@ private static class CipherListCache { } } - private static String[] getSupportedCipherSuites(final String protocol) + void setApplicationProtocolsOrSelector(final SSLEngine engine) { + setApplicationProtocolSelector(engine); + setApplicationProtocols(engine); + } + + private void setApplicationProtocolSelector(final SSLEngine engine) { + final RubyProc alpn_select_cb = internalContext.alpnSelectCallback; + if (alpn_select_cb != null) { + engine.setHandshakeApplicationProtocolSelector((_engine, protocols) -> { + final Ruby runtime = getRuntime(); + IRubyObject[] rubyProtocols = new IRubyObject[protocols.size()]; + int i = 0; for (String protocol : protocols) { + rubyProtocols[i++] = runtime.newString(protocol); + } + + IRubyObject[] args = new IRubyObject[] { RubyArray.newArray(runtime, rubyProtocols) }; + IRubyObject selected_protocol = alpn_select_cb.call(runtime.getCurrentContext(), args); + if (selected_protocol != null && !selected_protocol.isNil()) { + return ((RubyString) selected_protocol).asJavaString(); + } + return null; // callback returned nil - none of the advertised names are acceptable + }); + } + } + + private void setApplicationProtocols(final SSLEngine engine) { + final String[] alpn_protocols = internalContext.alpnProtocols; + if (alpn_protocols != null) { + SSLParameters params = engine.getSSLParameters(); + params.setApplicationProtocols(alpn_protocols); + engine.setSSLParameters(params); + } + } + + private static String[] getSupportedCipherSuites(Ruby runtime, final String protocol) throws GeneralSecurityException { - return dummySSLEngine(protocol).getSupportedCipherSuites(); + return dummySSLEngine(runtime, protocol).getSupportedCipherSuites(); } - private static SSLEngine dummySSLEngine(final String protocol) throws GeneralSecurityException { + private static SSLEngine dummySSLEngine(Ruby runtime, final String protocol) throws GeneralSecurityException { javax.net.ssl.SSLContext sslContext = SecurityHelper.getSSLContext(protocol); - sslContext.init(null, null, null); + sslContext.init(null, null, OpenSSL.getSecureRandom(runtime)); return sslContext.createSSLEngine(); } @@ -899,8 +962,9 @@ static RubyClass _SSLContext(final Ruby runtime) { private InternalContext createInternalContext(ThreadContext context, final X509Cert xCert, final PKey pKey, final Store store, final List clientCert, final List extraChainCert, - final int verifyMode, final int timeout) throws NoSuchAlgorithmException, KeyManagementException { - InternalContext internalContext = new InternalContext(xCert, pKey, store, clientCert, extraChainCert, verifyMode, timeout); + final int verifyMode, final int timeout, + final String[] alpnProtocols, final RubyProc alpnSelectCb) throws NoSuchAlgorithmException, KeyManagementException { + InternalContext internalContext = new InternalContext(xCert, pKey, store, clientCert, extraChainCert, verifyMode, timeout, alpnProtocols, alpnSelectCb); internalContext.initSSLContext(context); return internalContext; } @@ -917,7 +981,9 @@ private class InternalContext { final List clientCert, final List extraChainCert, final int verifyMode, - final int timeout) throws NoSuchAlgorithmException { + final int timeout, + final String[] alpnProtocols, + final RubyProc alpnSelectCallback) throws NoSuchAlgorithmException { if ( pKey != null && xCert != null ) { this.privateKey = pKey.getPrivateKey(); @@ -935,6 +1001,8 @@ private class InternalContext { this.extraChainCert = extraChainCert; this.verifyMode = verifyMode; this.timeout = timeout; + this.alpnProtocols = alpnProtocols; + this.alpnSelectCallback = alpnSelectCallback; // initialize SSL context : @@ -982,6 +1050,9 @@ void initSSLContext(final ThreadContext context) throws KeyManagementException { private final int timeout; + private final String[] alpnProtocols; + private final RubyProc alpnSelectCallback; + private final javax.net.ssl.SSLContext sslContext; // part of ssl_verify_cert_chain diff --git a/src/main/java/org/jruby/ext/openssl/SSLSocket.java b/src/main/java/org/jruby/ext/openssl/SSLSocket.java index 147d79e1..c40d6d96 100644 --- a/src/main/java/org/jruby/ext/openssl/SSLSocket.java +++ b/src/main/java/org/jruby/ext/openssl/SSLSocket.java @@ -229,6 +229,9 @@ private SSLEngine ossl_ssl_setup(final ThreadContext context, final boolean serv dummy = ByteBuffer.allocate(0); this.engine = engine; copySessionSetupIfSet(context); + + sslContext.setApplicationProtocolsOrSelector(engine); + return engine; } @@ -238,6 +241,12 @@ private SSLEngine ossl_ssl_setup(final ThreadContext context, final boolean serv @JRubyMethod(name = "context") public final SSLContext context() { return this.sslContext; } + @JRubyMethod(name = "alpn_protocol") + public IRubyObject alpn_protocol(final ThreadContext context) { + final String protocol = engine.getApplicationProtocol(); + return protocol == null ? context.nil : RubyString.newString(context.runtime, protocol); + } + @JRubyMethod(name = "sync") public IRubyObject sync(final ThreadContext context) { final CallSite[] sites = getMetaClass().getExtraCallSites(); @@ -283,7 +292,7 @@ private IRubyObject connectImpl(final ThreadContext context, final boolean block try { if ( ! initialHandshake ) { - SSLEngine engine = ossl_ssl_setup(context, true); + SSLEngine engine = ossl_ssl_setup(context, false); engine.setUseClientMode(true); engine.beginHandshake(); handshakeStatus = engine.getHandshakeStatus(); @@ -343,7 +352,7 @@ private IRubyObject acceptImpl(final ThreadContext context, final boolean blocki try { if ( ! initialHandshake ) { - final SSLEngine engine = ossl_ssl_setup(context, false); + final SSLEngine engine = ossl_ssl_setup(context, true); engine.setUseClientMode(false); final IRubyObject verify_mode = verify_mode(context); if ( verify_mode != context.nil ) { diff --git a/src/test/ruby/ssl/test_session.rb b/src/test/ruby/ssl/test_session.rb index 5f8a4c6d..71d39a3f 100644 --- a/src/test/ruby/ssl/test_session.rb +++ b/src/test/ruby/ssl/test_session.rb @@ -30,6 +30,27 @@ def test_session end end + def test_alpn_protocol_selection_ary + advertised = ["h2", "http/1.1"] + ctx_proc = Proc.new { |ctx| + ctx.alpn_select_cb = -> (protocols) { + assert_equal Array, protocols.class + assert_equal advertised, protocols + protocols.first + } + } + start_server0(PORT, OpenSSL::SSL::VERIFY_NONE, true, ctx_proc: ctx_proc) do |server, port| + sock = TCPSocket.new("127.0.0.1", port) + ctx = OpenSSL::SSL::SSLContext.new("TLSv1_2") + ctx.alpn_protocols = advertised + ssl = OpenSSL::SSL::SSLSocket.new(sock, ctx) + ssl.sync_close = true + ssl.connect + assert_equal("h2", ssl.alpn_protocol) + ssl.puts "abc"; assert_equal "abc\n", ssl.gets + end + end + def test_exposes_session_error OpenSSL::SSL::Session::SessionError end