Skip to content

Commit

Permalink
bridge,crypto: fix uses of deprecated NewRowIter
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Jan 29, 2025
1 parent 30ad8a9 commit 7c0ed06
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 29 deletions.
6 changes: 1 addition & 5 deletions bridge/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,11 @@ func (helper *CryptoHelper) Init(ctx context.Context) error {
func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) {
log := helper.log.With().Str("action", "resync encryption event").Logger()
rows, err := helper.bridge.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`)
roomIDs, err := dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList()
if err != nil {
log.Err(err).Msg("Failed to query rooms for resync")
return
}
roomIDs, err := dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList()
if err != nil {
log.Err(err).Msg("Failed to scan rooms for resync")
return
}
if len(roomIDs) > 0 {
log.Debug().Interface("room_ids", roomIDs).Msg("Resyncing rooms")
for _, roomID := range roomIDs {
Expand Down
6 changes: 1 addition & 5 deletions bridgev2/matrix/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,11 @@ func (helper *CryptoHelper) Init(ctx context.Context) error {
func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) {
log := helper.log.With().Str("action", "resync encryption event").Logger()
rows, err := helper.store.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`)
roomIDs, err := dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList()
if err != nil {
log.Err(err).Msg("Failed to query rooms for resync")
return
}
roomIDs, err := dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList()
if err != nil {
log.Err(err).Msg("Failed to scan rooms for resync")
return
}
if len(roomIDs) > 0 {
log.Debug().Interface("room_ids", roomIDs).Msg("Resyncing rooms")
for _, roomID := range roomIDs {
Expand Down
15 changes: 3 additions & 12 deletions crypto/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,10 +428,7 @@ func (store *SQLCryptoStore) RedactGroupSessions(ctx context.Context, roomID id.
AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL
RETURNING session_id
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, roomID, senderKey, store.AccountID)
if err != nil {
return nil, err
}
return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList()
return dbutil.NewRowIterWithError(res, dbutil.ScanSingleColumn[id.SessionID], err).AsList()
}

func (store *SQLCryptoStore) RedactExpiredGroupSessions(ctx context.Context) ([]id.SessionID, error) {
Expand Down Expand Up @@ -459,10 +456,7 @@ func (store *SQLCryptoStore) RedactExpiredGroupSessions(ctx context.Context) ([]
return nil, fmt.Errorf("unsupported dialect")
}
res, err := store.DB.Query(ctx, query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID)
if err != nil {
return nil, err
}
return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList()
return dbutil.NewRowIterWithError(res, dbutil.ScanSingleColumn[id.SessionID], err).AsList()
}

func (store *SQLCryptoStore) RedactOutdatedGroupSessions(ctx context.Context) ([]id.SessionID, error) {
Expand All @@ -472,10 +466,7 @@ func (store *SQLCryptoStore) RedactOutdatedGroupSessions(ctx context.Context) ([
WHERE account_id=$3 AND session IS NOT NULL AND received_at IS NULL
RETURNING session_id
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: outdated", store.AccountID)
if err != nil {
return nil, err
}
return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList()
return dbutil.NewRowIterWithError(res, dbutil.ScanSingleColumn[id.SessionID], err).AsList()
}

func (store *SQLCryptoStore) PutWithheldGroupSession(ctx context.Context, content event.RoomKeyWithheldEventContent) error {
Expand Down
12 changes: 5 additions & 7 deletions crypto/verificationhelper/verificationstore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package verificationhelper_test
import (
"context"
"database/sql"
"errors"

_ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog"
Expand Down Expand Up @@ -42,20 +43,17 @@ func NewSQLiteVerificationStore(ctx context.Context, db *sql.DB) (*SQLiteVerific

func (s *SQLiteVerificationStore) GetAllVerificationTransactions(ctx context.Context) ([]verificationhelper.VerificationTransaction, error) {
rows, err := s.db.QueryContext(ctx, selectVerifications)
if err != nil {
return nil, err
}
return dbutil.NewRowIter(rows, func(dbutil.Scannable) (txn verificationhelper.VerificationTransaction, err error) {
return dbutil.NewRowIterWithError(rows, func(dbutil.Scannable) (txn verificationhelper.VerificationTransaction, err error) {
err = rows.Scan(&dbutil.JSON{Data: &txn})
return
}).AsList()
}, err).AsList()
}

func (vq *SQLiteVerificationStore) GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (txn verificationhelper.VerificationTransaction, err error) {
zerolog.Ctx(ctx).Warn().Stringer("transaction_id", txnID).Msg("Getting verification transaction")
row := vq.db.QueryRowContext(ctx, getVerificationByTransactionID, txnID)
err = row.Scan(&dbutil.JSON{Data: &txn})
if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
err = verificationhelper.ErrUnknownVerificationTransaction
}
return
Expand All @@ -64,7 +62,7 @@ func (vq *SQLiteVerificationStore) GetVerificationTransaction(ctx context.Contex
func (vq *SQLiteVerificationStore) FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (txn verificationhelper.VerificationTransaction, err error) {
row := vq.db.QueryRowContext(ctx, getVerificationByUserDeviceID, userID, deviceID)
err = row.Scan(&dbutil.JSON{Data: &txn})
if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
err = verificationhelper.ErrUnknownVerificationTransaction
}
return
Expand Down

0 comments on commit 7c0ed06

Please sign in to comment.