From e61b89c56c4381abd68629a8ca151c8773c4f344 Mon Sep 17 00:00:00 2001 From: Alex Kasko Date: Fri, 4 Apr 2025 00:42:14 +0100 Subject: [PATCH] Make ResultSet getters case insensitive This change brings in the #43 PR with additional changes to the `Array` result set requested by the reviewer in #43. For the `Array` result set the column names `INDEX` and `VALUE` are taken from the implementation in Postgres' `pgjdbc` driver. Testing: test from #43 is included, test for `Array` result set is enhanced with additional checks. Fixes: #40 Co-authored-by: jghoman --- .../java/org/duckdb/DuckDBArrayResultSet.java | 82 ++++++++++++------- src/main/java/org/duckdb/DuckDBResultSet.java | 2 +- src/test/java/org/duckdb/TestDuckDBJDBC.java | 71 +++++++++++++++- 3 files changed, 119 insertions(+), 36 deletions(-) diff --git a/src/main/java/org/duckdb/DuckDBArrayResultSet.java b/src/main/java/org/duckdb/DuckDBArrayResultSet.java index a77305bcb..2f427a410 100644 --- a/src/main/java/org/duckdb/DuckDBArrayResultSet.java +++ b/src/main/java/org/duckdb/DuckDBArrayResultSet.java @@ -41,17 +41,13 @@ public boolean wasNull() throws SQLException { return wasNull; } - private T getValue(String columnLabel, SqlValueGetter getter) throws SQLException { - return getValue(findColumn(columnLabel), getter); - } - private T getValue(int columnIndex, SqlValueGetter getter) throws SQLException { if (columnIndex == 1) { - throw new IllegalArgumentException( - "The first element of Array-backed ResultSet can only be retrieved with getInt()"); + throw new SQLException( + "The first element of Array-backed ResultSet can only be retrieved with numeric getters"); } if (columnIndex != 2) { - throw new IllegalArgumentException("Array-backed ResultSet can only have two columns"); + throw new SQLException("Array-backed ResultSet can only have two columns"); } T value = getter.getValue(offset + currentValueIndex); @@ -59,6 +55,11 @@ private T getValue(int columnIndex, SqlValueGetter getter) throws SQLExce return value; } + private int getIndexColumnValue() { + wasNull = false; + return currentValueIndex + 1; + } + @Override public String getString(int columnIndex) throws SQLException { return getValue(columnIndex, vector::getLazyString); @@ -71,40 +72,57 @@ public boolean getBoolean(int columnIndex) throws SQLException { @Override public byte getByte(int columnIndex) throws SQLException { + if (columnIndex == 1) { + return (byte) getIndexColumnValue(); + } return getValue(columnIndex, vector::getByte); } @Override public short getShort(int columnIndex) throws SQLException { + if (columnIndex == 1) { + return (short) getIndexColumnValue(); + } return getValue(columnIndex, vector::getShort); } @Override public int getInt(int columnIndex) throws SQLException { if (columnIndex == 1) { - wasNull = false; - return currentValueIndex + 1; + return getIndexColumnValue(); } return getValue(columnIndex, vector::getInt); } @Override public long getLong(int columnIndex) throws SQLException { - return getInt(columnIndex); + if (columnIndex == 1) { + return getIndexColumnValue(); + } + return getValue(columnIndex, vector::getLong); } @Override public float getFloat(int columnIndex) throws SQLException { + if (columnIndex == 1) { + return getIndexColumnValue(); + } return getValue(columnIndex, vector::getFloat); } @Override public double getDouble(int columnIndex) throws SQLException { + if (columnIndex == 1) { + return getIndexColumnValue(); + } return getValue(columnIndex, vector::getDouble); } @Override public BigDecimal getBigDecimal(int columnIndex, int scale) throws SQLException { + if (columnIndex == 1) { + return BigDecimal.valueOf(getIndexColumnValue()); + } return getValue(columnIndex, vector::getBigDecimal); } @@ -145,51 +163,47 @@ public InputStream getBinaryStream(int columnIndex) throws SQLException { @Override public String getString(String columnLabel) throws SQLException { - return getValue(columnLabel, vector::getLazyString); + return getString(findColumn(columnLabel)); } @Override public boolean getBoolean(String columnLabel) throws SQLException { - return getValue(columnLabel, vector::getBoolean); + return getBoolean(findColumn(columnLabel)); } @Override public byte getByte(String columnLabel) throws SQLException { - return getValue(columnLabel, vector::getByte); + return getByte(findColumn(columnLabel)); } @Override public short getShort(String columnLabel) throws SQLException { - return getValue(columnLabel, vector::getShort); + return getShort(findColumn(columnLabel)); } @Override public int getInt(String columnLabel) throws SQLException { - int columnIndex = findColumn(columnLabel); - if (columnIndex == 1) { - return currentValueIndex; - } - return getValue(columnIndex, vector::getInt); + return getInt(findColumn(columnLabel)); } @Override public long getLong(String columnLabel) throws SQLException { - return getInt(columnLabel); + return getLong(findColumn(columnLabel)); } @Override public float getFloat(String columnLabel) throws SQLException { - return getValue(columnLabel, vector::getFloat); + return getFloat(findColumn(columnLabel)); } @Override public double getDouble(String columnLabel) throws SQLException { - return getValue(columnLabel, vector::getDouble); + return getDouble(findColumn(columnLabel)); } @Override public BigDecimal getBigDecimal(String columnLabel, int scale) throws SQLException { - return getValue(columnLabel, vector::getBigDecimal); + return getBigDecimal(findColumn(columnLabel), scale); } @Override @@ -199,17 +213,17 @@ public byte[] getBytes(String columnLabel) throws SQLException { @Override public Date getDate(String columnLabel) throws SQLException { - return getValue(columnLabel, vector::getDate); + return getDate(findColumn(columnLabel)); } @Override public Time getTime(String columnLabel) throws SQLException { - return getValue(columnLabel, vector::getTime); + return getTime(findColumn(columnLabel)); } @Override public Timestamp getTimestamp(String columnLabel) throws SQLException { - return getValue(columnLabel, vector::getTimestamp); + return getTimestamp(findColumn(columnLabel)); } @Override @@ -254,12 +268,18 @@ public Object getObject(int columnIndex) throws SQLException { @Override public Object getObject(String columnLabel) throws SQLException { - return getValue(columnLabel, vector::getTimestamp); + return getObject(findColumn(columnLabel)); } @Override public int findColumn(String columnLabel) throws SQLException { - return Integer.parseInt(columnLabel); + if ("INDEX".equalsIgnoreCase(columnLabel)) { + return 1; + } + if ("VALUE".equalsIgnoreCase(columnLabel)) { + return 2; + } + throw new SQLException("Could not find column with label " + columnLabel); } @Override @@ -279,7 +299,7 @@ public BigDecimal getBigDecimal(int columnIndex) throws SQLException { @Override public BigDecimal getBigDecimal(String columnLabel) throws SQLException { - return getValue(columnLabel, vector::getBigDecimal); + return getBigDecimal(findColumn(columnLabel)); } @Override @@ -671,7 +691,7 @@ public Array getArray(int columnIndex) throws SQLException { @Override public Object getObject(String columnLabel, Map> map) throws SQLException { - return getValue(columnLabel, vector::getObject); + return getObject(findColumn(columnLabel)); } @Override @@ -691,7 +711,7 @@ public Clob getClob(String columnLabel) throws SQLException { @Override public Array getArray(String columnLabel) throws SQLException { - return getValue(columnLabel, vector::getArray); + return getArray(findColumn(columnLabel)); } @Override diff --git a/src/main/java/org/duckdb/DuckDBResultSet.java b/src/main/java/org/duckdb/DuckDBResultSet.java index 5b6e0b622..2e77318a7 100644 --- a/src/main/java/org/duckdb/DuckDBResultSet.java +++ b/src/main/java/org/duckdb/DuckDBResultSet.java @@ -308,7 +308,7 @@ public int findColumn(String columnLabel) throws SQLException { throw new SQLException("ResultSet was closed"); } for (int col_idx = 0; col_idx < meta.column_count; col_idx++) { - if (meta.column_names[col_idx].contentEquals(columnLabel)) { + if (meta.column_names[col_idx].equalsIgnoreCase(columnLabel)) { return col_idx + 1; } } diff --git a/src/test/java/org/duckdb/TestDuckDBJDBC.java b/src/test/java/org/duckdb/TestDuckDBJDBC.java index 7b4152006..ef5a28743 100644 --- a/src/test/java/org/duckdb/TestDuckDBJDBC.java +++ b/src/test/java/org/duckdb/TestDuckDBJDBC.java @@ -3748,7 +3748,17 @@ public static void test_array_resultset() throws Exception { ResultSet arrayResultSet = rs.getArray(1).getResultSet(); assertTrue(arrayResultSet.next()); assertEquals(arrayResultSet.getInt(1), 1); + assertEquals(arrayResultSet.getInt("index"), 1); + assertEquals(arrayResultSet.getInt("Index"), 1); + assertEquals(arrayResultSet.getInt("INDEX"), 1); + assertEquals(arrayResultSet.getByte(2), (byte) 42); + assertEquals(arrayResultSet.getShort(2), (short) 42); assertEquals(arrayResultSet.getInt(2), 42); + assertEquals(arrayResultSet.getLong(2), (long) 42); + assertEquals(arrayResultSet.getFloat(2), (float) 42); + assertEquals(arrayResultSet.getDouble(2), (double) 42); + assertEquals(arrayResultSet.getBigDecimal(2), BigDecimal.valueOf(42)); + assertEquals(arrayResultSet.getInt("value"), 42); assertTrue(arrayResultSet.next()); assertEquals(arrayResultSet.getInt(1), 2); assertEquals(arrayResultSet.getInt(2), 69); @@ -3760,10 +3770,18 @@ public static void test_array_resultset() throws Exception { ResultSet arrayResultSet = rs.getArray(1).getResultSet(); assertTrue(arrayResultSet.next()); assertEquals(arrayResultSet.getInt(1), 1); - Array subArray = arrayResultSet.getArray(2); - assertNotNull(subArray); - ResultSet subArrayResultSet = subArray.getResultSet(); - assertFalse(subArrayResultSet.next()); // empty array + { + Array subArray = arrayResultSet.getArray(2); + assertNotNull(subArray); + ResultSet subArrayResultSet = subArray.getResultSet(); + assertFalse(subArrayResultSet.next()); // empty array + } + { + Array subArray = arrayResultSet.getArray("value"); + assertNotNull(subArray); + ResultSet subArrayResultSet = subArray.getResultSet(); + assertFalse(subArrayResultSet.next()); // empty array + } assertTrue(arrayResultSet.next()); assertEquals(arrayResultSet.getInt(1), 2); @@ -3845,6 +3863,13 @@ public static void test_array_resultset() throws Exception { assertEquals(arrayResultSet2.getInt(2), 69); assertFalse(arrayResultSet2.next()); } + + try (ResultSet rs = statement.executeQuery("select [" + Integer.MAX_VALUE + "::BIGINT + 1]")) { + assertTrue(rs.next()); + ResultSet arrayResultSet = rs.getArray(1).getResultSet(); + assertTrue(arrayResultSet.next()); + assertEquals(arrayResultSet.getLong(2), ((long) Integer.MAX_VALUE) + 1); + } } } @@ -4577,6 +4602,44 @@ public static void test_get_bytes() throws Exception { } } + public static void test_case_insensitivity() throws Exception { + try (Connection connection = DriverManager.getConnection("jdbc:duckdb:")) { + try (Statement s = connection.createStatement()) { + s.execute("CREATE TABLE someTable (lowercase INT, mixedCASE INT, UPPERCASE INT)"); + s.execute("INSERT INTO someTable VALUES (0, 1, 2)"); + } + + String[] tableNameVariations = new String[] {"sometable", "someTable", "SOMETABLE"}; + String[][] columnNameVariations = new String[][] {{"lowercase", "mixedcase", "uppercase"}, + {"lowerCASE", "mixedCASE", "upperCASE"}, + {"LOWERCASE", "MIXEDCASE", "UPPERCASE"}}; + + int totalTestsRun = 0; + + // Test every combination of upper, lower and mixedcase column and table names. + for (String tableName : tableNameVariations) { + for (int columnVariation = 0; columnVariation < columnNameVariations.length; columnVariation++) { + try (Statement s = connection.createStatement()) { + String query = String.format("SELECT %s, %s, %s from %s;", columnNameVariations[0][0], + columnNameVariations[0][1], columnNameVariations[0][2], tableName); + + ResultSet resultSet = s.executeQuery(query); + assertTrue(resultSet.next()); + for (int i = 0; i < columnNameVariations[0].length; i++) { + assertEquals(resultSet.getInt(columnNameVariations[columnVariation][i]), i, + "Query " + query + " did not get correct result back for column number " + i); + totalTestsRun++; + } + } + } + } + + assertEquals(totalTestsRun, + tableNameVariations.length * columnNameVariations.length * columnNameVariations[0].length, + "Number of test cases actually run did not match number expected to be run."); + } + } + public static void test_fractional_time() throws Exception { try (Connection conn = DriverManager.getConnection(JDBC_URL); PreparedStatement stmt = conn.prepareStatement("SELECT '01:02:03.123'::TIME");