diff --git a/.rubocop_gradual.lock b/.rubocop_gradual.lock index eaf00efa..d22e4a36 100644 --- a/.rubocop_gradual.lock +++ b/.rubocop_gradual.lock @@ -68,7 +68,7 @@ "spec/oauth2/response_spec.rb:2248532534": [ [3, 1, 31, "RSpec/SpecFilePathFormat: Spec path should end with `o_auth2/response*_spec.rb`.", 3190869319] ], - "spec/oauth2/strategy/assertion_spec.rb:793170256": [ + "spec/oauth2/strategy/assertion_spec.rb:3524328522": [ [6, 1, 42, "RSpec/SpecFilePathFormat: Spec path should end with `o_auth2/strategy/assertion*_spec.rb`.", 3665690869] ], "spec/oauth2/strategy/auth_code_spec.rb:142083698": [ diff --git a/lib/oauth2/strategy/assertion.rb b/lib/oauth2/strategy/assertion.rb index 800a4a78..6396fd6d 100644 --- a/lib/oauth2/strategy/assertion.rb +++ b/lib/oauth2/strategy/assertion.rb @@ -95,7 +95,10 @@ def build_request(assertion, request_opts = {}) def build_assertion(claims, encoding_opts) raise ArgumentError.new(message: "Please provide an encoding_opts hash with :algorithm and :key") if !encoding_opts.is_a?(Hash) || (%i[algorithm key] - encoding_opts.keys).any? - JWT.encode(claims, encoding_opts[:key], encoding_opts[:algorithm]) + headers = {} + headers[:kid] = encoding_opts[:kid] if encoding_opts.key?(:kid) + + JWT.encode(claims, encoding_opts[:key], encoding_opts[:algorithm], headers) end end end diff --git a/spec/oauth2/strategy/assertion_spec.rb b/spec/oauth2/strategy/assertion_spec.rb index 38a35dd0..d8d3af46 100644 --- a/spec/oauth2/strategy/assertion_spec.rb +++ b/spec/oauth2/strategy/assertion_spec.rb @@ -164,6 +164,24 @@ expect { client_assertion.get_token(claims, encoding_opts) }.to raise_error(ArgumentError, /encoding_opts/) end end + + context "when including a Key ID (kid)" do + let(:algorithm) { "HS256" } + let(:key) { "new_secret_key" } + let(:kid) { "my_super_secure_key_id_123" } + + before do + client_assertion.get_token(claims, algorithm: algorithm, key: key, kid: kid) + raise "No request made!" if @request_body.nil? + end + + it_behaves_like "encodes the JWT" + + it "includes the kid in the JWT header" do + expect(header).not_to be_nil + expect(header["kid"]).to eq(kid) + end + end end end