diff --git a/src/main/java/org/jruby/ext/openssl/PKeyRSA.java b/src/main/java/org/jruby/ext/openssl/PKeyRSA.java index 6eef4588..665923cb 100644 --- a/src/main/java/org/jruby/ext/openssl/PKeyRSA.java +++ b/src/main/java/org/jruby/ext/openssl/PKeyRSA.java @@ -648,17 +648,19 @@ public synchronized IRubyObject get_q() { return getRuntime().getNil(); } - @JRubyMethod(name="e") - public synchronized IRubyObject get_e() { - RSAPublicKey key; - BigInteger e; - if ((key = publicKey) != null) { - e = key.getPublicExponent(); - } else if(privateKey != null) { - e = privateKey.getPublicExponent(); + private BigInteger getPublicExponent() { + if (publicKey != null) { + return publicKey.getPublicExponent(); + } else if (privateKey != null) { + return privateKey.getPublicExponent(); } else { - e = rsa_e; + return rsa_e; } + } + + @JRubyMethod(name="e") + public synchronized IRubyObject get_e() { + BigInteger e = getPublicExponent(); if (e != null) { return BN.newBN(getRuntime(), e); } @@ -679,17 +681,19 @@ public synchronized IRubyObject set_e(final ThreadContext context, IRubyObject v return value; } - @JRubyMethod(name="n") - public synchronized IRubyObject get_n() { - RSAPublicKey key; - BigInteger n; - if ((key = publicKey) != null) { - n = key.getModulus(); - } else if(privateKey != null) { - n = privateKey.getModulus(); + private BigInteger getModulus() { + if (publicKey != null) { + return publicKey.getModulus(); + } else if (privateKey != null) { + return privateKey.getModulus(); } else { - n = rsa_n; + return rsa_n; } + } + + @JRubyMethod(name="n") + public synchronized IRubyObject get_n() { + BigInteger n = getModulus(); if (n != null) { return BN.newBN(getRuntime(), n); } @@ -715,8 +719,12 @@ private void generatePublicKeyIfParams(final ThreadContext context) { if ( publicKey != null ) throw newRSAError(runtime, "illegal modification"); - BigInteger e, n; - if ( (e = rsa_e) != null && (n = rsa_n) != null ) { + // Don't access the rsa_n and rsa_e fields directly. They may have + // already been consumed and cleared by generatePrivateKeyIfParams. + BigInteger _rsa_n = getModulus(); + BigInteger _rsa_e = getPublicExponent(); + + if (_rsa_n != null && _rsa_e != null) { final KeyFactory rsaFactory; try { rsaFactory = SecurityHelper.getKeyFactory("RSA"); @@ -726,7 +734,7 @@ private void generatePublicKeyIfParams(final ThreadContext context) { } try { - publicKey = (RSAPublicKey) rsaFactory.generatePublic(new RSAPublicKeySpec(n, e)); + publicKey = (RSAPublicKey) rsaFactory.generatePublic(new RSAPublicKeySpec(_rsa_n, _rsa_e)); } catch (InvalidKeySpecException ex) { throw newRSAError(runtime, "invalid parameters"); @@ -741,7 +749,12 @@ private void generatePrivateKeyIfParams(final ThreadContext context) { if ( privateKey != null ) throw newRSAError(runtime, "illegal modification"); - if (rsa_e != null && rsa_n != null && rsa_p != null && rsa_q != null && rsa_d != null && rsa_dmp1 != null && rsa_dmq1 != null && rsa_iqmp != null) { + // Don't access the rsa_n and rsa_e fields directly. They may have + // already been consumed and cleared by generatePublicKeyIfParams. + BigInteger _rsa_n = getModulus(); + BigInteger _rsa_e = getPublicExponent(); + + if (_rsa_n != null && _rsa_e != null && rsa_p != null && rsa_q != null && rsa_d != null && rsa_dmp1 != null && rsa_dmq1 != null && rsa_iqmp != null) { final KeyFactory rsaFactory; try { rsaFactory = SecurityHelper.getKeyFactory("RSA"); @@ -752,7 +765,7 @@ private void generatePrivateKeyIfParams(final ThreadContext context) { try { privateKey = (RSAPrivateCrtKey) rsaFactory.generatePrivate( - new RSAPrivateCrtKeySpec(rsa_n, rsa_e, rsa_d, rsa_p, rsa_q, rsa_dmp1, rsa_dmq1, rsa_iqmp) + new RSAPrivateCrtKeySpec(_rsa_n, _rsa_e, rsa_d, rsa_p, rsa_q, rsa_dmp1, rsa_dmq1, rsa_iqmp) ); } catch (InvalidKeySpecException e) { diff --git a/src/test/ruby/rsa/test_rsa.rb b/src/test/ruby/rsa/test_rsa.rb index dffd4dad..58ce1faa 100644 --- a/src/test/ruby/rsa/test_rsa.rb +++ b/src/test/ruby/rsa/test_rsa.rb @@ -47,4 +47,47 @@ def test_rsa_public_encrypt # } end + def test_rsa_param_accessors + key_file = File.join(File.dirname(__FILE__), 'private_key.pem') + key = OpenSSL::PKey::RSA.new(File.read(key_file)) + + [:e, :n, :d, :p, :q, :iqmp, :dmp1, :dmq1].each do |param| + rsa = OpenSSL::PKey::RSA.new + assert_nil(rsa.send(param)) + value = key.send(param) + rsa.send("#{param}=", value) + assert_equal(value, rsa.send(param), param) + end + end + + def test_rsa_from_params_public_first + key_file = File.join(File.dirname(__FILE__), 'private_key.pem') + key = OpenSSL::PKey::RSA.new(File.read(key_file)) + + rsa = OpenSSL::PKey::RSA.new + rsa.e, rsa.n = key.e, key.n + assert_nothing_raised { rsa.public_encrypt('Test string') } + [:e, :n].each {|param| assert_equal(key.send(param), rsa.send(param)) } + + rsa.d, rsa.p, rsa.q, rsa.iqmp, rsa.dmp1, rsa.dmq1 = key.d, key.p, key.q, key.iqmp, key.dmp1, key.dmq1 + assert_nothing_raised { rsa.private_encrypt('Test string') } + [:e, :n, :d, :p, :q, :iqmp, :dmp1, :dmq1].each do |param| + assert_equal(key.send(param), rsa.send(param), param) + end + end + + def test_rsa_from_params_private_first + key_file = File.join(File.dirname(__FILE__), 'private_key.pem') + key = OpenSSL::PKey::RSA.new(File.read(key_file)) + + rsa = OpenSSL::PKey::RSA.new + rsa.d, rsa.p, rsa.q, rsa.iqmp, rsa.dmp1, rsa.dmq1 = key.d, key.p, key.q, key.iqmp, key.dmp1, key.dmq1 + rsa.e, rsa.n = key.e, key.n + assert_nothing_raised { rsa.public_encrypt('Test string') } + assert_nothing_raised { rsa.private_encrypt('Test string') } + [:e, :n, :d, :p, :q, :iqmp, :dmp1, :dmq1].each do |param| + assert_equal(key.send(param), rsa.send(param), param) + end + end + end