Skip to content

Commit 33fd643

Browse files
committed
added tx support
1 parent c431fd2 commit 33fd643

File tree

6 files changed

+154
-24
lines changed

6 files changed

+154
-24
lines changed

odbc/src/connection.cpp

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,49 @@ void TConnection::ClearErrors() {
122122
Errors_.clear();
123123
}
124124

125-
std::pair<std::string, std::string> TConnection::ParseConnectionString(const std::string& connectionString) {
126-
// TODO: Implement
127-
return {"", ""};
125+
SQLRETURN TConnection::SetAutocommit(bool value) {
126+
Autocommit_ = value;
127+
if (Autocommit_ && Tx_) {
128+
auto status = Tx_->Commit().ExtractValueSync();
129+
if (!status.IsSuccess()) {
130+
AddError("08001", 0, "Failed to commit transaction");
131+
return SQL_ERROR;
132+
}
133+
Tx_.reset();
134+
}
135+
return SQL_SUCCESS;
136+
}
137+
138+
bool TConnection::GetAutocommit() const {
139+
return Autocommit_;
140+
}
141+
142+
const std::optional<NQuery::TTransaction>& TConnection::GetTx() {
143+
return Tx_;
144+
}
145+
146+
void TConnection::SetTx(const NQuery::TTransaction& tx) {
147+
Tx_ = tx;
148+
}
149+
150+
SQLRETURN TConnection::CommitTx() {
151+
auto status = Tx_->Commit().ExtractValueSync();
152+
if (!status.IsSuccess()) {
153+
AddError("08001", 0, "Failed to commit transaction");
154+
return SQL_ERROR;
155+
}
156+
Tx_.reset();
157+
return SQL_SUCCESS;
158+
}
159+
160+
SQLRETURN TConnection::RollbackTx() {
161+
auto status = Tx_->Rollback().ExtractValueSync();
162+
if (!status.IsSuccess()) {
163+
AddError("08001", 0, "Failed to rollback transaction");
164+
return SQL_ERROR;
165+
}
166+
Tx_.reset();
167+
return SQL_SUCCESS;
128168
}
129169

130170
} // namespace NOdbc

odbc/src/connection.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,18 @@ class TStatement;
1919

2020
class TConnection {
2121
private:
22-
std::unique_ptr<NYdb::TDriver> YdbDriver_;
23-
std::unique_ptr<NYdb::NQuery::TQueryClient> YdbClient_;
22+
std::unique_ptr<TDriver> YdbDriver_;
23+
std::unique_ptr<NQuery::TQueryClient> YdbClient_;
24+
std::optional<NQuery::TTransaction> Tx_;
2425

2526
TErrorList Errors_;
2627
std::vector<std::unique_ptr<TStatement>> Statements_;
2728
std::string Endpoint_;
2829
std::string Database_;
2930
std::string AuthToken_;
3031

32+
bool Autocommit_ = true;
33+
3134
public:
3235
SQLRETURN Connect(const std::string& serverName,
3336
const std::string& userName,
@@ -46,8 +49,14 @@ class TConnection {
4649
void AddError(const std::string& sqlState, SQLINTEGER nativeError, const std::string& message);
4750
void ClearErrors();
4851

49-
private:
50-
std::pair<std::string, std::string> ParseConnectionString(const std::string& connectionString);
52+
SQLRETURN SetAutocommit(bool value);
53+
bool GetAutocommit() const;
54+
55+
const std::optional<NQuery::TTransaction>& GetTx();
56+
void SetTx(const NQuery::TTransaction& tx);
57+
58+
SQLRETURN CommitTx();
59+
SQLRETURN RollbackTx();
5160
};
5261

5362
} // namespace NOdbc

odbc/src/odbc_driver.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,4 +250,62 @@ SQLRETURN SQL_API SQLBindParameter(SQLHSTMT statementHandle,
250250
return stmt->BindParameter(paramNumber, inputOutputType, valueType, parameterType, columnSize, decimalDigits, parameterValuePtr, bufferLength, strLenOrIndPtr);
251251
}
252252

253+
SQLRETURN SQL_API SQLEndTran(SQLSMALLINT handleType, SQLHANDLE handle, SQLSMALLINT completionType) {
254+
if (!handle) {
255+
return SQL_INVALID_HANDLE;
256+
}
257+
try {
258+
switch (handleType) {
259+
case SQL_HANDLE_DBC: {
260+
auto conn = static_cast<NYdb::NOdbc::TConnection*>(handle);
261+
if (completionType == SQL_COMMIT) {
262+
return conn->CommitTx();
263+
} else if (completionType == SQL_ROLLBACK) {
264+
return conn->RollbackTx();
265+
} else {
266+
return SQL_ERROR;
267+
}
268+
}
269+
case SQL_HANDLE_STMT: {
270+
auto stmt = static_cast<NYdb::NOdbc::TStatement*>(handle);
271+
auto conn = stmt->GetConnection();
272+
if (!conn) return SQL_INVALID_HANDLE;
273+
if (completionType == SQL_COMMIT) {
274+
return conn->CommitTx();
275+
} else if (completionType == SQL_ROLLBACK) {
276+
return conn->RollbackTx();
277+
} else {
278+
return SQL_ERROR;
279+
}
280+
}
281+
case SQL_HANDLE_ENV: {
282+
// TODO: if's list of connections in ENV, go through them and commit/rollback transactions
283+
return SQL_SUCCESS;
284+
}
285+
default:
286+
return SQL_ERROR;
287+
}
288+
} catch (...) {
289+
return SQL_ERROR;
290+
}
291+
}
292+
293+
SQLRETURN SQL_API SQLSetConnectAttr(SQLHDBC connectionHandle, SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength) {
294+
auto conn = static_cast<NYdb::NOdbc::TConnection*>(connectionHandle);
295+
if (!conn) {
296+
return SQL_INVALID_HANDLE;
297+
}
298+
if (attribute == SQL_ATTR_AUTOCOMMIT) {
299+
if ((intptr_t)value == SQL_AUTOCOMMIT_ON) {
300+
return conn->SetAutocommit(true);
301+
} else if ((intptr_t)value == SQL_AUTOCOMMIT_OFF) {
302+
return conn->SetAutocommit(false);
303+
} else {
304+
return SQL_ERROR;
305+
}
306+
}
307+
// TODO: other attributes
308+
return SQL_ERROR;
309+
}
310+
253311
}

