diff --git a/README.md b/README.md index 367660788..5139316c9 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,7 @@ We support and actively test against certain third-party clients to ensure compa |`IFNULL(expr1, expr2)`|If expr1 is not NULL, IFNULL() returns expr1; otherwise it returns expr2.| |`IS_BINARY(blob)`|Returns whether a BLOB is a binary file or not.| |`JSON_EXTRACT(json_doc, path, ...)`|Extracts data from a json document using json paths.| +|`JSON_UNQUOTE(json)`|Unquotes JSON value and returns the result as a utf8mb4 string.| |`LEAST(...)`|Returns the smaller numeric or string value.| |`LENGTH(str)`|Return the length of the string in bytes.| |`LN(X)`|Return the natural logarithm of X.| diff --git a/SUPPORTED.md b/SUPPORTED.md index 9cc357053..40a31dd98 100644 --- a/SUPPORTED.md +++ b/SUPPORTED.md @@ -98,6 +98,7 @@ - IS_BINARY - IS_BINARY - JSON_EXTRACT +- JSON_UNQUOTE - LEAST - LN - LOG10 diff --git a/engine_test.go b/engine_test.go index dfda34bb2..abfc6f3d5 100644 --- a/engine_test.go +++ b/engine_test.go @@ -659,6 +659,22 @@ var queries = []struct { `SELECT JSON_EXTRACT("foo", "$")`, []sql.Row{{"foo"}}, }, + { + `SELECT JSON_UNQUOTE('"foo"')`, + []sql.Row{{"foo"}}, + }, + { + `SELECT JSON_UNQUOTE('[1, 2, 3]')`, + []sql.Row{{"[1, 2, 3]"}}, + }, + { + `SELECT JSON_UNQUOTE('"\\t\\u0032"')`, + []sql.Row{{"\t2"}}, + }, + { + `SELECT JSON_UNQUOTE('"\t\\u0032"')`, + []sql.Row{{"\t2"}}, + }, { `SELECT CONNECTION_ID()`, []sql.Row{{uint32(1)}}, diff --git a/sql/expression/function/json_unquote.go b/sql/expression/function/json_unquote.go new file mode 100644 index 000000000..4b5715c4e --- /dev/null +++ b/sql/expression/function/json_unquote.go @@ -0,0 +1,139 @@ +package function + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "fmt" + "reflect" + "unicode/utf8" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// JSONUnquote unquotes JSON value and returns the result as a utf8mb4 string. +// Returns NULL if the argument is NULL. +// An error occurs if the value starts and ends with double quotes but is not a valid JSON string literal. +type JSONUnquote struct { + expression.UnaryExpression +} + +// NewJSONUnquote creates a new JSONUnquote UDF. +func NewJSONUnquote(json sql.Expression) sql.Expression { + return &JSONUnquote{expression.UnaryExpression{Child: json}} +} + +func (js *JSONUnquote) String() string { + return fmt.Sprintf("JSON_UNQUOTE(%s)", js.Child) +} + +// Type implements the Expression interface. +func (*JSONUnquote) Type() sql.Type { + return sql.Text +} + +// TransformUp implements the Expression interface. +func (js *JSONUnquote) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { + json, err := js.Child.TransformUp(f) + if err != nil { + return nil, err + } + return f(NewJSONUnquote(json)) +} + +// Eval implements the Expression interface. +func (js *JSONUnquote) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + json, err := js.Child.Eval(ctx, row) + if json == nil || err != nil { + return json, err + } + + ex, err := sql.Text.Convert(json) + if err != nil { + return nil, err + } + str, ok := ex.(string) + if !ok { + return nil, sql.ErrInvalidType.New(reflect.TypeOf(ex).String()) + } + + return unquote(str) +} + +// The implementation is taken from TiDB +// https://github.com/pingcap/tidb/blob/a594287e9f402037b06930026906547000006bb6/types/json/binary_functions.go#L89 +func unquote(s string) (string, error) { + ret := new(bytes.Buffer) + for i := 0; i < len(s); i++ { + if s[i] == '\\' { + i++ + if i == len(s) { + return "", fmt.Errorf("Missing a closing quotation mark in string") + } + switch s[i] { + case '"': + ret.WriteByte('"') + case 'b': + ret.WriteByte('\b') + case 'f': + ret.WriteByte('\f') + case 'n': + ret.WriteByte('\n') + case 'r': + ret.WriteByte('\r') + case 't': + ret.WriteByte('\t') + case '\\': + ret.WriteByte('\\') + case 'u': + if i+4 > len(s) { + return "", fmt.Errorf("Invalid unicode: %s", s[i+1:]) + } + char, size, err := decodeEscapedUnicode([]byte(s[i+1 : i+5])) + if err != nil { + return "", err + } + ret.Write(char[0:size]) + i += 4 + default: + // For all other escape sequences, backslash is ignored. + ret.WriteByte(s[i]) + } + } else { + ret.WriteByte(s[i]) + } + } + + str := ret.String() + strlen := len(str) + // Remove prefix and suffix '"'. + if strlen > 1 { + head, tail := str[0], str[strlen-1] + if head == '"' && tail == '"' { + return str[1 : strlen-1], nil + } + } + return str, nil +} + +// decodeEscapedUnicode decodes unicode into utf8 bytes specified in RFC 3629. +// According RFC 3629, the max length of utf8 characters is 4 bytes. +// And MySQL use 4 bytes to represent the unicode which must be in [0, 65536). +// The implementation is taken from TiDB: +// https://github.com/pingcap/tidb/blob/a594287e9f402037b06930026906547000006bb6/types/json/binary_functions.go#L136 +func decodeEscapedUnicode(s []byte) (char [4]byte, size int, err error) { + size, err = hex.Decode(char[0:2], s) + if err != nil || size != 2 { + // The unicode must can be represented in 2 bytes. + return char, 0, err + } + var unicode uint16 + err = binary.Read(bytes.NewReader(char[0:2]), binary.BigEndian, &unicode) + if err != nil { + return char, 0, err + } + size = utf8.RuneLen(rune(unicode)) + utf8.EncodeRune(char[0:size], rune(unicode)) + return +} diff --git a/sql/expression/function/json_unquote_test.go b/sql/expression/function/json_unquote_test.go new file mode 100644 index 000000000..d5d054f10 --- /dev/null +++ b/sql/expression/function/json_unquote_test.go @@ -0,0 +1,37 @@ +package function + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestJSONUnquote(t *testing.T) { + require := require.New(t) + js := NewJSONUnquote(expression.NewGetField(0, sql.Text, "json", false)) + + testCases := []struct { + row sql.Row + expected interface{} + err bool + }{ + {sql.Row{nil}, nil, false}, + {sql.Row{"\"abc\""}, `abc`, false}, + {sql.Row{"[1, 2, 3]"}, `[1, 2, 3]`, false}, + {sql.Row{"\"\t\u0032\""}, "\t2", false}, + {sql.Row{"\\"}, nil, true}, + } + + for _, tt := range testCases { + result, err := js.Eval(sql.NewEmptyContext(), tt.row) + + if !tt.err { + require.NoError(err) + require.Equal(tt.expected, result) + } else { + require.NotNil(err) + } + } +} diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index c1712ee5b..faa5acdd8 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -60,6 +60,7 @@ var Defaults = []sql.Function{ sql.Function0{Name: "connection_id", Fn: NewConnectionID}, sql.Function1{Name: "soundex", Fn: NewSoundex}, sql.FunctionN{Name: "json_extract", Fn: NewJSONExtract}, + sql.Function1{Name: "json_unquote", Fn: NewJSONUnquote}, sql.Function1{Name: "ln", Fn: NewLogBaseFunc(float64(math.E))}, sql.Function1{Name: "log2", Fn: NewLogBaseFunc(float64(2))}, sql.Function1{Name: "log10", Fn: NewLogBaseFunc(float64(10))}, diff --git a/sql/functionregistry.go b/sql/functionregistry.go index b366dedfb..c22e391c6 100644 --- a/sql/functionregistry.go +++ b/sql/functionregistry.go @@ -1,8 +1,8 @@ package sql import ( - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/internal/similartext" + "gopkg.in/src-d/go-errors.v1" ) // ErrFunctionAlreadyRegistered is thrown when a function is already registered @@ -13,7 +13,7 @@ var ErrFunctionNotFound = errors.NewKind("A function: '%s' not found.") // ErrInvalidArgumentNumber is returned when the number of arguments to call a // function is different from the function arity. -var ErrInvalidArgumentNumber = errors.NewKind("A function: '%s' expected %d arguments, %d received.") +var ErrInvalidArgumentNumber = errors.NewKind("A function: '%s' expected %v arguments, %v received.") // Function is a function defined by the user that can be applied in a SQL query. type Function interface {