Skip to content

Commit e6deab2

Browse files
committed
Minor improvements for /query endpoint
Signed-off-by: Carlos Martín <[email protected]>
1 parent a68e070 commit e6deab2

File tree

2 files changed

+96
-37
lines changed

2 files changed

+96
-37
lines changed

server/handler/query.go

Lines changed: 84 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"io/ioutil"
88
"net/http"
9+
"strings"
910

1011
"github.com/src-d/gitbase-playground/server/serializer"
1112
"gopkg.in/bblfsh/sdk.v1/uast"
@@ -20,20 +21,22 @@ type queryRequest struct {
2021

2122
// genericVals returns a slice of interface{}, each one a pointer to the proper
2223
// type for each column
23-
func genericVals(colTypes []*sql.ColumnType) []interface{} {
24+
func genericVals(colTypes []string) []interface{} {
2425
columnValsPtr := make([]interface{}, len(colTypes))
2526

2627
for i, colType := range colTypes {
27-
switch colType.DatabaseTypeName() {
28+
switch colType {
2829
case "BIT":
2930
columnValsPtr[i] = new(sql.NullBool)
30-
case "TIMESTAMP":
31+
case "TIMESTAMP", "DATE", "DATETIME":
3132
columnValsPtr[i] = new(mysql.NullTime)
32-
case "INT":
33+
case "INT", "MEDIUMINT", "BIGINT", "SMALLINT", "TINYINT":
3334
columnValsPtr[i] = new(sql.NullInt64)
35+
case "DOUBLE", "FLOAT":
36+
columnValsPtr[i] = new(sql.NullFloat64)
3437
case "JSON":
3538
columnValsPtr[i] = new([]byte)
36-
default: // "TEXT" and any others
39+
default: // All the text and binary variations
3740
columnValsPtr[i] = new(sql.NullString)
3841
}
3942
}
@@ -47,7 +50,6 @@ func Query(db *sql.DB) RequestProcessFunc {
4750
return func(r *http.Request) (*serializer.Response, error) {
4851
var queryRequest queryRequest
4952
body, err := ioutil.ReadAll(r.Body)
50-
defer r.Body.Close()
5153
if err == nil {
5254
err = json.Unmarshal(body, &queryRequest)
5355
}
@@ -56,26 +58,26 @@ func Query(db *sql.DB) RequestProcessFunc {
5658
return nil, err
5759
}
5860

59-
// TODO (carlosms) this only works if the query does not end in ;
60-
// and does not have a limit. It will also fail for queries like
61-
// DESCRIBE TABLE
62-
query := fmt.Sprintf("%s LIMIT %d", queryRequest.Query, queryRequest.Limit)
61+
query := addLimit(queryRequest.Query, queryRequest.Limit)
6362
rows, err := db.Query(query)
6463
if err != nil {
64+
if mysqlErr, ok := err.(*mysql.MySQLError); ok {
65+
return nil, serializer.NewMySQLError(
66+
http.StatusBadRequest,
67+
mysqlErr.Number,
68+
mysqlErr.Message)
69+
}
70+
6571
return nil, serializer.NewHTTPError(http.StatusBadRequest, err.Error())
6672
}
6773
defer rows.Close()
6874

69-
columnNames, err := rows.Columns()
75+
columnNames, columnTypes, err := columnsInfo(rows)
7076
if err != nil {
7177
return nil, err
7278
}
7379

74-
colTypes, err := rows.ColumnTypes()
75-
if err != nil {
76-
return nil, err
77-
}
78-
columnValsPtr := genericVals(colTypes)
80+
columnValsPtr := genericVals(columnTypes)
7981

8082
tableData := make([]map[string]interface{}, 0)
8183

@@ -84,7 +86,7 @@ func Query(db *sql.DB) RequestProcessFunc {
8486
return nil, err
8587
}
8688

87-
colData := make(map[string]interface{})
89+
colData := make(map[string]interface{}, len(columnTypes))
8890

8991
for i, val := range columnValsPtr {
9092
colData[columnNames[i]] = nil
@@ -111,23 +113,21 @@ func Query(db *sql.DB) RequestProcessFunc {
111113
colData[columnNames[i]] = sqlVal.String
112114
}
113115
case *[]byte:
114-
// TODO (carlosms) this may not be an array always
115-
var protobufs [][]byte
116-
if err := json.Unmarshal(*val.(*[]byte), &protobufs); err != nil {
117-
return nil, err
118-
}
119-
120-
nodes := make([]*uast.Node, len(protobufs))
121-
122-
for i, v := range protobufs {
123-
node := uast.NewNode()
124-
if err = node.Unmarshal(v); err != nil {
116+
// TODO (carlosms) this may not always be a JSON array, it could
117+
// be a JSON object
118+
119+
// DatabaseTypeName JSON is used for arrays, but we don't know the
120+
// type of each element. We try with uast node first and text later
121+
nodes, err := unmarshallUAST(val)
122+
if err == nil {
123+
colData[columnNames[i]] = nodes
124+
} else {
125+
var strings []string
126+
if err := json.Unmarshal(*val.(*[]byte), &strings); err != nil {
125127
return nil, err
126128
}
127-
nodes[i] = node
129+
colData[columnNames[i]] = strings
128130
}
129-
130-
colData[columnNames[i]] = nodes
131131
}
132132
}
133133

@@ -138,6 +138,58 @@ func Query(db *sql.DB) RequestProcessFunc {
138138
return nil, err
139139
}
140140

141-
return serializer.NewQueryResponse(tableData, columnNames), nil
141+
return serializer.NewQueryResponse(tableData, columnNames, columnTypes), nil
142+
}
143+
}
144+
145+
// columnsInfo returns the column names and column types, or error
146+
func columnsInfo(rows *sql.Rows) ([]string, []string, error) {
147+
names, err := rows.Columns()
148+
if err != nil {
149+
return nil, nil, err
150+
}
151+
152+
types, err := rows.ColumnTypes()
153+
if err != nil {
154+
return nil, nil, err
155+
}
156+
157+
typesStr := make([]string, len(types))
158+
for i, colType := range types {
159+
typesStr[i] = colType.DatabaseTypeName()
160+
}
161+
162+
return names, typesStr, nil
163+
}
164+
165+
// unmarshallUAST tries to cast data as [][]byte and unmarshall uast nodes
166+
func unmarshallUAST(data interface{}) ([]*uast.Node, error) {
167+
var protobufs [][]byte
168+
if err := json.Unmarshal(*data.(*[]byte), &protobufs); err != nil {
169+
return nil, err
170+
}
171+
172+
nodes := make([]*uast.Node, len(protobufs))
173+
174+
for i, v := range protobufs {
175+
node := uast.NewNode()
176+
if err := node.Unmarshal(v); err != nil {
177+
return nil, err
178+
}
179+
nodes[i] = node
180+
}
181+
182+
return nodes, nil
183+
}
184+
185+
// addLimit adds LIMIT to the query, performing basic tests to skip it
186+
// for DESCRIBE TABLE, SHOW TABLES, and avoid '; limit'
187+
func addLimit(query string, limit int) string {
188+
upper := strings.ToUpper(query)
189+
if strings.Contains(upper, "DESCRIBE") || strings.Contains(upper, "SHOW") {
190+
return query
142191
}
192+
193+
trimmed := strings.TrimRight(strings.TrimSpace(query), ";")
194+
return fmt.Sprintf("%s LIMIT %d", trimmed, limit)
143195
}

server/serializer/serializers.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ type Response struct {
2020
}
2121

2222
type httpError struct {
23-
Status int `json:"status"`
24-
Title string `json:"title"`
25-
Details string `json:"details,omitempty"`
23+
Status int `json:"status"`
24+
Title string `json:"title"`
25+
Details string `json:"details,omitempty"`
26+
MySQLCode uint16 `json:"mysqlCode,omitempty"`
2627
}
2728

2829
// StatusCode returns the Status of the httpError
@@ -48,6 +49,11 @@ func NewHTTPError(statusCode int, msg ...string) HTTPError {
4849
return httpError{Status: statusCode, Title: strings.Join(msg, " ")}
4950
}
5051

52+
// NewHTTPError returns an Error with the MySQL error code
53+
func NewMySQLError(statusCode int, mysqlCode uint16, msg ...string) HTTPError {
54+
return httpError{Status: statusCode, MySQLCode: mysqlCode, Title: strings.Join(msg, " ")}
55+
}
56+
5157
func newResponse(data interface{}, meta interface{}) *Response {
5258
if data == nil {
5359
return &Response{
@@ -78,9 +84,10 @@ func NewVersionResponse(version string) *Response {
7884

7985
type queryMetaResponse struct {
8086
Headers []string `json:"headers"`
87+
Types []string `json:"types"`
8188
}
8289

8390
// NewQueryResponse returns a Response with table headers and row contents
84-
func NewQueryResponse(rows []map[string]interface{}, columnNames []string) *Response {
85-
return newResponse(rows, queryMetaResponse{columnNames})
91+
func NewQueryResponse(rows []map[string]interface{}, columnNames, columnTypes []string) *Response {
92+
return newResponse(rows, queryMetaResponse{columnNames, columnTypes})
8693
}

0 commit comments

Comments
 (0)