Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 06ebb75

Browse files
committed
sql/analyzer: add rule to fix field's indexes in inner joins
Signed-off-by: Manuel Carmona <[email protected]>
1 parent ad35890 commit 06ebb75

File tree

4 files changed

+80
-7
lines changed

4 files changed

+80
-7
lines changed

engine_test.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,32 @@ func TestQueries(t *testing.T) {
206206
}
207207
}
208208

209+
func TestJoinsFields(t *testing.T) {
210+
const (
211+
query = `SELECT * FROM tabletest, mytable mt INNER JOIN othertable ot ON mt.i = ot.i2`
212+
expectedRows = 9
213+
)
214+
215+
expectedSchema := sql.Schema{
216+
&sql.Column{Name: "text", Type: sql.Text, Default: interface{}(nil), Nullable: false, Source: "tabletest"},
217+
&sql.Column{Name: "number", Type: sql.Int32, Default: interface{}(nil), Nullable: false, Source: "tabletest"},
218+
&sql.Column{Name: "i", Type: sql.Int64, Default: interface{}(nil), Nullable: false, Source: "mytable"},
219+
&sql.Column{Name: "s", Type: sql.Text, Default: interface{}(nil), Nullable: false, Source: "mytable"},
220+
&sql.Column{Name: "s2", Type: sql.Text, Default: interface{}(nil), Nullable: false, Source: "othertable"},
221+
&sql.Column{Name: "i2", Type: sql.Int64, Default: interface{}(nil), Nullable: false, Source: "othertable"},
222+
}
223+
224+
e := newEngine(t)
225+
session := sql.NewEmptyContext()
226+
schema, rowIter, err := e.Query(session, query)
227+
require.NoError(t, err)
228+
require.Exactly(t, expectedSchema, schema)
229+
230+
rows, err := sql.RowIterToRows(rowIter)
231+
require.NoError(t, err)
232+
require.Len(t, rows, expectedRows)
233+
}
234+
209235
func TestOrderByColumns(t *testing.T) {
210236
require := require.New(t)
211237
e := newEngine(t)

sql/analyzer/analyzer_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,9 @@ func TestAddRule(t *testing.T) {
208208
require := require.New(t)
209209

210210
a := New(nil)
211-
require.Len(a.Rules, 12)
211+
require.Len(a.Rules, len(DefaultRules))
212212
a.AddRule("foo", pushdown)
213-
require.Len(a.Rules, 13)
213+
require.Len(a.Rules, len(DefaultRules)+1)
214214
}
215215

216216
func TestAddValidationRule(t *testing.T) {

sql/analyzer/rules.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ var DefaultRules = []Rule{
2424
{"pushdown", pushdown},
2525
{"optimize_distinct", optimizeDistinct},
2626
{"erase_projection", eraseProjection},
27+
{"fix_join_fields", fixJoinFields},
2728
}
2829

2930
var (
@@ -809,3 +810,48 @@ func fixFieldIndexes(schema sql.Schema, exp sql.Expression) (sql.Expression, err
809810
return e, nil
810811
})
811812
}
813+
814+
func fixJoinFields(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) {
815+
span, ctx := ctx.Span("fix join fields")
816+
defer span.Finish()
817+
818+
a.Log("fix join fields, node of type: %T", node)
819+
return node.TransformUp(func(n sql.Node) (sql.Node, error) {
820+
a.Log("transforming node of type: %T", n)
821+
if !n.Resolved() {
822+
return n, nil
823+
}
824+
825+
_, ok := n.(*plan.InnerJoin)
826+
if !ok {
827+
return n, nil
828+
}
829+
830+
return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) {
831+
a.Log("transforming expression of type: %T", e)
832+
if !e.Resolved() {
833+
return e, nil
834+
}
835+
836+
field, ok := e.(*expression.GetField)
837+
if !ok {
838+
return e, nil
839+
}
840+
841+
idx, ok := n.Schema().Contains(field.Name())
842+
if !ok {
843+
return nil, ErrColumnNotFound.New(field.Name())
844+
}
845+
846+
fixedField := expression.NewGetField(
847+
idx,
848+
field.Type(),
849+
field.Name(),
850+
field.IsNullable(),
851+
)
852+
853+
a.Log("fixed field for %T", n)
854+
return fixedField, nil
855+
})
856+
})
857+
}

sql/type.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,15 @@ func (s Schema) CheckRow(row Row) error {
6363
return nil
6464
}
6565

66-
// Contains returns whether the schema contains a column with the given name.
67-
func (s Schema) Contains(column string) bool {
68-
for _, col := range s {
66+
// Contains returns whether the schema contains a column with the given name
67+
// and the index of that column in the schema.
68+
func (s Schema) Contains(column string) (int, bool) {
69+
for i, col := range s {
6970
if col.Name == column {
70-
return true
71+
return i, true
7172
}
7273
}
73-
return false
74+
return 0, false
7475
}
7576

7677
// Equals checks whether the given schema is equal to this one.

0 commit comments

Comments
 (0)