Skip to content

Commit 67838ac

Browse files
committed
graphql: adapt "Validate a variable type" fron graphql.0
This patch is a port of tarantool/graphql.0@38fc560
1 parent 8c4d414 commit 67838ac

File tree

5 files changed

+311
-55
lines changed

5 files changed

+311
-55
lines changed

cartridge/graphql/execute.lua

Lines changed: 58 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ local types = require(path .. '.types')
33
local util = require(path .. '.util')
44
local introspection = require(path .. '.introspection')
55
local query_util = require(path .. '.query_util')
6+
local validate_variables = require(path .. '.validate_variables')
67

78
local function getFieldResponseKey(field)
89
return field.alias and field.alias.name.value or field.name.value
@@ -70,48 +71,69 @@ local function defaultResolver(object, arguments, info)
7071
return object[info.fieldASTs[1].name.value]
7172
end
7273

73-
local function buildContext(schema, tree, rootValue, variables, operationName)
74-
local context = {
75-
schema = schema,
76-
rootValue = rootValue,
77-
variables = variables,
78-
operation = nil,
79-
fragmentMap = {},
80-
defaultValues = {},
81-
request_cache = {},
82-
}
74+
local function getOperation(tree, operationName)
75+
local operation
8376

84-
for _, definition in ipairs(tree.definitions) do
85-
if definition.kind == 'operation' then
86-
if not operationName and context.operation then
87-
error('Operation name must be specified if more than one operation exists.')
88-
end
77+
for _, definition in ipairs(tree.definitions) do
78+
if definition.kind == 'operation' then
79+
if not operationName and operation then
80+
error('Operation name must be specified if more than one operation exists.')
81+
end
8982

90-
if not operationName or definition.name.value == operationName then
91-
context.operation = definition
92-
end
83+
if not operationName or definition.name.value == operationName then
84+
operation = definition
85+
end
86+
end
87+
end
9388

94-
for _, variableDefinition in ipairs(definition.variableDefinitions or {}) do
95-
if variableDefinition.defaultValue ~= nil then
96-
context.defaultValues[variableDefinition.variable.name.value] =
97-
variableDefinition.defaultValue.value
89+
if not operation then
90+
if operationName then
91+
error('Unknown operation "' .. operationName .. '"')
92+
else
93+
error('Must provide an operation')
94+
end
95+
end
9896

99-
end
100-
end
101-
elseif definition.kind == 'fragmentDefinition' then
102-
context.fragmentMap[definition.name.value] = definition
97+
return operation
98+
end
99+
100+
local function getFragmentDefinitions(tree)
101+
local fragmentMap = {}
102+
103+
for _, definition in ipairs(tree.definitions) do
104+
if definition.kind == 'fragmentDefinition' then
105+
fragmentMap[definition.name.value] = definition
106+
end
103107
end
104-
end
105108

106-
if not context.operation then
107-
if operationName then
108-
error('Unknown operation "' .. operationName .. '"')
109-
else
110-
error('Must provide an operation')
109+
return fragmentMap
110+
end
111+
112+
-- Extract variableTypes from the operation.
113+
local function getVariableTypes(schema, operation)
114+
local variableTypes = {}
115+
116+
for _, definition in ipairs(operation.variableDefinitions or {}) do
117+
variableTypes[definition.variable.name.value] =
118+
query_util.typeFromAST(definition.type, schema)
111119
end
112-
end
113120

114-
return context
121+
return variableTypes
122+
end
123+
124+
local function buildContext(schema, tree, rootValue, variables, operationName)
125+
local operation = getOperation(tree, operationName)
126+
local fragmentMap = getFragmentDefinitions(tree)
127+
local variableTypes = getVariableTypes(schema, operation)
128+
return {
129+
schema = schema,
130+
rootValue = rootValue,
131+
variables = variables,
132+
operation = operation,
133+
fragmentMap = fragmentMap,
134+
variableTypes = variableTypes,
135+
request_cache = {},
136+
}
115137
end
116138

