Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 51 additions & 31 deletions src/main/java/org/duckdb/DuckDBArrayResultSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,25 @@ public boolean wasNull() throws SQLException {
return wasNull;
}

private <T> T getValue(String columnLabel, SqlValueGetter<T> getter) throws SQLException {
return getValue(findColumn(columnLabel), getter);
}

private <T> T getValue(int columnIndex, SqlValueGetter<T> getter) throws SQLException {
if (columnIndex == 1) {
throw new IllegalArgumentException(
"The first element of Array-backed ResultSet can only be retrieved with getInt()");
throw new SQLException(
"The first element of Array-backed ResultSet can only be retrieved with numeric getters");
}
if (columnIndex != 2) {
throw new IllegalArgumentException("Array-backed ResultSet can only have two columns");
throw new SQLException("Array-backed ResultSet can only have two columns");
}
T value = getter.getValue(offset + currentValueIndex);

wasNull = value == null;
return value;
}

private int getIndexColumnValue() {
wasNull = false;
return currentValueIndex + 1;
}

@Override
public String getString(int columnIndex) throws SQLException {
return getValue(columnIndex, vector::getLazyString);
Expand All @@ -71,40 +72,57 @@ public boolean getBoolean(int columnIndex) throws SQLException {

@Override
public byte getByte(int columnIndex) throws SQLException {
if (columnIndex == 1) {
return (byte) getIndexColumnValue();
}
return getValue(columnIndex, vector::getByte);
}

@Override
public short getShort(int columnIndex) throws SQLException {
if (columnIndex == 1) {
return (short) getIndexColumnValue();
}
return getValue(columnIndex, vector::getShort);
}

@Override
public int getInt(int columnIndex) throws SQLException {
if (columnIndex == 1) {
wasNull = false;
return currentValueIndex + 1;
return getIndexColumnValue();
}
return getValue(columnIndex, vector::getInt);
}

@Override
public long getLong(int columnIndex) throws SQLException {
return getInt(columnIndex);
if (columnIndex == 1) {
return getIndexColumnValue();
}
return getValue(columnIndex, vector::getLong);
}

@Override
public float getFloat(int columnIndex) throws SQLException {
if (columnIndex == 1) {
return getIndexColumnValue();
}
return getValue(columnIndex, vector::getFloat);
}

@Override
public double getDouble(int columnIndex) throws SQLException {
if (columnIndex == 1) {
return getIndexColumnValue();
}
return getValue(columnIndex, vector::getDouble);
}

@Override
public BigDecimal getBigDecimal(int columnIndex, int scale) throws SQLException {
if (columnIndex == 1) {
return BigDecimal.valueOf(getIndexColumnValue());
}
return getValue(columnIndex, vector::getBigDecimal);
}

Expand Down Expand Up @@ -145,51 +163,47 @@ public InputStream getBinaryStream(int columnIndex) throws SQLException {

@Override
public String getString(String columnLabel) throws SQLException {
return getValue(columnLabel, vector::getLazyString);
return getString(findColumn(columnLabel));
}

@Override
public boolean getBoolean(String columnLabel) throws SQLException {
return getValue(columnLabel, vector::getBoolean);
return getBoolean(findColumn(columnLabel));
}

@Override
public byte getByte(String columnLabel) throws SQLException {
return getValue(columnLabel, vector::getByte);
return getByte(findColumn(columnLabel));
}

@Override
public short getShort(String columnLabel) throws SQLException {
return getValue(columnLabel, vector::getShort);
return getShort(findColumn(columnLabel));
}

@Override
public int getInt(String columnLabel) throws SQLException {
int columnIndex = findColumn(columnLabel);
if (columnIndex == 1) {
return currentValueIndex;
}
return getValue(columnIndex, vector::getInt);
return getInt(findColumn(columnLabel));
}

@Override
public long getLong(String columnLabel) throws SQLException {
return getInt(columnLabel);
return getLong(findColumn(columnLabel));
}

@Override
public float getFloat(String columnLabel) throws SQLException {
return getValue(columnLabel, vector::getFloat);
return getFloat(findColumn(columnLabel));
}

@Override
public double getDouble(String columnLabel) throws SQLException {
return getValue(columnLabel, vector::getDouble);
return getDouble(findColumn(columnLabel));
}

@Override
public BigDecimal getBigDecimal(String columnLabel, int scale) throws SQLException {
return getValue(columnLabel, vector::getBigDecimal);
return getBigDecimal(findColumn(columnLabel), scale);
}

@Override
Expand All @@ -199,17 +213,17 @@ public byte[] getBytes(String columnLabel) throws SQLException {

@Override
public Date getDate(String columnLabel) throws SQLException {
return getValue(columnLabel, vector::getDate);
return getDate(findColumn(columnLabel));
}

@Override
public Time getTime(String columnLabel) throws SQLException {
return getValue(columnLabel, vector::getTime);
return getTime(findColumn(columnLabel));
}

@Override
public Timestamp getTimestamp(String columnLabel) throws SQLException {
return getValue(columnLabel, vector::getTimestamp);
return getTimestamp(findColumn(columnLabel));
}

@Override
Expand Down Expand Up @@ -254,12 +268,18 @@ public Object getObject(int columnIndex) throws SQLException {

@Override
public Object getObject(String columnLabel) throws SQLException {
return getValue(columnLabel, vector::getTimestamp);
return getObject(findColumn(columnLabel));
}

@Override
public int findColumn(String columnLabel) throws SQLException {
return Integer.parseInt(columnLabel);
if ("INDEX".equalsIgnoreCase(columnLabel)) {
return 1;
}
if ("VALUE".equalsIgnoreCase(columnLabel)) {
return 2;
}
throw new SQLException("Could not find column with label " + columnLabel);
}

@Override
Expand All @@ -279,7 +299,7 @@ public BigDecimal getBigDecimal(int columnIndex) throws SQLException {

@Override
public BigDecimal getBigDecimal(String columnLabel) throws SQLException {
return getValue(columnLabel, vector::getBigDecimal);
return getBigDecimal(findColumn(columnLabel));
}

@Override
Expand Down Expand Up @@ -671,7 +691,7 @@ public Array getArray(int columnIndex) throws SQLException {

@Override
public Object getObject(String columnLabel, Map<String, Class<?>> map) throws SQLException {
return getValue(columnLabel, vector::getObject);
return getObject(findColumn(columnLabel));
}

@Override
Expand All @@ -691,7 +711,7 @@ public Clob getClob(String columnLabel) throws SQLException {

@Override
public Array getArray(String columnLabel) throws SQLException {
return getValue(columnLabel, vector::getArray);
return getArray(findColumn(columnLabel));
}

@Override
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/duckdb/DuckDBResultSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ public int findColumn(String columnLabel) throws SQLException {
throw new SQLException("ResultSet was closed");
}
for (int col_idx = 0; col_idx < meta.column_count; col_idx++) {
if (meta.column_names[col_idx].contentEquals(columnLabel)) {
if (meta.column_names[col_idx].equalsIgnoreCase(columnLabel)) {
return col_idx + 1;
}
}
Expand Down
71 changes: 67 additions & 4 deletions src/test/java/org/duckdb/TestDuckDBJDBC.java
Original file line number Diff line number Diff line change
Expand Up @@ -3748,7 +3748,17 @@ public static void test_array_resultset() throws Exception {
ResultSet arrayResultSet = rs.getArray(1).getResultSet();
assertTrue(arrayResultSet.next());
assertEquals(arrayResultSet.getInt(1), 1);
assertEquals(arrayResultSet.getInt("index"), 1);
assertEquals(arrayResultSet.getInt("Index"), 1);
assertEquals(arrayResultSet.getInt("INDEX"), 1);
assertEquals(arrayResultSet.getByte(2), (byte) 42);
assertEquals(arrayResultSet.getShort(2), (short) 42);
assertEquals(arrayResultSet.getInt(2), 42);
assertEquals(arrayResultSet.getLong(2), (long) 42);
assertEquals(arrayResultSet.getFloat(2), (float) 42);
assertEquals(arrayResultSet.getDouble(2), (double) 42);
assertEquals(arrayResultSet.getBigDecimal(2), BigDecimal.valueOf(42));
assertEquals(arrayResultSet.getInt("value"), 42);
assertTrue(arrayResultSet.next());
assertEquals(arrayResultSet.getInt(1), 2);
assertEquals(arrayResultSet.getInt(2), 69);
Expand All @@ -3760,10 +3770,18 @@ public static void test_array_resultset() throws Exception {
ResultSet arrayResultSet = rs.getArray(1).getResultSet();
assertTrue(arrayResultSet.next());
assertEquals(arrayResultSet.getInt(1), 1);
Array subArray = arrayResultSet.getArray(2);
assertNotNull(subArray);
ResultSet subArrayResultSet = subArray.getResultSet();
assertFalse(subArrayResultSet.next()); // empty array
{
Array subArray = arrayResultSet.getArray(2);
assertNotNull(subArray);
ResultSet subArrayResultSet = subArray.getResultSet();
assertFalse(subArrayResultSet.next()); // empty array
}
{
Array subArray = arrayResultSet.getArray("value");
assertNotNull(subArray);
ResultSet subArrayResultSet = subArray.getResultSet();
assertFalse(subArrayResultSet.next()); // empty array
}

assertTrue(arrayResultSet.next());
assertEquals(arrayResultSet.getInt(1), 2);
Expand Down Expand Up @@ -3845,6 +3863,13 @@ public static void test_array_resultset() throws Exception {
assertEquals(arrayResultSet2.getInt(2), 69);
assertFalse(arrayResultSet2.next());
}

try (ResultSet rs = statement.executeQuery("select [" + Integer.MAX_VALUE + "::BIGINT + 1]")) {
assertTrue(rs.next());
ResultSet arrayResultSet = rs.getArray(1).getResultSet();
assertTrue(arrayResultSet.next());
assertEquals(arrayResultSet.getLong(2), ((long) Integer.MAX_VALUE) + 1);
}
}
}

Expand Down Expand Up @@ -4577,6 +4602,44 @@ public static void test_get_bytes() throws Exception {
}
}

public static void test_case_insensitivity() throws Exception {
try (Connection connection = DriverManager.getConnection("jdbc:duckdb:")) {
try (Statement s = connection.createStatement()) {
s.execute("CREATE TABLE someTable (lowercase INT, mixedCASE INT, UPPERCASE INT)");
s.execute("INSERT INTO someTable VALUES (0, 1, 2)");
}

String[] tableNameVariations = new String[] {"sometable", "someTable", "SOMETABLE"};
String[][] columnNameVariations = new String[][] {{"lowercase", "mixedcase", "uppercase"},
{"lowerCASE", "mixedCASE", "upperCASE"},
{"LOWERCASE", "MIXEDCASE", "UPPERCASE"}};

int totalTestsRun = 0;

// Test every combination of upper, lower and mixedcase column and table names.
for (String tableName : tableNameVariations) {
for (int columnVariation = 0; columnVariation < columnNameVariations.length; columnVariation++) {
try (Statement s = connection.createStatement()) {
String query = String.format("SELECT %s, %s, %s from %s;", columnNameVariations[0][0],
columnNameVariations[0][1], columnNameVariations[0][2], tableName);

ResultSet resultSet = s.executeQuery(query);
assertTrue(resultSet.next());
for (int i = 0; i < columnNameVariations[0].length; i++) {
assertEquals(resultSet.getInt(columnNameVariations[columnVariation][i]), i,
"Query " + query + " did not get correct result back for column number " + i);
totalTestsRun++;
}
}
}
}

assertEquals(totalTestsRun,
tableNameVariations.length * columnNameVariations.length * columnNameVariations[0].length,
"Number of test cases actually run did not match number expected to be run.");
}
}

public static void test_fractional_time() throws Exception {
try (Connection conn = DriverManager.getConnection(JDBC_URL);
PreparedStatement stmt = conn.prepareStatement("SELECT '01:02:03.123'::TIME");
Expand Down
Loading