Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit d9b20c8

Browse files
authored
Merge pull request #768 from erizocosmico/fix/bool-coercion
sql: use custom boolean coercion logic in condition evaluation
2 parents bc23348 + 568f586 commit d9b20c8

File tree

4 files changed

+36
-18
lines changed

4 files changed

+36
-18
lines changed

engine_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,10 @@ var queries = []struct {
12301230
`SELECT (NULL+1)`,
12311231
[]sql.Row{{nil}},
12321232
},
1233+
{
1234+
`SELECT * FROM mytable WHERE NULL AND i = 3`,
1235+
[]sql.Row{},
1236+
},
12331237
}
12341238

12351239
func TestQueries(t *testing.T) {

sql/core.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package sql // import "github.com/src-d/go-mysql-server/sql"
33
import (
44
"fmt"
55
"io"
6+
"math"
7+
"time"
68

79
"gopkg.in/src-d/go-errors.v1"
810
)
@@ -218,3 +220,27 @@ type Lockable interface {
218220
// available.
219221
Unlock(ctx *Context, id uint32) error
220222
}
223+
224+
// EvaluateCondition evaluates a condition, which is an expression whose value
225+
// will be coerced to boolean.
226+
func EvaluateCondition(ctx *Context, cond Expression, row Row) (bool, error) {
227+
v, err := cond.Eval(ctx, row)
228+
if err != nil {
229+
return false, err
230+
}
231+
232+
switch b := v.(type) {
233+
case bool:
234+
return b, nil
235+
case int, int64, int32, int16, int8, uint, uint64, uint32, uint16, uint8:
236+
return b != 0, nil
237+
case time.Duration:
238+
return int64(b) != 0, nil
239+
case time.Time:
240+
return b.UnixNano() != 0, nil
241+
case float32, float64:
242+
return int(math.Round(v.(float64))) != 0, nil
243+
default:
244+
return false, nil
245+
}
246+
}

sql/plan/filter.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,12 @@ func (i *FilterIter) Next() (sql.Row, error) {
107107
return nil, err
108108
}
109109

110-
result, err := i.cond.Eval(i.ctx, row)
110+
ok, err := sql.EvaluateCondition(i.ctx, i.cond, row)
111111
if err != nil {
112112
return nil, err
113113
}
114114

115-
if result == true {
115+
if ok {
116116
return row, nil
117117
}
118118
}

sql/type.go

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -616,25 +616,13 @@ func (t booleanT) Convert(v interface{}) (interface{}, error) {
616616
case bool:
617617
return b, nil
618618
case int, int64, int32, int16, int8, uint, uint64, uint32, uint16, uint8:
619-
if b != 0 {
620-
return true, nil
621-
}
622-
return false, nil
619+
return b != 0, nil
623620
case time.Duration:
624-
if int64(b) != 0 {
625-
return true, nil
626-
}
627-
return false, nil
621+
return int64(b) != 0, nil
628622
case time.Time:
629-
if b.UnixNano() != 0 {
630-
return true, nil
631-
}
632-
return false, nil
623+
return b.UnixNano() != 0, nil
633624
case float32, float64:
634-
if int(math.Round(v.(float64))) != 0 {
635-
return true, nil
636-
}
637-
return false, nil
625+
return int(math.Round(v.(float64))) != 0, nil
638626
case string:
639627
return false, fmt.Errorf("unable to cast string to bool")
640628

0 commit comments

Comments
 (0)