Skip to content

Commit 6693fbc

Browse files
firewalldb: use queries to assert migration results
As the firewalldb package kvdb to sql migration tests creates `sqlc` models to assert the migration results, we will need to update those call sites to instead use the `sqlcmig6` models instead, in order to be compatible with the `sqlcmig6.Queries` queries. However, since we can't update the `SQLDB` methods to use `sqlcmig6` models as params, we need to update the test code assertion to instead use the `sqlc.Queries` object directly instead of the `SQLDB` object. This makes it easy to swap that `sqlc.Queries` object to a `sqlcmig6.Queries` object in the commit that updates the firewalldb package to use the `sqlcmig6` package for the kvdb to sql migration.
1 parent 91432b0 commit 6693fbc

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

firewalldb/sql_migration_test.go

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func TestFirewallDBMigration(t *testing.T) {
7474
// The assertKvStoreMigrationResults function will currently assert that
7575
// the migrated kv stores entries in the SQLDB match the original kv
7676
// stores entries in the BoltDB.
77-
assertKvStoreMigrationResults := func(t *testing.T, store *SQLDB,
77+
assertKvStoreMigrationResults := func(t *testing.T, store *sqlc.Queries,
7878
kvEntries []*kvEntry) {
7979

8080
var (
@@ -87,9 +87,7 @@ func TestFirewallDBMigration(t *testing.T) {
8787
getRuleID := func(ruleName string) int64 {
8888
ruleID, ok := ruleIDs[ruleName]
8989
if !ok {
90-
ruleID, err = store.db.GetRuleID(
91-
ctx, ruleName,
92-
)
90+
ruleID, err = store.GetRuleID(ctx, ruleName)
9391
require.NoError(t, err)
9492

9593
ruleIDs[ruleName] = ruleID
@@ -101,7 +99,7 @@ func TestFirewallDBMigration(t *testing.T) {
10199
getGroupID := func(groupAlias []byte) int64 {
102100
groupID, ok := groupIDs[string(groupAlias)]
103101
if !ok {
104-
groupID, err = store.db.GetSessionIDByAlias(
102+
groupID, err = store.GetSessionIDByAlias(
105103
ctx, groupAlias,
106104
)
107105
require.NoError(t, err)
@@ -115,7 +113,7 @@ func TestFirewallDBMigration(t *testing.T) {
115113
getFeatureID := func(featureName string) int64 {
116114
featureID, ok := featureIDs[featureName]
117115
if !ok {
118-
featureID, err = store.db.GetFeatureID(
116+
featureID, err = store.GetFeatureID(
119117
ctx, featureName,
120118
)
121119
require.NoError(t, err)
@@ -126,10 +124,10 @@ func TestFirewallDBMigration(t *testing.T) {
126124
return featureID
127125
}
128126

129-
// First we extract all migrated kv entries from the SQLDB,
127+
// First we extract all migrated kv entries from the store,
130128
// in order to be able to compare them to the original kv
131129
// entries, to ensure that the migration was successful.
132-
sqlKvEntries, err := store.db.ListAllKVStoresRecords(ctx)
130+
sqlKvEntries, err := store.ListAllKVStoresRecords(ctx)
133131
require.NoError(t, err)
134132
require.Equal(t, len(kvEntries), len(sqlKvEntries))
135133

@@ -145,7 +143,7 @@ func TestFirewallDBMigration(t *testing.T) {
145143
ruleID := getRuleID(entry.ruleName)
146144

147145
if entry.groupAlias.IsNone() {
148-
sqlVal, err := store.db.GetGlobalKVStoreRecord(
146+
sqlVal, err := store.GetGlobalKVStoreRecord(
149147
ctx,
150148
sqlc.GetGlobalKVStoreRecordParams{
151149
Key: entry.key,
@@ -163,7 +161,7 @@ func TestFirewallDBMigration(t *testing.T) {
163161
groupAlias := entry.groupAlias.UnwrapOrFail(t)
164162
groupID := getGroupID(groupAlias[:])
165163

166-
v, err := store.db.GetGroupKVStoreRecord(
164+
v, err := store.GetGroupKVStoreRecord(
167165
ctx,
168166
sqlc.GetGroupKVStoreRecordParams{
169167
Key: entry.key,
@@ -188,7 +186,7 @@ func TestFirewallDBMigration(t *testing.T) {
188186
entry.featureName.UnwrapOrFail(t),
189187
)
190188

191-
sqlVal, err := store.db.GetFeatureKVStoreRecord(
189+
sqlVal, err := store.GetFeatureKVStoreRecord(
192190
ctx,
193191
sqlc.GetFeatureKVStoreRecordParams{
194192
Key: entry.key,
@@ -219,14 +217,14 @@ func TestFirewallDBMigration(t *testing.T) {
219217
// BoltDB. It also asserts that the SQL DB does not contain any other
220218
// privacy pairs than the expected ones.
221219
assertPrivacyMapperMigrationResults := func(t *testing.T,
222-
sqlStore *SQLDB, privPairs privacyPairs) {
220+
sqlStore *sqlc.Queries, privPairs privacyPairs) {
223221

224222
var totalExpectedPairs, totalPairs int
225223

226224
// First assert that the SQLDB contains the expected privacy
227225
// pairs.
228226
for groupID, groupPairs := range privPairs {
229-
storePairs, err := sqlStore.db.GetAllPrivacyPairs(
227+
storePairs, err := sqlStore.GetAllPrivacyPairs(
230228
ctx, groupID,
231229
)
232230
require.NoError(t, err)
@@ -246,14 +244,13 @@ func TestFirewallDBMigration(t *testing.T) {
246244
}
247245
}
248246

249-
// Then assert that SQLDB doesn't contain any other privacy
247+
// Then assert that store doesn't contain any other privacy
250248
// pairs than the expected ones.
251-
queries := sqlc.NewForType(sqlStore, sqlStore.BackendType)
252-
sessions, err := queries.ListSessions(ctx)
249+
sessions, err := sqlStore.ListSessions(ctx)
253250
require.NoError(t, err)
254251

255252
for _, dbSession := range sessions {
256-
sessionPairs, err := sqlStore.db.GetAllPrivacyPairs(
253+
sessionPairs, err := sqlStore.GetAllPrivacyPairs(
257254
ctx, dbSession.ID,
258255
)
259256
if errors.Is(err, sql.ErrNoRows) {
@@ -272,7 +269,7 @@ func TestFirewallDBMigration(t *testing.T) {
272269
// The assertMigrationResults asserts that the migrated entries in the
273270
// firewall SQLDB match the expected results which should represent the
274271
// original entries in the BoltDB.
275-
assertMigrationResults := func(t *testing.T, sqlStore *SQLDB,
272+
assertMigrationResults := func(t *testing.T, sqlStore *sqlc.Queries,
276273
expRes *expectedResult) {
277274

278275
// Assert that the kv store migration results match the expected
@@ -400,7 +397,10 @@ func TestFirewallDBMigration(t *testing.T) {
400397
require.NoError(t, err)
401398

402399
// Assert migration results.
403-
assertMigrationResults(t, sqlStore, entries)
400+
queries := sqlc.NewForType(
401+
sqlStore, sqlStore.BackendType,
402+
)
403+
assertMigrationResults(t, queries, entries)
404404
})
405405
}
406406
}

0 commit comments

Comments
 (0)