diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f42fad86..e5772f965 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -439,7 +439,7 @@ set(DUCKDB_SRC_FILES src/duckdb/extension/json/json_serializer.cpp src/duckdb/ub_extension_json_json_functions.cpp) -set(JEMACLLOC_SRC_FILES +set(JEMALLOC_SRC_FILES src/duckdb/extension/jemalloc/jemalloc_extension.cpp src/duckdb/extension/jemalloc/jemalloc/src/jemalloc.c src/duckdb/extension/jemalloc/jemalloc/src/arena.c @@ -553,7 +553,7 @@ add_jar(duckdb_jdbc_tests ${JAVA_TEST_FILES} INCLUDE_JARS duckdb_jdbc) if(MSVC) list(APPEND DUCKDB_SRC_FILES duckdb_java.def) else() - list(APPEND DUCKDB_SRC_FILES ${JEMACLLOC_SRC_FILES}) + list(APPEND DUCKDB_SRC_FILES ${JEMALLOC_SRC_FILES}) endif() add_library(duckdb_java SHARED diff --git a/CMakeLists.txt.in b/CMakeLists.txt.in index 0b45d9f44..3baeb7a66 100644 --- a/CMakeLists.txt.in +++ b/CMakeLists.txt.in @@ -46,7 +46,7 @@ set(DUCKDB_DEFINITIONS set(DUCKDB_SRC_FILES ${SOURCES}) -set(JEMACLLOC_SRC_FILES +set(JEMALLOC_SRC_FILES ${JEMALLOC_SOURCES}) @@ -95,7 +95,7 @@ add_jar(duckdb_jdbc_tests ${JAVA_TEST_FILES} INCLUDE_JARS duckdb_jdbc) if(MSVC) list(APPEND DUCKDB_SRC_FILES duckdb_java.def) else() - list(APPEND DUCKDB_SRC_FILES ${JEMACLLOC_SRC_FILES}) + list(APPEND DUCKDB_SRC_FILES ${JEMALLOC_SRC_FILES}) endif() add_library(duckdb_java SHARED diff --git a/src/jni/duckdb_java.cpp b/src/jni/duckdb_java.cpp index 985ce437d..4822448c5 100644 --- a/src/jni/duckdb_java.cpp +++ b/src/jni/duckdb_java.cpp @@ -13,6 +13,7 @@ #include "duckdb/main/extension_util.hpp" #include "duckdb/parser/parsed_data/create_type_info.hpp" #include "functions.hpp" +#include "holders.hpp" #include "refs.hpp" #include "types.hpp" #include "util.hpp" @@ -59,40 +60,6 @@ JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) { delete_global_refs(env); } -/** - * Associates a duckdb::Connection with a duckdb::DuckDB. The DB may be shared amongst many ConnectionHolders, but the - * Connection is unique to this holder. Every Java DuckDBConnection has exactly 1 of these holders, and they are never - * shared. The holder is freed when the DuckDBConnection is closed. When the last holder sharing a DuckDB is freed, the - * DuckDB is released as well. - */ -struct ConnectionHolder { - const duckdb::shared_ptr db; - const duckdb::unique_ptr connection; - - ConnectionHolder(duckdb::shared_ptr _db) - : db(_db), connection(make_uniq(*_db)) { - } -}; - -/** - * Throws a SQLException and returns nullptr if a valid Connection can't be retrieved from the buffer. - */ -static Connection *get_connection(JNIEnv *env, jobject conn_ref_buf) { - if (!conn_ref_buf) { - throw ConnectionException("Invalid connection"); - } - auto conn_holder = (ConnectionHolder *)env->GetDirectBufferAddress(conn_ref_buf); - if (!conn_holder) { - throw ConnectionException("Invalid connection"); - } - auto conn_ref = conn_holder->connection.get(); - if (!conn_ref || !conn_ref->context) { - throw ConnectionException("Invalid connection"); - } - - return conn_ref; -} - //! The database instance cache, used so that multiple connections to the same file point to the same database object duckdb::DBInstanceCache instance_cache; @@ -189,10 +156,6 @@ void _duckdb_jdbc_disconnect(JNIEnv *env, jclass, jobject conn_ref_buf) { } } -struct StatementHolder { - duckdb::unique_ptr stmt; -}; - #include "utf8proc_wrapper.hpp" jobject _duckdb_jdbc_prepare(JNIEnv *env, jclass, jobject conn_ref_buf, jbyteArray query_j) { @@ -233,11 +196,6 @@ jobject _duckdb_jdbc_prepare(JNIEnv *env, jclass, jobject conn_ref_buf, jbyteArr return env->NewDirectByteBuffer(stmt_ref, 0); } -struct ResultHolder { - duckdb::unique_ptr res; - duckdb::unique_ptr chunk; -}; - Value ToValue(JNIEnv *env, jobject param, duckdb::shared_ptr context) { param = env->CallStaticObjectMethod(J_Timestamp, J_Timestamp_valueOf, param); @@ -930,6 +888,7 @@ static ProfilerPrintFormat GetProfilerPrintFormat(JNIEnv *env, jobject format) { if (env->IsSameObject(format, J_ProfilerPrintFormat_GRAPHVIZ)) { return ProfilerPrintFormat::GRAPHVIZ; } + throw InvalidInputException("Invalid profiling format"); } jstring _duckdb_jdbc_get_profiling_information(JNIEnv *env, jclass, jobject conn_ref_buf, jobject j_format) { diff --git a/src/jni/functions.cpp b/src/jni/functions.cpp index 4bda7b88d..dbf43a74e 100644 --- a/src/jni/functions.cpp +++ b/src/jni/functions.cpp @@ -404,5 +404,6 @@ JNIEXPORT jstring JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1get_1profil duckdb::ErrorData error(e); ThrowJNI(env, error.Message().c_str()); + return nullptr; } } diff --git a/src/jni/holders.hpp b/src/jni/holders.hpp new file mode 100644 index 000000000..48a49cea2 --- /dev/null +++ b/src/jni/holders.hpp @@ -0,0 +1,48 @@ +#pragma once + +#include "duckdb.hpp" + +#include + +/** + * Associates a duckdb::Connection with a duckdb::DuckDB. The DB may be shared amongst many ConnectionHolders, but the + * Connection is unique to this holder. Every Java DuckDBConnection has exactly 1 of these holders, and they are never + * shared. The holder is freed when the DuckDBConnection is closed. When the last holder sharing a DuckDB is freed, the + * DuckDB is released as well. + */ +struct ConnectionHolder { + const duckdb::shared_ptr db; + const duckdb::unique_ptr connection; + + ConnectionHolder(duckdb::shared_ptr _db) + : db(_db), connection(duckdb::make_uniq(*_db)) { + } +}; + +struct StatementHolder { + duckdb::unique_ptr stmt; +}; + +struct ResultHolder { + duckdb::unique_ptr res; + duckdb::unique_ptr chunk; +}; + +/** + * Throws a SQLException and returns nullptr if a valid Connection can't be retrieved from the buffer. + */ +inline duckdb::Connection *get_connection(JNIEnv *env, jobject conn_ref_buf) { + if (!conn_ref_buf) { + throw duckdb::ConnectionException("Invalid connection"); + } + auto conn_holder = (ConnectionHolder *)env->GetDirectBufferAddress(conn_ref_buf); + if (!conn_holder) { + throw duckdb::ConnectionException("Invalid connection"); + } + auto conn_ref = conn_holder->connection.get(); + if (!conn_ref || !conn_ref->context) { + throw duckdb::ConnectionException("Invalid connection"); + } + + return conn_ref; +} diff --git a/src/main/java/org/duckdb/DuckDBAppender.java b/src/main/java/org/duckdb/DuckDBAppender.java index 08eb5ef35..fd79f9726 100644 --- a/src/main/java/org/duckdb/DuckDBAppender.java +++ b/src/main/java/org/duckdb/DuckDBAppender.java @@ -16,7 +16,7 @@ public DuckDBAppender(DuckDBConnection con, String schemaName, String tableName) throw new SQLException("Invalid connection"); } appender_ref = DuckDBNative.duckdb_jdbc_create_appender( - con.conn_ref, schemaName.getBytes(StandardCharsets.UTF_8), tableName.getBytes(StandardCharsets.UTF_8)); + con.connRef, schemaName.getBytes(StandardCharsets.UTF_8), tableName.getBytes(StandardCharsets.UTF_8)); } public void beginRow() throws SQLException { diff --git a/src/main/java/org/duckdb/DuckDBConnection.java b/src/main/java/org/duckdb/DuckDBConnection.java index 79a37891f..504d5a236 100644 --- a/src/main/java/org/duckdb/DuckDBConnection.java +++ b/src/main/java/org/duckdb/DuckDBConnection.java @@ -1,8 +1,9 @@ package org.duckdb; +import static java.nio.charset.StandardCharsets.UTF_8; + import java.lang.reflect.InvocationTargetException; import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; import java.sql.Array; import java.sql.Blob; import java.sql.CallableStatement; @@ -20,10 +21,10 @@ import java.sql.Savepoint; import java.sql.Statement; import java.sql.Struct; -import java.util.HashMap; -import java.util.Map; -import java.util.Properties; +import java.util.*; import java.util.concurrent.Executor; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import org.duckdb.user.DuckDBMap; import org.duckdb.user.DuckDBUserArray; import org.duckdb.user.DuckDBUserStruct; @@ -33,7 +34,11 @@ public final class DuckDBConnection implements java.sql.Connection { /** Name of the DuckDB default schema. */ public static final String DEFAULT_SCHEMA = "main"; - ByteBuffer conn_ref; + ByteBuffer connRef; + final Lock connRefLock = new ReentrantLock(); + final LinkedHashSet preparedStatements = new LinkedHashSet<>(); + volatile boolean closing = false; + boolean autoCommit = true; boolean transactionRunning; final String url; @@ -48,13 +53,12 @@ public static DuckDBConnection newConnection(String url, boolean readOnly, Prope if (db_dir.length() == 0) { db_dir = ":memory:"; } - ByteBuffer nativeReference = - DuckDBNative.duckdb_jdbc_startup(db_dir.getBytes(StandardCharsets.UTF_8), readOnly, properties); + ByteBuffer nativeReference = DuckDBNative.duckdb_jdbc_startup(db_dir.getBytes(UTF_8), readOnly, properties); return new DuckDBConnection(nativeReference, url, readOnly); } private DuckDBConnection(ByteBuffer connectionReference, String url, boolean readOnly) throws SQLException { - conn_ref = connectionReference; + this.connRef = connectionReference; this.url = url; this.readOnly = readOnly; DuckDBNative.duckdb_jdbc_set_auto_commit(connectionReference, true); @@ -62,9 +66,7 @@ private DuckDBConnection(ByteBuffer connectionReference, String url, boolean rea public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException { - if (isClosed()) { - throw new SQLException("Connection was closed"); - } + checkOpen(); if (resultSetConcurrency == ResultSet.CONCUR_READ_ONLY && resultSetType == ResultSet.TYPE_FORWARD_ONLY) { return new DuckDBPreparedStatement(this); } @@ -73,9 +75,7 @@ public Statement createStatement(int resultSetType, int resultSetConcurrency, in public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException { - if (isClosed()) { - throw new SQLException("Connection was closed"); - } + checkOpen(); if (resultSetConcurrency == ResultSet.CONCUR_READ_ONLY && resultSetType == ResultSet.TYPE_FORWARD_ONLY) { return new DuckDBPreparedStatement(this, sql); } @@ -87,10 +87,14 @@ public Statement createStatement() throws SQLException { } public Connection duplicate() throws SQLException { - if (isClosed()) { - throw new SQLException("Connection is closed"); + checkOpen(); + connRefLock.lock(); + try { + checkOpen(); + return new DuckDBConnection(DuckDBNative.duckdb_jdbc_connect(connRef), url, readOnly); + } finally { + connRefLock.unlock(); } - return new DuckDBConnection(DuckDBNative.duckdb_jdbc_connect(conn_ref), url, readOnly); } public void commit() throws SQLException { @@ -111,15 +115,46 @@ protected void finalize() throws Throwable { close(); } - public synchronized void close() throws SQLException { - if (conn_ref != null) { - DuckDBNative.duckdb_jdbc_disconnect(conn_ref); - conn_ref = null; + public void close() throws SQLException { + if (isClosed()) { + return; + } + connRefLock.lock(); + try { + if (isClosed()) { + return; + } + + // Mark this instance as 'closing' to skip untrack call in + // prepared statements, that requires connection lock and can + // cause a deadlock when the statement closure is caused by the + // connection interrupt called by us. + this.closing = true; + + // Interrupt running query if any + try { + interrupt(); + } catch (SQLException e) { + // suppress + } + + // Last statement created is first deleted + List psList = new ArrayList<>(preparedStatements); + Collections.reverse(psList); + for (DuckDBPreparedStatement ps : psList) { + ps.close(); + } + preparedStatements.clear(); + + DuckDBNative.duckdb_jdbc_disconnect(connRef); + connRef = null; + } finally { + connRefLock.unlock(); } } public boolean isClosed() throws SQLException { - return conn_ref == null; + return connRef == null; } public boolean isValid(int timeout) throws SQLException { @@ -197,19 +232,47 @@ public DatabaseMetaData getMetaData() throws SQLException { } public void setCatalog(String catalog) throws SQLException { - DuckDBNative.duckdb_jdbc_set_catalog(conn_ref, catalog); + checkOpen(); + connRefLock.lock(); + try { + checkOpen(); + DuckDBNative.duckdb_jdbc_set_catalog(connRef, catalog); + } finally { + connRefLock.unlock(); + } } public String getCatalog() throws SQLException { - return DuckDBNative.duckdb_jdbc_get_catalog(conn_ref); + checkOpen(); + connRefLock.lock(); + try { + checkOpen(); + return DuckDBNative.duckdb_jdbc_get_catalog(connRef); + } finally { + connRefLock.unlock(); + } } public void setSchema(String schema) throws SQLException { - DuckDBNative.duckdb_jdbc_set_schema(conn_ref, schema); + checkOpen(); + connRefLock.lock(); + try { + checkOpen(); + DuckDBNative.duckdb_jdbc_set_schema(connRef, schema); + } finally { + connRefLock.unlock(); + } } public String getSchema() throws SQLException { - return DuckDBNative.duckdb_jdbc_get_schema(conn_ref); + checkOpen(); + connRefLock.lock(); + try { + checkOpen(); + return DuckDBNative.duckdb_jdbc_get_schema(connRef); + } finally { + connRefLock.unlock(); + } } @Override @@ -381,11 +444,49 @@ private static long getArrowStreamAddress(Object arrow_array_stream) { } public void registerArrowStream(String name, Object arrow_array_stream) { - long array_stream_address = getArrowStreamAddress(arrow_array_stream); - DuckDBNative.duckdb_jdbc_arrow_register(conn_ref, array_stream_address, name.getBytes(StandardCharsets.UTF_8)); + try { + checkOpen(); + long array_stream_address = getArrowStreamAddress(arrow_array_stream); + connRefLock.lock(); + try { + checkOpen(); + DuckDBNative.duckdb_jdbc_arrow_register(connRef, array_stream_address, name.getBytes(UTF_8)); + } finally { + connRefLock.unlock(); + } + } catch (SQLException e) { + throw new RuntimeException(e); + } } public String getProfilingInformation(ProfilerPrintFormat format) throws SQLException { - return DuckDBNative.duckdb_jdbc_get_profiling_information(conn_ref, format); + checkOpen(); + connRefLock.lock(); + try { + checkOpen(); + return DuckDBNative.duckdb_jdbc_get_profiling_information(connRef, format); + } finally { + connRefLock.unlock(); + } + } + + void checkOpen() throws SQLException { + if (isClosed()) { + throw new SQLException("Connection was closed"); + } + } + + /** + * This function calls the underlying C++ interrupt function which aborts the query running on this connection. + */ + void interrupt() throws SQLException { + checkOpen(); + connRefLock.lock(); + try { + checkOpen(); + DuckDBNative.duckdb_jdbc_interrupt(connRef); + } finally { + connRefLock.unlock(); + } } } diff --git a/src/main/java/org/duckdb/DuckDBDatabaseMetaData.java b/src/main/java/org/duckdb/DuckDBDatabaseMetaData.java index ed3269314..fb965371a 100644 --- a/src/main/java/org/duckdb/DuckDBDatabaseMetaData.java +++ b/src/main/java/org/duckdb/DuckDBDatabaseMetaData.java @@ -14,8 +14,6 @@ import java.util.Arrays; import java.util.Map; import java.util.stream.Collectors; -import javax.sql.rowset.CachedRowSet; -import javax.sql.rowset.RowSetProvider; public class DuckDBDatabaseMetaData implements DatabaseMetaData { @@ -173,74 +171,74 @@ public String getIdentifierQuoteString() throws SQLException { @Override public String getSQLKeywords() throws SQLException { - Statement statement = conn.createStatement(); - statement.closeOnCompletion(); - ResultSet rs = statement.executeQuery("SELECT keyword_name FROM duckdb_keywords()"); - StringBuilder sb = new StringBuilder(); - while (rs.next()) { - sb.append(rs.getString(1)); - sb.append(','); + try (Statement statement = conn.createStatement(); + ResultSet rs = statement.executeQuery("SELECT keyword_name FROM duckdb_keywords()")) { + StringBuilder sb = new StringBuilder(); + while (rs.next()) { + sb.append(rs.getString(1)); + sb.append(','); + } + return sb.toString(); } - return sb.toString(); } @Override public String getNumericFunctions() throws SQLException { - Statement statement = conn.createStatement(); - statement.closeOnCompletion(); - ResultSet rs = statement.executeQuery("SELECT DISTINCT function_name FROM duckdb_functions() " - + "WHERE parameter_types[1] ='DECIMAL'" - + "OR parameter_types[1] ='DOUBLE'" - + "OR parameter_types[1] ='SMALLINT'" - + "OR parameter_types[1] = 'BIGINT'"); - StringBuilder sb = new StringBuilder(); - while (rs.next()) { - sb.append(rs.getString(1)); - sb.append(','); + try (Statement statement = conn.createStatement(); + ResultSet rs = statement.executeQuery("SELECT DISTINCT function_name FROM duckdb_functions() " + + "WHERE parameter_types[1] ='DECIMAL'" + + "OR parameter_types[1] ='DOUBLE'" + + "OR parameter_types[1] ='SMALLINT'" + + "OR parameter_types[1] = 'BIGINT'")) { + StringBuilder sb = new StringBuilder(); + while (rs.next()) { + sb.append(rs.getString(1)); + sb.append(','); + } + return sb.toString(); } - return sb.toString(); } @Override public String getStringFunctions() throws SQLException { - Statement statement = conn.createStatement(); - statement.closeOnCompletion(); - ResultSet rs = statement.executeQuery( - "SELECT DISTINCT function_name FROM duckdb_functions() WHERE parameter_types[1] = 'VARCHAR'"); - StringBuilder sb = new StringBuilder(); - while (rs.next()) { - sb.append(rs.getString(1)); - sb.append(','); + try (Statement statement = conn.createStatement(); + ResultSet rs = statement.executeQuery( + "SELECT DISTINCT function_name FROM duckdb_functions() WHERE parameter_types[1] = 'VARCHAR'")) { + StringBuilder sb = new StringBuilder(); + while (rs.next()) { + sb.append(rs.getString(1)); + sb.append(','); + } + return sb.toString(); } - return sb.toString(); } @Override public String getSystemFunctions() throws SQLException { - Statement statement = conn.createStatement(); - statement.closeOnCompletion(); - ResultSet rs = statement.executeQuery( - "SELECT DISTINCT function_name FROM duckdb_functions() WHERE length(parameter_types) = 0"); - StringBuilder sb = new StringBuilder(); - while (rs.next()) { - sb.append(rs.getString(1)); - sb.append(','); + try (Statement statement = conn.createStatement(); + ResultSet rs = statement.executeQuery( + "SELECT DISTINCT function_name FROM duckdb_functions() WHERE length(parameter_types) = 0")) { + StringBuilder sb = new StringBuilder(); + while (rs.next()) { + sb.append(rs.getString(1)); + sb.append(','); + } + return sb.toString(); } - return sb.toString(); } @Override public String getTimeDateFunctions() throws SQLException { - Statement statement = conn.createStatement(); - statement.closeOnCompletion(); - ResultSet rs = statement.executeQuery( - "SELECT DISTINCT function_name FROM duckdb_functions() WHERE parameter_types[1] LIKE 'TIME%'"); - StringBuilder sb = new StringBuilder(); - while (rs.next()) { - sb.append(rs.getString(1)); - sb.append(','); + try (Statement statement = conn.createStatement(); + ResultSet rs = statement.executeQuery( + "SELECT DISTINCT function_name FROM duckdb_functions() WHERE parameter_types[1] LIKE 'TIME%'")) { + StringBuilder sb = new StringBuilder(); + while (rs.next()) { + sb.append(rs.getString(1)); + sb.append(','); + } + return sb.toString(); } - return sb.toString(); } @Override diff --git a/src/main/java/org/duckdb/DuckDBNative.java b/src/main/java/org/duckdb/DuckDBNative.java index 6d4c3fb1a..4277785e7 100644 --- a/src/main/java/org/duckdb/DuckDBNative.java +++ b/src/main/java/org/duckdb/DuckDBNative.java @@ -13,7 +13,7 @@ import java.sql.SQLException; import java.util.Properties; -class DuckDBNative { +final class DuckDBNative { static { try { String os_name = ""; @@ -70,108 +70,92 @@ class DuckDBNative { */ // results ConnectionHolder reference object - protected static native ByteBuffer duckdb_jdbc_startup(byte[] path, boolean read_only, Properties props) - throws SQLException; + static native ByteBuffer duckdb_jdbc_startup(byte[] path, boolean read_only, Properties props) throws SQLException; // returns conn_ref connection reference object - protected static native ByteBuffer duckdb_jdbc_connect(ByteBuffer db_ref) throws SQLException; + static native ByteBuffer duckdb_jdbc_connect(ByteBuffer db_ref) throws SQLException; - protected static native void duckdb_jdbc_set_auto_commit(ByteBuffer conn_ref, boolean auto_commit) - throws SQLException; + static native void duckdb_jdbc_set_auto_commit(ByteBuffer conn_ref, boolean auto_commit) throws SQLException; - protected static native boolean duckdb_jdbc_get_auto_commit(ByteBuffer conn_ref) throws SQLException; + static native boolean duckdb_jdbc_get_auto_commit(ByteBuffer conn_ref) throws SQLException; - protected static native void duckdb_jdbc_disconnect(ByteBuffer conn_ref); + static native void duckdb_jdbc_disconnect(ByteBuffer conn_ref); - protected static native void duckdb_jdbc_set_schema(ByteBuffer conn_ref, String schema); + static native void duckdb_jdbc_set_schema(ByteBuffer conn_ref, String schema); - protected static native void duckdb_jdbc_set_catalog(ByteBuffer conn_ref, String catalog); + static native void duckdb_jdbc_set_catalog(ByteBuffer conn_ref, String catalog); - protected static native String duckdb_jdbc_get_schema(ByteBuffer conn_ref); + static native String duckdb_jdbc_get_schema(ByteBuffer conn_ref); - protected static native String duckdb_jdbc_get_catalog(ByteBuffer conn_ref); + static native String duckdb_jdbc_get_catalog(ByteBuffer conn_ref); // returns stmt_ref result reference object - protected static native ByteBuffer duckdb_jdbc_prepare(ByteBuffer conn_ref, byte[] query) throws SQLException; + static native ByteBuffer duckdb_jdbc_prepare(ByteBuffer conn_ref, byte[] query) throws SQLException; - protected static native void duckdb_jdbc_release(ByteBuffer stmt_ref); + static native void duckdb_jdbc_release(ByteBuffer stmt_ref); - protected static native DuckDBResultSetMetaData duckdb_jdbc_query_result_meta(ByteBuffer result_ref) - throws SQLException; + static native DuckDBResultSetMetaData duckdb_jdbc_query_result_meta(ByteBuffer result_ref) throws SQLException; - protected static native DuckDBResultSetMetaData duckdb_jdbc_prepared_statement_meta(ByteBuffer stmt_ref) - throws SQLException; + static native DuckDBResultSetMetaData duckdb_jdbc_prepared_statement_meta(ByteBuffer stmt_ref) throws SQLException; // returns res_ref result reference object - protected static native ByteBuffer duckdb_jdbc_execute(ByteBuffer stmt_ref, Object[] params) throws SQLException; + static native ByteBuffer duckdb_jdbc_execute(ByteBuffer stmt_ref, Object[] params) throws SQLException; - protected static native void duckdb_jdbc_free_result(ByteBuffer res_ref); + static native void duckdb_jdbc_free_result(ByteBuffer res_ref); - protected static native DuckDBVector[] duckdb_jdbc_fetch(ByteBuffer res_ref, ByteBuffer conn_ref) - throws SQLException; + static native DuckDBVector[] duckdb_jdbc_fetch(ByteBuffer res_ref, ByteBuffer conn_ref) throws SQLException; - protected static native int duckdb_jdbc_fetch_size(); + static native int duckdb_jdbc_fetch_size(); - protected static native long duckdb_jdbc_arrow_stream(ByteBuffer res_ref, long batch_size); + static native long duckdb_jdbc_arrow_stream(ByteBuffer res_ref, long batch_size); - protected static native void duckdb_jdbc_arrow_register(ByteBuffer conn_ref, long arrow_array_stream_pointer, - byte[] name); + static native void duckdb_jdbc_arrow_register(ByteBuffer conn_ref, long arrow_array_stream_pointer, byte[] name); - protected static native ByteBuffer duckdb_jdbc_create_appender(ByteBuffer conn_ref, byte[] schema_name, - byte[] table_name) throws SQLException; + static native ByteBuffer duckdb_jdbc_create_appender(ByteBuffer conn_ref, byte[] schema_name, byte[] table_name) + throws SQLException; - protected static native void duckdb_jdbc_appender_begin_row(ByteBuffer appender_ref) throws SQLException; + static native void duckdb_jdbc_appender_begin_row(ByteBuffer appender_ref) throws SQLException; - protected static native void duckdb_jdbc_appender_end_row(ByteBuffer appender_ref) throws SQLException; + static native void duckdb_jdbc_appender_end_row(ByteBuffer appender_ref) throws SQLException; - protected static native void duckdb_jdbc_appender_flush(ByteBuffer appender_ref) throws SQLException; + static native void duckdb_jdbc_appender_flush(ByteBuffer appender_ref) throws SQLException; - protected static native void duckdb_jdbc_interrupt(ByteBuffer conn_ref); + static native void duckdb_jdbc_interrupt(ByteBuffer conn_ref); - protected static native void duckdb_jdbc_appender_close(ByteBuffer appender_ref) throws SQLException; + static native void duckdb_jdbc_appender_close(ByteBuffer appender_ref) throws SQLException; - protected static native void duckdb_jdbc_appender_append_boolean(ByteBuffer appender_ref, boolean value) - throws SQLException; + static native void duckdb_jdbc_appender_append_boolean(ByteBuffer appender_ref, boolean value) throws SQLException; - protected static native void duckdb_jdbc_appender_append_byte(ByteBuffer appender_ref, byte value) - throws SQLException; + static native void duckdb_jdbc_appender_append_byte(ByteBuffer appender_ref, byte value) throws SQLException; - protected static native void duckdb_jdbc_appender_append_short(ByteBuffer appender_ref, short value) - throws SQLException; + static native void duckdb_jdbc_appender_append_short(ByteBuffer appender_ref, short value) throws SQLException; - protected static native void duckdb_jdbc_appender_append_int(ByteBuffer appender_ref, int value) - throws SQLException; + static native void duckdb_jdbc_appender_append_int(ByteBuffer appender_ref, int value) throws SQLException; - protected static native void duckdb_jdbc_appender_append_long(ByteBuffer appender_ref, long value) - throws SQLException; + static native void duckdb_jdbc_appender_append_long(ByteBuffer appender_ref, long value) throws SQLException; - protected static native void duckdb_jdbc_appender_append_float(ByteBuffer appender_ref, float value) - throws SQLException; + static native void duckdb_jdbc_appender_append_float(ByteBuffer appender_ref, float value) throws SQLException; - protected static native void duckdb_jdbc_appender_append_double(ByteBuffer appender_ref, double value) - throws SQLException; + static native void duckdb_jdbc_appender_append_double(ByteBuffer appender_ref, double value) throws SQLException; - protected static native void duckdb_jdbc_appender_append_string(ByteBuffer appender_ref, byte[] value) - throws SQLException; + static native void duckdb_jdbc_appender_append_string(ByteBuffer appender_ref, byte[] value) throws SQLException; - protected static native void duckdb_jdbc_appender_append_bytes(ByteBuffer appender_ref, byte[] value) - throws SQLException; + static native void duckdb_jdbc_appender_append_bytes(ByteBuffer appender_ref, byte[] value) throws SQLException; - protected static native void duckdb_jdbc_appender_append_timestamp(ByteBuffer appender_ref, long value) - throws SQLException; + static native void duckdb_jdbc_appender_append_timestamp(ByteBuffer appender_ref, long value) throws SQLException; - protected static native void duckdb_jdbc_appender_append_decimal(ByteBuffer appender_ref, BigDecimal value) + static native void duckdb_jdbc_appender_append_decimal(ByteBuffer appender_ref, BigDecimal value) throws SQLException; - protected static native void duckdb_jdbc_appender_append_null(ByteBuffer appender_ref) throws SQLException; + static native void duckdb_jdbc_appender_append_null(ByteBuffer appender_ref) throws SQLException; - protected static native void duckdb_jdbc_create_extension_type(ByteBuffer conn_ref) throws SQLException; + static native void duckdb_jdbc_create_extension_type(ByteBuffer conn_ref) throws SQLException; protected static native String duckdb_jdbc_get_profiling_information(ByteBuffer conn_ref, ProfilerPrintFormat format) throws SQLException; public static void duckdb_jdbc_create_extension_type(DuckDBConnection conn) throws SQLException { - duckdb_jdbc_create_extension_type(conn.conn_ref); + duckdb_jdbc_create_extension_type(conn.connRef); } } diff --git a/src/main/java/org/duckdb/DuckDBPreparedStatement.java b/src/main/java/org/duckdb/DuckDBPreparedStatement.java index b46a4760a..cfbed506d 100644 --- a/src/main/java/org/duckdb/DuckDBPreparedStatement.java +++ b/src/main/java/org/duckdb/DuckDBPreparedStatement.java @@ -1,13 +1,14 @@ package org.duckdb; -import java.io.ByteArrayOutputStream; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.duckdb.StatementReturnType.*; + import java.io.IOException; import java.io.InputStream; import java.io.Reader; import java.math.BigDecimal; import java.net.URL; import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; import java.sql.Array; import java.sql.Blob; import java.sql.Clob; @@ -31,24 +32,28 @@ import java.time.LocalDateTime; import java.time.OffsetDateTime; import java.util.ArrayList; -import java.util.Arrays; import java.util.Calendar; import java.util.List; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Level; import java.util.logging.Logger; public class DuckDBPreparedStatement implements PreparedStatement { - private static Logger logger = Logger.getLogger(DuckDBPreparedStatement.class.getName()); + private static final Logger logger = Logger.getLogger(DuckDBPreparedStatement.class.getName()); private DuckDBConnection conn; - private ByteBuffer stmt_ref = null; - private DuckDBResultSet select_result = null; - private int update_result = 0; + private ByteBuffer stmtRef = null; + final Lock stmtRefLock = new ReentrantLock(); + volatile boolean closeOnCompletion = false; + + private DuckDBResultSet selectResult = null; + private int updateResult = 0; + private boolean returnsChangedRows = false; private boolean returnsNothing = false; private boolean returnsResultSet = false; - boolean closeOnCompletion = false; private Object[] params = new Object[0]; private DuckDBResultSetMetaData meta = null; private final List batchedParams = new ArrayList<>(); @@ -76,49 +81,63 @@ public DuckDBPreparedStatement(DuckDBConnection conn, String sql) throws SQLExce } private void startTransaction() throws SQLException { - if (this.conn.autoCommit || this.conn.transactionRunning) { - return; - } - - this.conn.transactionRunning = true; - - // Start transaction via Statement - try (Statement s = conn.createStatement()) { - s.execute("BEGIN TRANSACTION;"); + checkOpen(); + try { + if (this.conn.autoCommit || this.conn.transactionRunning) { + return; + } + this.conn.transactionRunning = true; + // Start transaction via Statement + try (Statement s = conn.createStatement()) { + s.execute("BEGIN TRANSACTION;"); + } + } catch (NullPointerException e) { + throw new SQLException(e); } } private void prepare(String sql) throws SQLException { - if (isClosed()) { - throw new SQLException("Statement was closed"); - } + checkOpen(); if (sql == null) { throw new SQLException("sql query parameter cannot be null"); } - // In case the statement is reused, release old one first - if (stmt_ref != null) { - DuckDBNative.duckdb_jdbc_release(stmt_ref); - stmt_ref = null; - } - - meta = null; - params = null; + stmtRefLock.lock(); + try { + checkOpen(); - if (select_result != null) { - select_result.close(); - } - select_result = null; - update_result = 0; + // In case the statement is reused, release old one first + if (stmtRef != null) { + DuckDBNative.duckdb_jdbc_release(stmtRef); + stmtRef = null; + } - try { - stmt_ref = DuckDBNative.duckdb_jdbc_prepare(conn.conn_ref, sql.getBytes(StandardCharsets.UTF_8)); - meta = DuckDBNative.duckdb_jdbc_prepared_statement_meta(stmt_ref); + meta = null; params = new Object[0]; + + if (selectResult != null) { + selectResult.close(); + } + selectResult = null; + updateResult = 0; + + // Lock connection while still holding statement lock + conn.connRefLock.lock(); + try { + conn.checkOpen(); + stmtRef = DuckDBNative.duckdb_jdbc_prepare(conn.connRef, sql.getBytes(UTF_8)); + // Track prepared statement inside the parent connection + conn.preparedStatements.add(this); + } finally { + conn.connRefLock.unlock(); + } + + meta = DuckDBNative.duckdb_jdbc_prepared_statement_meta(stmtRef); } catch (SQLException e) { - // Delete stmt_ref as it might already be allocated close(); - throw new SQLException(e); + throw e; + } finally { + stmtRefLock.unlock(); } } @@ -128,47 +147,52 @@ public boolean execute() throws SQLException { } private boolean execute(boolean startTransaction) throws SQLException { - if (isClosed()) { - throw new SQLException("Statement was closed"); - } - if (stmt_ref == null) { - throw new SQLException("Prepare something first"); - } + checkOpen(); + checkPrepared(); - ByteBuffer result_ref = null; - if (select_result != null) { - select_result.close(); - } - select_result = null; + ByteBuffer resultRef = null; + stmtRefLock.lock(); try { + checkOpen(); + checkPrepared(); + + if (selectResult != null) { + selectResult.close(); + } + selectResult = null; + if (startTransaction) { startTransaction(); } - result_ref = DuckDBNative.duckdb_jdbc_execute(stmt_ref, params); - DuckDBResultSetMetaData result_meta = DuckDBNative.duckdb_jdbc_query_result_meta(result_ref); - select_result = new DuckDBResultSet(this, result_meta, result_ref, conn.conn_ref); - returnsResultSet = result_meta.return_type.equals(StatementReturnType.QUERY_RESULT); - returnsChangedRows = result_meta.return_type.equals(StatementReturnType.CHANGED_ROWS); - returnsNothing = result_meta.return_type.equals(StatementReturnType.NOTHING); + + resultRef = DuckDBNative.duckdb_jdbc_execute(stmtRef, params); + DuckDBResultSetMetaData resultMeta = DuckDBNative.duckdb_jdbc_query_result_meta(resultRef); + selectResult = new DuckDBResultSet(conn, this, resultMeta, resultRef); + returnsResultSet = resultMeta.return_type.equals(QUERY_RESULT); + returnsChangedRows = resultMeta.return_type.equals(CHANGED_ROWS); + returnsNothing = resultMeta.return_type.equals(NOTHING); + } catch (SQLException e) { - // Delete stmt_ref as it cannot be used anymore and - // result_ref as it might be allocated - if (select_result != null) { - select_result.close(); - } else if (result_ref != null) { - DuckDBNative.duckdb_jdbc_free_result(result_ref); - result_ref = null; + // Delete result set that might have been allocated + if (selectResult != null) { + selectResult.close(); + } else if (resultRef != null) { + DuckDBNative.duckdb_jdbc_free_result(resultRef); + resultRef = null; } close(); throw e; + + } finally { + stmtRefLock.unlock(); } if (returnsChangedRows) { - if (select_result.next()) { - update_result = select_result.getInt(1); + if (selectResult.next()) { + updateResult = selectResult.getInt(1); } - select_result.close(); + selectResult.close(); } return returnsResultSet; @@ -181,7 +205,7 @@ public ResultSet executeQuery() throws SQLException { if (!returnsResultSet) { throw new SQLException("executeQuery() can only be used with queries that return a ResultSet"); } - return select_result; + return selectResult; } @Override @@ -216,33 +240,27 @@ public int executeUpdate(String sql) throws SQLException { @Override public ResultSetMetaData getMetaData() throws SQLException { - if (isClosed()) { - throw new SQLException("Statement was closed"); - } - if (meta == null) { - throw new SQLException("Prepare something first"); - } + checkOpen(); + checkPrepared(); return meta; } @Override public ParameterMetaData getParameterMetaData() throws SQLException { - if (isClosed()) { - throw new SQLException("Statement was closed"); - } - if (stmt_ref == null) { - throw new SQLException("Prepare something first"); - } + checkOpen(); + checkPrepared(); return meta.param_meta; } @Override public void setObject(int parameterIndex, Object x) throws SQLException { - if (parameterIndex < 1 || parameterIndex > getParameterMetaData().getParameterCount()) { + checkOpen(); + int paramsCount = getParameterMetaData().getParameterCount(); + if (parameterIndex < 1 || parameterIndex > paramsCount) { throw new SQLException("Parameter index out of bounds"); } if (params.length == 0) { - params = new Object[getParameterMetaData().getParameterCount()]; + params = new Object[paramsCount]; } params[parameterIndex - 1] = x; } @@ -294,20 +312,53 @@ public void setString(int parameterIndex, String x) throws SQLException { @Override public void clearParameters() throws SQLException { + checkOpen(); params = new Object[0]; } @Override public void close() throws SQLException { - if (select_result != null) { - select_result.close(); - select_result = null; + if (isClosed()) { + return; } - if (stmt_ref != null) { - DuckDBNative.duckdb_jdbc_release(stmt_ref); - stmt_ref = null; + stmtRefLock.lock(); + try { + if (isClosed()) { + return; + } + if (selectResult != null) { + selectResult.close(); + selectResult = null; + } + if (stmtRef != null) { + // Delete prepared statement + DuckDBNative.duckdb_jdbc_release(stmtRef); + + // Untrack prepared statement from parent connection, + // if 'closing' flag is set it means that the parent connection itself + // is being closed and we don't need to call untrack from the statement. + if (!conn.closing) { + conn.connRefLock.lock(); + try { + conn.preparedStatements.remove(this); + } finally { + conn.connRefLock.unlock(); + } + } + + stmtRef = null; + } + conn = null; // we use this as a check for closed-ness + } finally { + stmtRefLock.unlock(); } - conn = null; // we use this as a check for closed-ness + } + + @Override + public boolean isClosed() throws SQLException { + // Cannot check stmtRef here because it is created only + // when prepare() is called. + return conn == null || conn.connRef == null; } protected void finalize() throws Throwable { @@ -347,14 +398,16 @@ public void setQueryTimeout(int seconds) throws SQLException { logger.log(Level.FINE, "setQueryTimeout not supported"); } - /** - * This function calls the underlying C++ interrupt function which aborts the query running on that connection. - * It is not safe to call this function when the connection is already closed. - */ @Override - public synchronized void cancel() throws SQLException { - if (conn.conn_ref != null) { - DuckDBNative.duckdb_jdbc_interrupt(conn.conn_ref); + public void cancel() throws SQLException { + try { + // Cancel is intended to be called concurrently with execute, + // thus we cannot take the statement lock that is held while + // query is running. NPE may be thrown if connection is closed + // concurrently. + conn.interrupt(); + } catch (NullPointerException e) { + throw new SQLException(e); } } @@ -380,7 +433,7 @@ public ResultSet getResultSet() throws SQLException { if (isClosed()) { throw new SQLException("Statement was closed"); } - if (stmt_ref == null) { + if (stmtRef == null) { throw new SQLException("Prepare something first"); } @@ -389,8 +442,8 @@ public ResultSet getResultSet() throws SQLException { } // getResultSet can only be called once per result - ResultSet to_return = select_result; - this.select_result = null; + ResultSet to_return = selectResult; + this.selectResult = null; return to_return; } @@ -398,21 +451,21 @@ private Integer getUpdateCountInternal() throws SQLException { if (isClosed()) { throw new SQLException("Statement was closed"); } - if (stmt_ref == null) { + if (stmtRef == null) { throw new SQLException("Prepare something first"); } - if (returnsResultSet || returnsNothing || select_result.isFinished()) { + if (returnsResultSet || returnsNothing || selectResult.isFinished()) { return -1; } - return update_result; + return updateResult; } @Override public int getUpdateCount() throws SQLException { // getUpdateCount can only be called once per result int to_return = getUpdateCountInternal(); - update_result = -1; + updateResult = -1; return to_return; } @@ -557,11 +610,6 @@ public int getResultSetHoldability() throws SQLException { throw new SQLFeatureNotSupportedException("getResultSetHoldability"); } - @Override - public boolean isClosed() throws SQLException { - return conn == null; - } - @Override public void setPoolable(boolean poolable) throws SQLException { throw new SQLFeatureNotSupportedException("setPoolable"); @@ -574,15 +622,15 @@ public boolean isPoolable() throws SQLException { @Override public void closeOnCompletion() throws SQLException { - if (isClosed()) - throw new SQLException("Statement is closed"); + checkOpen(); + ; closeOnCompletion = true; } @Override public boolean isCloseOnCompletion() throws SQLException { - if (isClosed()) - throw new SQLException("Statement is closed"); + checkOpen(); + ; return closeOnCompletion; } @@ -633,6 +681,8 @@ public void setBinaryStream(int parameterIndex, InputStream x, int length) throw @Override public void setObject(int parameterIndex, Object x, int targetSqlType) throws SQLException { + checkOpen(); + if (x == null) { setNull(parameterIndex, targetSqlType); return; @@ -769,6 +819,7 @@ public void setObject(int parameterIndex, Object x, int targetSqlType) throws SQ @Override public void addBatch() throws SQLException { + checkOpen(); batchedParams.add(params); clearParameters(); this.isBatch = true; @@ -939,4 +990,16 @@ private void requireNonPreparedStatement() throws SQLException { throw new SQLException("Cannot add batched SQL statement to PreparedStatement"); } } + + private void checkOpen() throws SQLException { + if (isClosed()) { + throw new SQLException("Statement was closed"); + } + } + + private void checkPrepared() throws SQLException { + if (stmtRef == null) { + throw new SQLException("Prepare something first"); + } + } } diff --git a/src/main/java/org/duckdb/DuckDBResultSet.java b/src/main/java/org/duckdb/DuckDBResultSet.java index 2e77318a7..dc8895889 100644 --- a/src/main/java/org/duckdb/DuckDBResultSet.java +++ b/src/main/java/org/duckdb/DuckDBResultSet.java @@ -32,73 +32,84 @@ import java.time.OffsetDateTime; import java.time.OffsetTime; import java.util.*; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; public class DuckDBResultSet implements ResultSet { + private final DuckDBConnection conn; private final DuckDBPreparedStatement stmt; private final DuckDBResultSetMetaData meta; /** * {@code null} if this result set is closed. */ - private ByteBuffer result_ref; - private DuckDBVector[] current_chunk = {}; - private int chunk_idx = 0; + private ByteBuffer resultRef; + private final Lock resultRefLock = new ReentrantLock(); + + private DuckDBVector[] currentChunk = {}; + private int chunkIdx = 0; private boolean finished = false; - private boolean was_null; - private final ByteBuffer conn_ref; + private boolean wasNull; - public DuckDBResultSet(DuckDBPreparedStatement stmt, DuckDBResultSetMetaData meta, ByteBuffer result_ref, - ByteBuffer conn_ref) throws SQLException { - this.stmt = Objects.requireNonNull(stmt); - this.result_ref = Objects.requireNonNull(result_ref); - this.meta = Objects.requireNonNull(meta); - this.conn_ref = Objects.requireNonNull(conn_ref); + public DuckDBResultSet(DuckDBConnection conn, DuckDBPreparedStatement stmt, DuckDBResultSetMetaData meta, + ByteBuffer resultRef) throws SQLException { + try { + this.conn = Objects.requireNonNull(conn); + this.stmt = Objects.requireNonNull(stmt); + this.resultRef = Objects.requireNonNull(resultRef); + this.meta = Objects.requireNonNull(meta); + } catch (NullPointerException e) { + throw new SQLException(e); + } } public Statement getStatement() throws SQLException { - if (isClosed()) { - throw new SQLException("ResultSet was closed"); - } + checkOpen(); return stmt; } public ResultSetMetaData getMetaData() throws SQLException { - if (isClosed()) { - throw new SQLException("ResultSet was closed"); - } + checkOpen(); return meta; } - public synchronized boolean next() throws SQLException { - if (isClosed()) { - throw new SQLException("ResultSet was closed"); - } + public boolean next() throws SQLException { + checkOpen(); if (finished) { return false; } - chunk_idx++; - if (current_chunk.length == 0 || chunk_idx > current_chunk[0].length) { - current_chunk = DuckDBNative.duckdb_jdbc_fetch(result_ref, conn_ref); - chunk_idx = 1; + chunkIdx++; + if (currentChunk.length == 0 || chunkIdx > currentChunk[0].length) { + currentChunk = fetchChunk(); + chunkIdx = 1; } - if (current_chunk.length == 0) { + if (currentChunk.length == 0) { finished = true; return false; } return true; } - public synchronized void close() throws SQLException { - if (result_ref != null) { - DuckDBNative.duckdb_jdbc_free_result(result_ref); + public void close() throws SQLException { + if (isClosed()) { + return; + } + resultRefLock.lock(); + try { + if (isClosed()) { + return; + } + DuckDBNative.duckdb_jdbc_free_result(resultRef); // Nullness is used to determine whether we're closed - result_ref = null; + resultRef = null; + } finally { + resultRefLock.unlock(); + } - // isCloseOnCompletion() throws if already closed, and we can't check for isClosed() because it could change - // between when we check and call isCloseOnCompletion, so access the field directly. - if (stmt.closeOnCompletion) { - stmt.close(); - } + // isCloseOnCompletion() throws if already closed, and we can't check for isClosed() because it could change + // between when we check and call isCloseOnCompletion, so access the field directly. + if (stmt.closeOnCompletion) { + stmt.close(); } } @@ -106,14 +117,12 @@ protected void finalize() throws Throwable { close(); } - public synchronized boolean isClosed() throws SQLException { - return result_ref == null; + public boolean isClosed() throws SQLException { + return resultRef == null; } private void check(int columnIndex) throws SQLException { - if (isClosed()) { - throw new SQLException("ResultSet was closed"); - } + checkOpen(); if (columnIndex < 1 || columnIndex > meta.column_count) { throw new SQLException("Column index out of bounds"); } @@ -128,16 +137,14 @@ private void check(int columnIndex) throws SQLException { */ public synchronized Object arrowExportStream(Object arrow_buffer_allocator, long arrow_batch_size) throws SQLException { - if (isClosed()) { - throw new SQLException("Result set is closed"); - } + checkOpen(); try { Class buffer_allocator_class = Class.forName("org.apache.arrow.memory.BufferAllocator"); if (!buffer_allocator_class.isInstance(arrow_buffer_allocator)) { throw new RuntimeException("Need to pass an Arrow BufferAllocator"); } - Long stream_pointer = DuckDBNative.duckdb_jdbc_arrow_stream(result_ref, arrow_batch_size); + Long stream_pointer = DuckDBNative.duckdb_jdbc_arrow_stream(resultRef, arrow_batch_size); Class arrow_array_stream_class = Class.forName("org.apache.arrow.c.ArrowArrayStream"); Object arrow_array_stream = arrow_array_stream_class.getMethod("wrap", long.class).invoke(null, stream_pointer); @@ -153,38 +160,38 @@ public synchronized Object arrowExportStream(Object arrow_buffer_allocator, long } public Object getObject(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return null; } - return current_chunk[columnIndex - 1].getObject(chunk_idx - 1); + return currentChunk[columnIndex - 1].getObject(chunkIdx - 1); } public Struct getStruct(int columnIndex) throws SQLException { - return check_and_null(columnIndex) ? null : current_chunk[columnIndex - 1].getStruct(chunk_idx - 1); + return checkAndNull(columnIndex) ? null : currentChunk[columnIndex - 1].getStruct(chunkIdx - 1); } public OffsetTime getOffsetTime(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return null; } - return current_chunk[columnIndex - 1].getOffsetTime(chunk_idx - 1); + return currentChunk[columnIndex - 1].getOffsetTime(chunkIdx - 1); } public boolean wasNull() throws SQLException { if (isClosed()) { throw new SQLException("ResultSet was closed"); } - return was_null; + return wasNull; } - private boolean check_and_null(int columnIndex) throws SQLException { + private boolean checkAndNull(int columnIndex) throws SQLException { check(columnIndex); try { - was_null = current_chunk[columnIndex - 1].check_and_null(chunk_idx - 1); + wasNull = currentChunk[columnIndex - 1].check_and_null(chunkIdx - 1); } catch (ArrayIndexOutOfBoundsException e) { throw new SQLException("No row in context", e); } - return was_null; + return wasNull; } public JsonNode getJsonObject(int columnIndex) throws SQLException { @@ -193,14 +200,14 @@ public JsonNode getJsonObject(int columnIndex) throws SQLException { } public String getLazyString(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return null; } - return current_chunk[columnIndex - 1].getLazyString(chunk_idx - 1); + return currentChunk[columnIndex - 1].getLazyString(chunkIdx - 1); } public String getString(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return null; } @@ -213,100 +220,98 @@ public String getString(int columnIndex) throws SQLException { } public boolean getBoolean(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return false; } - return current_chunk[columnIndex - 1].getBoolean(chunk_idx - 1); + return currentChunk[columnIndex - 1].getBoolean(chunkIdx - 1); } public byte getByte(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return 0; } - return current_chunk[columnIndex - 1].getByte(chunk_idx - 1); + return currentChunk[columnIndex - 1].getByte(chunkIdx - 1); } public short getShort(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return 0; } - return current_chunk[columnIndex - 1].getShort(chunk_idx - 1); + return currentChunk[columnIndex - 1].getShort(chunkIdx - 1); } public int getInt(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return 0; } - return current_chunk[columnIndex - 1].getInt(chunk_idx - 1); + return currentChunk[columnIndex - 1].getInt(chunkIdx - 1); } private short getUint8(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return 0; } - return current_chunk[columnIndex - 1].getUint8(chunk_idx - 1); + return currentChunk[columnIndex - 1].getUint8(chunkIdx - 1); } private int getUint16(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return 0; } - return current_chunk[columnIndex - 1].getUint16(chunk_idx - 1); + return currentChunk[columnIndex - 1].getUint16(chunkIdx - 1); } private long getUint32(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return 0; } - return current_chunk[columnIndex - 1].getUint32(chunk_idx - 1); + return currentChunk[columnIndex - 1].getUint32(chunkIdx - 1); } private BigInteger getUint64(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return BigInteger.ZERO; } - return current_chunk[columnIndex - 1].getUint64(chunk_idx - 1); + return currentChunk[columnIndex - 1].getUint64(chunkIdx - 1); } public long getLong(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return 0; } - return current_chunk[columnIndex - 1].getLong(chunk_idx - 1); + return currentChunk[columnIndex - 1].getLong(chunkIdx - 1); } public BigInteger getHugeint(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return BigInteger.ZERO; } - return current_chunk[columnIndex - 1].getHugeint(chunk_idx - 1); + return currentChunk[columnIndex - 1].getHugeint(chunkIdx - 1); } public BigInteger getUhugeint(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return BigInteger.ZERO; } - return current_chunk[columnIndex - 1].getUhugeint(chunk_idx - 1); + return currentChunk[columnIndex - 1].getUhugeint(chunkIdx - 1); } public float getFloat(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return Float.NaN; } - return current_chunk[columnIndex - 1].getFloat(chunk_idx - 1); + return currentChunk[columnIndex - 1].getFloat(chunkIdx - 1); } public double getDouble(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return Double.NaN; } - return current_chunk[columnIndex - 1].getDouble(chunk_idx - 1); + return currentChunk[columnIndex - 1].getDouble(chunkIdx - 1); } public int findColumn(String columnLabel) throws SQLException { - if (isClosed()) { - throw new SQLException("ResultSet was closed"); - } + checkOpen(); for (int col_idx = 0; col_idx < meta.column_count; col_idx++) { if (meta.column_names[col_idx].equalsIgnoreCase(columnLabel)) { return col_idx + 1; @@ -356,15 +361,15 @@ public BigDecimal getBigDecimal(int columnIndex, int scale) throws SQLException } public byte[] getBytes(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return null; } - return current_chunk[columnIndex - 1].getBytes(chunk_idx - 1); + return currentChunk[columnIndex - 1].getBytes(chunkIdx - 1); } public Date getDate(int columnIndex) throws SQLException { - return check_and_null(columnIndex) ? null : current_chunk[columnIndex - 1].getDate(chunk_idx - 1); + return checkAndNull(columnIndex) ? null : currentChunk[columnIndex - 1].getDate(chunkIdx - 1); } public Time getTime(int columnIndex) throws SQLException { @@ -372,31 +377,31 @@ public Time getTime(int columnIndex) throws SQLException { } public Timestamp getTimestamp(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return null; } - return current_chunk[columnIndex - 1].getTimestamp(chunk_idx - 1); + return currentChunk[columnIndex - 1].getTimestamp(chunkIdx - 1); } private LocalDateTime getLocalDateTime(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return null; } - return current_chunk[columnIndex - 1].getLocalDateTime(chunk_idx - 1); + return currentChunk[columnIndex - 1].getLocalDateTime(chunkIdx - 1); } private OffsetDateTime getOffsetDateTime(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return null; } - return current_chunk[columnIndex - 1].getOffsetDateTime(chunk_idx - 1); + return currentChunk[columnIndex - 1].getOffsetDateTime(chunkIdx - 1); } public UUID getUuid(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return null; } - return current_chunk[columnIndex - 1].getUuid(chunk_idx - 1); + return currentChunk[columnIndex - 1].getUuid(chunkIdx - 1); } public static class DuckDBBlobResult implements Blob { @@ -505,14 +510,14 @@ public int hashCode() { return Objects.hash(buffer); } - private ByteBuffer buffer; + private final ByteBuffer buffer; } public Blob getBlob(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return null; } - return current_chunk[columnIndex - 1].getBlob(chunk_idx - 1); + return currentChunk[columnIndex - 1].getBlob(chunkIdx - 1); } public Blob getBlob(String columnLabel) throws SQLException { @@ -584,10 +589,10 @@ public Reader getCharacterStream(String columnLabel) throws SQLException { } public BigDecimal getBigDecimal(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return null; } - return current_chunk[columnIndex - 1].getBigDecimal(chunk_idx - 1); + return currentChunk[columnIndex - 1].getBigDecimal(chunkIdx - 1); } public BigDecimal getBigDecimal(String columnLabel) throws SQLException { @@ -643,31 +648,36 @@ public boolean previous() throws SQLException { } public void setFetchDirection(int direction) throws SQLException { + checkOpen(); if (direction != ResultSet.FETCH_FORWARD && direction != ResultSet.FETCH_UNKNOWN) { throw new SQLFeatureNotSupportedException("setFetchDirection"); } } public int getFetchDirection() throws SQLException { + checkOpen(); return ResultSet.FETCH_FORWARD; } public void setFetchSize(int rows) throws SQLException { + checkOpen(); if (rows < 0) { throw new SQLException("Fetch size has to be >= 0"); } - // whatevs } public int getFetchSize() throws SQLException { + checkOpen(); return DuckDBNative.duckdb_jdbc_fetch_size(); } public int getType() throws SQLException { + checkOpen(); return ResultSet.TYPE_FORWARD_ONLY; } public int getConcurrency() throws SQLException { + checkOpen(); return ResultSet.CONCUR_READ_ONLY; } @@ -876,10 +886,10 @@ public Clob getClob(int columnIndex) throws SQLException { } public Array getArray(int columnIndex) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return null; } - return current_chunk[columnIndex - 1].getArray(chunk_idx - 1); + return currentChunk[columnIndex - 1].getArray(chunkIdx - 1); } public Object getObject(String columnLabel, Map> map) throws SQLException { @@ -907,10 +917,10 @@ public Date getDate(String columnLabel, Calendar cal) throws SQLException { } public Time getTime(int columnIndex, Calendar cal) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return null; } - return current_chunk[columnIndex - 1].getTime(chunk_idx - 1, cal); + return currentChunk[columnIndex - 1].getTime(chunkIdx - 1, cal); } public Time getTime(String columnLabel, Calendar cal) throws SQLException { @@ -918,10 +928,10 @@ public Time getTime(String columnLabel, Calendar cal) throws SQLException { } public Timestamp getTimestamp(int columnIndex, Calendar cal) throws SQLException { - if (check_and_null(columnIndex)) { + if (checkAndNull(columnIndex)) { return null; } - return current_chunk[columnIndex - 1].getTimestamp(chunk_idx - 1, cal); + return currentChunk[columnIndex - 1].getTimestamp(chunkIdx - 1, cal); } public Timestamp getTimestamp(String columnLabel, Calendar cal) throws SQLException { @@ -1161,6 +1171,8 @@ private boolean isTimestamp(DuckDBColumnType sqlType) { } public T getObject(int columnIndex, Class type) throws SQLException { + checkOpen(); + if (type == null) { throw new SQLException("type is null"); } @@ -1320,4 +1332,30 @@ public boolean isWrapperFor(Class iface) { boolean isFinished() { return finished; } + + private void checkOpen() throws SQLException { + if (isClosed()) { + throw new SQLException("ResultSet was closed"); + } + } + + private DuckDBVector[] fetchChunk() throws SQLException { + // Take both result set and connection locks for fetching + resultRefLock.lock(); + try { + checkOpen(); + conn.connRefLock.lock(); + try { + conn.checkOpen(); + return DuckDBNative.duckdb_jdbc_fetch(resultRef, conn.connRef); + } finally { + conn.connRefLock.unlock(); + } + } catch (SQLException e) { + close(); + throw e; + } finally { + resultRefLock.unlock(); + } + } } diff --git a/src/test/java/org/duckdb/TestClosure.java b/src/test/java/org/duckdb/TestClosure.java new file mode 100644 index 000000000..1baaf52be --- /dev/null +++ b/src/test/java/org/duckdb/TestClosure.java @@ -0,0 +1,223 @@ +package org.duckdb; + +import static org.duckdb.TestDuckDBJDBC.JDBC_URL; +import static org.duckdb.test.Assertions.*; + +import java.io.File; +import java.sql.*; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +public class TestClosure { + + // https://github.com/duckdb/duckdb-java/issues/101 + public static void test_unclosed_statement_does_not_hang() throws Exception { + String dbName = "test_issue_101.db"; + String url = JDBC_URL + dbName; + Connection conn = DriverManager.getConnection(url); + Statement stmt = conn.createStatement(); + stmt.execute("select 42"); + // statement not closed explicitly + conn.close(); + assertTrue(stmt.isClosed()); + Connection connOther = DriverManager.getConnection(url); + connOther.close(); + assertTrue(new File(dbName).delete()); + } + + public static void test_result_set_auto_closed() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL)) { + Statement stmt = conn.createStatement(); + ResultSet rs1 = stmt.executeQuery("select 42"); + ResultSet rs2 = stmt.executeQuery("select 43"); + assertTrue(rs1.isClosed()); + stmt.close(); + assertTrue(rs2.isClosed()); + } + } + + public static void test_statements_auto_closed_on_conn_close() throws Exception { + Connection conn = DriverManager.getConnection(JDBC_URL); + Statement stmt1 = conn.createStatement(); + stmt1.execute("select 42"); + PreparedStatement stmt2 = conn.prepareStatement("select 43"); + stmt2.execute(); + Statement stmt3 = conn.createStatement(); + stmt3.execute("select 44"); + stmt3.close(); + conn.close(); + assertTrue(stmt1.isClosed()); + assertTrue(stmt2.isClosed()); + } + + public static void test_results_auto_closed_on_conn_close() throws Exception { + Connection conn = DriverManager.getConnection(JDBC_URL); + Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("select 42"); + rs.next(); + conn.close(); + assertTrue(rs.isClosed()); + assertTrue(stmt.isClosed()); + } + + public static void test_statement_auto_closed_on_completion() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL)) { + Statement stmt = conn.createStatement(); + stmt.closeOnCompletion(); + assertTrue(stmt.isCloseOnCompletion()); + try (ResultSet rs = stmt.executeQuery("select 42")) { + rs.next(); + } + assertTrue(stmt.isClosed()); + } + } + + public static void test_long_query_conn_close() throws Exception { + Connection conn = DriverManager.getConnection(JDBC_URL); + Statement stmt = conn.createStatement(); + stmt.execute("DROP TABLE IF EXISTS test_fib1"); + stmt.execute("CREATE TABLE test_fib1(i bigint, p double, f double)"); + stmt.execute("INSERT INTO test_fib1 values(1, 0, 1)"); + long start = System.currentTimeMillis(); + Thread th = new Thread(() -> { + try { + Thread.sleep(1000); + conn.close(); + } catch (Exception e) { + e.printStackTrace(); + } + }); + th.start(); + assertThrows( + () + -> stmt.executeQuery( + "WITH RECURSIVE cte AS (" + + + "SELECT * from test_fib1 UNION ALL SELECT cte.i + 1, cte.f, cte.p + cte.f from cte WHERE cte.i < 150000) " + + "SELECT avg(f) FROM cte"), + SQLException.class); + th.join(); + long elapsed = System.currentTimeMillis() - start; + assertTrue(elapsed < 2000); + assertTrue(stmt.isClosed()); + assertTrue(conn.isClosed()); + } + + public static void test_long_query_stmt_close() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL)) { + Statement stmt = conn.createStatement(); + stmt.execute("DROP TABLE IF EXISTS test_fib1"); + stmt.execute("CREATE TABLE test_fib1(i bigint, p double, f double)"); + stmt.execute("INSERT INTO test_fib1 values(1, 0, 1)"); + long start = System.currentTimeMillis(); + Thread th = new Thread(() -> { + try { + Thread.sleep(1000); + stmt.cancel(); + stmt.close(); + } catch (Exception e) { + e.printStackTrace(); + } + }); + th.start(); + assertThrows( + () + -> stmt.executeQuery( + "WITH RECURSIVE cte AS (" + + + "SELECT * from test_fib1 UNION ALL SELECT cte.i + 1, cte.f, cte.p + cte.f from cte WHERE cte.i < 150000) " + + "SELECT avg(f) FROM cte"), + SQLException.class); + th.join(); + long elapsed = System.currentTimeMillis() - start; + assertTrue(elapsed < 2000); + assertTrue(stmt.isClosed()); + assertFalse(conn.isClosed()); + } + } + + public static void test_conn_close_no_crash() throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + for (int i = 0; i < 1 << 7; i++) { + Connection conn = DriverManager.getConnection(JDBC_URL); + Statement stmt = conn.createStatement(); + Future future = executor.submit(() -> { + try { + conn.close(); + } catch (SQLException e) { + fail(); + } + }); + try { + stmt.executeQuery("select 42"); + } catch (SQLException e) { + } + future.get(); + } + } + + public static void test_stmt_close_no_crash() throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + try (Connection conn = DriverManager.getConnection(JDBC_URL)) { + for (int i = 0; i < 1 << 10; i++) { + Statement stmt = conn.createStatement(); + Future future = executor.submit(() -> { + try { + stmt.close(); + } catch (SQLException e) { + fail(); + } + }); + try { + stmt.executeQuery("select 42"); + } catch (SQLException e) { + } + future.get(); + } + } + } + + public static void test_results_close_no_crash() throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { + for (int i = 0; i < 1 << 12; i++) { + ResultSet rs = stmt.executeQuery("select 42"); + Future future = executor.submit(() -> { + try { + rs.close(); + } catch (SQLException e) { + fail(); + } + }); + try { + rs.next(); + } catch (SQLException e) { + } + future.get(); + } + } + } + + public static void test_results_close_prepared_stmt_no_crash() throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + try (Connection conn = DriverManager.getConnection(JDBC_URL); + PreparedStatement stmt = conn.prepareStatement("select 42")) { + for (int i = 0; i < 1 << 12; i++) { + ResultSet rs = stmt.executeQuery(); + Future future = executor.submit(() -> { + try { + rs.close(); + } catch (SQLException e) { + fail(); + } + }); + try { + rs.next(); + } catch (SQLException e) { + } + future.get(); + } + } + } +} diff --git a/src/test/java/org/duckdb/TestDuckDBJDBC.java b/src/test/java/org/duckdb/TestDuckDBJDBC.java index 805c0bb70..4c5b0707b 100644 --- a/src/test/java/org/duckdb/TestDuckDBJDBC.java +++ b/src/test/java/org/duckdb/TestDuckDBJDBC.java @@ -3085,7 +3085,7 @@ public static void test_get_schema() throws Exception { conn.getSchema(); fail(); } catch (SQLException e) { - assertEquals(e.getMessage(), "Connection Error: Invalid connection"); + assertEquals(e.getMessage(), "Connection was closed"); } } @@ -4895,7 +4895,7 @@ public static void main(String[] args) throws Exception { } else { // extension installation fails on CI, Spatial test is temporary disabled statusCode = runTests(args, TestDuckDBJDBC.class, TestExtensionTypes.class /*, TestSpatial.class */, - TestParameterMetadata.class); + TestParameterMetadata.class, TestClosure.class); } System.exit(statusCode); } diff --git a/src/test/java/org/duckdb/test/Assertions.java b/src/test/java/org/duckdb/test/Assertions.java index d72114dc1..9a036e709 100644 --- a/src/test/java/org/duckdb/test/Assertions.java +++ b/src/test/java/org/duckdb/test/Assertions.java @@ -71,12 +71,12 @@ public static void assertEquals(double a, double b, double epsilon) throws Excep assertTrue(Math.abs(a - b) < epsilon); } - public static void fail() throws Exception { + public static void fail() { fail(null); } - public static void fail(String s) throws Exception { - throw new Exception(s); + public static void fail(String s) { + throw new RuntimeException(s); } public static String assertThrows(Thrower thrower, Class exception) throws Exception {