Skip to content

Commit 198bc5d

Browse files
authored
Merge pull request #1355 from dolthub/zachmu/create-as
Allow any select statement for `CREATE TABLE AS SELECT ...`
2 parents 49134f1 + 7a9fa5a commit 198bc5d

File tree

5 files changed

+108
-22
lines changed

5 files changed

+108
-22
lines changed

enginetest/memory_engine_test.go

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,19 +182,23 @@ func TestSingleScript(t *testing.T) {
182182

183183
var scripts = []queries.ScriptTest{
184184
{
185-
Name: "enums with default, case-sensitive collation (utf8mb4_0900_bin)",
185+
Name: "create table as select distinct",
186186
SetUpScript: []string{
187-
"CREATE TABLE enumtest1 (pk int primary key, e enum('abc', 'XYZ'));",
188-
"CREATE TABLE enumtest2 (pk int PRIMARY KEY, e enum('x ', 'X ', 'y', 'Y'));",
187+
"CREATE TABLE t1 (a int, b varchar(10));",
188+
"insert into t1 values (1, 'a'), (2, 'b'), (2, 'b'), (3, 'c');",
189189
},
190190
Assertions: []queries.ScriptTestAssertion{
191191
{
192-
Query: "select data_type, column_type from information_schema.columns where table_name='enumtest1' and column_name='e';",
193-
Expected: []sql.Row{{"enum('abc','XYZ')", "enum('abc','XYZ')"}},
192+
Query: "create table t2 as select distinct b, a from t1;",
193+
Expected: []sql.Row{{sql.OkResult{RowsAffected: 3}}},
194194
},
195195
{
196-
Query: "select data_type, column_type from information_schema.columns where table_name='enumtest2' and column_name='e';",
197-
Expected: []sql.Row{{"enum('x','X','y','Y')", "enum('x','X','y','Y')"}},
196+
Query: "select * from t2 order by a;",
197+
Expected: []sql.Row{
198+
{"a", 1},
199+
{"b", 2},
200+
{"c", 3},
201+
},
198202
},
199203
},
200204
},

enginetest/queries/charset_collation_engine.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ var CharsetCollationEngineTests = []CharsetCollationEngineTest{
313313
{
314314
Query: "SHOW CREATE TABLE test4;",
315315
Expected: []sql.Row{
316-
{"test4", "CREATE TABLE `test4` (\n `pk` bigint NOT NULL,\n `v1` varchar(255) COLLATE utf8mb4_unicode_ci,\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"},
316+
{"test4", "CREATE TABLE `test4` (\n `pk` bigint NOT NULL,\n `v1` varchar(255) COLLATE utf8mb4_unicode_ci\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"},
317317
},
318318
},
319319
{

enginetest/queries/create_table_queries.go

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ var CreateTableQueries = []WriteQueryTest{
8585
WriteQuery: `CREATE TABLE t1 SELECT * from mytable`,
8686
ExpectedWriteResult: []sql.Row{{sql.NewOkResult(3)}},
8787
SelectQuery: "SHOW CREATE TABLE t1",
88-
ExpectedSelect: []sql.Row{sql.Row{"t1", "CREATE TABLE `t1` (\n `i` bigint NOT NULL,\n `s` varchar(20) NOT NULL COMMENT 'column s',\n PRIMARY KEY (`i`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
88+
ExpectedSelect: []sql.Row{sql.Row{"t1", "CREATE TABLE `t1` (\n `i` bigint NOT NULL,\n `s` varchar(20) NOT NULL\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
8989
},
9090
{
9191
WriteQuery: `CREATE TABLE mydb.t1 (a INTEGER NOT NULL PRIMARY KEY, b VARCHAR(10) NOT NULL)`,
@@ -150,4 +150,59 @@ var CreateTableQueries = []WriteQueryTest{
150150
SelectQuery: `SHOW CREATE TABLE t1`,
151151
ExpectedSelect: []sql.Row{{"t1", "CREATE TABLE `t1` (\n `i` int NOT NULL,\n `j` varchar(16383),\n PRIMARY KEY (`i`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
152152
},
153+
{
154+
WriteQuery: `CREATE TABLE t1 as select * from mytable`,
155+
ExpectedWriteResult: []sql.Row{{sql.NewOkResult(3)}},
156+
SelectQuery: `select * from t1 order by i`,
157+
ExpectedSelect: []sql.Row{{1, "first row"}, {2, "second row"}, {3, "third row"}},
158+
},
159+
{
160+
WriteQuery: `CREATE TABLE t1 as select * from mytable`,
161+
ExpectedWriteResult: []sql.Row{{sql.NewOkResult(3)}},
162+
SelectQuery: `show create table t1`,
163+
ExpectedSelect: []sql.Row{{"t1", "CREATE TABLE `t1` (\n `i` bigint NOT NULL,\n `s` varchar(20) NOT NULL\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
164+
},
165+
{
166+
WriteQuery: `CREATE TABLE t1 as select s, i from mytable`,
167+
ExpectedWriteResult: []sql.Row{{sql.NewOkResult(3)}},
168+
SelectQuery: `select * from t1 order by i`,
169+
ExpectedSelect: []sql.Row{{"first row", 1}, {"second row", 2}, {"third row", 3}},
170+
},
171+
{
172+
WriteQuery: `CREATE TABLE t1 as select distinct s, i from mytable`,
173+
ExpectedWriteResult: []sql.Row{{sql.NewOkResult(3)}},
174+
SelectQuery: `select * from t1 order by i`,
175+
ExpectedSelect: []sql.Row{{"first row", 1}, {"second row", 2}, {"third row", 3}},
176+
},
177+
{
178+
WriteQuery: `CREATE TABLE t1 as select s, i from mytable order by s`,
179+
ExpectedWriteResult: []sql.Row{{sql.NewOkResult(3)}},
180+
SelectQuery: `select * from t1 order by i`,
181+
ExpectedSelect: []sql.Row{{"first row", 1}, {"second row", 2}, {"third row", 3}},
182+
},
183+
// TODO: the second column should be named `sum(i)` but is `SUM(mytable.i)`
184+
{
185+
WriteQuery: `CREATE TABLE t1 as select s, sum(i) from mytable group by s`,
186+
ExpectedWriteResult: []sql.Row{{sql.NewOkResult(3)}},
187+
SelectQuery: `select * from t1 order by s`, // other column is named `SUM(mytable.i)`
188+
ExpectedSelect: []sql.Row{{"first row", 1}, {"second row", 2}, {"third row", 3}},
189+
},
190+
{
191+
WriteQuery: `CREATE TABLE t1 as select s, sum(i) from mytable group by s having sum(i) > 2`,
192+
ExpectedWriteResult: []sql.Row{{sql.NewOkResult(1)}},
193+
SelectQuery: "select * from t1",
194+
ExpectedSelect: []sql.Row{{"third row", 3}},
195+
},
196+
{
197+
WriteQuery: `CREATE TABLE t1 as select s, i from mytable order by s limit 1`,
198+
ExpectedWriteResult: []sql.Row{{sql.NewOkResult(1)}},
199+
SelectQuery: `select * from t1 order by i`,
200+
ExpectedSelect: []sql.Row{{"first row", 1}},
201+
},
202+
{
203+
WriteQuery: `CREATE TABLE t1 as select concat("new", s), i from mytable`,
204+
ExpectedWriteResult: []sql.Row{{sql.NewOkResult(3)}},
205+
SelectQuery: `select * from t1 order by i`,
206+
ExpectedSelect: []sql.Row{{"newfirst row", 1}, {"newsecond row", 2}, {"newthird row", 3}},
207+
},
153208
}

sql/analyzer/resolve_create_select.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ func resolveCreateSelect(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope
1919

2020
// Get the correct schema of the CREATE TABLE based on the select query
2121
inputSpec := ct.TableSpec()
22-
selectSchema := analyzedSelect.Schema()
22+
23+
// We don't want to carry any information about keys, constraints, defaults, etc. from a `create table as select`
24+
// statement. When the underlying select node is a table, we must remove all such info from its schema. The only
25+
// exception is NOT NULL constraints, which we leave alone.
26+
selectSchema := stripSchema(analyzedSelect.Schema())
2327
mergedSchema := mergeSchemas(inputSpec.Schema.Schema, selectSchema)
2428
newSch := make(sql.Schema, len(mergedSchema))
2529

@@ -49,6 +53,20 @@ func resolveCreateSelect(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope
4953
return plan.NewTableCopier(ct.Database(), StripPassthroughNodes(analyzedCreate), StripPassthroughNodes(analyzedSelect), plan.CopierProps{}), transform.NewTree, nil
5054
}
5155

56+
// stripSchema removes all non-type information from a schema, such as the key info, default value, etc.
57+
func stripSchema(schema sql.Schema) sql.Schema {
58+
sch := make(sql.Schema, len(schema))
59+
for i := range schema {
60+
sch[i] = schema[i].Copy()
61+
sch[i].Default = nil
62+
sch[i].AutoIncrement = false
63+
sch[i].PrimaryKey = false
64+
sch[i].Source = ""
65+
sch[i].Comment = ""
66+
}
67+
return sch
68+
}
69+
5270
// mergeSchemas takes in the table spec of the CREATE TABLE and merges it with the schema used by the
5371
// select query. The ultimate structure for the new table will be [CREATE TABLE exclusive columns, columns with the same
5472
// name, SELECT exclusive columns]

sql/plan/ddl.go

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ func (c *CreateTable) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error
267267
if !ok {
268268
return sql.RowsToRowIter(), sql.ErrTemporaryTableNotSupported.New()
269269
}
270-
vd = maybePrivDb.(sql.ViewDatabase)
270+
vd, _ = maybePrivDb.(sql.ViewDatabase)
271271

272272
if err := c.validateDefaultPosition(); err != nil {
273273
return sql.RowsToRowIter(), err
@@ -283,7 +283,7 @@ func (c *CreateTable) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error
283283
if !ok {
284284
return sql.RowsToRowIter(), sql.ErrCreateTableNotSupported.New(c.db.Name())
285285
}
286-
vd = maybePrivDb.(sql.ViewDatabase)
286+
vd, _ = maybePrivDb.(sql.ViewDatabase)
287287

288288
if err := c.validateDefaultPosition(); err != nil {
289289
return sql.RowsToRowIter(), err
@@ -295,12 +295,15 @@ func (c *CreateTable) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error
295295
return sql.RowsToRowIter(), err
296296
}
297297

298-
_, ok, err := vd.GetView(ctx, c.name)
299-
if err != nil {
300-
return nil, err
301-
}
302-
if ok {
303-
return nil, sql.ErrTableAlreadyExists.New(c.name)
298+
if vd != nil {
299+
_, ok, err := vd.GetView(ctx, c.name)
300+
if err != nil {
301+
return nil, err
302+
}
303+
304+
if ok {
305+
return nil, sql.ErrTableAlreadyExists.New(c.name)
306+
}
304307
}
305308

306309
//TODO: in the event that foreign keys or indexes aren't supported, you'll be left with a created table and no foreign keys/indexes
@@ -470,11 +473,10 @@ func (c CreateTable) WithChildren(children ...sql.Node) (sql.Node, error) {
470473
} else if len(children) == 1 {
471474
child := children[0]
472475

473-
switch child.(type) {
474-
case *Project, *Limit:
475-
c.selectNode = child
476-
default:
476+
if c.like != nil {
477477
c.like = child
478+
} else {
479+
c.selectNode = child
478480
}
479481

480482
return &c, nil
@@ -508,6 +510,13 @@ func (c *CreateTable) DebugString() string {
508510
ifNotExists = "if not exists "
509511
}
510512
p := sql.NewTreePrinter()
513+
514+
if c.selectNode != nil {
515+
p.WriteNode("Create table %s%s as", ifNotExists, c.name)
516+
p.WriteChildren(sql.DebugString(c.selectNode))
517+
return p.String()
518+
}
519+
511520
p.WriteNode("Create table %s%s", ifNotExists, c.name)
512521

513522
var children []string

0 commit comments

Comments
 (0)