odbc/src/statement.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,23 @@ SQLRETURN TStatement::ExecDirect(const std::string& statementText) {
2222
return SQL_ERROR;
2323
}
2424

25-
auto sessionResult = client->GetSession().ExtractValueSync();
26-
if (!sessionResult.IsSuccess()) {
27-
return SQL_ERROR;
25+
if (!Conn_->GetTx()) {
26+
auto sessionResult = client->GetSession().ExtractValueSync();
27+
if (!sessionResult.IsSuccess()) {
28+
return SQL_ERROR;
29+
}
30+
auto session = sessionResult.GetSession();
31+
auto beginTxResult = session.BeginTransaction(NQuery::TTxSettings::SerializableRW()).ExtractValueSync();
32+
if (!beginTxResult.IsSuccess()) {
33+
return SQL_ERROR;
34+
}
35+
Conn_->SetTx(beginTxResult.GetTransaction());
2836
}
2937

30-
auto session = sessionResult.GetSession();
38+
auto session = Conn_->GetTx()->GetSession();
39+
auto iterator = session.StreamExecuteQuery(statementText,
40+
NQuery::TTxControl::Tx(*Conn_->GetTx()).CommitTx(Conn_->GetAutocommit()), params).ExtractValueSync();
3141

32-
auto iterator = session.StreamExecuteQuery(statementText, NYdb::NQuery::TTxControl::NoTx(), params).ExtractValueSync();
3342
if (!iterator.IsSuccess()) {
3443
return SQL_ERROR;
3544
}

odbc/src/utils/convert.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ template<SQLSMALLINT CType>
99
struct TSqlTypeTraits;
1010

1111
template<> struct TSqlTypeTraits<SQL_C_CHAR> { using Type = std::string; };
12+
template<> struct TSqlTypeTraits<SQL_C_BINARY> { using Type = std::string; };
1213
template<> struct TSqlTypeTraits<SQL_C_SBIGINT> { using Type = SQLBIGINT; };
1314
template<> struct TSqlTypeTraits<SQL_C_UBIGINT> { using Type = SQLUBIGINT; };
1415
template<> struct TSqlTypeTraits<SQL_C_LONG> { using Type = SQLINTEGER; };
@@ -38,6 +39,11 @@ TTypedValue<SQL_C_CHAR>::TTypedValue(const TBoundParam& param) {
3839
Data = std::string(static_cast<const char*>(param.ParameterValuePtr), param.BufferLength);
3940
}
4041

42+
template<>
43+
TTypedValue<SQL_C_BINARY>::TTypedValue(const TBoundParam& param) {
44+
Data = std::string(static_cast<const char*>(param.ParameterValuePtr), param.BufferLength);
45+
}
46+
4147
class IConverter {
4248
public:
4349
virtual void AddToBuilder(const TBoundParam& param, TParamValueBuilder& builder) = 0;
@@ -259,15 +265,15 @@ REGISTER_CONVERTER(SQL_C_CHAR, SQL_LONGVARCHAR, EPrimitiveType::Utf8) {
259265

260266
// Binary types
261267

262-
REGISTER_CONVERTER(SQL_C_CHAR, SQL_BINARY, EPrimitiveType::String) {
268+
REGISTER_CONVERTER(SQL_C_BINARY, SQL_BINARY, EPrimitiveType::String) {
263269
builder.OptionalString(std::move(data));
264270
}
265271

266-
REGISTER_CONVERTER(SQL_C_CHAR, SQL_VARBINARY, EPrimitiveType::String) {
272+
REGISTER_CONVERTER(SQL_C_BINARY, SQL_VARBINARY, EPrimitiveType::String) {
267273
builder.OptionalString(std::move(data));
268274
}
269275

270-
REGISTER_CONVERTER(SQL_C_CHAR, SQL_LONGVARBINARY, EPrimitiveType::String) {
276+
REGISTER_CONVERTER(SQL_C_BINARY, SQL_LONGVARBINARY, EPrimitiveType::String) {
271277
builder.OptionalString(std::move(data));
272278
}
273279

odbc/tests/unit/convert_ut.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
using namespace NYdb::NOdbc;
1313
using namespace NYdb;
1414

15-
void CheckProtoValue(const Ydb::Value& value, const std::string& expected) {
15+
template<typename T>
16+
void CheckProto(const T& value, const std::string& expected) {
1617
std::string protoStr;
1718
google::protobuf::TextFormat::PrintToString(value, &protoStr);
1819
ASSERT_EQ(protoStr, expected);
@@ -36,7 +37,8 @@ TEST(OdbcConvert, Int64ToYdb) {
3637
auto params = paramsBuilder.Build();
3738
auto value = params.GetValue("$p1");
3839
ASSERT_TRUE(value);
39-
CheckProtoValue(value->GetProto(), "int64_value: 42\n");
40+
CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: INT64\n }\n}\n");
41+
CheckProto(value->GetProto(), "int64_value: 42\n");
4042
}
4143

4244
TEST(OdbcConvert, Uint64ToYdb) {
@@ -49,7 +51,8 @@ TEST(OdbcConvert, Uint64ToYdb) {
4951
auto params = paramsBuilder.Build();
5052
auto value = params.GetValue("$p1");
5153
ASSERT_TRUE(value);
52-
CheckProtoValue(value->GetProto(), "uint64_value: 123\n");
54+
CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: UINT64\n }\n}\n");
55+
CheckProto(value->GetProto(), "uint64_value: 123\n");
5356
}
5457

5558
TEST(OdbcConvert, DoubleToYdb) {
@@ -62,7 +65,8 @@ TEST(OdbcConvert, DoubleToYdb) {
6265
auto params = paramsBuilder.Build();
6366
auto value = params.GetValue("$p1");
6467
ASSERT_TRUE(value);
65-
CheckProtoValue(value->GetProto(), "double_value: 3.14\n");
68+
CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: DOUBLE\n }\n}\n");
69+
CheckProto(value->GetProto(), "double_value: 3.14\n");
6670
}
6771

6872
TEST(OdbcConvert, StringToYdbUtf8) {
@@ -76,21 +80,23 @@ TEST(OdbcConvert, StringToYdbUtf8) {
7680
auto params = paramsBuilder.Build();
7781
auto value = params.GetValue("$p1");
7882
ASSERT_TRUE(value);
79-
CheckProtoValue(value->GetProto(), "text_value: \"hello\"\n");
83+
CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: UTF8\n }\n}\n");
84+
CheckProto(value->GetProto(), "text_value: \"hello\"\n");
8085
}
8186

8287
TEST(OdbcConvert, StringToYdbBinary) {
8388
const char* str = "bin\x01\x02";
8489
SQLLEN len = 5;
8590
TBoundParam param{
86-
1, SQL_PARAM_INPUT, SQL_C_CHAR, SQL_BINARY, 0, 0, (SQLPOINTER)str, len, nullptr
91+
1, SQL_PARAM_INPUT, SQL_C_BINARY, SQL_BINARY, 0, 0, (SQLPOINTER)str, len, nullptr
8792
};
8893
TParamsBuilder paramsBuilder;
8994
ConvertValue(param, paramsBuilder.AddParam("$p1"));
9095
auto params = paramsBuilder.Build();
9196
auto value = params.GetValue("$p1");
9297
ASSERT_TRUE(value);
93-
ASSERT_EQ(value->GetProto().bytes_value(), std::string(str, len));
98+
CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: STRING\n }\n}\n");
99+
CheckProto(value->GetProto(), "bytes_value: \"bin\\001\\002\"\n");
94100
}
95101

96102
TEST(OdbcConvert, Int64NullToYdb) {
@@ -104,7 +110,8 @@ TEST(OdbcConvert, Int64NullToYdb) {
104110
auto params = paramsBuilder.Build();
105111
auto value = params.GetValue("$p1");
106112
ASSERT_TRUE(value);
107-
ASSERT_EQ(value->GetProto().null_flag_value(), ::google::protobuf::NullValue::NULL_VALUE);
113+
CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: INT64\n }\n}\n");
114+
CheckProto(value->GetProto(), "null_flag_value: NULL_VALUE\n");
108115
}
109116

110117
TEST(OdbcConvert, StringNullToYdb) {
@@ -118,5 +125,6 @@ TEST(OdbcConvert, StringNullToYdb) {
118125
auto params = paramsBuilder.Build();
119126
auto value = params.GetValue("$p1");
120127
ASSERT_TRUE(value);
121-
ASSERT_EQ(value->GetProto().null_flag_value(), ::google::protobuf::NullValue::NULL_VALUE);
128+
CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: UTF8\n }\n}\n");
129+
CheckProto(value->GetProto(), "null_flag_value: NULL_VALUE\n");
122130
}

0 commit comments

Comments
 (0)