From 6104681aacb65ddb8b80c9de861c147f10eba2d5 Mon Sep 17 00:00:00 2001 From: Carlo Teubner Date: Sat, 30 Sep 2023 17:07:41 +0100 Subject: [PATCH] Add SetErrorLog function This function sets a callback that SQLite invokes when it detects an anomaly. See https://sqlite.org/errlog.html. --- callback.go | 9 +++++++++ sqlite3.go | 39 +++++++++++++++++++++++++++++++++++++-- sqlite3_test.go | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 2 deletions(-) diff --git a/callback.go b/callback.go index b794bcd8..c2fdf365 100644 --- a/callback.go +++ b/callback.go @@ -16,6 +16,7 @@ package sqlite3 #else #include #endif +#include #include void _sqlite3_result_text(sqlite3_context* ctx, const char* s); @@ -29,9 +30,17 @@ import ( "math" "reflect" "sync" + "sync/atomic" "unsafe" ) +var errorLogCallback atomic.Value + +//export errorLogTrampoline +func errorLogTrampoline(_ C.uintptr_t, errCode C.int, msg *C.char) { + errorLogCallback.Load().(func(Error, string))(errorFromCode(errCode), C.GoString(msg)) +} + //export callbackTrampoline func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) { args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc] diff --git a/sqlite3.go b/sqlite3.go index a16d8abf..69b6574d 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -59,6 +59,13 @@ package sqlite3 # define USE_PWRITE64 1 #endif +void errorLogTrampoline(void *userPtr, int errCode, const char *msg); + +static int +_sqlite3_config_log() { + return sqlite3_config(SQLITE_CONFIG_LOG, &errorLogTrampoline, NULL); +} + static int _sqlite3_open_v2(const char *filename, sqlite3 **ppDb, int flags, const char *zVfs) { #ifdef SQLITE_OPEN_URI @@ -263,14 +270,35 @@ func Version() (libVersion string, libVersionNumber int, sourceID string) { return libVersion, libVersionNumber, sourceID } +// SetErrorLog registers the given callback to be invoked with a message whenever SQLite detects an +// anomaly. It is good practice to redirect such messages to the application log. See +// https://sqlite.org/errlog.html. +// The provided callback function receives an SQLite error object, denoting the broad category of +// error, and a message string. It must not call any SQLite functions; in fact, the SQLite docs +// recommend treating the callback function like a signal handler, minimizing the work done in it. +// SetErrorLog must not be called while any other goroutine is running that might be calling into +// the SQLite library. +func SetErrorLog(callback func(err Error, msg string)) error { + errorLogCallback.Store(callback) + if rc := C._sqlite3_config_log(); rc == 0 { + return nil + } else { + return errorFromCode(rc) + } +} + const ( + // some common return codes + SQLITE_OK = C.SQLITE_OK + SQLITE_NOTICE = C.SQLITE_NOTICE + SQLITE_WARNING = C.SQLITE_WARNING + // used by authorizer and pre_update_hook SQLITE_DELETE = C.SQLITE_DELETE SQLITE_INSERT = C.SQLITE_INSERT SQLITE_UPDATE = C.SQLITE_UPDATE - // used by authorzier - as return value - SQLITE_OK = C.SQLITE_OK + // used by authorizer as return value, in addition to SQLITE_OK SQLITE_IGNORE = C.SQLITE_IGNORE SQLITE_DENY = C.SQLITE_DENY @@ -845,6 +873,13 @@ func lastError(db *C.sqlite3) error { } } +func errorFromCode(rc C.int) Error { + return Error{ + Code: ErrNo(rc & ErrNoMask), + ExtendedCode: ErrNoExtended(rc), + } +} + // Exec implements Execer. func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) { list := make([]driver.NamedValue, len(args)) diff --git a/sqlite3_test.go b/sqlite3_test.go index 63c939d3..b7745aa4 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1248,6 +1248,40 @@ func TestVersion(t *testing.T) { } } +func TestErrorLog(t *testing.T) { + var errorLogged bool + var capturedErr Error + var capturedMsg string + err := SetErrorLog(func(err Error, msg string) { + errorLogged = true + capturedErr = err + capturedMsg = msg + }) + if err != nil { + t.Fatal("Failed to set error logger:", err) + } + + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + if _, err := db.Exec(`SELECT "foo"`); err != nil { + t.Fatal("SELECT failed:", err) + } + + if !errorLogged { + t.Fatal("No error was logged") + } + if capturedErr.Code != SQLITE_WARNING { + t.Errorf("Unexpected error log code: %d", capturedErr.Code) + } + if !strings.Contains(capturedMsg, "double-quoted string literal") { + t.Errorf("Unexpected error log message: '%s'", capturedMsg) + } +} + func TestStringContainingZero(t *testing.T) { tempFilename := TempFilename(t) defer os.Remove(tempFilename)