6
6
"fmt"
7
7
"io/ioutil"
8
8
"net/http"
9
+ "strings"
9
10
10
11
"github.com/src-d/gitbase-playground/server/serializer"
11
12
"gopkg.in/bblfsh/sdk.v1/uast"
@@ -20,20 +21,22 @@ type queryRequest struct {
20
21
21
22
// genericVals returns a slice of interface{}, each one a pointer to the proper
22
23
// type for each column
23
- func genericVals (colTypes []* sql. ColumnType ) []interface {} {
24
+ func genericVals (colTypes []string ) []interface {} {
24
25
columnValsPtr := make ([]interface {}, len (colTypes ))
25
26
26
27
for i , colType := range colTypes {
27
- switch colType . DatabaseTypeName () {
28
+ switch colType {
28
29
case "BIT" :
29
30
columnValsPtr [i ] = new (sql.NullBool )
30
- case "TIMESTAMP" :
31
+ case "TIMESTAMP" , "DATE" , "DATETIME" :
31
32
columnValsPtr [i ] = new (mysql.NullTime )
32
- case "INT" :
33
+ case "INT" , "MEDIUMINT" , "BIGINT" , "SMALLINT" , "TINYINT" :
33
34
columnValsPtr [i ] = new (sql.NullInt64 )
35
+ case "DOUBLE" , "FLOAT" :
36
+ columnValsPtr [i ] = new (sql.NullFloat64 )
34
37
case "JSON" :
35
38
columnValsPtr [i ] = new ([]byte )
36
- default : // "TEXT" and any others
39
+ default : // All the text and binary variations
37
40
columnValsPtr [i ] = new (sql.NullString )
38
41
}
39
42
}
@@ -47,7 +50,6 @@ func Query(db *sql.DB) RequestProcessFunc {
47
50
return func (r * http.Request ) (* serializer.Response , error ) {
48
51
var queryRequest queryRequest
49
52
body , err := ioutil .ReadAll (r .Body )
50
- defer r .Body .Close ()
51
53
if err == nil {
52
54
err = json .Unmarshal (body , & queryRequest )
53
55
}
@@ -56,26 +58,26 @@ func Query(db *sql.DB) RequestProcessFunc {
56
58
return nil , err
57
59
}
58
60
59
- // TODO (carlosms) this only works if the query does not end in ;
60
- // and does not have a limit. It will also fail for queries like
61
- // DESCRIBE TABLE
62
- query := fmt .Sprintf ("%s LIMIT %d" , queryRequest .Query , queryRequest .Limit )
61
+ query := addLimit (queryRequest .Query , queryRequest .Limit )
63
62
rows , err := db .Query (query )
64
63
if err != nil {
64
+ if mysqlErr , ok := err .(* mysql.MySQLError ); ok {
65
+ return nil , serializer .NewMySQLError (
66
+ http .StatusBadRequest ,
67
+ mysqlErr .Number ,
68
+ mysqlErr .Message )
69
+ }
70
+
65
71
return nil , serializer .NewHTTPError (http .StatusBadRequest , err .Error ())
66
72
}
67
73
defer rows .Close ()
68
74
69
- columnNames , err := rows . Columns ( )
75
+ columnNames , columnTypes , err := columnsInfo ( rows )
70
76
if err != nil {
71
77
return nil , err
72
78
}
73
79
74
- colTypes , err := rows .ColumnTypes ()
75
- if err != nil {
76
- return nil , err
77
- }
78
- columnValsPtr := genericVals (colTypes )
80
+ columnValsPtr := genericVals (columnTypes )
79
81
80
82
tableData := make ([]map [string ]interface {}, 0 )
81
83
@@ -84,7 +86,7 @@ func Query(db *sql.DB) RequestProcessFunc {
84
86
return nil , err
85
87
}
86
88
87
- colData := make (map [string ]interface {})
89
+ colData := make (map [string ]interface {}, len ( columnTypes ) )
88
90
89
91
for i , val := range columnValsPtr {
90
92
colData [columnNames [i ]] = nil
@@ -111,23 +113,21 @@ func Query(db *sql.DB) RequestProcessFunc {
111
113
colData [columnNames [i ]] = sqlVal .String
112
114
}
113
115
case * []byte :
114
- // TODO (carlosms) this may not be an array always
115
- var protobufs [][] byte
116
- if err := json . Unmarshal ( * val .( * [] byte ), & protobufs ); err != nil {
117
- return nil , err
118
- }
119
-
120
- nodes := make ([] * uast. Node , len ( protobufs ))
121
-
122
- for i , v := range protobufs {
123
- node := uast . NewNode ()
124
- if err = node .Unmarshal (v ); err != nil {
116
+ // TODO (carlosms) this may not always be a JSON array, it could
117
+ // be a JSON object
118
+
119
+ // DatabaseTypeName JSON is used for arrays, but we don't know the
120
+ // type of each element. We try with uast node first and text later
121
+ nodes , err := unmarshallUAST ( val )
122
+ if err == nil {
123
+ colData [ columnNames [ i ]] = nodes
124
+ } else {
125
+ var strings [] string
126
+ if err := json .Unmarshal (* val .( * [] byte ), & strings ); err != nil {
125
127
return nil , err
126
128
}
127
- nodes [ i ] = node
129
+ colData [ columnNames [ i ]] = strings
128
130
}
129
-
130
- colData [columnNames [i ]] = nodes
131
131
}
132
132
}
133
133
@@ -138,6 +138,58 @@ func Query(db *sql.DB) RequestProcessFunc {
138
138
return nil , err
139
139
}
140
140
141
- return serializer .NewQueryResponse (tableData , columnNames ), nil
141
+ return serializer .NewQueryResponse (tableData , columnNames , columnTypes ), nil
142
+ }
143
+ }
144
+
145
+ // columnsInfo returns the column names and column types, or error
146
+ func columnsInfo (rows * sql.Rows ) ([]string , []string , error ) {
147
+ names , err := rows .Columns ()
148
+ if err != nil {
149
+ return nil , nil , err
150
+ }
151
+
152
+ types , err := rows .ColumnTypes ()
153
+ if err != nil {
154
+ return nil , nil , err
155
+ }
156
+
157
+ typesStr := make ([]string , len (types ))
158
+ for i , colType := range types {
159
+ typesStr [i ] = colType .DatabaseTypeName ()
160
+ }
161
+
162
+ return names , typesStr , nil
163
+ }
164
+
165
+ // unmarshallUAST tries to cast data as [][]byte and unmarshall uast nodes
166
+ func unmarshallUAST (data interface {}) ([]* uast.Node , error ) {
167
+ var protobufs [][]byte
168
+ if err := json .Unmarshal (* data .(* []byte ), & protobufs ); err != nil {
169
+ return nil , err
170
+ }
171
+
172
+ nodes := make ([]* uast.Node , len (protobufs ))
173
+
174
+ for i , v := range protobufs {
175
+ node := uast .NewNode ()
176
+ if err := node .Unmarshal (v ); err != nil {
177
+ return nil , err
178
+ }
179
+ nodes [i ] = node
180
+ }
181
+
182
+ return nodes , nil
183
+ }
184
+
185
+ // addLimit adds LIMIT to the query, performing basic tests to skip it
186
+ // for DESCRIBE TABLE, SHOW TABLES, and avoid '; limit'
187
+ func addLimit (query string , limit int ) string {
188
+ upper := strings .ToUpper (query )
189
+ if strings .Contains (upper , "DESCRIBE" ) || strings .Contains (upper , "SHOW" ) {
190
+ return query
142
191
}
192
+
193
+ trimmed := strings .TrimRight (strings .TrimSpace (query ), ";" )
194
+ return fmt .Sprintf ("%s LIMIT %d" , trimmed , limit )
143
195
}
0 commit comments