diff --git a/graphql/execute.lua b/graphql/execute.lua index e04688b..a2b9d56 100644 --- a/graphql/execute.lua +++ b/graphql/execute.lua @@ -286,6 +286,36 @@ local function getFieldEntry(objectType, object, fields, context) arguments = setmetatable(arguments, {__index=positions}) + local directiveMap = {} + for _, directive in ipairs(firstField.directives or {}) do + directiveMap[directive.name.value] = directive + end + + local directives = {} + + if next(directiveMap) ~= nil then + util.map_name(context.schema.directives or {}, function(directive, directive_name) + local supplied_directive = directiveMap[directive_name] + if supplied_directive == nil then + return nil + end + + local directiveArgumentMap = {} + for _, argument in ipairs(supplied_directive.arguments or {}) do + directiveArgumentMap[argument.name.value] = argument + end + + directives[directive_name] = util.map(directive.arguments or {}, function(argument, name) + local supplied = directiveArgumentMap[name] and directiveArgumentMap[name].value + if argument.kind then argument = argument.kind end + return util.coerceValue(supplied, argument, context.variables, { + strict_non_null = true, + defaultValues = defaultValues, + }) + end) + end) + end + local info = { context = context, fieldName = fieldName, @@ -298,6 +328,7 @@ local function getFieldEntry(objectType, object, fields, context) operation = context.operation, variableValues = context.variables, defaultValues = context.defaultValues, + directives = directives, } local resolvedObject, err = (fieldType.resolve or defaultResolver)(object, arguments, info) diff --git a/graphql/introspection.lua b/graphql/introspection.lua index 35cb7ea..985e489 100644 --- a/graphql/introspection.lua +++ b/graphql/introspection.lua @@ -109,6 +109,18 @@ __Directive = types.object({ if directive.onFragmentDefinition then table.insert(res, 'FRAGMENT_DEFINITION') end if directive.onFragmentSpread then table.insert(res, 'FRAGMENT_SPREAD') end if directive.onInlineFragment then table.insert(res, 'INLINE_FRAGMENT') end + if directive.onVariableDefinition then table.insert(res, 'VARIABLE_DEFINITION') end + if directive.onSchema then table.insert(res, 'SCHEMA') end + if directive.onScalar then table.insert(res, 'SCALAR') end + if directive.onObject then table.insert(res, 'OBJECT') end + if directive.onFieldDefinition then table.insert(res, 'FIELD_DEFINITION') end + if directive.onArgumentDefinition then table.insert(res, 'ARGUMENT_DEFINITION') end + if directive.onInterface then table.insert(res, 'INTERFACE') end + if directive.onUnion then table.insert(res, 'UNION') end + if directive.onEnum then table.insert(res, 'ENUM') end + if directive.onEnumValue then table.insert(res, 'ENUM_VALUE') end + if directive.onInputObject then table.insert(res, 'INPUT_OBJECT') end + if directive.onInputFieldDefinition then table.insert(res, 'INPUT_FIELD_DEFINITION') end return res end, @@ -118,6 +130,13 @@ __Directive = types.object({ kind = types.nonNull(types.list(types.nonNull(__InputValue))), resolve = resolveArgs, }, + + isRepeatable = { + kind = types.nonNull(types.boolean), + resolve = function(directive) + return directive.isRepeatable == true + end, + }, } end, }) @@ -160,6 +179,66 @@ __DirectiveLocation = types.enum({ value = 'INLINE_FRAGMENT', description = 'Location adjacent to an inline fragment.', }, + + VARIABLE_DEFINITION = { + value = 'VARIABLE_DEFINITION', + description = 'Location adjacent to a variable definition.', + }, + + SCHEMA = { + value = 'SCHEMA', + description = 'Location adjacent to schema.', + }, + + SCALAR = { + value = 'SCALAR', + description = 'Location adjacent to a scalar.', + }, + + OBJECT = { + value = 'OBJECT', + description = 'Location adjacent to an object.', + }, + + FIELD_DEFINITION = { + value = 'FIELD_DEFINITION', + description = 'Location adjacent to a field definition.', + }, + + ARGUMENT_DEFINITION = { + value = 'ARGUMENT_DEFINITION', + description = 'Location adjacent to an argument definition.', + }, + + INTERFACE = { + value = 'INTERFACE', + description = 'Location adjacent to an interface.', + }, + + UNION = { + value = 'UNION', + description = 'Location adjacent to an union.', + }, + + ENUM = { + value = 'ENUM', + description = 'Location adjacent to an enum.', + }, + + ENUM_VALUE = { + value = 'ENUM_VALUE', + description = 'Location adjacent to an enum value.', + }, + + INPUT_OBJECT = { + value = 'INPUT_OBJECT', + description = 'Location adjacent to an input object.', + }, + + INPUT_FIELD_DEFINITION = { + value = 'INPUT_FIELD_DEFINITION', + description = 'Location adjacent to an input field definition.', + }, }, }) diff --git a/graphql/schema.lua b/graphql/schema.lua index 88ae56a..a031817 100644 --- a/graphql/schema.lua +++ b/graphql/schema.lua @@ -99,6 +99,28 @@ end function schema:generateDirectiveMap() for _, directive in ipairs(self.directives) do self.directiveMap[directive.name] = directive + if directive.arguments ~= nil then + for name, argument in pairs(directive.arguments) do + + -- BEGIN_HACK: resolve type names to real types + if type(argument) == 'string' then + argument = types.resolve(argument, self.name) + directive.arguments[name] = argument + end + + if type(argument.kind) == 'string' then + argument.kind = types.resolve(argument.kind, self.name) + end + -- END_HACK: resolve type names to real types + + local argumentType = argument.__type and argument or argument.kind + if argumentType == nil then + error('Must supply type for argument "' .. name .. '" on "' .. directive.name .. '"') + end + argumentType.defaultValue = argument.defaultValue + self:generateTypeMap(argumentType) + end + end end end diff --git a/graphql/types.lua b/graphql/types.lua index 74a40a0..908b392 100644 --- a/graphql/types.lua +++ b/graphql/types.lua @@ -419,6 +419,19 @@ function types.directive(config) onFragmentDefinition = config.onFragmentDefinition, onFragmentSpread = config.onFragmentSpread, onInlineFragment = config.onInlineFragment, + onVariableDefinition = config.onVariableDefinition, + onSchema = config.onSchema, + onScalar = config.onScalar, + onObject = config.onObject, + onFieldDefinition = config.onFieldDefinition, + onArgumentDefinition = config.onArgumentDefinition, + onInterface = config.onInterface, + onUnion = config.onUnion, + onEnum = config.onEnum, + onEnumValue = config.onEnumValue, + onInputObject = config.onInputObject, + onInputFieldDefinition = config.onInputFieldDefinition, + isRepeatable = config.isRepeatable or false, } return instance diff --git a/graphql/util.lua b/graphql/util.lua index 20a093b..b6329b4 100644 --- a/graphql/util.lua +++ b/graphql/util.lua @@ -11,6 +11,16 @@ local function map(t, fn) return res end +local function map_name(t, fn) + local res = {} + for _, v in pairs(t or {}) do + if v.name then + res[v.name] = fn(v, v.name) + end + end + return res +end + local function find(t, fn) for k, v in pairs(t) do if fn(v, k) then return v end @@ -270,6 +280,7 @@ end return { map = map, + map_name = map_name, find = find, filter = filter, values = values, diff --git a/test/integration/graphql_test.lua b/test/integration/graphql_test.lua index b7a0841..11673fc 100644 --- a/test/integration/graphql_test.lua +++ b/test/integration/graphql_test.lua @@ -4,24 +4,27 @@ local schema = require('graphql.schema') local parse = require('graphql.parse') local validate = require('graphql.validate') local execute = require('graphql.execute') +local introspection = require('test.integration.introspection') local t = require('luatest') local g = t.group('integration') -local function check_request(query, query_schema, opts) +local test_schema_name = 'default' +local function check_request(query, query_schema, mutation_schema, directives, opts) opts = opts or {} local root = { query = types.object({ name = 'Query', - fields = query_schema, + fields = query_schema or {}, }), mutation = types.object({ name = 'Mutation', - fields = {}, + fields = mutation_schema or {}, }), + directives = directives, } - local compiled_schema = schema.create(root, 'default') + local compiled_schema = schema.create(root, test_schema_name) local parsed = parse.parse(query) @@ -111,7 +114,7 @@ function g.test_variables() } -- Positive test - t.assert_equals(check_request(query, query_schema, {variables = variables}), {test = 'B22'}) + t.assert_equals(check_request(query, query_schema, nil, nil, {variables = variables}), {test = 'B22'}) -- Negative tests local query = [[ @@ -121,7 +124,7 @@ function g.test_variables() t.assert_error_msg_equals( 'Variable "arg2" expected to be non-null', function() - check_request(query, query_schema, {variables = {}}) + check_request(query, query_schema, nil, nil, {variables = {}}) end ) @@ -134,7 +137,7 @@ function g.test_variables() ' the variable type "String" is not compatible' .. ' with the argument type "NonNull(String)"', function() - check_request(query, query_schema, {variables = {}}) + check_request(query, query_schema, nil, nil, {variables = {}}) end ) @@ -157,7 +160,7 @@ function g.test_variables() function() check_request([[ query { test(arg: "") } - ]], query_schema, { variables = {unknown_arg = ''}}) + ]], query_schema, nil, nil, { variables = {unknown_arg = ''}}) end ) @@ -336,7 +339,7 @@ function g.test_enum_input() query($arg: simple_input_object) { simple_enum(arg: $arg) } - ]], query_schema, {variables = {arg = {field = 'a'}}}), {simple_enum = 'a'}) + ]], query_schema, nil, nil, {variables = {arg = {field = 'a'}}}), {simple_enum = 'a'}) t.assert_error_msg_equals( 'Wrong variable "arg.field" for the Enum "simple_enum" with value "d"', @@ -345,7 +348,7 @@ function g.test_enum_input() query($arg: simple_input_object) { simple_enum(arg: $arg) } - ]], query_schema, {variables = {arg = {field = 'd'}}}) + ]], query_schema, nil, nil, {variables = {arg = {field = 'd'}}}) end ) end @@ -459,7 +462,7 @@ function g.test_nested_input() servers: [{ field: $field }] ) } - ]], query_schema, {variables = {field = 'echo'}}), {test_nested_InputObject = 'echo'}) + ]], query_schema, nil, nil, {variables = {field = 'echo'}}), {test_nested_InputObject = 'echo'}) t.assert_error_msg_equals( 'Unused variable "field"', @@ -470,7 +473,7 @@ function g.test_nested_input() servers: [{ field: "not-variable" }] ) } - ]], query_schema, {variables = {field = 'echo'}}) + ]], query_schema, nil, nil, {variables = {field = 'echo'}}) end ) @@ -480,7 +483,7 @@ function g.test_nested_input() servers: [$field] ) } - ]], query_schema, {variables = {field = 'echo'}}), {test_nested_list = 'echo'}) + ]], query_schema, nil, nil, {variables = {field = 'echo'}}), {test_nested_list = 'echo'}) t.assert_equals(check_request([[ query($field: String! $field2: String! $upvalue: String!) { @@ -492,7 +495,7 @@ function g.test_nested_input() } ) } - ]], query_schema, { + ]], query_schema, nil, nil, { variables = {field = 'echo', field2 = 'field', upvalue = 'upvalue'}, }), {test_nested_InputObject_complex = 'upvalue+field+echo'}) end @@ -599,7 +602,7 @@ function g.test_custom_type_scalar_variables() field: $field ) } - ]], query_schema, { + ]], query_schema, nil, nil, { variables = {field = '{"test": 123}'}, }), {test_json_type = '{"test":123}'}) @@ -609,7 +612,7 @@ function g.test_custom_type_scalar_variables() field: $field ) } - ]], query_schema, { + ]], query_schema, nil, nil, { variables = {field = box.NULL}, }), {test_json_type = 'null'}) @@ -619,7 +622,7 @@ function g.test_custom_type_scalar_variables() field: "null" ) } - ]], query_schema, { + ]], query_schema, nil, nil, { variables = {}, }), {test_json_type = 'null'}) @@ -629,7 +632,7 @@ function g.test_custom_type_scalar_variables() field: $field ) } - ]], query_schema, { + ]], query_schema, nil, nil, { variables = {field = 'echo'}, }), {test_custom_type_scalar = 'echo'}) @@ -644,7 +647,7 @@ function g.test_custom_type_scalar_variables() field: $field ) } - ]], query_schema, {variables = {field = 'echo'}}) + ]], query_schema, nil, nil, {variables = {field = 'echo'}}) end ) @@ -654,7 +657,7 @@ function g.test_custom_type_scalar_variables() fields: [$field] ) } - ]], query_schema, { + ]], query_schema, nil, nil, { variables = {field = 'echo'}, }), {test_custom_type_scalar_list = 'echo'}) @@ -683,7 +686,7 @@ function g.test_custom_type_scalar_variables() fields: [$field] ) } - ]], query_schema, {variables = {field = 'echo'}}) + ]], query_schema, nil, nil, {variables = {field = 'echo'}}) end ) @@ -693,7 +696,7 @@ function g.test_custom_type_scalar_variables() fields: $fields ) } - ]], query_schema, { + ]], query_schema, nil, nil, { variables = {fields = {'echo'}}, }), {test_custom_type_scalar_list = 'echo'}) @@ -708,7 +711,7 @@ function g.test_custom_type_scalar_variables() fields: $fields ) } - ]], query_schema, {variables = {fields = {'echo'}}}) + ]], query_schema, nil, nil, {variables = {fields = {'echo'}}}) end ) @@ -723,7 +726,7 @@ function g.test_custom_type_scalar_variables() fields: $fields ) } - ]], query_schema, {variables = {fields = {'echo'}}}) + ]], query_schema, nil, nil, {variables = {fields = {'echo'}}}) end ) @@ -733,7 +736,7 @@ function g.test_custom_type_scalar_variables() object: { nested_object: { field: $field } } ) } - ]], query_schema, { + ]], query_schema, nil, nil, { variables = {field = 'echo'}, }), {test_custom_type_scalar_inputObject = 'echo'}) @@ -748,7 +751,7 @@ function g.test_custom_type_scalar_variables() object: { nested_object: { field: $field } } ) } - ]], query_schema, {variables = {fields = {'echo'}}}) + ]], query_schema, nil, nil, {variables = {fields = {'echo'}}}) end ) end @@ -960,7 +963,7 @@ function g.test_default_values() query($arg: String = "default_value") { test_default_value(arg: $arg) } - ]], query_schema, { + ]], query_schema, nil, nil, { variables = {}, }), {test_default_value = 'default_value'}) @@ -968,7 +971,7 @@ function g.test_default_values() query($arg: String = "default_value") { test_default_value(arg: $arg) } - ]], query_schema, { + ]], query_schema, nil, nil, { variables = {arg = box.NULL}, }), {test_default_value = 'nil'}) @@ -976,7 +979,7 @@ function g.test_default_values() query($arg: [String] = ["default_value"]) { test_default_list(arg: $arg) } - ]], query_schema, { + ]], query_schema, nil, nil, { variables = {}, }), {test_default_list = 'default_value'}) @@ -984,7 +987,7 @@ function g.test_default_values() query($arg: [String] = ["default_value"]) { test_default_list(arg: $arg) } - ]], query_schema, { + ]], query_schema, nil, nil, { variables = {arg = box.NULL}, }), {test_default_list = 'nil'}) @@ -992,7 +995,7 @@ function g.test_default_values() query($arg: default_input_object = {field: "default_value"}) { test_default_object(arg: $arg) } - ]], query_schema, { + ]], query_schema, nil, nil, { variables = {}, }), {test_default_object = 'default_value'}) @@ -1000,7 +1003,7 @@ function g.test_default_values() query($arg: default_input_object = {field: "default_value"}) { test_default_object(arg: $arg) } - ]], query_schema, { + ]], query_schema, nil, nil, { variables = {arg = box.NULL}, }), {test_default_object = 'nil'}) @@ -1010,7 +1013,7 @@ function g.test_default_values() field: $field ) } - ]], query_schema, { + ]], query_schema, nil, nil, { variables = {}, }), {test_json_type = '{"test":123}'}) @@ -1018,7 +1021,7 @@ function g.test_default_values() query($arg: String = null, $is_null: Boolean) { test_null(arg: $arg is_null: $is_null) } - ]], query_schema, { + ]], query_schema, nil, nil, { variables = {arg = 'abc'}, }), {test_null = 'abc'}) @@ -1026,7 +1029,7 @@ function g.test_default_values() query($arg: String = null, $is_null: Boolean) { test_null(arg: $arg is_null: $is_null) } - ]], query_schema, { + ]], query_schema, nil, nil, { variables = {arg = box.NULL, is_null = true}, }), {test_null = 'is_null'}) @@ -1034,7 +1037,7 @@ function g.test_default_values() query($arg: String = null, $is_null: Boolean) { test_null(arg: $arg is_null: $is_null) } - ]], query_schema, { + ]], query_schema, nil, nil, { variables = {is_null = false}, }), {test_null = 'not is_null'}) end @@ -1071,7 +1074,7 @@ function g.test_null() query { test_null_nullable(arg: null) } - ]], query_schema, { + ]], query_schema, nil, nil, { variables = {}, }), {test_null_nullable = 'nil'}) @@ -1204,3 +1207,221 @@ function g.test_both_data_and_error_result() {message = "Simple error from external resolver"}, }, "Errors from each resolver were returned") end + +function g.test_introspection() + local function callback(_, _) + return nil + end + + local query_schema = { + ['test'] = { + kind = types.string.nonNull, + arguments = { + arg = types.string.nonNull, + arg2 = types.string, + arg3 = types.int, + arg4 = types.long, + }, + resolve = callback, + } + } + + local mutation_schema = { + ['test_mutation'] = { + kind = types.string.nonNull, + arguments = { + arg = types.string.nonNull, + arg2 = types.string, + arg3 = types.int, + arg4 = types.long, + }, + resolve = callback, + } + } + + local directives = { + types.directive({ + name = 'custom', + arguments = {}, + onQuery = true, + onMutation = true, + onField = true, + onFragmentDefinition = true, + onFragmentSpread = true, + onInlineFragment = true, + onVariableDefinition = true, + onSchema = true, + onScalar = true, + onObject = true, + onFieldDefinition = true, + onArgumentDefinition = true, + onInterface = true, + onUnion = true, + onEnum = true, + onEnumValue = true, + onInputObject = true, + onInputFieldDefinition = true, + isRepeatable = true, + }) + } + + local data, errors = check_request(introspection.query, query_schema, mutation_schema, directives) + t.assert_type(data, 'table') + t.assert_equals(errors, nil) +end + +function g.test_custom_directives() + -- simple string directive + local function callback(_, _, info) + return require('json').encode(info.directives) + end + + local query_schema = { + ['prefix'] = { + kind = types.object({ + name = 'prefix', + fields = { + ['test'] = { + kind = types.string, + arguments = { + arg = types.string.nonNull, + arg2 = types.string, + arg3 = types.int, + arg4 = types.long, + }, + resolve = callback, + } + }, + }), + arguments = {}, + resolve = function() + return {} + end, + } + } + + local directives = { + types.directive({ + name = 'custom', + arguments = { + arg = types.string.nonNull, + }, + onQuery = true, + onMutation = true, + onField = true, + onFragmentDefinition = true, + onFragmentSpread = true, + onInlineFragment = true, + onVariableDefinition = true, + onSchema = true, + onScalar = true, + onObject = true, + onFieldDefinition = true, + onArgumentDefinition = true, + onInterface = true, + onUnion = true, + onEnum = true, + onEnumValue = true, + onInputObject = true, + onInputFieldDefinition = true, + isRepeatable = true, + }) + } + + local simple_query = [[query TEST{ + prefix { + test_A: test(arg: "A")@custom(arg: "a") + } + }]] + local data, errors = check_request(simple_query, query_schema, nil, directives) + t.assert_equals(data, { prefix = { test_A = '{"custom":{"arg":"a"}}' }}) + t.assert_equals(errors, nil) + + local var_query = [[query TEST($arg: String){ + prefix { + test_B: test(arg: "B")@custom(arg: $arg) + } + }]] + data, errors = check_request(var_query, query_schema, nil, directives, + {variables = {arg = 'echo'}}) + t.assert_equals(data, { prefix = { test_B = '{"custom":{"arg":"echo"}}' }}) + t.assert_equals(errors, nil) + + -- InputObject directives + local Entity = types.inputObject({ + name = 'Entity', + fields = { + num = types.int, + str = types.string, + }, + schema = test_schema_name, -- add type to schema registry so it may be called by name + }) + + local function callback(_, args, info) + local obj = args['arg'] + local dir = info.directives + + if dir ~= nil then + local override = dir.override_v2 or dir.override or {} + for k, v in pairs(override['arg']) do + obj[k] = v + end + end + + return require('json').encode(obj) + end + + query_schema = { + ['test'] = { + kind = types.string, + arguments = { + arg = Entity, + }, + resolve = callback, + } + } + + directives = { + types.directive({ + name = 'override', + arguments = { + arg = Entity, + }, + onInputObject = true, + }) + } + + local query = [[query TEST{ + test_C: test(arg: { num: 2, str: "s" })@override(arg: { num: 3, str: "s1" }) + test_D: test(arg: { num: 2, str: "s" })@override(arg: { num: 3 }) + test_E: test(arg: { num: 2, str: "s" })@override(arg: { str: "s1" }) + test_F: test(arg: { num: 2, str: "s" })@override(arg: { }) + }]] + data, errors = check_request(query, query_schema, nil, directives) + t.assert_equals(data, { + test_C = '{"num":3,"str":"s1"}', + test_D = '{"num":3,"str":"s"}', + test_E = '{"num":2,"str":"s1"}', + test_F = '{"num":2,"str":"s"}', + }) + t.assert_equals(errors, nil) + + -- Check internal type resolver + directives = { + types.directive({ + name = 'override_v2', + arguments = { + arg = 'Entity', -- refer to type by name + }, + onInputObject = true, + }) + } + + query = [[query TEST{ + test_G: test(arg: { num: 2, str: "s" })@override_v2(arg: { num: 5, str: "news" }) + }]] + + data, errors = check_request(query, query_schema, nil, directives) + t.assert_equals(data, { test_G = '{"num":5,"str":"news"}' }) + t.assert_equals(errors, nil) +end diff --git a/test/integration/introspection.lua b/test/integration/introspection.lua new file mode 100644 index 0000000..7829eaa --- /dev/null +++ b/test/integration/introspection.lua @@ -0,0 +1,98 @@ +return { + query = [[ + query IntrospectionQuery { + __schema { + queryType { name } + mutationType { name } + subscriptionType { name } + types { + ...FullType + } + directives { + name + description + isRepeatable + locations + args { + ...InputValue + } + } + } + } + + fragment FullType on __Type { + kind + name + description + specifiedByUrl + fields(includeDeprecated: true) { + name + description + args { + ...InputValue + } + type { + ...TypeRef + } + isDeprecated + deprecationReason + } + inputFields { + ...InputValue + } + interfaces { + ...TypeRef + } + enumValues(includeDeprecated: true) { + name + description + isDeprecated + deprecationReason + } + possibleTypes { + ...TypeRef + } + } + + fragment InputValue on __InputValue { + name + description + type { ...TypeRef } + defaultValue + } + + fragment TypeRef on __Type { + kind + name + ofType { + kind + name + ofType { + kind + name + ofType { + kind + name + ofType { + kind + name + ofType { + kind + name + ofType { + kind + name + ofType { + kind + name + } + } + } + } + } + } + } + } + ]], + variables = {} +} diff --git a/test/unit/graphql_test.lua b/test/unit/graphql_test.lua index 1e10a9a..e67e192 100644 --- a/test/unit/graphql_test.lua +++ b/test/unit/graphql_test.lua @@ -1,10 +1,11 @@ local t = require('luatest') -local g = t.group() +local g = t.group('unit') local parse = require('graphql.parse').parse local types = require('graphql.types') local schema = require('graphql.schema') local validate = require('graphql.validate').validate +local util = require('graphql.util') function g.test_parse_comments() t.assert_error(parse('{a(b:"#")}').definitions, {}) @@ -1057,3 +1058,14 @@ function g.test_boolean_coerce() t.assert_error_msg_contains('Could not coerce value "value" with type "string" to type boolean', validate, test_schema, parse([[ { test_boolean(value: "value") } ]])) end + +function g.test_util_map_name() + local res = util.map_name(nil, nil) + t.assert_equals(res, {}) + + res = util.map_name({ { name = 'a' }, { name = 'b' }, }, function(v) return v end) + t.assert_equals(res, {a = {name = 'a'}, b = {name = 'b'}}) + + res = util.map_name({ entry_a = { name = 'a' }, entry_b = { name = 'b' }, }, function(v) return v end) + t.assert_equals(res, {a = {name = 'a'}, b = {name = 'b'}}) +end