diff --git a/sql/expression/function/json_extract.go b/sql/expression/function/json_extract.go index b49796c5c..5f5876b0f 100644 --- a/sql/expression/function/json_extract.go +++ b/sql/expression/function/json_extract.go @@ -69,10 +69,12 @@ func (j *JSONExtract) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - result[i], err = jsonpath.JsonPathLookup(doc, path.(string)) + c, err := jsonpath.Compile(path.(string)) if err != nil { return nil, err } + + result[i], _ = c.Lookup(doc) // err ignored } if len(result) == 1 { diff --git a/sql/expression/function/json_extract_test.go b/sql/expression/function/json_extract_test.go index 648f99e1a..7e7b9d703 100644 --- a/sql/expression/function/json_extract_test.go +++ b/sql/expression/function/json_extract_test.go @@ -1,6 +1,7 @@ package function import ( + "errors" "testing" "github.com/stretchr/testify/require" @@ -46,22 +47,30 @@ func TestJSONExtract(t *testing.T) { f sql.Expression row sql.Row expected interface{} + err error }{ - {f2, sql.Row{json, "$.b.c"}, "foo"}, - {f3, sql.Row{json, "$.b.c", "$.b.d"}, []interface{}{"foo", true}}, + {f2, sql.Row{json, "FOO"}, nil, errors.New("should start with '$'")}, + {f2, sql.Row{nil, "$.b.c"}, nil, nil}, + {f2, sql.Row{json, "$.foo"}, nil, nil}, + {f2, sql.Row{json, "$.b.c"}, "foo", nil}, + {f3, sql.Row{json, "$.b.c", "$.b.d"}, []interface{}{"foo", true}, nil}, {f4, sql.Row{json, "$.b.c", "$.b.d", "$.e[0][*]"}, []interface{}{ "foo", true, []interface{}{1., 2.}, - }}, + }, nil}, } for _, tt := range testCases { t.Run(tt.f.String(), func(t *testing.T) { require := require.New(t) - result, err := tt.f.Eval(sql.NewEmptyContext(), tt.row) - require.NoError(err) + if tt.err == nil { + require.NoError(err) + } else { + require.Equal(err.Error(), tt.err.Error()) + } + require.Equal(tt.expected, result) }) }