Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/jni/duckdb_java.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,11 @@ jobject _duckdb_jdbc_execute(JNIEnv *env, jclass, jobject stmt_ref_buf, jobjectA

res_ref->res = stmt_ref->stmt->Execute(duckdb_params, stream_results);
if (res_ref->res->HasError()) {
string error_msg = string(res_ref->res->GetError());
std::string error_msg = std::string(res_ref->res->GetError());
duckdb::ExceptionType error_type = res_ref->res->GetErrorType();
res_ref->res = nullptr;
ThrowJNI(env, error_msg.c_str());
jclass exc_type = duckdb::ExceptionType::INTERRUPT == error_type ? J_SQLTimeoutException : J_SQLException;
env->ThrowNew(exc_type, error_msg.c_str());
return nullptr;
}
return env->NewDirectByteBuffer(res_ref.release(), 0);
Expand Down
2 changes: 2 additions & 0 deletions src/jni/refs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jmethodID J_String_getBytes;
jclass J_Throwable;
jmethodID J_Throwable_getMessage;
jclass J_SQLException;
jclass J_SQLTimeoutException;

jclass J_Bool;
jclass J_Byte;
Expand Down Expand Up @@ -178,6 +179,7 @@ void create_refs(JNIEnv *env) {
J_Throwable = make_class_ref(env, "java/lang/Throwable");
J_Throwable_getMessage = get_method_id(env, J_Throwable, "getMessage", "()Ljava/lang/String;");
J_SQLException = make_class_ref(env, "java/sql/SQLException");
J_SQLTimeoutException = make_class_ref(env, "java/sql/SQLTimeoutException");

J_Bool = make_class_ref(env, "java/lang/Boolean");
J_Byte = make_class_ref(env, "java/lang/Byte");
Expand Down
1 change: 1 addition & 0 deletions src/jni/refs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ extern jmethodID J_String_getBytes;
extern jclass J_Throwable;
extern jmethodID J_Throwable_getMessage;
extern jclass J_SQLException;
extern jclass J_SQLTimeoutException;

extern jclass J_Bool;
extern jclass J_Byte;
Expand Down
9 changes: 8 additions & 1 deletion src/main/java/org/duckdb/DuckDBDriver.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import java.sql.SQLException;
import java.sql.SQLFeatureNotSupportedException;
import java.util.Properties;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.ThreadFactory;
import java.util.logging.Logger;

public class DuckDBDriver implements java.sql.Driver {
Expand All @@ -14,11 +16,16 @@ public class DuckDBDriver implements java.sql.Driver {
public static final String DUCKDB_USER_AGENT_PROPERTY = "custom_user_agent";
public static final String JDBC_STREAM_RESULTS = "jdbc_stream_results";

static final ScheduledThreadPoolExecutor scheduler;

static {
try {
DriverManager.registerDriver(new DuckDBDriver());
ThreadFactory tf = r -> new Thread(r, "duckdb-query-cancel-scheduler-thread");
scheduler = new ScheduledThreadPoolExecutor(1, tf);
scheduler.setRemoveOnCancelPolicy(true);
} catch (SQLException e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}

Expand Down
39 changes: 38 additions & 1 deletion src/main/java/org/duckdb/DuckDBPreparedStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static java.nio.charset.StandardCharsets.US_ASCII;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.duckdb.StatementReturnType.*;
import static org.duckdb.io.IOUtils.*;

Expand Down Expand Up @@ -37,6 +38,7 @@
import java.util.ArrayList;
import java.util.Calendar;
import java.util.List;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

Expand All @@ -59,6 +61,8 @@ public class DuckDBPreparedStatement implements PreparedStatement {
private final List<String> batchedStatements = new ArrayList<>();
private Boolean isBatch = false;
private Boolean isPreparedStatement = false;
private int queryTimeoutSeconds = 0;
private ScheduledFuture<?> cancelQueryFuture = null;

public DuckDBPreparedStatement(DuckDBConnection conn) throws SQLException {
if (conn == null) {
Expand Down Expand Up @@ -180,7 +184,14 @@ private boolean execute(boolean startTransaction) throws SQLException {
startTransaction();
}

if (queryTimeoutSeconds > 0) {
cleanupCancelQueryTask();
cancelQueryFuture =
DuckDBDriver.scheduler.schedule(new CancelQueryTask(), queryTimeoutSeconds, SECONDS);
}

resultRef = DuckDBNative.duckdb_jdbc_execute(stmtRef, params);
cleanupCancelQueryTask();
DuckDBResultSetMetaData resultMeta = DuckDBNative.duckdb_jdbc_query_result_meta(resultRef);
selectResult = new DuckDBResultSet(conn, this, resultMeta, resultRef);
returnsResultSet = resultMeta.return_type.equals(QUERY_RESULT);
Expand Down Expand Up @@ -356,6 +367,7 @@ public void close() throws SQLException {
if (isClosed()) {
return;
}
cleanupCancelQueryTask();
if (selectResult != null) {
selectResult.close();
selectResult = null;
Expand Down Expand Up @@ -436,12 +448,16 @@ public void setEscapeProcessing(boolean enable) throws SQLException {
@Override
public int getQueryTimeout() throws SQLException {
checkOpen();
return 0;
return queryTimeoutSeconds;
}

@Override
public void setQueryTimeout(int seconds) throws SQLException {
checkOpen();
if (seconds < 0) {
throw new SQLException("Invalid negative timeout value: " + seconds);
}
this.queryTimeoutSeconds = seconds;
}

@Override
Expand Down Expand Up @@ -1244,4 +1260,25 @@ private Lock getConnRefLock() throws SQLException {
throw new SQLException(e);
}
}

private void cleanupCancelQueryTask() {
if (cancelQueryFuture != null) {
cancelQueryFuture.cancel(false);
cancelQueryFuture = null;
}
}

private class CancelQueryTask implements Runnable {
@Override
public void run() {
try {
if (DuckDBPreparedStatement.this.isClosed()) {
return;
}
DuckDBPreparedStatement.this.cancel();
} catch (SQLException e) {
// suppress
}
}
}
}
31 changes: 28 additions & 3 deletions src/test/java/org/duckdb/TestClosure.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ public static void test_statement_auto_closed_on_completion() throws Exception {
public static void test_long_query_conn_close() throws Exception {
Connection conn = DriverManager.getConnection(JDBC_URL);
Statement stmt = conn.createStatement();
stmt.execute("DROP TABLE IF EXISTS test_fib1");
stmt.execute("CREATE TABLE test_fib1(i bigint, p double, f double)");
stmt.execute("INSERT INTO test_fib1 values(1, 0, 1)");
long start = System.currentTimeMillis();
Expand Down Expand Up @@ -108,7 +107,6 @@ public static void test_long_query_conn_close() throws Exception {
public static void test_long_query_stmt_close() throws Exception {
try (Connection conn = DriverManager.getConnection(JDBC_URL)) {
Statement stmt = conn.createStatement();
stmt.execute("DROP TABLE IF EXISTS test_fib1");
stmt.execute("CREATE TABLE test_fib1(i bigint, p double, f double)");
stmt.execute("INSERT INTO test_fib1 values(1, 0, 1)");
long start = System.currentTimeMillis();
Expand Down Expand Up @@ -272,7 +270,7 @@ public static void test_stmt_can_only_cancel_self() throws Exception {
ResultSet rs = stmt2.executeQuery(
"WITH RECURSIVE cte AS ("
+
"SELECT * from test_fib1 UNION ALL SELECT cte.i + 1, cte.f, cte.p + cte.f from cte WHERE cte.i < 40000) "
"SELECT * from test_fib1 UNION ALL SELECT cte.i + 1, cte.f, cte.p + cte.f from cte WHERE cte.i < 50000) "
+ "SELECT avg(f) FROM cte")) {
rs.next();
assertTrue(rs.getDouble(1) > 0);
Expand All @@ -285,4 +283,31 @@ public static void test_stmt_can_only_cancel_self() throws Exception {
assertFalse(stmt2.isClosed());
}
}

public static void test_stmt_query_timeout() throws Exception {
try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) {
stmt.setQueryTimeout(1);
stmt.execute("CREATE TABLE test_fib1(i bigint, p double, f double)");
stmt.execute("INSERT INTO test_fib1 values(1, 0, 1)");
long start = System.currentTimeMillis();
assertThrows(
()
-> stmt.executeQuery(
"WITH RECURSIVE cte AS ("
+
"SELECT * from test_fib1 UNION ALL SELECT cte.i + 1, cte.f, cte.p + cte.f from cte WHERE cte.i < 150000) "
+ "SELECT avg(f) FROM cte"),
SQLTimeoutException.class);
long elapsed = System.currentTimeMillis() - start;
assertTrue(elapsed < 1500);
assertFalse(conn.isClosed());
assertTrue(stmt.isClosed());
assertEquals(DuckDBDriver.scheduler.getQueue().size(), 0);
}
try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) {
stmt.setQueryTimeout(1);
assertThrows(() -> { stmt.execute("FAIL"); }, SQLException.class);
assertEquals(DuckDBDriver.scheduler.getQueue().size(), 0);
}
}
}
2 changes: 1 addition & 1 deletion src/test/java/org/duckdb/TestDuckDBJDBC.java
Original file line number Diff line number Diff line change
Expand Up @@ -3455,7 +3455,7 @@ public static void test_query_progress() throws Exception {
@Override
public QueryProgress call() throws Exception {
try {
Thread.sleep(1500);
Thread.sleep(2500);
QueryProgress qp = stmt.getQueryProgress();
stmt.cancel();
return qp;
Expand Down