diff --git a/README.md b/README.md index 8a063a6..203bd15 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,8 @@ Connect to a database. - `db` - a database name - `conn_string` (mutual exclusive with host, port, user, pass, db) - PostgreSQL [connection string][PQconnstring] + - `dec_cast` - an option that switches casting types for `NUMERIC` PostgreSQL + type. Possible values: `n` (`number`), `s` (`string`), `d` (`decimal`). *Returns*: diff --git a/pg/driver.c b/pg/driver.c index 0bc9620..1fba631 100644 --- a/pg/driver.c +++ b/pg/driver.c @@ -49,6 +49,31 @@ #undef PACKAGE_VERSION #include +/** + * The fallthrough attribute with a null statement serves as a fallthrough + * statement. It hints to the compiler that a statement that falls through + * to another case label, or user-defined label in a switch statement is + * intentional and thus the -Wimplicit-fallthrough warning must not trigger. + * The fallthrough attribute may appear at most once in each attribute list, + * and may not be mixed with other attributes. It can only be used in a switch + * statement (the compiler will issue an error otherwise), after a preceding + * statement and before a logically succeeding case label, or user-defined + * label. + */ +#if defined(__cplusplus) && __has_cpp_attribute(fallthrough) +# define FALLTHROUGH [[fallthrough]] +#elif __has_attribute(fallthrough) || (defined(__GNUC__) && __GNUC__ >= 7) +# define FALLTHROUGH __attribute__((fallthrough)) +#else +# define FALLTHROUGH +#endif + +struct dec_opt { + char cast; + int dnew_index; +}; +typedef struct dec_opt dec_opt_t; + /** * Infinity timeout from tarantool_ev.c. I mean, this should be in * a module.h file. @@ -97,7 +122,7 @@ lua_push_error(struct lua_State *L) * Parse pg values to lua */ static int -parse_pg_value(struct lua_State *L, PGresult *res, int row, int col) +parse_pg_value(struct lua_State *L, PGresult *res, int row, int col, dec_opt_t *dopt) { if (PQgetisnull(res, row, col)) return false; @@ -107,9 +132,26 @@ parse_pg_value(struct lua_State *L, PGresult *res, int row, int col) int len = PQgetlength(res, row, col); switch (PQftype(res, col)) { - case INT2OID: - case INT4OID: case NUMERICOID: { + if (dopt->cast == 's') { + lua_pushlstring(L, val, len); + break; + } + else if (dopt->cast == 'd' && dopt->dnew_index != -1) { + lua_rawgeti(L, LUA_REGISTRYINDEX, dopt->dnew_index); + lua_pushlstring(L, val, len); + int fail = lua_pcall(L, 1, 1, 0); + if (fail) { + lua_pop(L, 2); + return false; + } + break; + } + /* 'n': fallthrough */ + FALLTHROUGH; + } + case INT2OID: + case INT4OID: { lua_pushlstring(L, val, len); double v = lua_tonumber(L, -1); lua_pop(L, 1); @@ -141,6 +183,7 @@ static int safe_pg_parsetuples(struct lua_State *L) { PGresult *res = (PGresult *)lua_topointer(L, 1); + dec_opt_t *dopt = (dec_opt_t *)lua_topointer(L, 2); int row, rows = PQntuples(res); int col, cols = PQnfields(res); lua_newtable(L); @@ -148,7 +191,7 @@ safe_pg_parsetuples(struct lua_State *L) lua_pushnumber(L, row + 1); lua_newtable(L); for (col = 0; col < cols; ++col) - parse_pg_value(L, res, row, col); + parse_pg_value(L, res, row, col, dopt); lua_settable(L, -3); } return 1; @@ -205,7 +248,7 @@ pg_wait_for_result(PGconn *conn) * Appends result fom postgres to lua table */ static int -pg_resultget(struct lua_State *L, PGconn *conn, int *res_no, int status_ok) +pg_resultget(struct lua_State *L, PGconn *conn, int *res_no, int status_ok, dec_opt_t *dopt) { int wait_res = pg_wait_for_result(conn); if (wait_res != 1) @@ -235,9 +278,13 @@ pg_resultget(struct lua_State *L, PGconn *conn, int *res_no, int status_ok) lua_pushinteger(L, (*res_no)++); lua_pushcfunction(L, safe_pg_parsetuples); lua_pushlightuserdata(L, pg_res); - fail = lua_pcall(L, 1, 1, 0); - if (!fail) + lua_pushlightuserdata(L, dopt); + fail = lua_pcall(L, 2, 1, 0); + if (!fail) { lua_settable(L, -3); + break; + } + break; case PGRES_COMMAND_OK: res = 1; break; @@ -269,6 +316,15 @@ static void lua_parse_param(struct lua_State *L, int idx, const char **value, int *length, Oid *type) { + /* Serialized [u]int64_t */ + static char buf[512]; + static char *pos = NULL; + /* lua_parse_param(L, idx + 5, ...) */ + if (idx == 5) { + *buf = '\0'; + pos = buf; + } + if (lua_isnil(L, idx)) { *value = NULL; *length = 0; @@ -293,6 +349,27 @@ lua_parse_param(struct lua_State *L, return; } + if (luaL_iscdata(L, idx)) { + uint32_t ctypeid = 0; + void *cdata = luaL_checkcdata(L, idx, &ctypeid); + int len = 0; + if (ctypeid == luaL_ctypeid(L, "int64_t")) { + len = snprintf(pos, sizeof(buf) - (pos - buf), "%ld", *(int64_t*)cdata); + *type = INT8OID; + } + else if (ctypeid == luaL_ctypeid(L, "uint64_t")) { + len = snprintf(pos, sizeof(buf) - (pos - buf), "%lu", *(uint64_t*)cdata); + *type = NUMERICOID; + } + + if (len > 0) { + *value = pos; + *length = len; + pos += len + 1; + return; + } + } + // We will pass all other types as strings size_t len; *value = lua_tolstring(L, idx, &len); @@ -307,12 +384,28 @@ static int lua_pg_execute(struct lua_State *L) { PGconn *conn = lua_check_pgconn(L, 1); - if (!lua_isstring(L, 2)) { + + dec_opt_t dopt = {'n', -1}; + if (lua_isstring(L, 2)) { + const char *dec_cast_type = lua_tostring(L, 2); + if (*dec_cast_type == 'n' || + *dec_cast_type == 's' || + *dec_cast_type == 'd') + dopt.cast = *dec_cast_type; + } + + if (!lua_isstring(L, 4)) { safe_pushstring(L, "Second param should be a sql command"); return lua_push_error(L); } - const char *sql = lua_tostring(L, 2); - int paramCount = lua_gettop(L) - 2; + + if (lua_isfunction(L, 3)) { + lua_pushvalue(L, 3); + dopt.dnew_index = luaL_ref(L, LUA_REGISTRYINDEX); + } + + const char *sql = lua_tostring(L, 4); + int paramCount = lua_gettop(L) - 4; const char **paramValues = NULL; int *paramLengths = NULL; @@ -333,7 +426,7 @@ lua_pg_execute(struct lua_State *L) int idx; for (idx = 0; idx < paramCount; ++idx) { - lua_parse_param(L, idx + 3, paramValues + idx, + lua_parse_param(L, idx + 5, paramValues + idx, paramLengths + idx, paramTypes + idx); } res = PQsendQueryParams(conn, sql, paramCount, paramTypes, @@ -345,6 +438,8 @@ lua_pg_execute(struct lua_State *L) if (res == -1) { lua_pushinteger(L, PQstatus(conn) == CONNECTION_BAD ? -1: 0); lua_pushstring(L, PQerrorMessage(conn)); + if (dopt.dnew_index != -1) + luaL_unref(L, LUA_REGISTRYINDEX, dopt.dnew_index); return 2; } lua_pushinteger(L, 0); @@ -352,7 +447,10 @@ lua_pg_execute(struct lua_State *L) int res_no = 1; int status_ok = 1; - while ((status_ok = pg_resultget(L, conn, &res_no, status_ok))); + while ((status_ok = pg_resultget(L, conn, &res_no, status_ok, &dopt))); + + if (dopt.dnew_index != -1) + luaL_unref(L, LUA_REGISTRYINDEX, dopt.dnew_index); return 2; } diff --git a/pg/init.lua b/pg/init.lua index b1000d6..91ff0ea 100644 --- a/pg/init.lua +++ b/pg/init.lua @@ -4,6 +4,11 @@ local fiber = require('fiber') local driver = require('pg.driver') local ffi = require('ffi') +local has_decimal, dec = pcall(require, 'decimal') +if has_decimal then + dnew = dec.new +end + local pool_mt local conn_mt @@ -15,6 +20,10 @@ local function conn_create(pg_conn) usable = true, conn = pg_conn, queue = queue, + dec_cast = 'n' -- Defined in pg/driver.c: + -- 'n' - number, + -- 's' - string, + -- 'd' - decimal. }, conn_mt) return conn @@ -60,7 +69,7 @@ conn_mt = { self.queue:put(false) return get_error(self.raise.pool, 'Connection is broken') end - local status, datas = self.conn:execute(sql, ...) + local status, datas = self.conn:execute(self.dec_cast, dnew, sql, ...) if status ~= 0 then self.queue:put(status > 0) return error(datas) diff --git a/test/pg.test.lua b/test/pg.test.lua index 144ceea..871f395 100755 --- a/test/pg.test.lua +++ b/test/pg.test.lua @@ -129,6 +129,41 @@ function test_pg_int64(t, p) p:put(conn) end +function test_pg_decimal(t, p) + t:plan(8) + + -- Setup + conn = p:get() + t:isnt(conn, nil, 'connection is established') + local num = 4500 + conn:execute('CREATE TABLE dectest (num NUMERIC(7,2))') + conn:execute(('INSERT INTO dectest VALUES(%d)'):format(num)) + + local res, r, _ + -- dec_cast is 'n' + t:is(conn.dec_cast, 'n', 'decimal casting type is "n" by default') + r, _ = conn:execute('SELECT num FROM dectest') + res = r[1][1]['num'] + t:is(type(res), 'number', 'type is "number"') + t:is(res, num, 'decimal number is correct') + -- dec_cast is 's' + conn.dec_cast = 's' + r, _ = conn:execute('SELECT num FROM dectest') + res = r[1][1]['num'] + t:is(type(res), 'string', 'type is "string"') + t:is(res, '4500.00', 'decimal number is correct') + -- dec_cast is 'd' + conn.dec_cast = 'd' + r, _ = conn:execute('SELECT num FROM dectest') + res = r[1][1]['num'] + t:is(type(res), 'cdata', 'type is "decimal"') + t:is(res, num, 'decimal number is correct') + + -- Teardown + conn:execute('DROP TABLE dectest') + p:put(conn) +end + tap.test('connection old api', test_old_api, conn) local pool_conn = p:get() tap.test('connection old api via pool', test_old_api, pool_conn) @@ -136,4 +171,5 @@ p:put(pool_conn) tap.test('test collection connections', test_gc, p) tap.test('connection concurrent', test_conn_concurrent, p) tap.test('int64', test_pg_int64, p) +tap.test('decimal', test_pg_decimal, p) p:close()