diff --git a/graphql/execute.lua b/graphql/execute.lua index 56a9c44..8622736 100644 --- a/graphql/execute.lua +++ b/graphql/execute.lua @@ -335,6 +335,7 @@ local function getFieldEntry(objectType, object, fields, context) variableValues = context.variables, defaultValues = context.defaultValues, directives = directives, + directivesDefaultValues = context.schema.directivesDefaultValues, } local resolvedObject, err = (fieldType.resolve or defaultResolver)(object, arguments, info) @@ -352,9 +353,27 @@ evaluateSelections = function(objectType, object, selections, context) local result = {} local err local fields = collectFields(objectType, selections, {}, {}, context) + local defaultValues + + if context.defaultValues == nil then + if context.schema.defaultValues ~= nil and type(context.schema.defaultValues) == 'table' then + local operationDefaults = context.schema.defaultValues[context.operation.operation] + if operationDefaults ~= nil and type(operationDefaults) == 'table' then + defaultValues = context.schema.defaultValues[context.operation.operation] + end + end + else + defaultValues = context.defaultValues + end + for _, field in ipairs(fields) do assert(result[field.name] == nil, 'two selections into the one field: ' .. field.name) + + if defaultValues ~= nil then + context.defaultValues = defaultValues[field.name] + end + result[field.name], err = getFieldEntry(objectType, object, {field.selection}, context) if err ~= nil then diff --git a/graphql/schema.lua b/graphql/schema.lua index b49e454..cff57b4 100644 --- a/graphql/schema.lua +++ b/graphql/schema.lua @@ -8,10 +8,11 @@ end local schema = {} schema.__index = schema -function schema.create(config, name) +function schema.create(config, name, opts) assert(type(config.query) == 'table', 'must provide query object') assert(not config.mutation or type(config.mutation) == 'table', 'mutation must be a table if provided') + opts = opts or {} local self = setmetatable({}, schema) for k, v in pairs(config) do @@ -34,9 +35,101 @@ function schema.create(config, name) self:generateTypeMap(introspection.__Schema) self:generateDirectiveMap() + if opts.defaultValues == true then + self.defaultValues = {} + self.defaultValues.mutation = self:extractDefaults(self.mutation) + self.defaultValues.query = self:extractDefaults(self.query) + end + + if opts.directivesDefaultValues == true then + self.directivesDefaultValues = {} + + for directiveName, directive in pairs(self.directiveMap or {}) do + self.directivesDefaultValues[directiveName] = self:extractDefaults(directive) + end + end + return self end +function schema:extractDefaults(node) + if not node then return end + + local defaultValues + local nodeType = node.__type ~= nil and node or node.kind + + if nodeType.__type == 'NonNull' then + return self:extractDefaults(nodeType.ofType) + end + + if nodeType.__type == 'Enum' then + return node.defaultValue + end + + if nodeType.__type == 'Scalar' then + return node.defaultValue + end + + node.fields = type(node.fields) == 'function' and node.fields() or node.fields + + if nodeType.__type == 'Object' or nodeType.__type == 'InputObject' then + for fieldName, field in pairs(nodeType.fields or {}) do + local fieldDefaultValue = self:extractDefaults(field) + if fieldDefaultValue ~= nil then + defaultValues = defaultValues or {} + defaultValues[fieldName] = fieldDefaultValue + end + + for argumentName, argument in pairs(field.arguments or {}) do + -- BEGIN_HACK: resolve type names to real types + if type(argument) == 'string' then + argument = types.resolve(argument, self.name) + field.arguments[argumentName] = argument + end + + if type(argument.kind) == 'string' then + argument.kind = types.resolve(argument.kind, self.name) + end + -- END_HACK: resolve type names to real types + + local argumentDefaultValue = self:extractDefaults(argument) + if argumentDefaultValue ~= nil then + defaultValues = defaultValues or {} + defaultValues[fieldName] = defaultValues[fieldName] or {} + defaultValues[fieldName][argumentName] = argumentDefaultValue + end + end + end + return defaultValues + end + + if nodeType.__type =='Directive' then + for argumentName, argument in pairs(nodeType.arguments or {}) do + -- BEGIN_HACK: resolve type names to real types + if type(argument) == 'string' then + argument = types.resolve(argument, self.name) + nodeType.arguments[argumentName] = argument + end + + if type(argument.kind) == 'string' then + argument.kind = types.resolve(argument.kind, self.name) + end + -- END_HACK: resolve type names to real types + + local argumentDefaultValue = self:extractDefaults(argument) + if argumentDefaultValue ~= nil then + defaultValues = defaultValues or {} + defaultValues[argumentName] = argumentDefaultValue + end + end + return defaultValues + end + + if nodeType.__type == 'List' then + return self:extractDefaults(nodeType.ofType) + end +end + function schema:generateTypeMap(node) if not node or (self.typeMap[node.name] and self.typeMap[node.name] == node) then return end diff --git a/test/integration/graphql_test.lua b/test/integration/graphql_test.lua index 6a0f88d..be7d57e 100644 --- a/test/integration/graphql_test.lua +++ b/test/integration/graphql_test.lua @@ -25,7 +25,7 @@ local function check_request(query, query_schema, mutation_schema, directives, o directives = directives, } - local compiled_schema = schema.create(root, test_schema_name) + local compiled_schema = schema.create(root, test_schema_name, opts) local parsed = parse.parse(query) @@ -1871,3 +1871,183 @@ function g.test_mutation_and_directive_arguments_default_values() -- directive default values used on the same argument type t.assert_equals(compiled_schema.typeMap['Int'].defaultValue, nil) end + +g.test_propagate_defaults_to_callback = function() + local query = '{test_mutation}' + + local function callback(_, _, info) + return json.encode({ + defaultValues = info.defaultValues, + directivesDefaultValues = info.directivesDefaultValues, + }) + end + + local input_object = types.inputObject({ + name = 'test_input_object', + fields = { + nested_int_arg = { + kind = types.int, + defaultValue = 2, + }, + nested_string_arg = { + kind = types.string, + defaultValue = 'default nested value', + }, + nested_boolean_arg = { + kind = types.boolean, + defaultValue = true, + }, + nested_float_arg = { + kind = types.float, + defaultValue = 1.1, + }, + nested_long_arg = { + kind = types.long, + defaultValue = 2^50, + }, + nested_list_scalar_arg = { + kind = types.list(types.string), + -- defaultValue seems illogical + } + }, + kind = types.string, + }) + + local mutation_schema = { + ['test_mutation'] = { + kind = types.string.nonNull, + arguments = { + int_arg = { + kind = types.int, + defaultValue = 1, + }, + string_arg = { + kind = types.string, + defaultValue = 'string_arg' + }, + boolean_arg = { + kind = types.boolean, + defaultValue = false, + }, + float_arg = { + kind = types.float, + defaultValue = 1.1, + }, + long_arg = { + kind = types.long, + defaultValue = 2^50, + }, + object_arg = { + kind = input_object, + -- defaultValue seems illogical + }, + list_scalar_arg = { + kind = types.list(types.string), + -- defaultValue seems illogical + } + }, + resolve = callback, + } + } + + local directives = { + types.directive({ + schema = schema, + name = 'timeout', + description = 'Request execute timeout', + arguments = { + int_arg = { + kind = types.int, + defaultValue = 1, + }, + string_arg = { + kind = types.string, + defaultValue = 'string_arg' + }, + boolean_arg = { + kind = types.boolean, + defaultValue = false, + }, + float_arg = { + kind = types.float, + defaultValue = 1.1, + }, + long_arg = { + kind = types.long, + defaultValue = 2^50, + }, + object = input_object + }, + onField = true, + }) + } + + local result = { + defaultValues = { + boolean_arg = false, + int_arg = 1, + float_arg = 1.1, + long_arg = 2^50, + object_arg = { + nested_boolean_arg = true, + nested_float_arg = 1.1, + nested_int_arg = 2, + nested_long_arg = 2^50, + nested_string_arg = "default nested value", + }, + string_arg = "string_arg", + }, + directivesDefaultValues = { + timeout = { + boolean_arg = false, + float_arg = 1.1, + int_arg = 1, + long_arg = 2^50, + object = { + nested_boolean_arg = true, + nested_float_arg = 1.1, + nested_int_arg = 2, + nested_long_arg = 2^50, + nested_string_arg = "default nested value", + }, + string_arg = "string_arg", + }, + }, + } + + local data, errors = check_request( + query, + mutation_schema, + nil, + directives, + { defaultValues = true, directivesDefaultValues = true, } + ) + + t.assert_equals(errors, nil) + t.assert_items_equals(json.decode(data.test_mutation), result) + + query = '{prefix{test_mutation}}' + + local mutation_schema_with_prefix = { + ['prefix'] = { + kind = types.object({ + name = 'prefix', + fields = mutation_schema, + }), + arguments = {}, + resolve = function() + return {} + end, + } + } + + data, errors = check_request( + query, + mutation_schema_with_prefix, + nil, + directives, + { defaultValues = true, directivesDefaultValues = true, } + ) + t.assert_equals(errors, nil) + t.assert_items_equals(json.decode(data.prefix.test_mutation), result) +end