Skip to content

Commit 8acd9b2

Browse files
authored
Escape table names (#73)
This PR here is somewhat a beginning of a conversation about how do we want to make proper sql identifiers when it comes to schemas, tables and columns, particularly for my use case I want this in order to use sql reserved words for table names. I'll try to split this into 2 logical parts: **How do we know how to escape depending on db variant?** A SQL 1999 rfc says that double quotes should be used in order to escape an identifier, nevertheless: - in mysql it is backticks - in h2 when you escape you suddenly loose the ability to refer to a table ignoring it's definition case: defining a table `T`, `SELECT * from t` will work, `SELECT * from "t"` will not Therefore it felt logical to make an escaping mechanism based on `Dialect`. I didn't want to propagate `DialectConfig` together with `Context` down the call chain, it seemed to me kinda redundant, so I made `DialectConfig` part of `Context`. That's how we can access escaping mechanism in various places as long as you have `Context` **How do we decide which table to escape?** With this PR I want to stick with tables, otherwise it can become too big. I see three options: 1. escape all tables 2. escape based on list of reserved words per database variant 3. let user decide to escape First options we already discarded as one that will change too much tests. As for a second option - I looked into lists of reserved words, and they are very different based on database, also it will not be fun to maintain them, that's why most of the libs resolve to escape all tables. Third option is most simple for now, since it will allow users to opt-in when they want/bump into this issue, also it will allow us to test mechanism first without changing all 1k+ tests, and perhaps we can start with option 3 and then eventually move to option 1. Therefore in this draft PR I made a flag in `Table` to opt in for escaping table name. _Optional_: At the moment `DialectConfig` is kinda part of the `Dialect`, and I don't feel entirely good about including whole `Dialect` in `Context`. I kinda want to move `DialectConfig` into case class and make a separate implicit for it as a part of `Dialect` so that when we need to construct `Context` only `DialectConfig` case class will be picked up It is a draft PR yet, so obviously I haven't written tests for all cases, only for a simple `FROM` expression. This will follow when we come to consensus about the implementation. What do you guys think? fixes #53
1 parent 0ce0c79 commit 8acd9b2

24 files changed

+358
-80
lines changed

docs/reference.md

Lines changed: 144 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6552,11 +6552,13 @@ Buyer.select
65526552
65536553
65546554
## Schema
6555-
Additional tests to ensure schema mapping produces valid SQL
6555+
6556+
If your table belongs to a schema other than the default schema of your database, you can specify this in your table definition with
6557+
`override def schemaName = "otherschema"`
6558+
65566559
### Schema.schema.select
65576560
6558-
If your table belongs to a schema other than the default schema of your database,
6559-
you can specify this in your table definition with table.schemaName
6561+
65606562
65616563
```scala
65626564
Invoice.select
@@ -6584,8 +6586,7 @@ Invoice.select
65846586
65856587
### Schema.schema.insert.columns
65866588
6587-
If your table belongs to a schema other than the default schema of your database,
6588-
you can specify this in your table definition with table.schemaName
6589+
65896590
65906591
```scala
65916592
Invoice.insert.columns(
@@ -6611,8 +6612,7 @@ Invoice.insert.columns(
66116612
66126613
### Schema.schema.insert.values
66136614
6614-
If your table belongs to a schema other than the default schema of your database,
6615-
you can specify this in your table definition with table.schemaName
6615+
66166616
66176617
```scala
66186618
Invoice.insert
@@ -6643,8 +6643,7 @@ Invoice.insert
66436643
66446644
### Schema.schema.update
66456645
6646-
If your table belongs to a schema other than the default schema of your database,
6647-
you can specify this in your table definition with table.schemaName
6646+
66486647
66496648
```scala
66506649
Invoice
@@ -6677,8 +6676,7 @@ Invoice
66776676
66786677
### Schema.schema.delete
66796678
6680-
If your table belongs to a schema other than the default schema of your database,
6681-
you can specify this in your table definition with table.schemaName
6679+
66826680
66836681
```scala
66846682
Invoice.delete(_.id === 1)
@@ -6701,8 +6699,7 @@ Invoice.delete(_.id === 1)
67016699
67026700
### Schema.schema.insert into
67036701
6704-
If your table belongs to a schema other than the default schema of your database,
6705-
you can specify this in your table definition with table.schemaName
6702+
67066703
67076704
```scala
67086705
Invoice.insert.select(
@@ -6734,8 +6731,7 @@ Invoice.insert.select(
67346731
67356732
### Schema.schema.join
67366733
6737-
If your table belongs to a schema other than the default schema of your database,
6738-
you can specify this in your table definition with table.schemaName
6734+
67396735
67406736
```scala
67416737
Invoice.select.join(Invoice)(_.id `=` _.id).map(_._1.id)
@@ -6760,6 +6756,139 @@ Invoice.select.join(Invoice)(_.id `=` _.id).map(_._1.id)
67606756
67616757
67626758
6759+
## EscapedTableName
6760+
6761+
If your table name is a reserved sql world, e.g. `order`, you can specify this in your table definition with
6762+
`override def escape = true`
6763+
6764+
### EscapedTableName.escape table name.select
6765+
6766+
6767+
6768+
```scala
6769+
Select.select
6770+
```
6771+
6772+
6773+
*
6774+
```sql
6775+
SELECT select0.id AS id, select0.name AS name
6776+
FROM "select" select0
6777+
```
6778+
6779+
6780+
6781+
*
6782+
```scala
6783+
Seq.empty[Select[Sc]]
6784+
```
6785+
6786+
6787+
6788+
### EscapedTableName.escape table name.delete
6789+
6790+
6791+
6792+
```scala
6793+
Select.delete(_ => true)
6794+
```
6795+
6796+
6797+
*
6798+
```sql
6799+
DELETE FROM "select" WHERE ?
6800+
```
6801+
6802+
6803+
6804+
*
6805+
```scala
6806+
0
6807+
```
6808+
6809+
6810+
6811+
### EscapedTableName.escape table name.join
6812+
6813+
6814+
6815+
```scala
6816+
Select.select.join(Select)(_.id `=` _.id)
6817+
```
6818+
6819+
6820+
*
6821+
```sql
6822+
SELECT
6823+
select0.id AS res_0_id,
6824+
select0.name AS res_0_name,
6825+
select1.id AS res_1_id,
6826+
select1.name AS res_1_name
6827+
FROM
6828+
"select" select0
6829+
JOIN "select" select1 ON (select0.id = select1.id)
6830+
```
6831+
6832+
6833+
6834+
*
6835+
```scala
6836+
Seq.empty[(Select[Sc], Select[Sc])]
6837+
```
6838+
6839+
6840+
6841+
### EscapedTableName.escape table name.update
6842+
6843+
6844+
6845+
```scala
6846+
Select.update(_ => true).set(_.name := "hello")
6847+
```
6848+
6849+
6850+
*
6851+
```sql
6852+
UPDATE "select" SET name = ?
6853+
```
6854+
6855+
6856+
6857+
*
6858+
```scala
6859+
0
6860+
```
6861+
6862+
6863+
6864+
### EscapedTableName.escape table name.insert
6865+
6866+
6867+
6868+
```scala
6869+
Select.insert.values(
6870+
Select[Sc](
6871+
id = 0,
6872+
name = "hello"
6873+
)
6874+
)
6875+
```
6876+
6877+
6878+
*
6879+
```sql
6880+
INSERT INTO "select" (id, name) VALUES (?, ?)
6881+
```
6882+
6883+
6884+
6885+
*
6886+
```scala
6887+
1
6888+
```
6889+
6890+
6891+
67636892
## SubQuery
67646893
Queries that explicitly use subqueries (e.g. for `JOIN`s) or require subqueries to preserve the Scala semantics of the various operators
67656894
### SubQuery.sortTakeJoin

scalasql/core/src/Context.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ trait Context {
2424
*/
2525
def config: Config
2626

27+
def dialectConfig: DialectConfig
28+
2729
def withFromNaming(fromNaming: Map[Context.From, String]): Context
2830
def withExprNaming(exprNaming: Map[Expr.Identity, SqlStr]): Context
2931
}
@@ -56,7 +58,8 @@ object Context {
5658
case class Impl(
5759
fromNaming: Map[From, String],
5860
exprNaming: Map[Expr.Identity, SqlStr],
59-
config: Config
61+
config: Config,
62+
dialectConfig: DialectConfig
6063
) extends Context {
6164
def withFromNaming(fromNaming: Map[From, String]): Context = copy(fromNaming = fromNaming)
6265

@@ -93,7 +96,7 @@ object Context {
9396
.map { case (e, s) => (e, sql"${SqlStr.raw(newFromNaming(t), Array(e))}.$s") }
9497
}
9598

96-
Context.Impl(newFromNaming, newExprNaming, prevContext.config)
99+
Context.Impl(newFromNaming, newExprNaming, prevContext.config, prevContext.dialectConfig)
97100
}
98101

99102
}

scalasql/core/src/DbApi.scala

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -125,17 +125,22 @@ trait DbApi extends AutoCloseable {
125125

126126
object DbApi {
127127

128-
def unpackQueryable[R, Q](query: Q, qr: Queryable[Q, R], config: Config) = {
129-
val ctx = Context.Impl(Map(), Map(), config)
128+
def unpackQueryable[R, Q](
129+
query: Q,
130+
qr: Queryable[Q, R],
131+
config: Config,
132+
dialectConfig: DialectConfig
133+
) = {
134+
val ctx = Context.Impl(Map(), Map(), config, dialectConfig)
130135
val flattened = SqlStr.flatten(qr.renderSql(query, ctx))
131136
flattened
132137
}
133138

134-
def renderSql[Q, R](query: Q, config: Config, castParams: Boolean = false)(
139+
def renderSql[Q, R](query: Q, config: Config, dialectConfig: DialectConfig)(
135140
implicit qr: Queryable[Q, R]
136141
): String = {
137-
val flattened = unpackQueryable(query, qr, config)
138-
flattened.renderSql(castParams)
142+
val flattened = unpackQueryable(query, qr, config, dialectConfig)
143+
flattened.renderSql(dialectConfig.castParams)
139144
}
140145

141146
/**
@@ -254,7 +259,7 @@ object DbApi {
254259
lineNum: sourcecode.Line
255260
): R = {
256261

257-
val flattened = unpackQueryable(query, qr, config)
262+
val flattened = unpackQueryable(query, qr, config, dialect)
258263
if (qr.isGetGeneratedKeys(query).nonEmpty)
259264
updateGetGeneratedKeysSql(flattened)(qr.isGetGeneratedKeys(query).get, fileName, lineNum)
260265
.asInstanceOf[R]
@@ -284,7 +289,7 @@ object DbApi {
284289
fileName: sourcecode.FileName,
285290
lineNum: sourcecode.Line
286291
): Generator[R] = {
287-
val flattened = unpackQueryable(query, qr, config)
292+
val flattened = unpackQueryable(query, qr, config, dialect)
288293
streamFlattened0(
289294
r => {
290295
qr.asInstanceOf[Queryable[Q, R]].construct(query, r) match {
@@ -335,7 +340,7 @@ object DbApi {
335340
): Int = {
336341
val flattened = SqlStr.flatten(sql)
337342
runRawUpdate0(
338-
flattened.renderSql(DialectConfig.castParams(dialect)),
343+
flattened.renderSql(dialect.castParams),
339344
flattenParamPuts(flattened),
340345
fetchSize,
341346
queryTimeoutSeconds,
@@ -355,7 +360,7 @@ object DbApi {
355360
): IndexedSeq[R] = {
356361
val flattened = SqlStr.flatten(sql)
357362
runRawUpdateGetGeneratedKeys0(
358-
flattened.renderSql(DialectConfig.castParams(dialect)),
363+
flattened.renderSql(dialect.castParams),
359364
flattenParamPuts(flattened),
360365
fetchSize,
361366
queryTimeoutSeconds,
@@ -441,7 +446,7 @@ object DbApi {
441446
lineNum: sourcecode.Line
442447
) = streamRaw0(
443448
construct,
444-
flattened.renderSql(DialectConfig.castParams(dialect)),
449+
flattened.renderSql(dialect.castParams),
445450
flattenParamPuts(flattened),
446451
fetchSize,
447452
queryTimeoutSeconds,
@@ -567,7 +572,7 @@ object DbApi {
567572
def renderSql[Q, R](query: Q, castParams: Boolean = false)(
568573
implicit qr: Queryable[Q, R]
569574
): String = {
570-
DbApi.renderSql(query, config, castParams)
575+
DbApi.renderSql(query, config, dialect.withCastParams(castParams))
571576
}
572577

573578
val savepointStack = collection.mutable.ArrayDeque.empty[java.sql.Savepoint]

scalasql/core/src/DbClient.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ object DbClient {
6868
def renderSql[Q, R](query: Q, castParams: Boolean = false)(
6969
implicit qr: Queryable[Q, R]
7070
): String = {
71-
DbApi.renderSql(query, config, castParams)
71+
DbApi.renderSql(query, config, dialect.withCastParams(castParams))
7272
}
7373

7474
def transaction[T](block: DbApi.Txn => T): T = {
@@ -127,7 +127,7 @@ object DbClient {
127127
def renderSql[Q, R](query: Q, castParams: Boolean = false)(
128128
implicit qr: Queryable[Q, R]
129129
): String = {
130-
DbApi.renderSql(query, config, castParams)
130+
DbApi.renderSql(query, config, dialect.withCastParams(castParams))
131131
}
132132

133133
private def withConnection[T](f: DbClient.Connection => T): T = {

scalasql/core/src/DialectConfig.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
package scalasql.core
22

3-
trait DialectConfig {
4-
protected def dialectCastParams: Boolean
5-
}
3+
trait DialectConfig { that =>
4+
def castParams: Boolean
5+
def escape(str: String): String
6+
7+
def withCastParams(params: Boolean) = new DialectConfig {
8+
def castParams: Boolean = params
9+
10+
def escape(str: String): String = that.escape(str)
611

7-
object DialectConfig {
8-
def castParams(d: DialectConfig) = d.dialectCastParams
12+
}
913
}

scalasql/query/src/Delete.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ object Delete {
2424
class Renderer(table: TableRef, expr: Expr[Boolean], prevContext: Context) {
2525
implicit val implicitCtx: Context = Context.compute(prevContext, Nil, Some(table))
2626
lazy val tableNameStr =
27-
SqlStr.raw(Table.resolve(table.value))
27+
SqlStr.raw(Table.fullIdentifier(table.value))
2828

2929
def render() = sql"DELETE FROM $tableNameStr WHERE $expr"
3030
}

scalasql/query/src/From.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class TableRef(val value: Table.Base) extends From {
1515
def fromExprAliases(prevContext: Context): Seq[(Expr.Identity, SqlStr)] = Nil
1616

1717
def renderSql(name: SqlStr, prevContext: Context, liveExprs: LiveExprs) = {
18-
val resolvedTable = Table.resolve(value)(prevContext)
18+
val resolvedTable = Table.fullIdentifier(value)(prevContext)
1919
SqlStr.raw(resolvedTable + sql" " + name)
2020
}
2121
}

scalasql/query/src/InsertColumns.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ object InsertColumns {
2424
protected def expr: V[Column] = WithSqlExpr.get(insert)
2525

2626
private[scalasql] override def renderSql(ctx: Context) =
27-
new Renderer(columns, ctx, valuesLists, Table.resolve(table.value)(ctx)).render()
27+
new Renderer(columns, ctx, valuesLists, Table.fullIdentifier(table.value)(ctx)).render()
2828

2929
override protected def queryConstruct(args: Queryable.ResultSetIterator): Int =
3030
args.get(IntType)

0 commit comments

Comments
 (0)