117139
local function collectFields(objectType, selections, visitedFragments, result, context)
@@ -297,6 +319,8 @@ local function execute(schema, tree, rootValue, variables, operationName)
297319
error('Unsupported operation "' .. context.operation.operation .. '"')
298320
end
299321

322+
validate_variables.validate_variables(context)
323+
300324
return evaluateSelections(rootType, rootValue, context.operation.selectionSet.selections, context)
301325
end
302326

cartridge/graphql/types.lua

Lines changed: 73 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
local ffi = require('ffi')
2+
local checks = require('checks')
3+
local errors = require('errors')
14
local util = require('cartridge.graphql.util')
25
local vars = require('cartridge.vars').new('cartridge.graphql.types')
36
local utils = require('cartridge.utils')
4-
local checks = require('checks')
5-
local errors = require('errors')
67

78
local type_not_found = errors.new_class("type_not_found")
89
local validation_error = errors.new_class("validation_error")
@@ -58,6 +59,15 @@ function types.list(kind)
5859
return instance
5960
end
6061

62+
function types.nullable(kind)
63+
assert(type(kind) == 'table', 'kind must be a table, got ' .. type(kind))
64+
65+
if kind.__type ~= 'NonNull' then return kind end
66+
67+
assert(kind.ofType ~= nil, 'kind.ofType must not be nil')
68+
return types.nullable(kind.ofType)
69+
end
70+
6171
function types.scalar(config)
6272
assert(type(config.name) == 'string', 'type name must be provided as a string')
6373
assert(type(config.serialize) == 'function', 'serialize must be a function')
@@ -74,7 +84,8 @@ function types.scalar(config)
7484
description = config.description,
7585
serialize = config.serialize,
7686
parseValue = config.parseValue,
77-
parseLiteral = config.parseLiteral
87+
parseLiteral = config.parseLiteral,
88+
isValueOfTheType = config.isValueOfTheType,
7889
}
7990

8091
instance.nonNull = types.nonNull(instance)
@@ -216,28 +227,58 @@ function types.inputObject(config)
216227
return instance
217228
end
218229

219-
local coerceInt = function(value)
220-
value = tonumber(value)
221-
222-
if not value then return end
230+
-- Based on the code from tarantool/checks.
231+
local function isInt(value)
232+
if type(value) == 'number' then
233+
return value >= -2^31 and value < 2^31 and math.floor(value) == value
234+
end
223235

224-
if value == value and value < 2 ^ 32 and value >= -2 ^ 32 then
225-
return value < 0 and math.ceil(value) or math.floor(value)
236+
if type(value) == 'cdata' then
237+
if ffi.istype('int64_t', value) then
238+
return value >= -2^31 and value < 2^31
239+
elseif ffi.istype('uint64_t', value) then
240+
return value < 2^31
241+
end
226242
end
243+
244+
return false
227245
end
228246

229-
local coerceLong = function(value)
230-
value = tonumber64(value)
247+
-- The code from tarantool/checks.
248+
local function isLong(value)
249+
if type(value) == 'number' then
250+
-- Double floating point format has 52 fraction bits. If we want to keep
251+
-- integer precision, the number must be less than 2^53.
252+
return value > -2^53 and value < 2^53 and math.floor(value) == value
253+
end
231254

232255
if type(value) == 'cdata' then
233-
return value
256+
if ffi.istype('int64_t', value) then
257+
return true
258+
elseif ffi.istype('uint64_t', value) then
259+
return value < 2^63
260+
end
234261
end
235262

236-
if not value then return end
263+
return false
264+
end
265+
266+
local function coerceInt(value)
267+
local value = tonumber(value)
237268

238-
if value == value and value < 2 ^ 52 and value >= -2 ^ 52 then
239-
return value < 0 and math.ceil(value) or math.floor(value)
240-
end
269+
if value == nil then return end
270+
if not isInt(value) then return end
271+
272+
return value
273+
end
274+
275+
local function coerceLong(value)
276+
local value = tonumber64(value)
277+
278+
if value == nil then return end
279+
if not isLong(value) then return end
280+
281+
return value
241282
end
242283

