@@ -3,9 +3,10 @@ package io.github.smyrgeorge.sqlx4k.processor
33import net.sf.jsqlparser.parser.CCJSqlParserUtil
44import net.sf.jsqlparser.statement.alter.Alter
55import net.sf.jsqlparser.statement.create.table.CreateTable
6+ import net.sf.jsqlparser.statement.drop.Drop
67import org.apache.calcite.adapter.java.JavaTypeFactory
78import org.apache.calcite.config.CalciteConnectionConfigImpl
8- import org.apache.calcite.config.CalciteConnectionProperty
9+ import org.apache.calcite.config.Lex
910import org.apache.calcite.jdbc.CalciteSchema
1011import org.apache.calcite.jdbc.JavaTypeFactoryImpl
1112import org.apache.calcite.prepare.CalciteCatalogReader
@@ -14,27 +15,54 @@ import org.apache.calcite.rel.type.RelDataTypeFactory
1415import org.apache.calcite.schema.impl.AbstractTable
1516import org.apache.calcite.sql.`fun`.SqlStdOperatorTable
1617import org.apache.calcite.sql.parser.SqlParser
17- import org.apache.calcite.sql.validate.SqlConformanceEnum
18+ import org.apache.calcite.sql.type.SqlTypeName
1819import org.apache.calcite.sql.validate.SqlValidator
1920import org.apache.calcite.sql.validate.SqlValidatorCatalogReader
2021import org.apache.calcite.sql.validate.SqlValidatorUtil
2122import org.apache.calcite.sql.validate.SqlValidatorWithHints
2223import java.io.File
23- import java.math.BigDecimal
2424import java.util.*
2525
2626object QueryValidator {
27- private lateinit var schema: Map <String , TableDef >
2827 private lateinit var tables: List <TableDef >
2928 private lateinit var validator: SqlValidator
3029
30+ fun validateQuerySyntax (fn : String , sql : String ) {
31+ try {
32+ CCJSqlParserUtil .parse(sql)
33+ } catch (e: Exception ) {
34+ val cause = e.message?.removePrefix(" net.sf.jsqlparser.parser.ParseException: " )
35+ error(" Invalid SQL in function $fn : $cause " )
36+ }
37+ }
38+
39+ fun validateQuerySchema (sql : String ) {
40+ // Use a lex that preserves/lowercases identifiers (avoid automatic UPPER-casing)
41+ val config = SqlParser .Config .DEFAULT .withLex(Lex .JAVA )
42+ validator.validate(SqlParser .create(sql, config).parseStmt())
43+ }
44+
3145 fun load (path : String ) {
46+ fun parseFileName (name : String ): Long {
47+ val fileNamePattern = Regex (""" ^\s*(\d+)_([A-Za-z0-9._-]+)\.sql\s*$""" )
48+ val name = name.trim()
49+ val match = fileNamePattern.matchEntire(name)
50+ ? : error(" Migration filename must be <version>_<name>.sql, got $name " )
51+ val (versionStr, _) = match.destructured
52+ val version = versionStr.toLongOrNull()
53+ ? : error(" Invalid version prefix in migration filename: $name " )
54+ return version
55+ }
56+
3257 val dir = File (path)
3358 if (! dir.exists()) error(" Schema directory does not exist: $path " )
3459 if (! dir.isDirectory) error(" Schema directory is not a directory: $path " )
3560
3661 val files = dir.listFiles()
3762 ?.filter { it.isFile && it.extension == " sql" }
63+ ?.map { parseFileName(it.name) to it }
64+ ?.sortedBy { it.first }
65+ ?.map { it.second }
3866 ? : error(" Cound not list schema files in directory: $path " )
3967
4068 val schema = mutableMapOf<String , TableDef >()
@@ -46,11 +74,11 @@ object QueryValidator {
4674 when (stmt) {
4775 is CreateTable -> {
4876 val cols = stmt.columnDefinitions.map { ColumnDef (it.columnName, it.colDataType.dataType) }
49- schema[stmt.table.name.lowercase() ] = TableDef (stmt.table.name, cols.toMutableList())
77+ schema[stmt.table.name] = TableDef (stmt.table.name, cols.toMutableList())
5078 }
5179
5280 is Alter -> {
53- val tableName = stmt.table.name.lowercase()
81+ val tableName = stmt.table.name
5482 val table = schema[tableName] ? : error(" ALTER TABLE on unknown table $tableName " )
5583
5684 stmt.alterExpressions.forEach { expr ->
@@ -76,67 +104,102 @@ object QueryValidator {
76104 }
77105 }
78106
107+ is Drop -> {
108+ val tableName = stmt.name.name
109+ val res = schema.remove(tableName)
110+ if (res == null ) println (" ⚠️ DROP TABLE on unknown table $tableName " )
111+ }
112+
79113 else -> {
80114 println (" ⚠️ Skipping unsupported statement: ${stmt.javaClass.simpleName} " )
81115 }
82116 }
83117 }
84118 }
85119
86- this .schema = schema
87120 this .tables = schema.values.toList()
88121 this .validator = createCalciteValidator()
89122 }
90123
91124 private fun createCalciteValidator (): SqlValidatorWithHints {
92125 val rootSchema = CalciteSchema .createRootSchema(true ).apply {
93- val root = this
94126 tables.forEach { t ->
95- root .add(t.name, MigrationTable (t.columns))
127+ this .add(t.name, MigrationTable (t.columns))
96128 }
97129 }
98130
99131 val typeFactory: JavaTypeFactory = JavaTypeFactoryImpl ()
100132
101- val props = Properties ().apply {
102- this [CalciteConnectionProperty .CASE_SENSITIVE .camelName()] = " false"
103- }
104-
105- val config = CalciteConnectionConfigImpl (props)
106-
107133 val catalogReader: SqlValidatorCatalogReader = CalciteCatalogReader (
108134 /* rootSchema = */ rootSchema,
109135 /* defaultSchema = */ listOf (), // search path (empty = root)
110136 /* typeFactory = */ typeFactory,
111- /* config = */ config
137+ /* config = */ CalciteConnectionConfigImpl ( Properties ())
112138 )
113139
114140 val validator: SqlValidatorWithHints = SqlValidatorUtil .newValidator(
115141 /* opTab = */ SqlStdOperatorTable .instance(),
116142 /* catalogReader = */ catalogReader,
117143 /* typeFactory = */ typeFactory,
118144 /* config = */ SqlValidator .Config .DEFAULT
119- .withConformance(SqlConformanceEnum .STRICT_2003 )
120- .withTypeCoercionEnabled(false ) // disable implicit casts
121145 )
122146 return validator
123147 }
124148
125- fun validateQuery (sql : String ) {
126- validator.validate(SqlParser .create(sql).parseStmt())
127- }
128-
129149 data class ColumnDef (val name : String , val type : String )
130150 data class TableDef (val name : String , val columns : MutableList <ColumnDef >)
131- class MigrationTable (private val columns : List <ColumnDef >) : AbstractTable() {
151+ data class MigrationTable (private val columns : List <ColumnDef >) : AbstractTable() {
132152 override fun getRowType (typeFactory : RelDataTypeFactory ): RelDataType {
133153 val builder = typeFactory.builder()
134154 for (col in columns) {
135- val type = when (col.type.uppercase()) {
136- " INT" , " INTEGER" -> typeFactory.createJavaType(Int ::class .java)
137- " VARCHAR" , " TEXT" -> typeFactory.createJavaType(String ::class .java)
138- " DECIMAL" -> typeFactory.createJavaType(BigDecimal ::class .java)
139- else -> typeFactory.createJavaType(String ::class .java) // fallback
155+ val type = when (val t = col.type.trim().uppercase()) {
156+ // Integer family
157+ " INT" , " INTEGER" , " INT4" -> typeFactory.createSqlType(SqlTypeName .INTEGER )
158+ " SMALLINT" , " INT2" -> typeFactory.createSqlType(SqlTypeName .SMALLINT )
159+ " BIGINT" , " INT8" -> typeFactory.createSqlType(SqlTypeName .BIGINT )
160+ " TINYINT" -> typeFactory.createSqlType(SqlTypeName .TINYINT )
161+ " MEDIUMINT" -> typeFactory.createSqlType(SqlTypeName .INTEGER )
162+ " SERIAL" -> typeFactory.createSqlType(SqlTypeName .INTEGER )
163+ " BIGSERIAL" -> typeFactory.createSqlType(SqlTypeName .BIGINT )
164+
165+ // Boolean
166+ " BOOLEAN" , " BOOL" -> typeFactory.createSqlType(SqlTypeName .BOOLEAN )
167+
168+ // Text/char family
169+ " CHAR" , " CHARACTER" , " NCHAR" -> typeFactory.createSqlType(SqlTypeName .CHAR , 1_000 )
170+ " VARCHAR" , " CHARACTER VARYING" , " NVARCHAR" , " TEXT" , " TINYTEXT" , " MEDIUMTEXT" , " LONGTEXT" , " CLOB" ->
171+ typeFactory.createSqlType(SqlTypeName .VARCHAR , Integer .MAX_VALUE )
172+
173+ // UUID
174+ " UUID" -> typeFactory.createSqlType(SqlTypeName .CHAR , 36 )
175+
176+ // Decimal / numeric
177+ " DECIMAL" , " NUMERIC" -> typeFactory.createSqlType(SqlTypeName .DECIMAL , 38 , 19 )
178+
179+ // Floating point
180+ " REAL" , " FLOAT4" -> typeFactory.createSqlType(SqlTypeName .REAL )
181+ " FLOAT" -> typeFactory.createSqlType(SqlTypeName .FLOAT )
182+ " DOUBLE" , " DOUBLE PRECISION" , " FLOAT8" -> typeFactory.createSqlType(SqlTypeName .DOUBLE )
183+
184+ // Temporal
185+ " DATE" -> typeFactory.createSqlType(SqlTypeName .DATE )
186+ " TIME" -> typeFactory.createSqlType(SqlTypeName .TIME )
187+ " TIMESTAMP" -> typeFactory.createSqlType(SqlTypeName .TIMESTAMP )
188+ " TIMESTAMPTZ" , " TIMESTAMP WITH TIME ZONE" -> typeFactory.createSqlType(SqlTypeName .TIMESTAMP_WITH_LOCAL_TIME_ZONE )
189+
190+ // Binary / blob
191+ " BYTEA" , " BLOB" , " LONGBLOB" , " MEDIUMBLOB" , " TINYBLOB" , " BINARY" , " VARBINARY" ->
192+ typeFactory.createSqlType(SqlTypeName .VARBINARY , Integer .MAX_VALUE )
193+
194+ // JSON and similar complex types mapped to text
195+ " JSON" , " JSONB" , " ENUM" -> typeFactory.createSqlType(SqlTypeName .VARCHAR , Integer .MAX_VALUE )
196+
197+ // Money and another numeric-ish
198+ " MONEY" -> typeFactory.createSqlType(SqlTypeName .DECIMAL , 19 , 4 )
199+ else -> {
200+ println (" ⚠️ Unsupported column type (fallback to string): $t " )
201+ typeFactory.createSqlType(SqlTypeName .VARCHAR , 1_000 )
202+ }
140203 }
141204 builder.add(col.name, type)
142205 }
0 commit comments