diff --git a/cmd/gitbase/command/server.go b/cmd/gitbase/command/server.go index 13a2e0f3e..15265b7d2 100644 --- a/cmd/gitbase/command/server.go +++ b/cmd/gitbase/command/server.go @@ -56,6 +56,7 @@ type Server struct { DisableGit bool `long:"no-git" description:"disable the load of git standard repositories."` DisableSiva bool `long:"no-siva" description:"disable the load of siva files."` Verbose bool `short:"v" description:"Activates the verbose mode"` + OldUast bool `long:"old-uast-serialization" description:"serialize uast in the old format" env:"GITBASE_UAST_SERIALIZATION"` } type jaegerLogrus struct { @@ -138,12 +139,17 @@ func (c *Server) Execute(args []string) error { c.engine, gitbase.NewSessionBuilder(c.pool, gitbase.WithSkipGitErrors(c.SkipGitErrors), + gitbase.WithOldUASTSerialization(c.OldUast), ), ) if err != nil { return err } + if c.OldUast { + function.UASTExpressionType = sql.Array(sql.Blob) + } + logrus.Infof("server started and listening on %s:%d", c.Host, c.Port) return s.Start() } diff --git a/internal/function/uast.go b/internal/function/uast.go index 5a9e65b79..b0f43d178 100644 --- a/internal/function/uast.go +++ b/internal/function/uast.go @@ -4,20 +4,18 @@ import ( "fmt" "strings" - "github.com/sirupsen/logrus" - "github.com/src-d/gitbase" bblfsh "gopkg.in/bblfsh/client-go.v2" "gopkg.in/bblfsh/client-go.v2/tools" "gopkg.in/bblfsh/sdk.v1/uast" - errors "gopkg.in/src-d/go-errors.v1" "gopkg.in/src-d/go-mysql-server.v0/sql" "gopkg.in/src-d/go-mysql-server.v0/sql/expression" ) var ( - // ErrParseBlob is returned when the blob can't be parsed with bblfsh. - ErrParseBlob = errors.NewKind("unable to parse the given blob using bblfsh: %s") + // UASTExpressionType represents the returned SQL type by + // the functions uast, uast_mode, uast_xpath and uast_children. + UASTExpressionType sql.Type = sql.Blob ) // UAST returns an array of UAST nodes as blobs. @@ -62,7 +60,7 @@ func (f UAST) Resolved() bool { // Type implements the Expression interface. func (f UAST) Type() sql.Type { - return sql.Array(sql.Blob) + return UASTExpressionType } // Children implements the Expression interface. @@ -181,7 +179,7 @@ func (f UASTMode) Resolved() bool { // Type implements the Expression interface. func (f UASTMode) Type() sql.Type { - return sql.Array(sql.Blob) + return UASTExpressionType } // Children implements the Expression interface. @@ -290,7 +288,7 @@ func NewUASTXPath(uast, xpath sql.Expression) sql.Expression { // Type implements the Expression interface. func (UASTXPath) Type() sql.Type { - return sql.Array(sql.Blob) + return UASTExpressionType } // Eval implements the Expression interface. @@ -309,15 +307,15 @@ func (f *UASTXPath) Eval(ctx *sql.Context, row sql.Row) (out interface{}, err er return nil, err } - if left == nil { - return nil, nil - } - - nodes, err := nodesFromBlobArray(left) + nodes, err := getNodes(ctx, left) if err != nil { return nil, err } + if nodes == nil { + return nil, nil + } + xpath, err := exprToString(ctx, f.Right, row) if err != nil { return nil, err @@ -327,58 +325,17 @@ func (f *UASTXPath) Eval(ctx *sql.Context, row sql.Row) (out interface{}, err er return nil, nil } - var result []interface{} + var filtered []*uast.Node for _, n := range nodes { ns, err := tools.Filter(n, xpath) if err != nil { return nil, err } - m, err := marshalNodes(ns) - if err != nil { - return nil, err - } - - result = append(result, m...) - } - - return result, nil -} - -func nodesFromBlobArray(data interface{}) ([]*uast.Node, error) { - data, err := sql.Array(sql.Blob).Convert(data) - if err != nil { - return nil, err - } - - arr := data.([]interface{}) - var nodes = make([]*uast.Node, len(arr)) - for i, n := range arr { - node := uast.NewNode() - if err := node.Unmarshal(n.([]byte)); err != nil { - return nil, err - } - - nodes[i] = node - } - - return nodes, nil -} - -func marshalNodes(nodes []*uast.Node) ([]interface{}, error) { - m := make([]interface{}, 0, len(nodes)) - for _, n := range nodes { - if n != nil { - data, err := n.Marshal() - if err != nil { - return nil, err - } - - m = append(m, data) - } + filtered = append(filtered, ns...) } - return m, nil + return marshalNodes(ctx, filtered) } func (f UASTXPath) String() string { @@ -400,84 +357,6 @@ func (f UASTXPath) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) return fn(NewUASTXPath(left, right)) } -func exprToString( - ctx *sql.Context, - e sql.Expression, - r sql.Row, -) (string, error) { - if e == nil { - return "", nil - } - - x, err := e.Eval(ctx, r) - if err != nil { - return "", err - } - - if x == nil { - return "", nil - } - - x, err = sql.Text.Convert(x) - if err != nil { - return "", err - } - - return x.(string), nil -} - -func getUAST( - ctx *sql.Context, - bytes []byte, - lang, xpath string, - mode bblfsh.Mode, -) (interface{}, error) { - session, ok := ctx.Session.(*gitbase.Session) - if !ok { - return nil, gitbase.ErrInvalidGitbaseSession.New(ctx.Session) - } - - client, err := session.BblfshClient() - if err != nil { - return nil, err - } - - // If we have a language we must check if it's supported. If we don't, bblfsh - // is the one that will have to identify the language. - if lang != "" { - ok, err = client.IsLanguageSupported(ctx, lang) - if err != nil { - return nil, err - } - - if !ok { - return nil, nil - } - } - - resp, err := client.ParseWithMode(ctx, mode, lang, bytes) - if err != nil { - logrus.Warn(ErrParseBlob.New(err)) - return nil, nil - } - - if len(resp.Errors) > 0 { - logrus.Warn(ErrParseBlob.New(strings.Join(resp.Errors, "\n"))) - } - - var nodes []*uast.Node - if xpath == "" { - nodes = []*uast.Node{resp.UAST} - } else { - nodes, err = tools.Filter(resp.UAST, xpath) - if err != nil { - return nil, err - } - } - - return marshalNodes(nodes) -} - // UASTExtract extracts keys from an UAST. type UASTExtract struct { expression.BinaryExpression @@ -514,15 +393,15 @@ func (u *UASTExtract) Eval(ctx *sql.Context, row sql.Row) (out interface{}, err return nil, err } - if left == nil { - return nil, nil - } - - nodes, err := nodesFromBlobArray(left) + nodes, err := getNodes(ctx, left) if err != nil { return nil, err } + if nodes == nil { + return nil, nil + } + key, err := exprToString(ctx, u.Right, row) if err != nil { return nil, err @@ -609,7 +488,7 @@ func (u *UASTChildren) String() string { // Type implements the sql.Expression interface. func (u *UASTChildren) Type() sql.Type { - return sql.Array(sql.Blob) + return UASTExpressionType } // TransformUp implements the sql.Expression interface. @@ -638,17 +517,17 @@ func (u *UASTChildren) Eval(ctx *sql.Context, row sql.Row) (out interface{}, err return nil, err } - if child == nil { - return nil, nil - } - - nodes, err := nodesFromBlobArray(child) + nodes, err := getNodes(ctx, child) if err != nil { return nil, err } + if nodes == nil { + return nil, nil + } + children := flattenChildren(nodes) - return marshalNodes(children) + return marshalNodes(ctx, children) } func flattenChildren(nodes []*uast.Node) []*uast.Node { diff --git a/internal/function/uast_test.go b/internal/function/uast_test.go index 3b1d757c6..ef54afc89 100644 --- a/internal/function/uast_test.go +++ b/internal/function/uast_test.go @@ -10,7 +10,6 @@ import ( "gopkg.in/bblfsh/client-go.v2/tools" "gopkg.in/bblfsh/sdk.v1/protocol" "gopkg.in/bblfsh/sdk.v1/uast" - errors "gopkg.in/src-d/go-errors.v1" fixtures "gopkg.in/src-d/go-git-fixtures.v3" "gopkg.in/src-d/go-mysql-server.v0/sql" "gopkg.in/src-d/go-mysql-server.v0/sql/expression" @@ -39,31 +38,34 @@ func TestUASTMode(t *testing.T) { Lang: expression.NewGetField(2, sql.Text, "", false), } - u, _ := bblfshFixtures(t, ctx) - - testCases := []struct { - name string - fn *UASTMode - row sql.Row - expected interface{} - }{ - {"annotated", mode, sql.NewRow("annotated", []byte(testCode), "Python"), u["annotated"]}, - {"semantic", mode, sql.NewRow("semantic", []byte(testCode), "Python"), u["semantic"]}, - {"native", mode, sql.NewRow("native", []byte(testCode), "Python"), u["native"]}, - } + oldSerialization := []bool{true, false} + for _, s := range oldSerialization { + session, ok := ctx.Session.(*gitbase.Session) + require.True(t, ok) + + session.OldUASTSerialization = s + + u, _ := bblfshFixtures(t, ctx) + testCases := []struct { + name string + fn *UASTMode + row sql.Row + expected interface{} + }{ + {"annotated", mode, sql.NewRow("annotated", []byte(testCode), "Python"), u["annotated"]}, + {"semantic", mode, sql.NewRow("semantic", []byte(testCode), "Python"), u["semantic"]}, + {"native", mode, sql.NewRow("native", []byte(testCode), "Python"), u["native"]}, + } - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - require := require.New(t) - result, err := tt.fn.Eval(ctx, tt.row) - require.NoError(err) + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + result, err := tt.fn.Eval(ctx, tt.row) + require.NoError(err) - if _, ok := tt.expected.([]interface{}); ok { - assertUASTBlobs(t, tt.expected, result) - } else { - require.Equal(tt.expected, result) - } - }) + assertUASTBlobs(t, ctx, tt.expected, result) + }) + } } } @@ -86,37 +88,41 @@ func TestUAST(t *testing.T) { XPath: expression.NewGetField(2, sql.Text, "", false), } - u, f := bblfshFixtures(t, ctx) - uast := u["semantic"] - filteredNodes := f["semantic"] - - testCases := []struct { - name string - fn *UAST - row sql.Row - expected interface{} - }{ - {"blob is nil", fn3, sql.NewRow(nil, nil, nil), nil}, - {"lang is nil", fn3, sql.NewRow([]byte{}, nil, nil), nil}, - {"xpath is nil", fn3, sql.NewRow([]byte{}, "Ruby", nil), nil}, - {"only blob, can't infer language", fn1, sql.NewRow([]byte(testCode)), nil}, - {"blob with unsupported lang", fn2, sql.NewRow([]byte(testCode), "YAML"), nil}, - {"blob with lang", fn2, sql.NewRow([]byte(testCode), "Python"), uast}, - {"blob with lang and xpath", fn3, sql.NewRow([]byte(testCode), "Python", testXPathSemantic), filteredNodes}, - } + oldSerialization := []bool{true, false} + for _, s := range oldSerialization { + session, ok := ctx.Session.(*gitbase.Session) + require.True(t, ok) + + session.OldUASTSerialization = s + + u, f := bblfshFixtures(t, ctx) + uast := u["semantic"] + filteredNodes := f["semantic"] + + testCases := []struct { + name string + fn *UAST + row sql.Row + expected interface{} + }{ + {"blob is nil", fn3, sql.NewRow(nil, nil, nil), nil}, + {"lang is nil", fn3, sql.NewRow([]byte{}, nil, nil), nil}, + {"xpath is nil", fn3, sql.NewRow([]byte{}, "Ruby", nil), nil}, + {"only blob, can't infer language", fn1, sql.NewRow([]byte(testCode)), nil}, + {"blob with unsupported lang", fn2, sql.NewRow([]byte(testCode), "YAML"), nil}, + {"blob with lang", fn2, sql.NewRow([]byte(testCode), "Python"), uast}, + {"blob with lang and xpath", fn3, sql.NewRow([]byte(testCode), "Python", testXPathSemantic), filteredNodes}, + } - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - require := require.New(t) - result, err := tt.fn.Eval(ctx, tt.row) - require.NoError(err) + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + result, err := tt.fn.Eval(ctx, tt.row) + require.NoError(err) - if _, ok := tt.expected.([]interface{}); ok { - assertUASTBlobs(t, tt.expected, result) - } else { - require.Equal(tt.expected, result) - } - }) + assertUASTBlobs(t, ctx, tt.expected, result) + }) + } } } @@ -129,38 +135,36 @@ func TestUASTXPath(t *testing.T) { expression.NewGetField(1, sql.Text, "", false), ) - u, f := bblfshFixtures(t, ctx) - - testCases := []struct { - name string - row sql.Row - expected interface{} - err *errors.Kind - }{ - {"left is nil", sql.NewRow(nil, "foo"), nil, nil}, - {"right is nil", sql.NewRow(u["semantic"], nil), nil, nil}, - {"both given", sql.NewRow(u["semantic"], testXPathSemantic), f["semantic"], nil}, - {"native", sql.NewRow(u["native"], testXPathNative), f["native"], nil}, - {"annotated", sql.NewRow(u["annotated"], testXPathAnnotated), f["annotated"], nil}, - } + oldSerialization := []bool{true, false} + for _, s := range oldSerialization { + session, ok := ctx.Session.(*gitbase.Session) + require.True(t, ok) + + session.OldUASTSerialization = s + + u, f := bblfshFixtures(t, ctx) + + testCases := []struct { + name string + row sql.Row + expected interface{} + }{ + {"left is nil", sql.NewRow(nil, "foo"), nil}, + {"right is nil", sql.NewRow(u["semantic"], nil), nil}, + {"both given", sql.NewRow(u["semantic"], testXPathSemantic), f["semantic"]}, + {"native", sql.NewRow(u["native"], testXPathNative), f["native"]}, + {"annotated", sql.NewRow(u["annotated"], testXPathAnnotated), f["annotated"]}, + } - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - require := require.New(t) - result, err := fn.Eval(ctx, tt.row) - if tt.err != nil { - require.Error(err) - require.True(tt.err.Is(err)) - } else { + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + result, err := fn.Eval(ctx, tt.row) require.NoError(err) - if _, ok := tt.expected.([]interface{}); ok { - assertUASTBlobs(t, tt.expected, result) - } else { - require.Equal(tt.expected, result) - } - } - }) + assertUASTBlobs(t, ctx, tt.expected, result) + }) + } } } @@ -263,21 +267,37 @@ func TestUASTExtract(t *testing.T) { }, } - _, filteredNodes := bblfshFixtures(t, ctx) + oldSerialization := []bool{true, false} + for _, s := range oldSerialization { + session, ok := ctx.Session.(*gitbase.Session) + require.True(t, ok) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - row := sql.NewRow(filteredNodes["annotated"], test.key) + session.OldUASTSerialization = s - fn := NewUASTExtract( - expression.NewGetField(0, sql.Array(sql.Blob), "", false), - expression.NewLiteral(test.key, sql.Text), - ) + var UASTType sql.Type + session.OldUASTSerialization = s + if s { + UASTType = sql.Array(sql.Blob) + } else { + UASTType = sql.Blob + } + + _, filteredNodes := bblfshFixtures(t, ctx) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + row := sql.NewRow(filteredNodes["annotated"], test.key) + + fn := NewUASTExtract( + expression.NewGetField(0, UASTType, "", false), + expression.NewLiteral(test.key, sql.Text), + ) - foo, err := fn.Eval(ctx, row) - require.NoError(t, err) - require.ElementsMatch(t, test.expected, foo) - }) + foo, err := fn.Eval(ctx, row) + require.NoError(t, err) + require.ElementsMatch(t, test.expected, foo) + }) + } } } @@ -287,77 +307,73 @@ func TestUASTChildren(t *testing.T) { ctx, cleanup := setup(t) defer cleanup() - uasts, _ := bblfshFixtures(t, ctx) modes := []string{"semantic", "annotated", "native"} - - for _, mode := range modes { - root, ok := uasts[mode] + oldSerialization := []bool{true, false} + for _, s := range oldSerialization { + session, ok := ctx.Session.(*gitbase.Session) require.True(ok) - nodes, err := nodesFromBlobArray(root) - require.NoError(err) - require.Len(nodes, 1) - expected := nodes[0].Children + var UASTType sql.Type + session.OldUASTSerialization = s + if s { + UASTType = sql.Array(sql.Blob) + } else { + UASTType = sql.Blob + } - row := sql.NewRow(root) + uasts, _ := bblfshFixtures(t, ctx) + for _, mode := range modes { + root, ok := uasts[mode] + require.True(ok) - fn := NewUASTChildren( - expression.NewGetField(0, sql.Array(sql.Blob), "", false), - ) + nodes, err := getNodes(ctx, root) + require.NoError(err) + require.Len(nodes, 1) + expected := nodes[0].Children - children, err := fn.Eval(ctx, row) - require.NoError(err) + row := sql.NewRow(root) - nodes, err = nodesFromBlobArray(children) - require.NoError(err) - require.Len(nodes, len(expected)) - for i, n := range nodes { - require.Equal( - n.InternalType, - expected[i].InternalType, + fn := NewUASTChildren( + expression.NewGetField(0, UASTType, "", false), ) + + children, err := fn.Eval(ctx, row) + require.NoError(err) + + nodes, err = getNodes(ctx, children) + require.NoError(err) + require.Len(nodes, len(expected)) + for i, n := range nodes { + require.Equal( + n.InternalType, + expected[i].InternalType, + ) + } } } } -func assertUASTBlobs(t *testing.T, a, b interface{}) { +func assertUASTBlobs(t *testing.T, ctx *sql.Context, a, b interface{}) { t.Helper() - require := require.New(t) - - expected, ok := a.([]interface{}) - require.True(ok) - - result, ok := b.([]interface{}) - require.True(ok) - - require.Equal(len(expected), len(result)) + var require = require.New(t) - var expectedNodes = make([]*uast.Node, len(expected)) - var resultNodes = make([]*uast.Node, len(result)) + expected, err := getNodes(ctx, a) + require.NoError(err) - for i, n := range expected { - node := uast.NewNode() - require.NoError(node.Unmarshal(n.([]byte))) - expectedNodes[i] = node - } + result, err := getNodes(ctx, b) + require.NoError(err) - for i, n := range result { - node := uast.NewNode() - require.NoError(node.Unmarshal(n.([]byte))) - resultNodes[i] = node - } - - require.Equal(expectedNodes, resultNodes) + require.Equal(expected, result) } func bblfshFixtures( t *testing.T, ctx *sql.Context, -) (map[string][]interface{}, map[string][]interface{}) { +) (map[string]interface{}, map[string]interface{}) { t.Helper() - uast := make(map[string][]interface{}) - filteredNodes := make(map[string][]interface{}) + uasts := make(map[string]interface{}) + filteredNodes := make(map[string]interface{}) modes := []struct { n string @@ -382,24 +398,20 @@ func bblfshFixtures( require.NoError(t, err) require.Equal(t, protocol.Ok, resp.Status, "errors: %v", resp.Errors) - testUAST, err := resp.UAST.Marshal() - require.NoError(t, err) idents, err := tools.Filter(resp.UAST, mode.x) require.NoError(t, err) - var identBlobs []interface{} - for _, id := range idents { - i, err := id.Marshal() - require.NoError(t, err) - identBlobs = append(identBlobs, i) - } + testUAST, err := marshalNodes(ctx, []*uast.Node{resp.UAST}) + require.NoError(t, err) + uasts[mode.n] = testUAST - uast[mode.n] = []interface{}{testUAST} - filteredNodes[mode.n] = identBlobs + testIdents, err := marshalNodes(ctx, idents) + require.NoError(t, err) + filteredNodes[mode.n] = testIdents } - return uast, filteredNodes + return uasts, filteredNodes } func setup(t *testing.T) (*sql.Context, func()) { diff --git a/internal/function/uast_utils.go b/internal/function/uast_utils.go new file mode 100644 index 000000000..ab51834b0 --- /dev/null +++ b/internal/function/uast_utils.go @@ -0,0 +1,245 @@ +package function + +import ( + "bytes" + "encoding/binary" + "io" + "strings" + + "github.com/sirupsen/logrus" + "github.com/src-d/gitbase" + bblfsh "gopkg.in/bblfsh/client-go.v2" + "gopkg.in/bblfsh/client-go.v2/tools" + "gopkg.in/bblfsh/sdk.v1/uast" + errors "gopkg.in/src-d/go-errors.v1" + "gopkg.in/src-d/go-mysql-server.v0/sql" +) + +var ( + // ErrParseBlob is returned when the blob can't be parsed with bblfsh. + ErrParseBlob = errors.NewKind("unable to parse the given blob using bblfsh: %s") + + // ErrUnmarshalUAST is returned when an error arises unmarshaling UASTs. + ErrUnmarshalUAST = errors.NewKind("error unmarshaling UAST: %s") + + // ErrMarshalUAST is returned when an error arises marshaling UASTs. + ErrMarshalUAST = errors.NewKind("error marshaling uast node: %s") +) + +func exprToString( + ctx *sql.Context, + e sql.Expression, + r sql.Row, +) (string, error) { + if e == nil { + return "", nil + } + + x, err := e.Eval(ctx, r) + if err != nil { + return "", err + } + + if x == nil { + return "", nil + } + + x, err = sql.Text.Convert(x) + if err != nil { + return "", err + } + + return x.(string), nil +} + +func getUAST( + ctx *sql.Context, + bytes []byte, + lang, xpath string, + mode bblfsh.Mode, +) (interface{}, error) { + session, ok := ctx.Session.(*gitbase.Session) + if !ok { + return nil, gitbase.ErrInvalidGitbaseSession.New(ctx.Session) + } + + client, err := session.BblfshClient() + if err != nil { + return nil, err + } + + // If we have a language we must check if it's supported. If we don't, bblfsh + // is the one that will have to identify the language. + if lang != "" { + ok, err = client.IsLanguageSupported(ctx, lang) + if err != nil { + return nil, err + } + + if !ok { + return nil, nil + } + } + + resp, err := client.ParseWithMode(ctx, mode, lang, bytes) + if err != nil { + logrus.Warn(ErrParseBlob.New(err)) + return nil, nil + } + + if len(resp.Errors) > 0 { + logrus.Warn(ErrParseBlob.New(strings.Join(resp.Errors, "\n"))) + } + + var nodes []*uast.Node + if xpath == "" { + nodes = []*uast.Node{resp.UAST} + } else { + nodes, err = tools.Filter(resp.UAST, xpath) + if err != nil { + return nil, err + } + } + + return marshalNodes(ctx, nodes) +} + +func marshalNodes(ctx *sql.Context, nodes []*uast.Node) (data interface{}, err error) { + session, ok := ctx.Session.(*gitbase.Session) + if !ok { + return nil, gitbase.ErrInvalidGitbaseSession.New(ctx.Session) + } + + if session.OldUASTSerialization { + data, err = marshalAsListNodes(nodes) + } else { + data, err = marshalAsBlobNodes(nodes) + } + + return data, err +} + +func marshalAsListNodes(nodes []*uast.Node) ([]interface{}, error) { + m := make([]interface{}, 0, len(nodes)) + for _, n := range nodes { + if n != nil { + data, err := n.Marshal() + if err != nil { + return nil, err + } + + m = append(m, data) + } + } + + return m, nil +} + +func marshalAsBlobNodes(nodes []*uast.Node) (out []byte, err error) { + defer func() { + if r := recover(); r != nil { + out, err = nil, r.(error) + } + }() + + buf := &bytes.Buffer{} + for _, n := range nodes { + if n != nil { + data, err := n.Marshal() + if err != nil { + return nil, err + } + + if err := binary.Write( + buf, binary.BigEndian, int32(len(data)), + ); err != nil { + return nil, err + } + + n, _ := buf.Write(data) + if n != len(data) { + return nil, ErrMarshalUAST.New("couldn't write all the data") + } + } + } + + return buf.Bytes(), nil +} + +func getNodes(ctx *sql.Context, data interface{}) (nodes []*uast.Node, err error) { + session, ok := ctx.Session.(*gitbase.Session) + if !ok { + return nil, gitbase.ErrInvalidGitbaseSession.New(ctx.Session) + } + + if session.OldUASTSerialization { + nodes, err = nodesFromBlobArray(data) + } else { + nodes, err = nodesFromBlob(data) + } + + return nodes, err +} + +func nodesFromBlobArray(data interface{}) ([]*uast.Node, error) { + if data == nil { + return nil, nil + } + + data, err := sql.Array(sql.Blob).Convert(data) + if err != nil { + return nil, err + } + + arr := data.([]interface{}) + var nodes = make([]*uast.Node, len(arr)) + for i, n := range arr { + node := uast.NewNode() + if err := node.Unmarshal(n.([]byte)); err != nil { + return nil, err + } + + nodes[i] = node + } + + return nodes, nil +} + +func nodesFromBlob(data interface{}) ([]*uast.Node, error) { + if data == nil { + return nil, nil + } + + raw, ok := data.([]byte) + if !ok { + return nil, ErrUnmarshalUAST.New("wrong underlying UAST format") + } + + return unmarshalNodes(raw) +} + +func unmarshalNodes(data []byte) ([]*uast.Node, error) { + nodes := []*uast.Node{} + buf := bytes.NewBuffer(data) + for { + var nodeLen int32 + if err := binary.Read( + buf, binary.BigEndian, &nodeLen, + ); err != nil { + if err == io.EOF { + break + } + + return nil, ErrUnmarshalUAST.New(err) + } + + node := uast.NewNode() + if err := node.Unmarshal(buf.Next(int(nodeLen))); err != nil { + return nil, ErrUnmarshalUAST.New(err) + } + + nodes = append(nodes, node) + } + + return nodes, nil +} diff --git a/session.go b/session.go index ef4fd9612..9dd2d9e26 100644 --- a/session.go +++ b/session.go @@ -25,7 +25,8 @@ type Session struct { bblfshEndpoint string bblfshClient *BblfshClient - SkipGitErrors bool + SkipGitErrors bool + OldUASTSerialization bool } // getSession returns the gitbase session from a context or an error if there @@ -65,6 +66,13 @@ func WithSkipGitErrors(enabled bool) SessionOption { } } +// WithOldUASTSerialization set the way UASTs must be serialized. +func WithOldUASTSerialization(enabled bool) SessionOption { + return func(s *Session) { + s.OldUASTSerialization = enabled + } +} + // NewSession creates a new Session. It requires a repository pool and any // number of session options can be passed to configure the session. func NewSession(pool *RepositoryPool, opts ...SessionOption) *Session {