Skip to content

bpo-44329: Refactor sqlite3 statement creation #26566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 8, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 35 additions & 43 deletions Modules/_sqlite/statement.c
Original file line number Diff line number Diff line change
Expand Up @@ -51,47 +51,50 @@ typedef enum {
pysqlite_Statement *
pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
{
const char* tail;
int rc;
const char* sql_cstr;
Py_ssize_t sql_cstr_len;
const char* p;

assert(PyUnicode_Check(sql));

sql_cstr = PyUnicode_AsUTF8AndSize(sql, &sql_cstr_len);
Py_ssize_t size;
const char *sql_cstr = PyUnicode_AsUTF8AndSize(sql, &size);
if (sql_cstr == NULL) {
PyErr_Format(pysqlite_Warning,
"SQL is of wrong type ('%s'). Must be string.",
Py_TYPE(sql)->tp_name);
return NULL;
}

int max_length = sqlite3_limit(connection->db, SQLITE_LIMIT_LENGTH, -1);
if (sql_cstr_len >= max_length) {
sqlite3 *db = connection->db;
int max_length = sqlite3_limit(db, SQLITE_LIMIT_LENGTH, -1);
if (size >= max_length) {
PyErr_SetString(pysqlite_DataError, "query string is too large");
return NULL;
}
if (strlen(sql_cstr) != (size_t)sql_cstr_len) {
if (strlen(sql_cstr) != (size_t)size) {
PyErr_SetString(PyExc_ValueError,
"the query contains a null character");
return NULL;
}

pysqlite_Statement *self = PyObject_GC_New(pysqlite_Statement,
pysqlite_StatementType);
if (self == NULL) {
sqlite3_stmt *stmt;
const char *tail;
int rc;
Py_BEGIN_ALLOW_THREADS
rc = sqlite3_prepare_v2(db, sql_cstr, (int)size + 1, &stmt, &tail);
Py_END_ALLOW_THREADS

if (rc != SQLITE_OK) {
_pysqlite_seterror(db);
return NULL;
}

self->st = NULL;
self->in_use = 0;
self->is_dml = 0;
self->in_weakreflist = NULL;
if (pysqlite_check_remaining_sql(tail)) {
PyErr_SetString(pysqlite_Warning,
"You can only execute one statement at a time.");
goto error;
}

/* Determine if the statement is a DML statement.
SELECT is the only exception. See #9924. */
for (p = sql_cstr; *p != 0; p++) {
int is_dml = 0;
for (const char *p = sql_cstr; *p != 0; p++) {
switch (*p) {
case ' ':
case '\r':
Expand All @@ -100,40 +103,29 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
continue;
}

self->is_dml = (PyOS_strnicmp(p, "insert", 6) == 0)
|| (PyOS_strnicmp(p, "update", 6) == 0)
|| (PyOS_strnicmp(p, "delete", 6) == 0)
|| (PyOS_strnicmp(p, "replace", 7) == 0);
is_dml = (PyOS_strnicmp(p, "insert", 6) == 0)
|| (PyOS_strnicmp(p, "update", 6) == 0)
|| (PyOS_strnicmp(p, "delete", 6) == 0)
|| (PyOS_strnicmp(p, "replace", 7) == 0);
break;
}

Py_BEGIN_ALLOW_THREADS
rc = sqlite3_prepare_v2(connection->db,
sql_cstr,
(int)sql_cstr_len + 1,
&self->st,
&tail);
Py_END_ALLOW_THREADS

PyObject_GC_Track(self);

if (rc != SQLITE_OK) {
_pysqlite_seterror(connection->db);
pysqlite_Statement *self = PyObject_GC_New(pysqlite_Statement,
pysqlite_StatementType);
if (self == NULL) {
goto error;
}

if (rc == SQLITE_OK && pysqlite_check_remaining_sql(tail)) {
(void)sqlite3_finalize(self->st);
self->st = NULL;
PyErr_SetString(pysqlite_Warning,
"You can only execute one statement at a time.");
goto error;
}
self->st = stmt;
self->in_use = 0;
self->is_dml = is_dml;
self->in_weakreflist = NULL;

PyObject_GC_Track(self);
return self;

error:
Py_DECREF(self);
(void)sqlite3_finalize(stmt);
return NULL;
}

Expand Down