Skip to content

fix: closing db on finalizer and memory leak #177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ func (sc *statementCache) newDB(sqldb *sql.DB) *DB {
}

// lookupStmt checks if a Statement has been prepared on the db driver with the
// given primedSQL. If it has, the driver-prepared sql.Stmt is returned.
func (sc *statementCache) lookupStmt(db *DB, s *Statement, primedSQL string) (stmt *sql.Stmt, ok bool) {
// given primedSQL. If it has, the driverStmt is returned.
func (sc *statementCache) lookupStmt(db *DB, s *Statement, primedSQL string) (dStmt *driverStmt, ok bool) {
// The Statement cache ID is only removed from stmtDBCache when the
// finalizer is run. The Statement's cache ID must be in the stmtDBCache
// since we hold a reference to the Statement. It is therefore safe to
Expand All @@ -116,27 +116,32 @@ func (sc *statementCache) lookupStmt(db *DB, s *Statement, primedSQL string) (st
if !ok || ds.sql != primedSQL {
return nil, false
}
return ds.stmt, ok
return ds, ok
}

// driverPrepareStatement prepares a statement on the database and then stores
// the prepared *sql.Stmt in the cache.
func (sc *statementCache) driverPrepareStmt(ctx context.Context, db *DB, s *Statement, primedSQL string) (*sql.Stmt, error) {
func (sc *statementCache) driverPrepareStmt(ctx context.Context, db *DB, s *Statement, primedSQL string) (*driverStmt, error) {
sqlstmt, err := db.sqldb.PrepareContext(ctx, primedSQL)
if err != nil {
return nil, err
}

sc.mutex.Lock()
defer sc.mutex.Unlock()
// If there is already a statement in the cache, replace it with ours.
if driverStmt, ok := sc.stmtDBCache[s.cacheID][db.cacheID]; ok {
// Set a finalizer on the statement we evict from the cache to close it
// once current users have finished with it.
runtime.SetFinalizer(driverStmt.stmt, (*sql.Stmt).Close)

// If there is already a statement in the cache, set a finalizer on it to
// close it once concurrent users have finished with it and replace it with
// ours.
if ds, ok := sc.stmtDBCache[s.cacheID][db.cacheID]; ok {
runtime.SetFinalizer(ds, func(ds *driverStmt) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this the same as setting it on the way in? Why not do that so that you don't have to do the lookup each time you set it?

Copy link
Collaborator Author

@Aflynn50 Aflynn50 Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would have the same effect, but we are only setting this finalizer in the very specific case that a sqlair.Statement is executed with arguments that cause different SQL to be generate to previous executions, as can happen with a bulk insert or slice query.

This way, we are only setting the finalizer if this conflict has occurred, i.e. we don't set this finalizer on most driverStmts. Setting it on the way in would mean we set it for all driverStmts we create.

If we made this change we would avoid the lookup, but we would still have to set a finalizer here on the newly created driverStmt

We already set a finalizer on the sqlair.Statement when it is created but this could not be replaced with finalizers on the driverStmts because the driverStmts would then never be removed from the cache.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be safer if this was not a closure.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically you set if for all statements -1 i.e. except for the last one to (ever) enter the cache.

It is not unsafe to set it unconditionally (if no one references it, they can't use it), but you drop the lookup. The added simplicity is worth the change.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how we set this for all statements -1. This is only set when you have a sqlair.Statement that has been executed with new arguments that cause the SQL generated in BindInputs to come out different to the SQL that was generated for the previous execution. e.g. if you have a bulk insert inserting 3 items instead of 2.

In that case, when the bulk insert is run with 3 items, the lookup in Query will fail, which will cause the stmtCache.driverPrepareStmt to be run. When it hits the if statement that sets the finalizer, it will find the bulk insert sql with 2 arguments, set a finalizer on it, and implicitly evict it by overwriting it below.

In the case of a sqlair.Statement that never changes the SQL it generates driverPrepareStmt will only be run once (as the lookup will always subsequently succeed), and when it is the if statement that conditionally sets the finalizer will fail as the cacheID of the sqlair.Statement is not yet in the cache.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be safer if this was not a closure.

Changed to be a function with a comment, also updated the comment above the setting of this finalizer to make things clearer.

ds.stmt.Close()
})
}
sc.stmtDBCache[s.cacheID][db.cacheID] = &driverStmt{sql: primedSQL, stmt: sqlstmt}
ds := &driverStmt{sql: primedSQL, stmt: sqlstmt}
sc.stmtDBCache[s.cacheID][db.cacheID] = ds
sc.dbStmtCache[db.cacheID][s.cacheID] = true
return sqlstmt, nil
return ds, nil
}

// removeAndCloseStmtFunc removes and closes all sql.Stmt objects associated
Expand All @@ -153,8 +158,7 @@ func (sc *statementCache) removeAndCloseStmtFunc(s *Statement) {
}

// removeAndCloseDBFunc closes and removes from the cache all sql.Stmt objects
// prepared on the database, removes the database from then cache, then closes
// the sql.DB.
// prepared on the database, removes the database from then cache.
func (sc *statementCache) removeAndCloseDBFunc(db *DB) {
sc.mutex.Lock()
defer sc.mutex.Unlock()
Expand All @@ -165,5 +169,4 @@ func (sc *statementCache) removeAndCloseDBFunc(db *DB) {
delete(dbCache, db.cacheID)
}
delete(sc.dbStmtCache, db.cacheID)
db.sqldb.Close()
}
38 changes: 0 additions & 38 deletions cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,44 +79,6 @@ func (s *CacheSuite) TestPreparedStatementReuse(c *C) {
s.checkDriverStmtsAllClosed(c)
}

func (s *CacheSuite) TestClosingDB(c *C) {
stmt, err := Prepare(`SELECT 'test'`)
c.Assert(err, IsNil)

var dbID uint64
// For a Statement or DB to be removed from the cache it needs to go out of
// scope and be garbage collected. A function is used to "forget" the
// statement.
func() {
db := s.openDB(c)
dbID = db.cacheID

// Start a query with stmt on db. This will prepare the stmt on the db.
err = db.Query(nil, stmt).Run()
c.Assert(err, IsNil)

// Check a statement is in the cache and a prepared statement has been
// opened on the DB.
s.checkStmtInCache(c, db.cacheID, stmt.cacheID)
s.checkNumDBStmts(c, db.cacheID, 1)
s.checkDriverStmtsOpened(c, 1)
}()

s.triggerFinalizers()
s.checkDBNotInCache(c, dbID)
s.checkDriverStmtsAllClosed(c)

// Check that the statement runs fine on a new DB.
db := s.openDB(c)
err = db.Query(nil, stmt).Run()
c.Assert(err, IsNil)

// Check the statement has been added to the cache for the new DB.
s.checkStmtInCache(c, db.cacheID, stmt.cacheID)
s.checkNumDBStmts(c, db.cacheID, 1)
s.checkDriverStmtsOpened(c, 2)
}

func (s *CacheSuite) TestStatementPreparedAndClosed(c *C) {
db := s.openDB(c)

Expand Down
38 changes: 22 additions & 16 deletions sqlair.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ func (db *DB) PlainDB() *sql.DB {
// Query represents a query on a database. It is designed to be run once and
// used immediately since it contains the query context.
type Query struct {
// run executes the Query against the DB or the TX.
run func(context.Context) (*sql.Rows, sql.Result, error)
// run executes the Query against the DB or the TX. It returns the results
// and a pointer to the driverStmt used to run the query if it needs to be
// kept in memory.
run func(context.Context) (*sql.Rows, sql.Result, *driverStmt, error)
ctx context.Context
err error
pq *expr.PrimedQuery
Expand All @@ -120,6 +122,10 @@ type Iterator struct {
err error
result sql.Result
started bool
// ds is the driverStmt used to run the query. The Iterator holds onto this
// so that it cannot be closed by finalizer while the rows are being
// iterated over. This finalizer can be set in the cache.
ds *driverStmt
}

// Query builds a new query from a context, a [Statement] and the input
Expand All @@ -138,22 +144,22 @@ func (db *DB) Query(ctx context.Context, s *Statement, inputArgs ...any) *Query
return &Query{ctx: ctx, err: err}
}

run := func(innerCtx context.Context) (rows *sql.Rows, result sql.Result, err error) {
run := func(innerCtx context.Context) (rows *sql.Rows, result sql.Result, ds *driverStmt, err error) {
primedSQL := pq.SQL()
sqlstmt, ok := stmtCache.lookupStmt(db, s, primedSQL)
ds, ok := stmtCache.lookupStmt(db, s, primedSQL)
if !ok {
sqlstmt, err = stmtCache.driverPrepareStmt(ctx, db, s, primedSQL)
ds, err = stmtCache.driverPrepareStmt(ctx, db, s, primedSQL)
if err != nil {
return nil, nil, err
return nil, nil, ds, err
}
}

if pq.HasOutputs() {
rows, err = sqlstmt.QueryContext(innerCtx, pq.Params()...)
rows, err = ds.stmt.QueryContext(innerCtx, pq.Params()...)
} else {
result, err = sqlstmt.ExecContext(innerCtx, pq.Params()...)
result, err = ds.stmt.ExecContext(innerCtx, pq.Params()...)
}
return rows, result, err
return rows, result, ds, err
}

return &Query{pq: pq, run: run, ctx: ctx, err: nil}
Expand Down Expand Up @@ -215,7 +221,7 @@ func (q *Query) Iter() *Iterator {
}

var cols []string
rows, result, err := q.run(q.ctx)
rows, result, ds, err := q.run(q.ctx)
if q.pq.HasOutputs() {
if err == nil { // if err IS nil
cols, err = rows.Columns()
Expand All @@ -225,7 +231,7 @@ func (q *Query) Iter() *Iterator {
return &Iterator{pq: q.pq, err: err}
}

return &Iterator{pq: q.pq, rows: rows, cols: cols, err: err, result: result}
return &Iterator{pq: q.pq, rows: rows, cols: cols, err: err, result: result, ds: ds}
}

// Next prepares the next row for [Iterator.Get]. If an error occurs during
Expand Down Expand Up @@ -486,28 +492,28 @@ func (tx *TX) Query(ctx context.Context, s *Statement, inputArgs ...any) *Query
return &Query{ctx: ctx, err: err}
}

run := func(innerCtx context.Context) (rows *sql.Rows, result sql.Result, err error) {
sqlstmt, ok := stmtCache.lookupStmt(tx.db, s, pq.SQL())
run := func(innerCtx context.Context) (rows *sql.Rows, result sql.Result, ds *driverStmt, err error) {
ds, ok := stmtCache.lookupStmt(tx.db, s, pq.SQL())
if ok {
// Register the prepared statement on the transaction. This function
// does not resend the prepare request to the database.
// The txstmt is closed by database/sql when the transaction is
// commited or rolled back.
txstmt := tx.sqltx.Stmt(sqlstmt)
txstmt := tx.sqltx.Stmt(ds.stmt)
if pq.HasOutputs() {
rows, err = txstmt.QueryContext(innerCtx, pq.Params()...)
} else {
result, err = txstmt.ExecContext(innerCtx, pq.Params()...)
}
return rows, result, err
return rows, result, ds, err
}

if pq.HasOutputs() {
rows, err = tx.sqltx.QueryContext(innerCtx, pq.SQL(), pq.Params()...)
} else {
result, err = tx.sqltx.ExecContext(innerCtx, pq.SQL(), pq.Params()...)
}
return rows, result, err
return rows, result, nil, err
}

return &Query{pq: pq, ctx: ctx, run: run, err: nil}
Expand Down
Loading