Skip to content

Commit dc16b2a

Browse files
committed
Add support for typed null parameters in prepared statements across all drivers
- Introduced `TypedNull` wrapper to specify type information for null parameters. - Added `bindNull` methods to `Statement` for positional and named parameters. - Updated SQL rendering logic for MySQL, PostgreSQL, and SQLite to handle `TypedNull`. - Extended test suite to cover typed null bindings with various SQL column types.
1 parent 2d4ce32 commit dc16b2a

File tree

9 files changed

+177
-16
lines changed

9 files changed

+177
-16
lines changed

sqlx4k-mysql/src/jvmMain/kotlin/io/github/smyrgeorge/sqlx4k/mysql/MySQL.kt

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import io.github.smyrgeorge.sqlx4k.ValueEncoderRegistry
2525
import io.github.smyrgeorge.sqlx4k.impl.migrate.Migration
2626
import io.github.smyrgeorge.sqlx4k.impl.migrate.MigrationFile
2727
import io.github.smyrgeorge.sqlx4k.impl.migrate.Migrator
28+
import io.github.smyrgeorge.sqlx4k.impl.types.TypedNull
2829
import io.netty.buffer.ByteBuf
2930
import io.netty.buffer.ByteBufAllocator
3031
import io.r2dbc.pool.ConnectionPool as NativeR2dbcConnectionPool
@@ -449,8 +450,7 @@ class MySQL(
449450
statement: Statement,
450451
encoders: ValueEncoderRegistry
451452
): io.r2dbc.spi.Statement {
452-
fun Any?.toR2dbc(): Any? = when (this) {
453-
null -> this
453+
fun Any.toR2dbc(): Any = when (this) {
454454
is Char -> toString()
455455
// MySQL DATETIME is timezone-agnostic; convert to UTC LocalDateTime to
456456
// avoid session-timezone shifts when the driver binds java.time.Instant.
@@ -469,9 +469,11 @@ class MySQL(
469469
val query = statement.renderNativeQuery(Dialect.MySQL, encoders)
470470
val stmt = createStatement(query.sql)
471471
query.values.forEachIndexed { index, value ->
472-
val converted = value?.toR2dbc()
473-
if (converted == null) stmt.bindNull(index, Any::class.java)
474-
else stmt.bind(index, converted)
472+
when (value) {
473+
is TypedNull -> stmt.bindNull(index, value.type.java)
474+
null -> stmt.bindNull(index, Any::class.java)
475+
else -> stmt.bind(index, value.toR2dbc())
476+
}
475477
}
476478
return stmt
477479
}

sqlx4k-postgres/src/commonTest/kotlin/io/github/smyrgeorge/sqlx4k/postgres/CommonPostgreSQLPreparedStatementTests.kt

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,19 @@ import assertk.assertions.isNull
1010
import assertk.assertions.isTrue
1111
import io.github.smyrgeorge.sqlx4k.Statement
1212
import io.github.smyrgeorge.sqlx4k.impl.extensions.asBoolean
13+
import io.github.smyrgeorge.sqlx4k.impl.extensions.asBooleanOrNull
1314
import io.github.smyrgeorge.sqlx4k.impl.extensions.asDouble
1415
import io.github.smyrgeorge.sqlx4k.impl.extensions.asFloat
1516
import io.github.smyrgeorge.sqlx4k.impl.extensions.asInstant
17+
import io.github.smyrgeorge.sqlx4k.impl.extensions.asInstantOrNull
1618
import io.github.smyrgeorge.sqlx4k.impl.extensions.asInt
1719
import io.github.smyrgeorge.sqlx4k.impl.extensions.asLocalDate
1820
import io.github.smyrgeorge.sqlx4k.impl.extensions.asLocalDateTime
1921
import io.github.smyrgeorge.sqlx4k.impl.extensions.asLocalTime
2022
import io.github.smyrgeorge.sqlx4k.impl.extensions.asLong
2123
import io.github.smyrgeorge.sqlx4k.impl.extensions.asShort
2224
import io.github.smyrgeorge.sqlx4k.impl.extensions.asUuid
25+
import io.github.smyrgeorge.sqlx4k.impl.extensions.asUuidOrNull
2326
import io.github.smyrgeorge.sqlx4k.impl.statement.ExtendedStatement
2427
import io.github.smyrgeorge.sqlx4k.impl.types.NoWrappingTuple
2528
import kotlin.random.Random
@@ -141,34 +144,82 @@ class CommonPostgreSQLPreparedStatementTests(
141144
val table = newTable()
142145
try {
143146
db.execute(
144-
"create table $table(id serial primary key, v_text text, v_int4 int4)"
147+
"""
148+
create table $table(
149+
id serial primary key,
150+
v_text text,
151+
v_int4 int4,
152+
v_bool bool,
153+
v_uuid uuid,
154+
v_tstz timestamptz
155+
)
156+
""".trimIndent()
145157
).getOrThrow()
146158

147-
// Positional null
159+
// Positional null — text column
148160
val insertPos = Statement.create("insert into $table(v_text, v_int4) values (?, ?)")
149161
.bind(0, null)
150162
.bind(1, 1)
151163
db.execute(insertPos).getOrThrow()
152164

153-
// Named null
165+
// Named null — int4 column
154166
val insertNamed = Statement.create("insert into $table(v_text, v_int4) values (:text, :int4)")
155167
.bind("text", "present")
156168
.bind("int4", null)
157169
db.execute(insertNamed).getOrThrow()
158170

159-
// Verify positional null
171+
// Positional null — uuid column (most likely to break with Any::class.java)
172+
val insertUuid = Statement.create("insert into $table(v_int4, v_uuid) values (?, ?)")
173+
.bind(0, 3)
174+
.bind(1, null)
175+
db.execute(insertUuid).getOrThrow()
176+
177+
// Named null — bool column
178+
val insertBool = Statement.create("insert into $table(v_int4, v_bool) values (:id, :flag)")
179+
.bind("id", 4)
180+
.bind("flag", null)
181+
db.execute(insertBool).getOrThrow()
182+
183+
// Positional null — timestamptz column
184+
val insertTstz = Statement.create("insert into $table(v_int4, v_tstz) values (?, ?)")
185+
.bind(0, 5)
186+
.bind(1, null)
187+
db.execute(insertTstz).getOrThrow()
188+
189+
// Verify positional null (text)
160190
val row1 = db.fetchAll(
161191
Statement.create("select v_text, v_int4 from $table where v_int4 = ?").bind(0, 1)
162192
).getOrThrow().first()
163193
assertThat(row1.get(0).asStringOrNull()).isNull()
164194
assertThat(row1.get(1).asInt()).isEqualTo(1)
165195

166-
// Verify named null
196+
// Verify named null (int4)
167197
val row2 = db.fetchAll(
168198
Statement.create("select v_text, v_int4 from $table where v_text = :t").bind("t", "present")
169199
).getOrThrow().first()
170200
assertThat(row2.get(0).asString()).isEqualTo("present")
171201
assertThat(row2.get(1).asStringOrNull()).isNull()
202+
203+
// Verify positional null (uuid)
204+
val row3 = db.fetchAll(
205+
Statement.create("select v_int4, v_uuid from $table where v_int4 = ?").bind(0, 3)
206+
).getOrThrow().first()
207+
assertThat(row3.get(0).asInt()).isEqualTo(3)
208+
assertThat(row3.get(1).asUuidOrNull()).isNull()
209+
210+
// Verify named null (bool)
211+
val row4 = db.fetchAll(
212+
Statement.create("select v_int4, v_bool from $table where v_int4 = ?").bind(0, 4)
213+
).getOrThrow().first()
214+
assertThat(row4.get(0).asInt()).isEqualTo(4)
215+
assertThat(row4.get(1).asBooleanOrNull()).isNull()
216+
217+
// Verify positional null (timestamptz)
218+
val row5 = db.fetchAll(
219+
Statement.create("select v_int4, v_tstz from $table where v_int4 = ?").bind(0, 5)
220+
).getOrThrow().first()
221+
assertThat(row5.get(0).asInt()).isEqualTo(5)
222+
assertThat(row5.get(1).asInstantOrNull()).isNull()
172223
} finally {
173224
runCatching { db.execute("drop table if exists $table") }
174225
}

sqlx4k-postgres/src/jvmMain/kotlin/io/github/smyrgeorge/sqlx4k/postgres/PostgreSQLImpl.kt

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import io.github.smyrgeorge.sqlx4k.ValueEncoderRegistry
1515
import io.github.smyrgeorge.sqlx4k.impl.migrate.Migration
1616
import io.github.smyrgeorge.sqlx4k.impl.migrate.MigrationFile
1717
import io.github.smyrgeorge.sqlx4k.impl.migrate.Migrator
18+
import io.github.smyrgeorge.sqlx4k.impl.types.TypedNull
1819
import io.r2dbc.pool.ConnectionPool as NativeR2dbcConnectionPool
1920
import io.r2dbc.postgresql.PostgresqlConnectionFactory
2021
import io.r2dbc.postgresql.api.Notification as NativeR2dbcNotification
@@ -502,8 +503,7 @@ class PostgreSQLImpl(
502503
statement: Statement,
503504
encoders: ValueEncoderRegistry
504505
): io.r2dbc.spi.Statement {
505-
fun Any?.toR2dbc(): Any? = when (this) {
506-
null -> this
506+
fun Any.toR2dbc(): Any = when (this) {
507507
is Instant -> toJavaInstant()
508508
is LocalDate -> toJavaLocalDate()
509509
is LocalTime -> toJavaLocalTime()
@@ -515,9 +515,11 @@ class PostgreSQLImpl(
515515
val query = statement.renderNativeQuery(Dialect.PostgreSQL, encoders)
516516
val stmt = createStatement(query.sql)
517517
query.values.forEachIndexed { index, value ->
518-
val converted = value?.toR2dbc()
519-
if (converted == null) stmt.bindNull(index, Any::class.java)
520-
else stmt.bind(index, converted)
518+
when (value) {
519+
is TypedNull -> stmt.bindNull(index, value.type.java)
520+
null -> stmt.bindNull(index, Any::class.java)
521+
else -> stmt.bind(index, value.toR2dbc())
522+
}
521523
}
522524
return stmt
523525
}

sqlx4k-sqlite/src/jvmMain/kotlin/io/github/smyrgeorge/sqlx4k/sqlite/SQLite.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import io.github.smyrgeorge.sqlx4k.impl.migrate.Migrator
1919
import io.github.smyrgeorge.sqlx4k.impl.pool.ConnectionPoolImpl
2020
import io.github.smyrgeorge.sqlx4k.impl.pool.PooledConnection
2121
import io.github.smyrgeorge.sqlx4k.impl.pool.PooledTransaction
22+
import io.github.smyrgeorge.sqlx4k.impl.types.TypedNull
2223
import java.sql.Connection as NativeJdbcConnection
2324
import java.sql.DriverManager
2425
import java.sql.ResultSet as NativeJdbcResultSet
@@ -428,7 +429,11 @@ class SQLite(
428429
val query = statement.renderNativeQuery(Dialect.SQLite, encoders)
429430
val stmt = prepareStatement(query.sql)
430431
query.values.forEachIndexed { index, value ->
431-
stmt.setObject(index + 1, value?.toJdbc())
432+
when (value) {
433+
is TypedNull -> stmt.setNull(index + 1, java.sql.Types.NULL)
434+
null -> stmt.setNull(index + 1, java.sql.Types.NULL)
435+
else -> stmt.setObject(index + 1, value.toJdbc())
436+
}
432437
}
433438
return stmt
434439
}

sqlx4k/src/commonMain/kotlin/io/github/smyrgeorge/sqlx4k/Statement.kt

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.github.smyrgeorge.sqlx4k
22

33
import io.github.smyrgeorge.sqlx4k.impl.statement.SimpleStatement
4+
import kotlin.reflect.KClass
45

56
/**
67
* Represents a statement that allows binding of positional and named parameters.
@@ -60,6 +61,20 @@ interface Statement {
6061
*/
6162
fun bind(index: Int, value: Any?): Statement
6263

64+
/**
65+
* Binds a typed null value to a positional parameter in the statement.
66+
*
67+
* This method is used to explicitly bind a `null` value to a positional parameter
68+
* while specifying the intended SQL type. It is particularly useful when interacting
69+
* with strongly-typed SQL columns where the type information is required to correctly
70+
* handle the `null` value (e.g., `uuid`, `timestamptz`, typed arrays).
71+
*
72+
* @param index The zero-based index of the positional parameter to bind the null value to.
73+
* @param type The Kotlin class corresponding to the intended SQL type of the parameter.
74+
* @return The current `Statement` instance to allow for method chaining.
75+
*/
76+
fun bindNull(index: Int, type: KClass<*>): Statement
77+
6378
/**
6479
* Binds a value to a named parameter in the statement.
6580
*
@@ -69,6 +84,21 @@ interface Statement {
6984
*/
7085
fun bind(parameter: String, value: Any?): Statement
7186

87+
/**
88+
* Binds a typed null value to a named parameter in the statement.
89+
*
90+
* This method is used to explicitly bind a `null` value to a named parameter
91+
* while specifying the intended SQL type. It is beneficial in cases where
92+
* the type information is required to properly handle the `null` value,
93+
* such as in the case of strongly-typed SQL columns (e.g., `uuid`, `timestamptz`,
94+
* typed arrays).
95+
*
96+
* @param parameter The name of the parameter to bind the null value to.
97+
* @param type The Kotlin class representing the intended SQL type of the parameter.
98+
* @return The current `Statement` instance to allow for method chaining.
99+
*/
100+
fun bindNull(parameter: String, type: KClass<*>): Statement
101+
72102
/**
73103
* Renders the SQL statement by replacing placeholders for positional and named parameters
74104
* with their respective bound values.

sqlx4k/src/commonMain/kotlin/io/github/smyrgeorge/sqlx4k/impl/extensions/encode.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package io.github.smyrgeorge.sqlx4k.impl.extensions
44

55
import io.github.smyrgeorge.sqlx4k.Dialect
66
import io.github.smyrgeorge.sqlx4k.SQLError
7+
import io.github.smyrgeorge.sqlx4k.impl.types.TypedNull
78
import io.github.smyrgeorge.sqlx4k.ValueEncoderRegistry
89
import io.github.smyrgeorge.sqlx4k.impl.types.NoWrappingTuple
910
import kotlin.time.Instant
@@ -32,6 +33,7 @@ import kotlinx.datetime.toLocalDateTime
3233
internal fun Any?.encodeValue(encoders: ValueEncoderRegistry): String {
3334
return when (this) {
3435
null -> "null"
36+
is TypedNull -> "null"
3537
is String -> {
3638
// Fast path: if no single quote present, avoid replace allocation
3739
if (indexOf('\'') < 0) return "'${this}'"
@@ -124,6 +126,7 @@ internal fun Instant.toTimestampString(timeZone: TimeZone = TimeZone.UTC): Strin
124126
internal fun Any?.resolveNativeValue(encoders: ValueEncoderRegistry): Any? {
125127
return when (this) {
126128
null -> null
129+
is TypedNull -> this
127130
is String, is Char, is Boolean, is Byte, is Short, is Int, is Long, is Float, is Double -> this
128131
is Instant, is LocalDate, is LocalTime, is LocalDateTime -> this
129132
is Uuid -> this

sqlx4k/src/commonMain/kotlin/io/github/smyrgeorge/sqlx4k/impl/statement/AbstractStatement.kt

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ package io.github.smyrgeorge.sqlx4k.impl.statement
33
import io.github.smyrgeorge.sqlx4k.Dialect
44
import io.github.smyrgeorge.sqlx4k.SQLError
55
import io.github.smyrgeorge.sqlx4k.Statement
6+
import io.github.smyrgeorge.sqlx4k.impl.types.TypedNull
67
import io.github.smyrgeorge.sqlx4k.ValueEncoderRegistry
8+
import kotlin.reflect.KClass
79
import io.github.smyrgeorge.sqlx4k.impl.extensions.appendNativeValue
810
import io.github.smyrgeorge.sqlx4k.impl.extensions.encodeValue
911
import io.github.smyrgeorge.sqlx4k.impl.extensions.isIdentPart
@@ -108,6 +110,44 @@ abstract class AbstractStatement(
108110
return this
109111
}
110112

113+
/**
114+
* Binds a typed null to a positional parameter.
115+
*
116+
* @param index The zero-based index of the positional parameter.
117+
* @param type The Kotlin class of the intended SQL type.
118+
* @return The current [Statement] instance to allow for method chaining.
119+
* @throws SQLError if the given index is out of bounds.
120+
*/
121+
override fun bindNull(index: Int, type: KClass<*>): AbstractStatement {
122+
if (index !in 0..<extractedPositionalParameters) {
123+
SQLError(
124+
code = SQLError.Code.PositionalParameterOutOfBounds,
125+
message = "Index '$index' out of bounds."
126+
).raise()
127+
}
128+
positionalParametersValues[index] = TypedNull(type)
129+
return this
130+
}
131+
132+
/**
133+
* Binds a typed null to a named parameter.
134+
*
135+
* @param parameter The name of the parameter.
136+
* @param type The Kotlin class of the intended SQL type.
137+
* @return The current [Statement] instance to allow for method chaining.
138+
* @throws SQLError if the specified named parameter is not found.
139+
*/
140+
override fun bindNull(parameter: String, type: KClass<*>): AbstractStatement {
141+
if (parameter !in extractedNamedParameters) {
142+
SQLError(
143+
code = SQLError.Code.NamedParameterNotFound,
144+
message = "Parameter '$parameter' not found."
145+
).raise()
146+
}
147+
namedParametersValues[parameter] = TypedNull(type)
148+
return this
149+
}
150+
111151
/**
112152
* Renders the SQL statement by replacing positional and named parameter placeholders
113153
* with their corresponding bound values using the provided encoder registry.

sqlx4k/src/commonMain/kotlin/io/github/smyrgeorge/sqlx4k/impl/statement/ExtendedStatement.kt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ package io.github.smyrgeorge.sqlx4k.impl.statement
33
import io.github.smyrgeorge.sqlx4k.Dialect
44
import io.github.smyrgeorge.sqlx4k.SQLError
55
import io.github.smyrgeorge.sqlx4k.Statement
6+
import io.github.smyrgeorge.sqlx4k.impl.types.TypedNull
67
import io.github.smyrgeorge.sqlx4k.ValueEncoderRegistry
78
import io.github.smyrgeorge.sqlx4k.impl.extensions.encodeValue
89
import io.github.smyrgeorge.sqlx4k.impl.extensions.resolveNativeValue
10+
import kotlin.reflect.KClass
911

1012
/**
1113
* The `ExtendedStatement` class provides an implementation that extends the functionality
@@ -57,6 +59,17 @@ class ExtendedStatement(sql: String) : AbstractStatement(sql) {
5759
return this
5860
}
5961

62+
override fun bindNull(index: Int, type: KClass<*>): ExtendedStatement {
63+
if (index < 0 || index >= pgParameters.size) {
64+
SQLError(
65+
code = SQLError.Code.PositionalParameterOutOfBounds,
66+
message = "Index '$index' out of bounds."
67+
).raise()
68+
}
69+
pgParametersValues[index] = TypedNull(type)
70+
return this
71+
}
72+
6073
/**
6174
* Renders a native SQL query by resolving positional parameters and encoding their values
6275
* using the specified encoder registry for the given SQL dialect.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package io.github.smyrgeorge.sqlx4k.impl.types
2+
3+
import kotlin.reflect.KClass
4+
5+
/**
6+
* Wraps a null value with type information for use with [io.github.smyrgeorge.sqlx4k.Statement.bindNull].
7+
*
8+
* When binding a null parameter to a prepared statement, the database driver
9+
* typically needs the target SQL type to determine the correct OID or wire format.
10+
* This wrapper preserves that type alongside the null so it can be forwarded to the
11+
* driver's `bindNull` call.
12+
*
13+
* @property type The Kotlin class representing the intended parameter type.
14+
*/
15+
data class TypedNull(val type: KClass<*>)

0 commit comments

Comments
 (0)