diff --git a/CMakeLists.txt b/CMakeLists.txt index 153402d5b..dd0043b96 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -145,7 +145,13 @@ if(MSVC) set(DUCKDB_SYSTEM_LIBS ${DUCKDB_SYSTEM_LIBS} ws2_32 rstrtmgr bcrypt) endif() -add_library(duckdb_java SHARED src/jni/duckdb_java.cpp src/jni/functions.cpp ${DUCKDB_SRC_FILES}) +add_library(duckdb_java SHARED + src/jni/config.cpp + src/jni/duckdb_java.cpp + src/jni/functions.cpp + src/jni/refs.cpp + src/jni/util.cpp + ${DUCKDB_SRC_FILES}) target_compile_options(duckdb_java PRIVATE -fexceptions) target_link_libraries(duckdb_java duckdb-native ) target_link_libraries(duckdb_java ${DUCKDB_SYSTEM_LIBS}) diff --git a/CMakeLists.txt.in b/CMakeLists.txt.in index d16640de7..c767d00b1 100644 --- a/CMakeLists.txt.in +++ b/CMakeLists.txt.in @@ -145,7 +145,13 @@ if(MSVC) set(DUCKDB_SYSTEM_LIBS ${DUCKDB_SYSTEM_LIBS} ws2_32 rstrtmgr bcrypt) endif() -add_library(duckdb_java SHARED src/jni/duckdb_java.cpp src/jni/functions.cpp ${DUCKDB_SRC_FILES}) +add_library(duckdb_java SHARED + src/jni/config.cpp + src/jni/duckdb_java.cpp + src/jni/functions.cpp + src/jni/refs.cpp + src/jni/util.cpp + ${DUCKDB_SRC_FILES}) target_compile_options(duckdb_java PRIVATE -fexceptions) target_link_libraries(duckdb_java duckdb-native ${LIBRARY_FILES}) target_link_libraries(duckdb_java ${DUCKDB_SYSTEM_LIBS}) diff --git a/scripts/format.py b/scripts/format.py index 91b20f684..37e399c65 100644 --- a/scripts/format.py +++ b/scripts/format.py @@ -11,7 +11,13 @@ if args.check: template += ['--dry-run', '--Werror'] -for name in ['src/jni/duckdb_java.cpp'] + glob('src/**/*.java', recursive=True): +hpp_files = set(glob('src/jni/*.hpp')) +hpp_files.remove('src/jni/functions.hpp') +cpp_files = set(glob('src/jni/*.cpp')) +cpp_files.remove('src/jni/functions.cpp') +java_files = set(glob('src/**/*.java', recursive=True)) + +for name in [*hpp_files] + [*cpp_files] + [*java_files]: print('Formatting', name) check_call(template + [name]) diff --git a/src/jni/config.cpp b/src/jni/config.cpp new file mode 100644 index 000000000..215aeef5e --- /dev/null +++ b/src/jni/config.cpp @@ -0,0 +1,125 @@ +#include "config.hpp" + +#include "duckdb/common/virtual_file_system.hpp" +#include "refs.hpp" +#include "util.hpp" + +#include +#include + +static duckdb::Value jobj_to_value(JNIEnv *env, const std::string &key, jobject jval) { + // On the right in comments are all types that are currently present + // in DuckDB config. + if (nullptr == jval) { + return duckdb::Value(); + + } else if (env->IsInstanceOf(jval, J_Bool)) { // BOOLEAN + jboolean val = env->CallBooleanMethod(jval, J_Bool_booleanValue); + check_java_exception_and_rethrow(env); + return duckdb::Value::BOOLEAN(val); + + } else if (env->IsInstanceOf(jval, J_Byte)) { // UBIGINT + jbyte val = env->CallByteMethod(jval, J_Byte_byteValue); + check_java_exception_and_rethrow(env); + return duckdb::Value::TINYINT(val); + + } else if (env->IsInstanceOf(jval, J_Short)) { // UBIGINT + jshort val = env->CallShortMethod(jval, J_Short_shortValue); + check_java_exception_and_rethrow(env); + return duckdb::Value::SMALLINT(val); + + } else if (env->IsInstanceOf(jval, J_Int)) { // UBIGINT + jint val = env->CallIntMethod(jval, J_Int_intValue); + check_java_exception_and_rethrow(env); + return duckdb::Value::INTEGER(val); + + } else if (env->IsInstanceOf(jval, J_Long)) { // UBIGINT + jlong val = env->CallLongMethod(jval, J_Long_longValue); + check_java_exception_and_rethrow(env); + return duckdb::Value::BIGINT(val); + + } else if (env->IsInstanceOf(jval, J_Float)) { // FLOAT + jfloat val = env->CallFloatMethod(jval, J_Float_floatValue); + check_java_exception_and_rethrow(env); + return duckdb::Value::FLOAT(val); + + } else if (env->IsInstanceOf(jval, J_Double)) { // DOUBLE + jdouble val = env->CallDoubleMethod(jval, J_Double_doubleValue); + check_java_exception_and_rethrow(env); + return duckdb::Value::DOUBLE(val); + + } else if (env->IsInstanceOf(jval, J_String)) { // VARCHAR + std::string val = jstring_to_string(env, reinterpret_cast(jval)); + return duckdb::Value(val); + + } else if (env->IsInstanceOf(jval, J_List)) { // VARCHAR[] + jobject iterator = env->CallObjectMethod(jval, J_List_iterator); + check_java_exception_and_rethrow(env); + + duckdb::vector vec; + while (env->CallBooleanMethod(iterator, J_Iterator_hasNext)) { + check_java_exception_and_rethrow(env); + jobject list_entry = env->CallObjectMethod(iterator, J_Iterator_next); + check_java_exception_and_rethrow(env); + // all list entries are coalesced to string + jstring jstr = reinterpret_cast(env->CallObjectMethod(list_entry, J_Object_toString)); + check_java_exception_and_rethrow(env); + std::string sval = jstring_to_string(env, jstr); + duckdb::Value val(std::move(sval)); + vec.push_back(std::move(val)); + } + return duckdb::Value::LIST(duckdb::LogicalType::VARCHAR, std::move(vec)); + + } else { + // coalesce to string the entry with an unknown type + jstring jstr = reinterpret_cast(env->CallObjectMethod(jval, J_Object_toString)); + check_java_exception_and_rethrow(env); + std::string str = jstring_to_string(env, jstr); + return duckdb::Value(str); + } +} + +std::unique_ptr create_db_config(JNIEnv *env, jboolean read_only, jobject java_config) { + auto config = std::unique_ptr(new duckdb::DBConfig()); + // Required for setting like 'allowed_directories' that use + // file separator when checking the property value. + config->file_system = duckdb::make_uniq(); + config->SetOptionByName("duckdb_api", "java"); + config->AddExtensionOption( + "jdbc_stream_results", + "Whether to stream results. Only one ResultSet on a connection can be open at once when true", + duckdb::LogicalType::BOOLEAN); + if (read_only) { + config->options.access_mode = duckdb::AccessMode::READ_ONLY; + } + jobject entry_set = env->CallObjectMethod(java_config, J_Map_entrySet); + check_java_exception_and_rethrow(env); + jobject iterator = env->CallObjectMethod(entry_set, J_Set_iterator); + check_java_exception_and_rethrow(env); + + while (env->CallBooleanMethod(iterator, J_Iterator_hasNext)) { + check_java_exception_and_rethrow(env); + jobject pair = env->CallObjectMethod(iterator, J_Iterator_next); + check_java_exception_and_rethrow(env); + jobject key = env->CallObjectMethod(pair, J_Entry_getKey); + check_java_exception_and_rethrow(env); + jobject value = env->CallObjectMethod(pair, J_Entry_getValue); + check_java_exception_and_rethrow(env); + + jstring key_jstr = reinterpret_cast(env->CallObjectMethod(key, J_Object_toString)); + check_java_exception_and_rethrow(env); + std::string key_str = jstring_to_string(env, key_jstr); + + duckdb::Value dvalue = jobj_to_value(env, key_str, value); + + try { + config->SetOptionByName(key_str, dvalue); + } catch (const std::exception &e) { + duckdb::ErrorData error(e); + throw duckdb::CatalogException("Failed to set configuration option \"%s\", error: %s", key_str, + error.RawMessage()); + } + } + + return config; +} diff --git a/src/jni/config.hpp b/src/jni/config.hpp new file mode 100644 index 000000000..8218226fe --- /dev/null +++ b/src/jni/config.hpp @@ -0,0 +1,8 @@ +#pragma once + +#include "duckdb.hpp" + +#include +#include + +std::unique_ptr create_db_config(JNIEnv *env, jboolean read_only, jobject java_config); diff --git a/src/jni/duckdb_java.cpp b/src/jni/duckdb_java.cpp index ef051abe2..91cd75a41 100644 --- a/src/jni/duckdb_java.cpp +++ b/src/jni/duckdb_java.cpp @@ -1,3 +1,4 @@ +#include "config.hpp" #include "duckdb.hpp" #include "duckdb/catalog/catalog_search_path.hpp" #include "duckdb/common/arrow/result_arrow_wrapper.hpp" @@ -12,374 +13,46 @@ #include "duckdb/main/extension_util.hpp" #include "duckdb/parser/parsed_data/create_type_info.hpp" #include "functions.hpp" +#include "refs.hpp" +#include "util.hpp" using namespace duckdb; using namespace std; static jint JNI_VERSION = JNI_VERSION_1_6; -// Static global vars of cached Java classes, methods and fields -static jclass J_Charset; -static jmethodID J_Charset_decode; -static jobject J_Charset_UTF8; - -static jclass J_CharBuffer; -static jmethodID J_CharBuffer_toString; - -static jmethodID J_String_getBytes; - -static jclass J_SQLException; - -static jclass J_Bool; -static jclass J_Byte; -static jclass J_Short; -static jclass J_Int; -static jclass J_Long; -static jclass J_Float; -static jclass J_Double; -static jclass J_String; -static jclass J_Timestamp; -static jmethodID J_Timestamp_valueOf; -static jclass J_TimestampTZ; -static jclass J_Decimal; -static jclass J_ByteArray; - -static jmethodID J_Bool_booleanValue; -static jmethodID J_Byte_byteValue; -static jmethodID J_Short_shortValue; -static jmethodID J_Int_intValue; -static jmethodID J_Long_longValue; -static jmethodID J_Float_floatValue; -static jmethodID J_Double_doubleValue; -static jmethodID J_Timestamp_getMicrosEpoch; -static jmethodID J_TimestampTZ_getMicrosEpoch; -static jmethodID J_Decimal_precision; -static jmethodID J_Decimal_scale; -static jmethodID J_Decimal_scaleByPowTen; -static jmethodID J_Decimal_toPlainString; -static jmethodID J_Decimal_longValue; - -static jclass J_DuckResultSetMeta; -static jmethodID J_DuckResultSetMeta_init; - -static jclass J_DuckVector; -static jmethodID J_DuckVector_init; -static jfieldID J_DuckVector_constlen; -static jfieldID J_DuckVector_varlen; - -static jclass J_DuckArray; -static jmethodID J_DuckArray_init; - -static jclass J_Struct; -static jmethodID J_Struct_getSQLTypeName; -static jmethodID J_Struct_getAttributes; - -static jclass J_Array; -static jmethodID J_Array_getBaseTypeName; -static jmethodID J_Array_getArray; - -static jclass J_DuckStruct; -static jmethodID J_DuckStruct_init; - -static jclass J_ByteBuffer; - -static jclass J_DuckMap; -static jmethodID J_DuckMap_getSQLTypeName; - -static jmethodID J_Map_entrySet; -static jmethodID J_Set_iterator; -static jmethodID J_Iterator_hasNext; -static jmethodID J_Iterator_next; -static jmethodID J_Entry_getKey; -static jmethodID J_Entry_getValue; - -static jclass J_UUID; -static jmethodID J_UUID_getMostSignificantBits; -static jmethodID J_UUID_getLeastSignificantBits; - -static jclass J_DuckDBDate; -static jmethodID J_DuckDBDate_getDaysSinceEpoch; - -static jclass J_Object; -static jmethodID J_Object_toString; - -static jclass J_DuckDBTime; - void ThrowJNI(JNIEnv *env, const char *message) { D_ASSERT(J_SQLException); env->ThrowNew(J_SQLException, message); } -static duckdb::vector toFree; - -static jclass GetClassRef(JNIEnv *env, const string &name) { - jclass tmpLocalRef; - tmpLocalRef = env->FindClass(name.c_str()); - D_ASSERT(tmpLocalRef); - jclass globalRef = (jclass)env->NewGlobalRef(tmpLocalRef); - D_ASSERT(globalRef); - toFree.emplace_back(globalRef); - env->DeleteLocalRef(tmpLocalRef); - return globalRef; -} - JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) { - // Get JNIEnv from vm JNIEnv *env; if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION) != JNI_OK) { return JNI_ERR; } - jclass tmpLocalRef; - - tmpLocalRef = env->FindClass("java/nio/charset/Charset"); - J_Charset = (jclass)env->NewGlobalRef(tmpLocalRef); - env->DeleteLocalRef(tmpLocalRef); - - jmethodID forName = env->GetStaticMethodID(J_Charset, "forName", "(Ljava/lang/String;)Ljava/nio/charset/Charset;"); - J_Charset_decode = env->GetMethodID(J_Charset, "decode", "(Ljava/nio/ByteBuffer;)Ljava/nio/CharBuffer;"); - jobject charset = env->CallStaticObjectMethod(J_Charset, forName, env->NewStringUTF("UTF-8")); - J_Charset_UTF8 = env->NewGlobalRef(charset); // Prevent garbage collector from cleaning this up. - - tmpLocalRef = env->FindClass("java/nio/CharBuffer"); - J_CharBuffer = (jclass)env->NewGlobalRef(tmpLocalRef); - env->DeleteLocalRef(tmpLocalRef); - - J_CharBuffer_toString = env->GetMethodID(J_CharBuffer, "toString", "()Ljava/lang/String;"); - - tmpLocalRef = env->FindClass("java/sql/SQLException"); - J_SQLException = (jclass)env->NewGlobalRef(tmpLocalRef); - env->DeleteLocalRef(tmpLocalRef); - - tmpLocalRef = env->FindClass("java/lang/Boolean"); - J_Bool = (jclass)env->NewGlobalRef(tmpLocalRef); - env->DeleteLocalRef(tmpLocalRef); - tmpLocalRef = env->FindClass("java/lang/Byte"); - J_Byte = (jclass)env->NewGlobalRef(tmpLocalRef); - env->DeleteLocalRef(tmpLocalRef); - tmpLocalRef = env->FindClass("java/lang/Short"); - J_Short = (jclass)env->NewGlobalRef(tmpLocalRef); - env->DeleteLocalRef(tmpLocalRef); - tmpLocalRef = env->FindClass("java/lang/Integer"); - J_Int = (jclass)env->NewGlobalRef(tmpLocalRef); - env->DeleteLocalRef(tmpLocalRef); - tmpLocalRef = env->FindClass("java/lang/Long"); - J_Long = (jclass)env->NewGlobalRef(tmpLocalRef); - env->DeleteLocalRef(tmpLocalRef); - tmpLocalRef = env->FindClass("java/lang/Float"); - J_Float = (jclass)env->NewGlobalRef(tmpLocalRef); - env->DeleteLocalRef(tmpLocalRef); - tmpLocalRef = env->FindClass("java/lang/Double"); - J_Double = (jclass)env->NewGlobalRef(tmpLocalRef); - env->DeleteLocalRef(tmpLocalRef); - tmpLocalRef = env->FindClass("java/lang/String"); - J_String = (jclass)env->NewGlobalRef(tmpLocalRef); - env->DeleteLocalRef(tmpLocalRef); - - tmpLocalRef = env->FindClass("org/duckdb/DuckDBTimestamp"); - J_Timestamp = (jclass)env->NewGlobalRef(tmpLocalRef); - env->DeleteLocalRef(tmpLocalRef); - J_Timestamp_valueOf = env->GetStaticMethodID(J_Timestamp, "valueOf", "(Ljava/lang/Object;)Ljava/lang/Object;"); - - tmpLocalRef = env->FindClass("org/duckdb/DuckDBTimestampTZ"); - J_TimestampTZ = (jclass)env->NewGlobalRef(tmpLocalRef); - env->DeleteLocalRef(tmpLocalRef); - - J_DuckDBDate = GetClassRef(env, "org/duckdb/DuckDBDate"); - J_DuckDBDate_getDaysSinceEpoch = env->GetMethodID(J_DuckDBDate, "getDaysSinceEpoch", "()J"); - D_ASSERT(J_DuckDBDate_getDaysSinceEpoch); - - J_DuckDBTime = GetClassRef(env, "org/duckdb/DuckDBTime"); - - tmpLocalRef = env->FindClass("java/math/BigDecimal"); - J_Decimal = (jclass)env->NewGlobalRef(tmpLocalRef); - env->DeleteLocalRef(tmpLocalRef); - tmpLocalRef = env->FindClass("[B"); - J_ByteArray = (jclass)env->NewGlobalRef(tmpLocalRef); - env->DeleteLocalRef(tmpLocalRef); - - J_DuckMap = GetClassRef(env, "org/duckdb/user/DuckDBMap"); - D_ASSERT(J_DuckMap); - J_DuckMap_getSQLTypeName = env->GetMethodID(J_DuckMap, "getSQLTypeName", "()Ljava/lang/String;"); - D_ASSERT(J_DuckMap_getSQLTypeName); - - tmpLocalRef = env->FindClass("java/util/Map"); - J_Map_entrySet = env->GetMethodID(tmpLocalRef, "entrySet", "()Ljava/util/Set;"); - env->DeleteLocalRef(tmpLocalRef); - - tmpLocalRef = env->FindClass("java/util/Set"); - J_Set_iterator = env->GetMethodID(tmpLocalRef, "iterator", "()Ljava/util/Iterator;"); - env->DeleteLocalRef(tmpLocalRef); - - tmpLocalRef = env->FindClass("java/util/Iterator"); - J_Iterator_hasNext = env->GetMethodID(tmpLocalRef, "hasNext", "()Z"); - J_Iterator_next = env->GetMethodID(tmpLocalRef, "next", "()Ljava/lang/Object;"); - env->DeleteLocalRef(tmpLocalRef); - - tmpLocalRef = env->FindClass("java/util/UUID"); - J_UUID = (jclass)env->NewGlobalRef(tmpLocalRef); - J_UUID_getMostSignificantBits = env->GetMethodID(J_UUID, "getMostSignificantBits", "()J"); - J_UUID_getLeastSignificantBits = env->GetMethodID(J_UUID, "getLeastSignificantBits", "()J"); - env->DeleteLocalRef(tmpLocalRef); - - tmpLocalRef = env->FindClass("org/duckdb/DuckDBArray"); - D_ASSERT(tmpLocalRef); - J_DuckArray = (jclass)env->NewGlobalRef(tmpLocalRef); - J_DuckArray_init = env->GetMethodID(J_DuckArray, "", "(Lorg/duckdb/DuckDBVector;II)V"); - D_ASSERT(J_DuckArray_init); - env->DeleteLocalRef(tmpLocalRef); - - J_DuckStruct = GetClassRef(env, "org/duckdb/DuckDBStruct"); - J_DuckStruct_init = - env->GetMethodID(J_DuckStruct, "", "([Ljava/lang/String;[Lorg/duckdb/DuckDBVector;ILjava/lang/String;)V"); - D_ASSERT(J_DuckStruct_init); - - J_Struct = GetClassRef(env, "java/sql/Struct"); - J_Struct_getSQLTypeName = env->GetMethodID(J_Struct, "getSQLTypeName", "()Ljava/lang/String;"); - J_Struct_getAttributes = env->GetMethodID(J_Struct, "getAttributes", "()[Ljava/lang/Object;"); - - J_Array = GetClassRef(env, "java/sql/Array"); - J_Array_getArray = env->GetMethodID(J_Array, "getArray", "()Ljava/lang/Object;"); - J_Array_getBaseTypeName = env->GetMethodID(J_Array, "getBaseTypeName", "()Ljava/lang/String;"); - - J_Object = GetClassRef(env, "java/lang/Object"); - J_Object_toString = env->GetMethodID(J_Object, "toString", "()Ljava/lang/String;"); - - tmpLocalRef = env->FindClass("java/util/Map$Entry"); - J_Entry_getKey = env->GetMethodID(tmpLocalRef, "getKey", "()Ljava/lang/Object;"); - J_Entry_getValue = env->GetMethodID(tmpLocalRef, "getValue", "()Ljava/lang/Object;"); - env->DeleteLocalRef(tmpLocalRef); - - J_Bool_booleanValue = env->GetMethodID(J_Bool, "booleanValue", "()Z"); - J_Byte_byteValue = env->GetMethodID(J_Byte, "byteValue", "()B"); - J_Short_shortValue = env->GetMethodID(J_Short, "shortValue", "()S"); - J_Int_intValue = env->GetMethodID(J_Int, "intValue", "()I"); - J_Long_longValue = env->GetMethodID(J_Long, "longValue", "()J"); - J_Float_floatValue = env->GetMethodID(J_Float, "floatValue", "()F"); - J_Double_doubleValue = env->GetMethodID(J_Double, "doubleValue", "()D"); - J_Timestamp_getMicrosEpoch = env->GetMethodID(J_Timestamp, "getMicrosEpoch", "()J"); - J_TimestampTZ_getMicrosEpoch = env->GetMethodID(J_TimestampTZ, "getMicrosEpoch", "()J"); - J_Decimal_precision = env->GetMethodID(J_Decimal, "precision", "()I"); - J_Decimal_scale = env->GetMethodID(J_Decimal, "scale", "()I"); - J_Decimal_scaleByPowTen = env->GetMethodID(J_Decimal, "scaleByPowerOfTen", "(I)Ljava/math/BigDecimal;"); - J_Decimal_toPlainString = env->GetMethodID(J_Decimal, "toPlainString", "()Ljava/lang/String;"); - J_Decimal_longValue = env->GetMethodID(J_Decimal, "longValue", "()J"); - - tmpLocalRef = env->FindClass("org/duckdb/DuckDBResultSetMetaData"); - J_DuckResultSetMeta = (jclass)env->NewGlobalRef(tmpLocalRef); - env->DeleteLocalRef(tmpLocalRef); - - J_DuckResultSetMeta_init = - env->GetMethodID(J_DuckResultSetMeta, "", - "(II[Ljava/lang/String;[Ljava/lang/String;[Ljava/lang/String;Ljava/lang/String;)V"); - - tmpLocalRef = env->FindClass("org/duckdb/DuckDBVector"); - J_DuckVector = (jclass)env->NewGlobalRef(tmpLocalRef); - env->DeleteLocalRef(tmpLocalRef); - - J_String_getBytes = env->GetMethodID(J_String, "getBytes", "(Ljava/nio/charset/Charset;)[B"); - - J_DuckVector_init = env->GetMethodID(J_DuckVector, "", "(Ljava/lang/String;I[Z)V"); - J_DuckVector_constlen = env->GetFieldID(J_DuckVector, "constlen_data", "Ljava/nio/ByteBuffer;"); - J_DuckVector_varlen = env->GetFieldID(J_DuckVector, "varlen_data", "[Ljava/lang/Object;"); - - tmpLocalRef = env->FindClass("java/nio/ByteBuffer"); - J_ByteBuffer = (jclass)env->NewGlobalRef(tmpLocalRef); - env->DeleteLocalRef(tmpLocalRef); - - tmpLocalRef = env->FindClass("java/lang/Object"); - J_Object_toString = env->GetMethodID(tmpLocalRef, "toString", "()Ljava/lang/String;"); - env->DeleteLocalRef(tmpLocalRef); + try { + create_refs(env); + } catch (const std::exception &e) { + if (!env->ExceptionCheck()) { + auto re_class = env->FindClass("java/lang/RuntimeException"); + if (nullptr != re_class) { + env->ThrowNew(re_class, e.what()); + } + } + return JNI_ERR; + } return JNI_VERSION; } JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) { - // Get JNIEnv from vm JNIEnv *env; - vm->GetEnv(reinterpret_cast(&env), JNI_VERSION); - - env->DeleteGlobalRef(J_Charset); - env->DeleteGlobalRef(J_CharBuffer); - env->DeleteGlobalRef(J_Charset_UTF8); - env->DeleteGlobalRef(J_SQLException); - env->DeleteGlobalRef(J_Bool); - env->DeleteGlobalRef(J_Byte); - env->DeleteGlobalRef(J_Short); - env->DeleteGlobalRef(J_Int); - env->DeleteGlobalRef(J_Long); - env->DeleteGlobalRef(J_Float); - env->DeleteGlobalRef(J_Double); - env->DeleteGlobalRef(J_String); - env->DeleteGlobalRef(J_Timestamp); - env->DeleteGlobalRef(J_TimestampTZ); - env->DeleteGlobalRef(J_Decimal); - env->DeleteGlobalRef(J_DuckResultSetMeta); - env->DeleteGlobalRef(J_DuckVector); - env->DeleteGlobalRef(J_ByteBuffer); - - for (auto &clazz : toFree) { - env->DeleteGlobalRef(clazz); - } -} - -static string byte_array_to_string(JNIEnv *env, jbyteArray ba_j) { - idx_t len = env->GetArrayLength(ba_j); - string ret; - ret.resize(len); - - jbyte *bytes = (jbyte *)env->GetByteArrayElements(ba_j, NULL); - - for (idx_t i = 0; i < len; i++) { - ret[i] = bytes[i]; - } - env->ReleaseByteArrayElements(ba_j, bytes, 0); - - return ret; -} - -static string jstring_to_string(JNIEnv *env, jstring string_j) { - jbyteArray bytes = (jbyteArray)env->CallObjectMethod(string_j, J_String_getBytes, J_Charset_UTF8); - return byte_array_to_string(env, bytes); -} - -static jobject decode_charbuffer_to_jstring(JNIEnv *env, const char *d_str, idx_t d_str_len) { - auto bb = env->NewDirectByteBuffer((void *)d_str, d_str_len); - auto j_cb = env->CallObjectMethod(J_Charset_UTF8, J_Charset_decode, bb); - auto j_str = env->CallObjectMethod(j_cb, J_CharBuffer_toString); - return j_str; -} - -static Value create_value_from_bigdecimal(JNIEnv *env, jobject decimal) { - jint precision = env->CallIntMethod(decimal, J_Decimal_precision); - jint scale = env->CallIntMethod(decimal, J_Decimal_scale); - - // Java BigDecimal type can have scale that exceeds the precision - // Which our DECIMAL type does not support (assert(width >= scale)) - if (scale > precision) { - precision = scale; - } - - // DECIMAL scale is unsigned, so negative values are not supported - if (scale < 0) { - throw InvalidInputException("Converting from a BigDecimal with negative scale is not supported"); - } - - Value val; - - if (precision <= 18) { // normal sizes -> avoid string processing - jobject no_point_dec = env->CallObjectMethod(decimal, J_Decimal_scaleByPowTen, scale); - jlong result = env->CallLongMethod(no_point_dec, J_Decimal_longValue); - val = Value::DECIMAL((int64_t)result, (uint8_t)precision, (uint8_t)scale); - } else if (precision <= 38) { // larger than int64 -> get string and cast - jobject str_val = env->CallObjectMethod(decimal, J_Decimal_toPlainString); - auto *str_char = env->GetStringUTFChars((jstring)str_val, 0); - val = Value(str_char); - val = val.DefaultCastAs(LogicalType::DECIMAL(precision, scale)); - env->ReleaseStringUTFChars((jstring)str_val, str_char); + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION) != JNI_OK) { + return; } - - return val; + delete_global_refs(env); } /** @@ -419,39 +92,11 @@ static Connection *get_connection(JNIEnv *env, jobject conn_ref_buf) { //! The database instance cache, used so that multiple connections to the same file point to the same database object duckdb::DBInstanceCache instance_cache; -static const char *const JDBC_STREAM_RESULTS = "jdbc_stream_results"; jobject _duckdb_jdbc_startup(JNIEnv *env, jclass, jbyteArray database_j, jboolean read_only, jobject props) { auto database = byte_array_to_string(env, database_j); - DBConfig config; - config.SetOptionByName("duckdb_api", "java"); - config.AddExtensionOption( - JDBC_STREAM_RESULTS, - "Whether to stream results. Only one ResultSet on a connection can be open at once when true", - LogicalType::BOOLEAN); - if (read_only) { - config.options.access_mode = AccessMode::READ_ONLY; - } - jobject entry_set = env->CallObjectMethod(props, J_Map_entrySet); - jobject iterator = env->CallObjectMethod(entry_set, J_Set_iterator); - - while (env->CallBooleanMethod(iterator, J_Iterator_hasNext)) { - jobject pair = env->CallObjectMethod(iterator, J_Iterator_next); - jobject key = env->CallObjectMethod(pair, J_Entry_getKey); - jobject value = env->CallObjectMethod(pair, J_Entry_getValue); - - const string &key_str = jstring_to_string(env, (jstring)env->CallObjectMethod(key, J_Object_toString)); - - const string &value_str = jstring_to_string(env, (jstring)env->CallObjectMethod(value, J_Object_toString)); - - try { - config.SetOptionByName(key_str, Value(value_str)); - } catch (const std::exception &e) { - ErrorData error(e); - throw CatalogException("Failed to set configuration option \"%s\", error: %s", key_str, error.RawMessage()); - } - } + std::unique_ptr config = create_db_config(env, read_only, props); bool cache_instance = database != ":memory:" && !database.empty(); - auto shared_db = instance_cache.GetOrCreateInstance(database, config, cache_instance); + auto shared_db = instance_cache.GetOrCreateInstance(database, *config, cache_instance); auto conn_holder = new ConnectionHolder(shared_db); return env->NewDirectByteBuffer(conn_holder, 0); @@ -719,7 +364,7 @@ jobject _duckdb_jdbc_execute(JNIEnv *env, jclass, jobject stmt_ref_buf, jobjectA Value result; bool stream_results = - stmt_ref->stmt->context->TryGetCurrentSetting(JDBC_STREAM_RESULTS, result) ? result.GetValue() : false; + stmt_ref->stmt->context->TryGetCurrentSetting("jdbc_stream_results", result) ? result.GetValue() : false; res_ref->res = stmt_ref->stmt->Execute(duckdb_params, stream_results); if (res_ref->res->HasError()) { diff --git a/src/jni/refs.cpp b/src/jni/refs.cpp new file mode 100644 index 000000000..6b3fefa1f --- /dev/null +++ b/src/jni/refs.cpp @@ -0,0 +1,265 @@ +#include "refs.hpp" + +#include +#include +#include +#include + +jclass J_Charset; +jmethodID J_Charset_decode; +jclass J_StandardCharsets; +jobject J_Charset_UTF8; + +jclass J_CharBuffer; +jmethodID J_CharBuffer_toString; + +jmethodID J_String_getBytes; + +jclass J_Throwable; +jmethodID J_Throwable_getMessage; +jclass J_SQLException; + +jclass J_Bool; +jclass J_Byte; +jclass J_Short; +jclass J_Int; +jclass J_Long; +jclass J_Float; +jclass J_Double; +jclass J_String; +jclass J_Timestamp; +jmethodID J_Timestamp_valueOf; +jclass J_TimestampTZ; +jclass J_Decimal; +jclass J_ByteArray; + +jmethodID J_Bool_booleanValue; +jmethodID J_Byte_byteValue; +jmethodID J_Short_shortValue; +jmethodID J_Int_intValue; +jmethodID J_Long_longValue; +jmethodID J_Float_floatValue; +jmethodID J_Double_doubleValue; +jmethodID J_Timestamp_getMicrosEpoch; +jmethodID J_TimestampTZ_getMicrosEpoch; +jmethodID J_Decimal_precision; +jmethodID J_Decimal_scale; +jmethodID J_Decimal_scaleByPowTen; +jmethodID J_Decimal_toPlainString; +jmethodID J_Decimal_longValue; + +jclass J_DuckResultSetMeta; +jmethodID J_DuckResultSetMeta_init; + +jclass J_DuckVector; +jmethodID J_DuckVector_init; +jfieldID J_DuckVector_constlen; +jfieldID J_DuckVector_varlen; + +jclass J_DuckArray; +jmethodID J_DuckArray_init; + +jclass J_Struct; +jmethodID J_Struct_getSQLTypeName; +jmethodID J_Struct_getAttributes; + +jclass J_Array; +jmethodID J_Array_getBaseTypeName; +jmethodID J_Array_getArray; + +jclass J_DuckStruct; +jmethodID J_DuckStruct_init; + +jclass J_ByteBuffer; + +jclass J_DuckMap; +jmethodID J_DuckMap_getSQLTypeName; + +jclass J_List; +jmethodID J_List_iterator; +jclass J_Map; +jmethodID J_Map_entrySet; +jclass J_Set; +jmethodID J_Set_iterator; +jclass J_Iterator; +jmethodID J_Iterator_hasNext; +jmethodID J_Iterator_next; +jclass J_Entry; +jmethodID J_Entry_getKey; +jmethodID J_Entry_getValue; + +jclass J_UUID; +jmethodID J_UUID_getMostSignificantBits; +jmethodID J_UUID_getLeastSignificantBits; + +jclass J_DuckDBDate; +jmethodID J_DuckDBDate_getDaysSinceEpoch; + +jclass J_Object; +jmethodID J_Object_toString; + +jclass J_DuckDBTime; + +static std::vector global_refs; + +template +static void check_not_null(T ptr, const std::string &message) { + if (nullptr == ptr) { + throw std::runtime_error(message); + } +} + +static jclass make_class_ref(JNIEnv *env, const std::string &name) { + jclass local_ref = env->FindClass(name.c_str()); + check_not_null(local_ref, "Class not found, name: [" + name + "]"); + jclass global_ref = reinterpret_cast(env->NewGlobalRef(local_ref)); + check_not_null(global_ref, "Cannot create global ref for class, name: [" + name + "]"); + env->DeleteLocalRef(local_ref); + global_refs.emplace_back(global_ref); + return global_ref; +} + +static jmethodID get_method_id(JNIEnv *env, jclass clazz, const std::string &name, const std::string &sig) { + jmethodID method_id = env->GetMethodID(clazz, name.c_str(), sig.c_str()); + check_not_null(method_id, "Method not found, name: [" + name + "], signature: [" + sig + "]"); + return method_id; +} + +static jmethodID get_static_method_id(JNIEnv *env, jclass clazz, const std::string &name, const std::string &sig) { + jmethodID method_id = env->GetStaticMethodID(clazz, name.c_str(), sig.c_str()); + check_not_null(method_id, "Static method not found, name: [" + name + "], signature: [" + sig + "]"); + return method_id; +} + +static jfieldID get_field_id(JNIEnv *env, jclass clazz, const std::string &name, const std::string &sig) { + jfieldID field_id = env->GetFieldID(clazz, name.c_str(), sig.c_str()); + check_not_null(field_id, "Field not found, name: [" + name + "], signature: [" + sig + "]"); + return field_id; +} + +static jobject make_static_object_field_ref(JNIEnv *env, jclass clazz, const std::string &name, + const std::string &sig) { + jfieldID field_id = env->GetStaticFieldID(clazz, name.c_str(), sig.c_str()); + check_not_null(field_id, "Static field not found, name: [" + name + "], signature: [" + sig + "]"); + jobject local_ref = env->GetStaticObjectField(clazz, field_id); + check_not_null(local_ref, "Specified static field is null, name: [" + name + "], signature: [" + sig + "]"); + jobject global_ref = env->NewGlobalRef(local_ref); + check_not_null(global_ref, + "Cannot create global ref for static field, name: [" + name + "], signature: [" + sig + "]"); + env->DeleteLocalRef(local_ref); + global_refs.emplace_back(global_ref); + return global_ref; +} + +void create_refs(JNIEnv *env) { + jclass tmpLocalRef; + + J_Charset = make_class_ref(env, "java/nio/charset/Charset"); + J_Charset_decode = get_method_id(env, J_Charset, "decode", "(Ljava/nio/ByteBuffer;)Ljava/nio/CharBuffer;"); + J_StandardCharsets = make_class_ref(env, "java/nio/charset/StandardCharsets"); + J_Charset_UTF8 = make_static_object_field_ref(env, J_StandardCharsets, "UTF_8", "Ljava/nio/charset/Charset;"); + J_CharBuffer = make_class_ref(env, "java/nio/CharBuffer"); + J_CharBuffer_toString = get_method_id(env, J_CharBuffer, "toString", "()Ljava/lang/String;"); + + J_Throwable = make_class_ref(env, "java/lang/Throwable"); + J_Throwable_getMessage = get_method_id(env, J_Throwable, "getMessage", "()Ljava/lang/String;"); + J_SQLException = make_class_ref(env, "java/sql/SQLException"); + + J_Bool = make_class_ref(env, "java/lang/Boolean"); + J_Byte = make_class_ref(env, "java/lang/Byte"); + J_Short = make_class_ref(env, "java/lang/Short"); + J_Int = make_class_ref(env, "java/lang/Integer"); + J_Long = make_class_ref(env, "java/lang/Long"); + J_Float = make_class_ref(env, "java/lang/Float"); + J_Double = make_class_ref(env, "java/lang/Double"); + J_String = make_class_ref(env, "java/lang/String"); + J_Decimal = make_class_ref(env, "java/math/BigDecimal"); + J_ByteArray = make_class_ref(env, "[B"); + + J_Timestamp = make_class_ref(env, "org/duckdb/DuckDBTimestamp"); + J_Timestamp_valueOf = get_static_method_id(env, J_Timestamp, "valueOf", "(Ljava/lang/Object;)Ljava/lang/Object;"); + J_TimestampTZ = make_class_ref(env, "org/duckdb/DuckDBTimestampTZ"); + + J_DuckDBDate = make_class_ref(env, "org/duckdb/DuckDBDate"); + J_DuckDBDate_getDaysSinceEpoch = get_method_id(env, J_DuckDBDate, "getDaysSinceEpoch", "()J"); + J_DuckDBTime = make_class_ref(env, "org/duckdb/DuckDBTime"); + + J_DuckMap = make_class_ref(env, "org/duckdb/user/DuckDBMap"); + J_DuckMap_getSQLTypeName = get_method_id(env, J_DuckMap, "getSQLTypeName", "()Ljava/lang/String;"); + + J_List = make_class_ref(env, "java/util/List"); + J_List_iterator = get_method_id(env, J_List, "iterator", "()Ljava/util/Iterator;"); + J_Map = make_class_ref(env, "java/util/Map"); + J_Map_entrySet = get_method_id(env, J_Map, "entrySet", "()Ljava/util/Set;"); + J_Set = make_class_ref(env, "java/util/Set"); + J_Set_iterator = get_method_id(env, J_Set, "iterator", "()Ljava/util/Iterator;"); + J_Iterator = make_class_ref(env, "java/util/Iterator"); + J_Iterator_hasNext = get_method_id(env, J_Iterator, "hasNext", "()Z"); + J_Iterator_next = get_method_id(env, J_Iterator, "next", "()Ljava/lang/Object;"); + + J_UUID = make_class_ref(env, "java/util/UUID"); + J_UUID_getMostSignificantBits = get_method_id(env, J_UUID, "getMostSignificantBits", "()J"); + J_UUID_getLeastSignificantBits = get_method_id(env, J_UUID, "getLeastSignificantBits", "()J"); + + J_DuckArray = make_class_ref(env, "org/duckdb/DuckDBArray"); + J_DuckArray_init = get_method_id(env, J_DuckArray, "", "(Lorg/duckdb/DuckDBVector;II)V"); + + J_DuckStruct = make_class_ref(env, "org/duckdb/DuckDBStruct"); + J_DuckStruct_init = get_method_id(env, J_DuckStruct, "", + "([Ljava/lang/String;[Lorg/duckdb/DuckDBVector;ILjava/lang/String;)V"); + + J_Struct = make_class_ref(env, "java/sql/Struct"); + J_Struct_getSQLTypeName = get_method_id(env, J_Struct, "getSQLTypeName", "()Ljava/lang/String;"); + J_Struct_getAttributes = get_method_id(env, J_Struct, "getAttributes", "()[Ljava/lang/Object;"); + + J_Array = make_class_ref(env, "java/sql/Array"); + J_Array_getArray = get_method_id(env, J_Array, "getArray", "()Ljava/lang/Object;"); + J_Array_getBaseTypeName = get_method_id(env, J_Array, "getBaseTypeName", "()Ljava/lang/String;"); + + J_Object = make_class_ref(env, "java/lang/Object"); + J_Object_toString = get_method_id(env, J_Object, "toString", "()Ljava/lang/String;"); + + J_Entry = make_class_ref(env, "java/util/Map$Entry"); + J_Entry_getKey = get_method_id(env, J_Entry, "getKey", "()Ljava/lang/Object;"); + J_Entry_getValue = get_method_id(env, J_Entry, "getValue", "()Ljava/lang/Object;"); + + J_Bool_booleanValue = get_method_id(env, J_Bool, "booleanValue", "()Z"); + J_Byte_byteValue = get_method_id(env, J_Byte, "byteValue", "()B"); + J_Short_shortValue = get_method_id(env, J_Short, "shortValue", "()S"); + J_Int_intValue = get_method_id(env, J_Int, "intValue", "()I"); + J_Long_longValue = get_method_id(env, J_Long, "longValue", "()J"); + J_Float_floatValue = get_method_id(env, J_Float, "floatValue", "()F"); + J_Double_doubleValue = get_method_id(env, J_Double, "doubleValue", "()D"); + J_Timestamp_getMicrosEpoch = get_method_id(env, J_Timestamp, "getMicrosEpoch", "()J"); + J_TimestampTZ_getMicrosEpoch = get_method_id(env, J_TimestampTZ, "getMicrosEpoch", "()J"); + J_Decimal_precision = get_method_id(env, J_Decimal, "precision", "()I"); + J_Decimal_scale = get_method_id(env, J_Decimal, "scale", "()I"); + J_Decimal_scaleByPowTen = get_method_id(env, J_Decimal, "scaleByPowerOfTen", "(I)Ljava/math/BigDecimal;"); + J_Decimal_toPlainString = get_method_id(env, J_Decimal, "toPlainString", "()Ljava/lang/String;"); + J_Decimal_longValue = get_method_id(env, J_Decimal, "longValue", "()J"); + + J_DuckResultSetMeta = make_class_ref(env, "org/duckdb/DuckDBResultSetMetaData"); + J_DuckResultSetMeta_init = + get_method_id(env, J_DuckResultSetMeta, "", + "(II[Ljava/lang/String;[Ljava/lang/String;[Ljava/lang/String;Ljava/lang/String;)V"); + + J_DuckVector = make_class_ref(env, "org/duckdb/DuckDBVector"); + + J_String_getBytes = get_method_id(env, J_String, "getBytes", "(Ljava/nio/charset/Charset;)[B"); + + J_DuckVector_init = get_method_id(env, J_DuckVector, "", "(Ljava/lang/String;I[Z)V"); + J_DuckVector_constlen = get_field_id(env, J_DuckVector, "constlen_data", "Ljava/nio/ByteBuffer;"); + J_DuckVector_varlen = get_field_id(env, J_DuckVector, "varlen_data", "[Ljava/lang/Object;"); + + J_ByteBuffer = make_class_ref(env, "java/nio/ByteBuffer"); +} + +void delete_global_refs(JNIEnv *env) noexcept { + try { + for (auto &rf : global_refs) { + env->DeleteGlobalRef(rf); + } + } catch (const std::exception e) { + std::cout << "ERROR: delete_global_refs: " << e.what() << std::endl; + } +} diff --git a/src/jni/refs.hpp b/src/jni/refs.hpp new file mode 100644 index 000000000..6d317e73b --- /dev/null +++ b/src/jni/refs.hpp @@ -0,0 +1,102 @@ +#pragma once + +#include + +extern jclass J_Charset; +extern jmethodID J_Charset_decode; +extern jclass J_StandardCharsets; +extern jobject J_Charset_UTF8; + +extern jclass J_CharBuffer; +extern jmethodID J_CharBuffer_toString; + +extern jmethodID J_String_getBytes; + +extern jclass J_Throwable; +extern jmethodID J_Throwable_getMessage; +extern jclass J_SQLException; + +extern jclass J_Bool; +extern jclass J_Byte; +extern jclass J_Short; +extern jclass J_Int; +extern jclass J_Long; +extern jclass J_Float; +extern jclass J_Double; +extern jclass J_String; +extern jclass J_Timestamp; +extern jmethodID J_Timestamp_valueOf; +extern jclass J_TimestampTZ; +extern jclass J_Decimal; +extern jclass J_ByteArray; + +extern jmethodID J_Bool_booleanValue; +extern jmethodID J_Byte_byteValue; +extern jmethodID J_Short_shortValue; +extern jmethodID J_Int_intValue; +extern jmethodID J_Long_longValue; +extern jmethodID J_Float_floatValue; +extern jmethodID J_Double_doubleValue; +extern jmethodID J_Timestamp_getMicrosEpoch; +extern jmethodID J_TimestampTZ_getMicrosEpoch; +extern jmethodID J_Decimal_precision; +extern jmethodID J_Decimal_scale; +extern jmethodID J_Decimal_scaleByPowTen; +extern jmethodID J_Decimal_toPlainString; +extern jmethodID J_Decimal_longValue; + +extern jclass J_DuckResultSetMeta; +extern jmethodID J_DuckResultSetMeta_init; + +extern jclass J_DuckVector; +extern jmethodID J_DuckVector_init; +extern jfieldID J_DuckVector_constlen; +extern jfieldID J_DuckVector_varlen; + +extern jclass J_DuckArray; +extern jmethodID J_DuckArray_init; + +extern jclass J_Struct; +extern jmethodID J_Struct_getSQLTypeName; +extern jmethodID J_Struct_getAttributes; + +extern jclass J_Array; +extern jmethodID J_Array_getBaseTypeName; +extern jmethodID J_Array_getArray; + +extern jclass J_DuckStruct; +extern jmethodID J_DuckStruct_init; + +extern jclass J_ByteBuffer; + +extern jclass J_DuckMap; +extern jmethodID J_DuckMap_getSQLTypeName; + +extern jclass J_List; +extern jmethodID J_List_iterator; +extern jclass J_Map; +extern jmethodID J_Map_entrySet; +extern jclass J_Set; +extern jmethodID J_Set_iterator; +extern jclass J_Iterator; +extern jmethodID J_Iterator_hasNext; +extern jmethodID J_Iterator_next; +extern jclass J_Entry; +extern jmethodID J_Entry_getKey; +extern jmethodID J_Entry_getValue; + +extern jclass J_UUID; +extern jmethodID J_UUID_getMostSignificantBits; +extern jmethodID J_UUID_getLeastSignificantBits; + +extern jclass J_DuckDBDate; +extern jmethodID J_DuckDBDate_getDaysSinceEpoch; + +extern jclass J_Object; +extern jmethodID J_Object_toString; + +extern jclass J_DuckDBTime; + +void create_refs(JNIEnv *env); + +void delete_global_refs(JNIEnv *env) noexcept; diff --git a/src/jni/util.cpp b/src/jni/util.cpp new file mode 100644 index 000000000..f88eb7e99 --- /dev/null +++ b/src/jni/util.cpp @@ -0,0 +1,76 @@ +#include "util.hpp" + +#include "refs.hpp" + +void check_java_exception_and_rethrow(JNIEnv *env) { + if (env->ExceptionCheck()) { + jthrowable exc = env->ExceptionOccurred(); + env->ExceptionClear(); + jclass clazz = env->GetObjectClass(exc); + jstring jmsg = reinterpret_cast(env->CallObjectMethod(exc, J_Throwable_getMessage)); + if (env->ExceptionCheck()) { + throw std::runtime_error("Error getting details of the Java exception"); + } + std::string msg = jstring_to_string(env, jmsg); + throw std::runtime_error(msg); + } +} + +std::string byte_array_to_string(JNIEnv *env, jbyteArray ba_j) { + idx_t len = env->GetArrayLength(ba_j); + std::string ret; + ret.resize(len); + + jbyte *bytes = (jbyte *)env->GetByteArrayElements(ba_j, NULL); + + for (idx_t i = 0; i < len; i++) { + ret[i] = bytes[i]; + } + env->ReleaseByteArrayElements(ba_j, bytes, 0); + + return ret; +} + +std::string jstring_to_string(JNIEnv *env, jstring string_j) { + jbyteArray bytes = (jbyteArray)env->CallObjectMethod(string_j, J_String_getBytes, J_Charset_UTF8); + return byte_array_to_string(env, bytes); +} + +jobject decode_charbuffer_to_jstring(JNIEnv *env, const char *d_str, idx_t d_str_len) { + auto bb = env->NewDirectByteBuffer((void *)d_str, d_str_len); + auto j_cb = env->CallObjectMethod(J_Charset_UTF8, J_Charset_decode, bb); + auto j_str = env->CallObjectMethod(j_cb, J_CharBuffer_toString); + return j_str; +} + +duckdb::Value create_value_from_bigdecimal(JNIEnv *env, jobject decimal) { + jint precision = env->CallIntMethod(decimal, J_Decimal_precision); + jint scale = env->CallIntMethod(decimal, J_Decimal_scale); + + // Java BigDecimal type can have scale that exceeds the precision + // Which our DECIMAL type does not support (assert(width >= scale)) + if (scale > precision) { + precision = scale; + } + + // DECIMAL scale is unsigned, so negative values are not supported + if (scale < 0) { + throw duckdb::InvalidInputException("Converting from a BigDecimal with negative scale is not supported"); + } + + duckdb::Value val; + + if (precision <= 18) { // normal sizes -> avoid string processing + jobject no_point_dec = env->CallObjectMethod(decimal, J_Decimal_scaleByPowTen, scale); + jlong result = env->CallLongMethod(no_point_dec, J_Decimal_longValue); + val = duckdb::Value::DECIMAL((int64_t)result, (uint8_t)precision, (uint8_t)scale); + } else if (precision <= 38) { // larger than int64 -> get string and cast + jobject str_val = env->CallObjectMethod(decimal, J_Decimal_toPlainString); + auto *str_char = env->GetStringUTFChars((jstring)str_val, 0); + val = duckdb::Value(str_char); + val = val.DefaultCastAs(duckdb::LogicalType::DECIMAL(precision, scale)); + env->ReleaseStringUTFChars((jstring)str_val, str_char); + } + + return val; +} diff --git a/src/jni/util.hpp b/src/jni/util.hpp new file mode 100644 index 000000000..921ffecbc --- /dev/null +++ b/src/jni/util.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "duckdb.hpp" + +#include +#include + +void check_java_exception_and_rethrow(JNIEnv *env); + +std::string byte_array_to_string(JNIEnv *env, jbyteArray ba_j); + +std::string jstring_to_string(JNIEnv *env, jstring string_j); + +jobject decode_charbuffer_to_jstring(JNIEnv *env, const char *d_str, idx_t d_str_len); + +duckdb::Value create_value_from_bigdecimal(JNIEnv *env, jobject decimal); diff --git a/src/test/java/org/duckdb/TestDuckDBJDBC.java b/src/test/java/org/duckdb/TestDuckDBJDBC.java index 1908d89e9..f2fac8dd4 100644 --- a/src/test/java/org/duckdb/TestDuckDBJDBC.java +++ b/src/test/java/org/duckdb/TestDuckDBJDBC.java @@ -4676,6 +4676,56 @@ public static void test_blob_after_rs_next() throws Exception { } } + public static void test_typed_connection_properties() throws Exception { + Properties config = new Properties(); + config.put("autoinstall_known_extensions", false); // BOOLEAN + List allowedDirsList = new ArrayList<>(); + allowedDirsList.add("path/to/dir1"); + allowedDirsList.add("path/to/dir2"); + config.put("allowed_directories", allowedDirsList); // VARCHAR[] + config.put("catalog_error_max_schemas", 42); // UBIGINT + config.put("index_scan_percentage", 0.042); // DOUBLE + + try (Connection conn = DriverManager.getConnection(JDBC_URL, config)) { + try (Statement stmt = conn.createStatement()) { + try (ResultSet rs = stmt.executeQuery("SELECT current_setting('autoinstall_known_extensions')")) { + rs.next(); + boolean val = rs.getBoolean(1); + assertFalse(val, "autoinstall_known_extensions"); + } + try (ResultSet rs = stmt.executeQuery("SELECT UNNEST(current_setting('allowed_directories'))")) { + List values = new ArrayList<>(); + while (rs.next()) { + String val = rs.getString(1); + values.add(val); + } + assertTrue(values.size() >= 2); + boolean dir1Found = false; + boolean dir2Found = false; + for (String val : values) { + if (val.contains("dir1")) { + dir1Found = true; + } + if (val.contains("dir2")) { + dir2Found = true; + } + } + assertTrue(dir1Found && dir2Found, "allowed_directories 1"); + } + try (ResultSet rs = stmt.executeQuery("SELECT current_setting('catalog_error_max_schemas')")) { + rs.next(); + long val = rs.getLong(1); + assertEquals(val, 42l, "catalog_error_max_schemas"); + } + try (ResultSet rs = stmt.executeQuery("SELECT current_setting('index_scan_percentage')")) { + rs.next(); + double val = rs.getDouble(1); + assertEquals(val, 0.042d, "index_scan_percentage"); + } + } + } + } + public static void main(String[] args) throws Exception { System.exit(runTests(args, TestDuckDBJDBC.class, TestExtensionTypes.class)); } diff --git a/src/test/java/org/duckdb/test/Assertions.java b/src/test/java/org/duckdb/test/Assertions.java index de30d27ba..9eb412537 100644 --- a/src/test/java/org/duckdb/test/Assertions.java +++ b/src/test/java/org/duckdb/test/Assertions.java @@ -20,6 +20,10 @@ public static void assertFalse(boolean val) throws Exception { assertTrue(!val); } + public static void assertFalse(boolean val, String message) throws Exception { + assertTrue(!val, message); + } + public static void assertEquals(Object actual, Object expected) throws Exception { assertEquals(actual, expected, ""); }