243284
types.int = types.scalar({
@@ -249,7 +290,8 @@ types.int = types.scalar({
249290
if node.kind == 'int' then
250291
return coerceInt(node.value)
251292
end
252-
end
293+
end,
294+
isValueOfTheType = isInt,
253295
})
254296

255297
types.long = types.scalar({
@@ -261,7 +303,8 @@ types.long = types.scalar({
261303
if node.kind == 'long' or node.kind == 'int' then
262304
return coerceLong(node.value)
263305
end
264-
end
306+
end,
307+
isValueOfTheType = isLong,
265308
})
266309

267310
types.float = types.scalar({
@@ -272,7 +315,10 @@ types.float = types.scalar({
272315
if node.kind == 'float' or node.kind == 'int' then
273316
return tonumber(node.value)
274317
end
275-
end
318+
end,
319+
isValueOfTheType = function(value)
320+
return type(value) == 'number'
321+
end,
276322
})
277323

278324
types.string = types.scalar({
@@ -284,7 +330,10 @@ types.string = types.scalar({
284330
if node.kind == 'string' then
285331
return node.value
286332
end
287-
end
333+
end,
334+
isValueOfTheType = function(value)
335+
return type(value) == 'string'
336+
end,
288337
})
289338

290339
local function toboolean(x)
@@ -302,7 +351,10 @@ types.boolean = types.scalar({
302351
else
303352
return nil
304353
end
305-
end
354+
end,
355+
isValueOfTheType = function(value)
356+
return type(value) == 'boolean'
357+
end,
306358
})
307359

308360
types.id = types.scalar({

cartridge/graphql/util.lua

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,64 @@ local function coerceValue(node, schemaType, variables, opts)
148148
end
149149
end
150150

151+
--- Check whether passed value has one of listed types.
152+
---
153+
--- @param obj value to check
154+
---
155+
--- @tparam string obj_name name of the value to form an error
156+
---
157+
--- @tparam string type_1
158+
--- @tparam[opt] string type_2
159+
--- @tparam[opt] string type_3
160+
---
161+
--- @return nothing
162+
local function check(obj, obj_name, type_1, type_2, type_3)
163+
if type(obj) == type_1 or type(obj) == type_2 or type(obj) == type_3 then
164+
return
165+
end
166+
167+
if type_3 ~= nil then
168+
error(('%s must be a %s or a % or a %s, got %s'):format(obj_name,
169+
type_1, type_2, type_3, type(obj)))
170+
elseif type_2 ~= nil then
171+
error(('%s must be a %s or a %s, got %s'):format(obj_name, type_1,
172+
type_2, type(obj)))
173+
else
174+
error(('%s must be a %s, got %s'):format(obj_name, type_1, type(obj)))
175+
end
176+
end
177+
178+
--- Check whether table is an array.
179+
---
180+
--- Based on [that][1] implementation.
181+
--- [1]: https://github.com/mpx/lua-cjson/blob/db122676/lua/cjson/util.lua
182+
---
183+
--- @tparam table table to check
184+
--- @return[1] `true` if passed table is an array (includes the empty table
185+
--- case)
186+
--- @return[2] `false` otherwise
187+
local function is_array(table)
188+
check(table, 'table', 'table')
189+
190+
local max = 0
191+
local count = 0
192+
for k, _ in pairs(table) do
193+
if type(k) == 'number' then
194+
if k > max then
195+
max = k
196+
end
197+
count = count + 1
198+
else
199+
return false
200+
end
201+
end
202+
if max > count * 2 then
203+
return false
204+
end
205+
206+
return max >= 0
207+
end
208+
151209
return {
152210
map = map,
153211
find = find,
@@ -158,4 +216,7 @@ return {
158216
trim = trim,
159217
getTypeName = getTypeName,
160218
coerceValue = coerceValue,
219+
220+
is_array = is_array,
221+
check = check,
161222
}

0 commit comments

Comments
 (0)