diff --git a/lib/mcp/configuration.rb b/lib/mcp/configuration.rb index 39eeebe..5a03d24 100644 --- a/lib/mcp/configuration.rb +++ b/lib/mcp/configuration.rb @@ -4,12 +4,18 @@ module MCP class Configuration DEFAULT_PROTOCOL_VERSION = "2024-11-05" - attr_writer :exception_reporter, :instrumentation_callback, :protocol_version + attr_writer :exception_reporter, :instrumentation_callback, :protocol_version, :validate_tool_call_arguments - def initialize(exception_reporter: nil, instrumentation_callback: nil, protocol_version: nil) + def initialize(exception_reporter: nil, instrumentation_callback: nil, protocol_version: nil, + validate_tool_call_arguments: true) @exception_reporter = exception_reporter @instrumentation_callback = instrumentation_callback @protocol_version = protocol_version + unless validate_tool_call_arguments.is_a?(TrueClass) || validate_tool_call_arguments.is_a?(FalseClass) + raise ArgumentError, "validate_tool_call_arguments must be a boolean" + end + + @validate_tool_call_arguments = validate_tool_call_arguments end def protocol_version @@ -36,6 +42,12 @@ def instrumentation_callback? !@instrumentation_callback.nil? end + attr_reader :validate_tool_call_arguments + + def validate_tool_call_arguments? + !!@validate_tool_call_arguments + end + def merge(other) return self if other.nil? @@ -54,11 +66,13 @@ def merge(other) else @protocol_version end + validate_tool_call_arguments = other.validate_tool_call_arguments Configuration.new( exception_reporter:, instrumentation_callback:, protocol_version:, + validate_tool_call_arguments:, ) end diff --git a/lib/mcp/server.rb b/lib/mcp/server.rb index 6dde615..1bd9430 100644 --- a/lib/mcp/server.rb +++ b/lib/mcp/server.rb @@ -213,6 +213,15 @@ def call_tool(request) ) end + if configuration.validate_tool_call_arguments && tool.input_schema + begin + tool.input_schema.validate_arguments(arguments) + rescue Tool::InputSchema::ValidationError => e + add_instrumentation_data(error: :invalid_schema) + raise RequestHandlerError.new(e.message, request, error_type: :invalid_schema) + end + end + begin call_params = tool_call_parameters(tool) diff --git a/lib/mcp/tool/input_schema.rb b/lib/mcp/tool/input_schema.rb index 4683b7e..b56b640 100644 --- a/lib/mcp/tool/input_schema.rb +++ b/lib/mcp/tool/input_schema.rb @@ -1,13 +1,18 @@ # frozen_string_literal: true +require "json-schema" + module MCP class Tool class InputSchema + class ValidationError < StandardError; end + attr_reader :properties, :required def initialize(properties: {}, required: []) @properties = properties @required = required.map(&:to_sym) + validate_schema! end def to_h @@ -21,6 +26,42 @@ def missing_required_arguments?(arguments) def missing_required_arguments(arguments) (required - arguments.keys.map(&:to_sym)) end + + def validate_arguments(arguments) + errors = JSON::Validator.fully_validate(to_h, arguments) + if errors.any? + raise ValidationError, "Invalid arguments: #{errors.join(", ")}" + end + end + + private + + def validate_schema! + check_for_refs! + schema = to_h + schema_reader = JSON::Schema::Reader.new( + accept_uri: false, + accept_file: ->(path) { path.to_s.start_with?(Gem.loaded_specs["json-schema"].full_gem_path) }, + ) + metaschema = JSON::Validator.validator_for_name("draft4").metaschema + errors = JSON::Validator.fully_validate(metaschema, schema, schema_reader: schema_reader) + if errors.any? + raise ArgumentError, "Invalid JSON Schema: #{errors.join(", ")}" + end + end + + def check_for_refs!(obj = properties) + case obj + when Hash + if obj.key?("$ref") || obj.key?(:$ref) + raise ArgumentError, "Invalid JSON Schema: $ref is not allowed in tool input schemas" + end + + obj.each_value { |value| check_for_refs!(value) } + when Array + obj.each { |item| check_for_refs!(item) } + end + end end end end diff --git a/mcp.gemspec b/mcp.gemspec index c598f45..7543bfc 100644 --- a/mcp.gemspec +++ b/mcp.gemspec @@ -28,6 +28,7 @@ Gem::Specification.new do |spec| spec.require_paths = ["lib"] spec.add_dependency("json_rpc_handler", "~> 0.1") + spec.add_dependency("json-schema", "~> 4.1") spec.add_development_dependency("activesupport") spec.add_development_dependency("sorbet-static-and-runtime") end diff --git a/test/mcp/configuration_test.rb b/test/mcp/configuration_test.rb index 71a1c27..9a5ad08 100644 --- a/test/mcp/configuration_test.rb +++ b/test/mcp/configuration_test.rb @@ -61,5 +61,45 @@ class ConfigurationTest < ActiveSupport::TestCase merged = config3.merge(config1) assert_equal "2025-03-27", merged.protocol_version end + + test "defaults validate_tool_call_arguments to true" do + config = Configuration.new + assert config.validate_tool_call_arguments + end + + test "can set validate_tool_call_arguments to false" do + config = Configuration.new(validate_tool_call_arguments: false) + refute config.validate_tool_call_arguments + end + + test "validate_tool_call_arguments? returns false when set" do + config = Configuration.new(validate_tool_call_arguments: false) + refute config.validate_tool_call_arguments? + end + + test "validate_tool_call_arguments? returns true when not set" do + config = Configuration.new + assert config.validate_tool_call_arguments? + end + + test "merge preserves validate_tool_call_arguments from other config" do + config1 = Configuration.new(validate_tool_call_arguments: false) + config2 = Configuration.new + merged = config1.merge(config2) + assert merged.validate_tool_call_arguments? + end + + test "merge preserves validate_tool_call_arguments from self when other not set" do + config1 = Configuration.new(validate_tool_call_arguments: false) + config2 = Configuration.new + merged = config2.merge(config1) + refute merged.validate_tool_call_arguments + end + + test "raises ArgumentError when validate_tool_call_arguments is not a boolean" do + assert_raises(ArgumentError) do + Configuration.new(validate_tool_call_arguments: "true") + end + end end end diff --git a/test/mcp/server_test.rb b/test/mcp/server_test.rb index 6dc56ee..267f4ca 100644 --- a/test/mcp/server_test.rb +++ b/test/mcp/server_test.rb @@ -839,5 +839,123 @@ def call(message:, server_context: nil) refute_includes server_without_resources.capabilities, :resources end + + test "tools/call validates arguments against input schema when validate_tool_call_arguments is true" do + server = Server.new( + tools: [TestTool], + configuration: Configuration.new(validate_tool_call_arguments: true), + ) + + response = server.handle( + { + jsonrpc: "2.0", + id: 1, + method: "tools/call", + params: { + name: "test_tool", + arguments: { message: 123 }, + }, + }, + ) + + assert_equal "2.0", response[:jsonrpc] + assert_equal 1, response[:id] + assert_equal(-32603, response[:error][:code]) + assert_includes response[:error][:data], "Invalid arguments" + end + + test "tools/call skips argument validation when validate_tool_call_arguments is false" do + server = Server.new( + tools: [TestTool], + configuration: Configuration.new(validate_tool_call_arguments: false), + ) + + response = server.handle( + { + jsonrpc: "2.0", + id: 1, + method: "tools/call", + params: { + name: "test_tool", + arguments: { message: 123 }, + }, + }, + ) + + assert_equal "2.0", response[:jsonrpc] + assert_equal 1, response[:id] + assert response[:result], "Expected result key in response" + assert_equal "text", response[:result][:content][0][:type] + assert_equal "OK", response[:result][:content][0][:content] + end + + test "tools/call validates arguments with complex types" do + server = Server.new( + tools: [ComplexTypesTool], + configuration: Configuration.new(validate_tool_call_arguments: true), + ) + + response = server.handle( + { + jsonrpc: "2.0", + id: 1, + method: "tools/call", + params: { + name: "complex_types_tool", + arguments: { + numbers: [1, 2, 3], + strings: ["a", "b", "c"], + objects: [{ name: "test" }], + }, + }, + }, + ) + + assert_equal "2.0", response[:jsonrpc] + assert_equal 1, response[:id] + assert response[:result], "Expected result key in response" + assert_equal "text", response[:result][:content][0][:type] + assert_equal "OK", response[:result][:content][0][:content] + end + + class TestTool < Tool + tool_name "test_tool" + description "a test tool for testing" + input_schema({ properties: { message: { type: "string" } }, required: ["message"] }) + + class << self + def call(message:, server_context: nil) + Tool::Response.new([{ type: "text", content: "OK" }]) + end + end + end + + class ComplexTypesTool < Tool + tool_name "complex_types_tool" + description "a test tool with complex types" + input_schema({ + properties: { + numbers: { type: "array", items: { type: "number" } }, + strings: { type: "array", items: { type: "string" } }, + objects: { + type: "array", + items: { + type: "object", + properties: { + name: { type: "string" }, + }, + required: ["name"], + }, + }, + }, + required: ["numbers", "strings", "objects"], + }) + + class << self + def call(numbers:, strings:, objects:, server_context: nil) + Tool::Response.new([{ type: "text", content: "OK" }]) + end + end + end end end diff --git a/test/mcp/tool/input_schema_test.rb b/test/mcp/tool/input_schema_test.rb index 56463bd..791f955 100644 --- a/test/mcp/tool/input_schema_test.rb +++ b/test/mcp/tool/input_schema_test.rb @@ -1,6 +1,7 @@ # frozen_string_literal: true require "test_helper" +require "mcp/tool/input_schema" module MCP class Tool @@ -27,6 +28,51 @@ class InputSchemaTest < ActiveSupport::TestCase input_schema = InputSchema.new(properties: { message: { type: "string" } }, required: [:message]) assert_empty input_schema.missing_required_arguments({ message: "Hello, world!" }) end + + test "valid schema initialization" do + schema = InputSchema.new(properties: { foo: { type: "string" } }, required: [:foo]) + assert_equal({ type: "object", properties: { foo: { type: "string" } }, required: [:foo] }, schema.to_h) + end + + test "invalid schema raises argument error" do + assert_raises(ArgumentError) do + InputSchema.new(properties: { foo: { type: "invalid_type" } }, required: [:foo]) + end + end + + test "validate arguments with valid data" do + schema = InputSchema.new(properties: { foo: { type: "string" } }, required: [:foo]) + assert_nil(schema.validate_arguments({ foo: "bar" })) + end + + test "validate arguments with invalid data" do + schema = InputSchema.new(properties: { foo: { type: "string" } }, required: [:foo]) + assert_raises(InputSchema::ValidationError) do + schema.validate_arguments({ foo: 123 }) + end + end + + test "unexpected errors bubble up from validate_arguments" do + schema = InputSchema.new(properties: { foo: { type: "string" } }, required: [:foo]) + + JSON::Validator.stub(:fully_validate, ->(*) { raise "unexpected error" }) do + assert_raises(RuntimeError) do + schema.validate_arguments({ foo: "bar" }) + end + end + end + + test "rejects schemas with $ref references" do + assert_raises(ArgumentError) do + InputSchema.new(properties: { foo: { "$ref" => "#/definitions/bar" } }, required: [:foo]) + end + end + + test "rejects schemas with symbol $ref references" do + assert_raises(ArgumentError) do + InputSchema.new(properties: { foo: { :$ref => "#/definitions/bar" } }, required: [:foo]) + end + end end end end diff --git a/test/mcp/tool_test.rb b/test/mcp/tool_test.rb index f53352b..e87c826 100644 --- a/test/mcp/tool_test.rb +++ b/test/mcp/tool_test.rb @@ -82,6 +82,23 @@ class InputSchemaTool < Tool assert_equal expected, tool.input_schema.to_h end + test "raises detailed error message for invalid schema" do + error = assert_raises(ArgumentError) do + Class.new(MCP::Tool) do + input_schema( + properties: { + count: { type: "integer", minimum: "not a number" }, + }, + required: [:count], + ) + end + end + + assert_includes error.message, "Invalid JSON Schema" + assert_includes error.message, "#/properties/count/minimum" + assert_includes error.message, "string did not match the following type: number" + end + test ".define allows definition of simple tools with a block" do tool = Tool.define(name: "mock_tool", description: "a mock tool for testing") do |_| Tool::Response.new([{ type: "text", content: "OK" }]) @@ -226,5 +243,23 @@ def call(message:, server_context: nil) assert_equal response.content, [{ type: "text", content: "OK" }] assert_equal response.is_error, false end + + test "input_schema rejects any $ref in schema" do + schema_with_ref = { + properties: { + foo: { "$ref" => "#/definitions/bar" }, + }, + required: ["foo"], + definitions: { + bar: { type: "string" }, + }, + } + error = assert_raises(ArgumentError) do + Class.new(MCP::Tool) do + input_schema schema_with_ref + end + end + assert_match(/Invalid JSON Schema/, error.message) + end end end