🌐 AI搜索 & 代理 主页
Skip to content
Closed
54 changes: 54 additions & 0 deletions Lib/sqlite3/test/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,11 +444,65 @@ class UnhashableType(type):
self.con.execute('SELECT %s()' % aggr_name)


class DMLStatementDetectionTestCase(unittest.TestCase):
"""
https://bugs.python.org/issue36859

Use sqlite3_stmt_readonly to determine if the statement is DML or not.
"""
@unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3),
'needs sqlite 3.8.3 or newer')
def test_dml_detection_cte(self):
conn = sqlite.connect(':memory:')
conn.execute('create table kv ("key" text, "val" integer)')
self.assertFalse(conn.in_transaction)
conn.execute('insert into kv (key, val) values (?, ?), (?, ?)',
('k1', 1, 'k2', 2))
self.assertTrue(conn.in_transaction)

conn.commit()
self.assertFalse(conn.in_transaction)

rc = conn.execute('update kv set val=val + ?', (10,))
self.assertEqual(rc.rowcount, 2)
self.assertTrue(conn.in_transaction)
conn.commit()
self.assertFalse(conn.in_transaction)

rc = conn.execute('with c(k, v) as (select key, val + ? from kv) '
'update kv set val=(select v from c where k=kv.key)',
(100,))
self.assertEqual(rc.rowcount, 2)
self.assertTrue(conn.in_transaction)

curs = conn.execute('select key, val from kv order by key')
self.assertEqual(curs.fetchall(), [('k1', 111), ('k2', 112)])

@unittest.skipIf(sqlite.sqlite_version_info < (3, 7, 11),
'needs sqlite 3.7.11 or newer')
def test_dml_detection_sql_comment(self):
conn = sqlite.connect(':memory:')
conn.execute('create table kv ("key" text, "val" integer)')
conn.execute('insert into kv (key, val) values (?, ?), (?, ?)',
('k1', 1, 'k2', 2))
conn.commit()

self.assertFalse(conn.in_transaction)
rc = conn.execute('-- a comment\nupdate kv set val=val + ?', (10,))
self.assertEqual(rc.rowcount, 2)
self.assertTrue(conn.in_transaction)

curs = conn.execute('select key, val from kv order by key')
self.assertEqual(curs.fetchall(), [('k1', 11), ('k2', 12)])
conn.rollback()


def suite():
regression_suite = unittest.makeSuite(RegressionTests, "Check")
return unittest.TestSuite((
regression_suite,
unittest.makeSuite(UnhashableCallbacksTestCase),
unittest.makeSuite(DMLStatementDetectionTestCase),
))

def test():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use `sqlite3_stmt_readonly` internally to determine if a SQL statement is data-modifying.
72 changes: 52 additions & 20 deletions Modules/_sqlite/statement.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
#include "prepare_protocol.h"
#include "util.h"

#if SQLITE_VERSION_NUMBER >= 3007011
#define HAVE_SQLITE3_STMT_READONLY
#endif

/* prototypes */
static int pysqlite_check_remaining_sql(const char* tail);

Expand All @@ -48,13 +52,59 @@ typedef enum {
TYPE_UNKNOWN
} parameter_type;

int pysqlite_statement_is_dml(sqlite3_stmt *st, const char *sql)
{
const char* p;
int is_dml = 0;

#ifdef HAVE_SQLITE3_STMT_READONLY
is_dml = ! sqlite3_stmt_readonly(st);
if (is_dml) {
/* Retain backwards-compatibility, as sqlite3_stmt_readonly will return
* false for BEGIN [IMMEDIATE|EXCLUSIVE] or DDL statements.
*/
for (p = sql; *p != 0; p++) {
switch (*p) {
case ' ':
case '\r':
case '\n':
case '\t':
continue;
}

is_dml = (PyOS_strnicmp(p, "begin", 5) &&
PyOS_strnicmp(p, "create", 6) &&
PyOS_strnicmp(p, "drop", 4));
break;
}
}
#else
/* Original implementation. */
for (p = sql; *p != 0; p++) {
switch (*p) {
case ' ':
case '\r':
case '\n':
case '\t':
continue;
}

is_dml = (PyOS_strnicmp(p, "insert", 6) == 0)
|| (PyOS_strnicmp(p, "update", 6) == 0)
|| (PyOS_strnicmp(p, "delete", 6) == 0)
|| (PyOS_strnicmp(p, "replace", 7) == 0);
break;
}
#endif
return is_dml;
}

int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* connection, PyObject* sql)
{
const char* tail;
int rc;
const char* sql_cstr;
Py_ssize_t sql_cstr_len;
const char* p;

self->st = NULL;
self->in_use = 0;
Expand All @@ -73,25 +123,6 @@ int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* con
Py_INCREF(sql);
self->sql = sql;

/* Determine if the statement is a DML statement.
SELECT is the only exception. See #9924. */
self->is_dml = 0;
for (p = sql_cstr; *p != 0; p++) {
switch (*p) {
case ' ':
case '\r':
case '\n':
case '\t':
continue;
}

self->is_dml = (PyOS_strnicmp(p, "insert", 6) == 0)
|| (PyOS_strnicmp(p, "update", 6) == 0)
|| (PyOS_strnicmp(p, "delete", 6) == 0)
|| (PyOS_strnicmp(p, "replace", 7) == 0);
break;
}

Py_BEGIN_ALLOW_THREADS
rc = sqlite3_prepare_v2(connection->db,
sql_cstr,
Expand All @@ -101,6 +132,7 @@ int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* con
Py_END_ALLOW_THREADS

self->db = connection->db;
self->is_dml = pysqlite_statement_is_dml(self->st, sql_cstr);

if (rc == SQLITE_OK && pysqlite_check_remaining_sql(tail)) {
(void)sqlite3_finalize(self->st);
Expand Down