diff --git a/src/main/java/org/jruby/ext/openssl/SSLContext.java b/src/main/java/org/jruby/ext/openssl/SSLContext.java index a62ee73b..26cc4a69 100644 --- a/src/main/java/org/jruby/ext/openssl/SSLContext.java +++ b/src/main/java/org/jruby/ext/openssl/SSLContext.java @@ -175,6 +175,7 @@ public static void createSSLContext(final Ruby runtime, final RubyModule SSL) { SSLContext.addReadWriteAttribute(context, "session_id_context"); SSLContext.addReadWriteAttribute(context, "tmp_dh_callback"); SSLContext.addReadWriteAttribute(context, "servername_cb"); + SSLContext.addReadWriteAttribute(context, "renegotiation_cb"); SSLContext.defineAlias("ssl_timeout", "timeout"); SSLContext.defineAlias("ssl_timeout=", "timeout="); diff --git a/src/main/java/org/jruby/ext/openssl/SSLSocket.java b/src/main/java/org/jruby/ext/openssl/SSLSocket.java index 49e5a538..419acbee 100644 --- a/src/main/java/org/jruby/ext/openssl/SSLSocket.java +++ b/src/main/java/org/jruby/ext/openssl/SSLSocket.java @@ -241,6 +241,7 @@ private IRubyObject connectImpl(final ThreadContext context, final boolean block handshakeStatus = engine.getHandshakeStatus(); initialHandshake = true; } + callRenegotiationCallback(context); final IRubyObject ex = doHandshake(blocking, exception); if ( ex != null ) return ex; // :wait_readable | :wait_writable } @@ -317,6 +318,7 @@ private IRubyObject acceptImpl(final ThreadContext context, final boolean blocki handshakeStatus = engine.getHandshakeStatus(); initialHandshake = true; } + callRenegotiationCallback(context); final IRubyObject ex = doHandshake(blocking, exception); if ( ex != null ) return ex; // :wait_readable | :wait_writable } @@ -584,6 +586,18 @@ private int writeToChannel(ByteBuffer buffer, boolean blocking) throws IOExcepti private void finishInitialHandshake() { initialHandshake = false; } + + private void callRenegotiationCallback(final ThreadContext context) throws RaiseException { + IRubyObject renegotiationCallback = sslContext.getInstanceVariable("@renegotiation_cb"); + if(renegotiationCallback == null || renegotiationCallback.isNil()) { + return; + } + else { + // the return of the Proc is not important + // Can throw ruby exception to "disallow" renegotiations + renegotiationCallback.callMethod(context, "call", this); + } + } public int write(ByteBuffer src, boolean blocking) throws SSLException, IOException { if ( initialHandshake ) { diff --git a/src/test/ruby/ssl/test_ssl.rb b/src/test/ruby/ssl/test_ssl.rb index ccdcdbcf..d0faee1f 100644 --- a/src/test/ruby/ssl/test_ssl.rb +++ b/src/test/ruby/ssl/test_ssl.rb @@ -185,4 +185,17 @@ def test_connect_nonblock_would_block end end if RUBY_VERSION > '1.9' + def test_renegotiation_cb + num_handshakes = 0 + renegotiation_cb = Proc.new { |ssl| num_handshakes += 1 } + ctx_proc = Proc.new { |ctx| ctx.renegotiation_cb = renegotiation_cb } + start_server(PORT, OpenSSL::SSL::VERIFY_NONE, true, {:ctx_proc => ctx_proc}) do |server, port| + sock = TCPSocket.new("127.0.0.1", port) + ssl = OpenSSL::SSL::SSLSocket.new(sock) + ssl.connect + assert_equal(1, num_handshakes) + ssl.close + end + end + end