Skip to content

Commit cad67fc

Browse files
authored
Merge pull request #1072 from ellemouton/sql39
[sql39]: Actions schemas, queries and SQL store impl
2 parents bd704c4 + 89b807c commit cad67fc

20 files changed

+905
-57
lines changed

config_dev.go

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -151,23 +151,19 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) {
151151

152152
stores.sessions = sessionStore
153153
stores.closeFns["bbolt-sessions"] = sessionStore.Close
154-
}
155154

156-
firewallBoltDB, err := firewalldb.NewBoltDB(
157-
networkDir, firewalldb.DBFilename, stores.sessions,
158-
stores.accounts, clock,
159-
)
160-
if err != nil {
161-
return stores, fmt.Errorf("error creating firewall BoltDB: %v",
162-
err)
163-
}
155+
firewallBoltDB, err := firewalldb.NewBoltDB(
156+
networkDir, firewalldb.DBFilename, stores.sessions,
157+
stores.accounts, clock,
158+
)
159+
if err != nil {
160+
return stores, fmt.Errorf("error creating firewall "+
161+
"BoltDB: %v", err)
162+
}
164163

165-
if stores.firewall == nil {
166164
stores.firewall = firewalldb.NewDB(firewallBoltDB)
165+
stores.closeFns["bbolt-firewalldb"] = firewallBoltDB.Close
167166
}
168167

169-
stores.firewallBolt = firewallBoltDB
170-
stores.closeFns["bbolt-firewalldb"] = firewallBoltDB.Close
171-
172168
return stores, nil
173169
}

config_prod.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) {
6262
if err != nil {
6363
return stores, fmt.Errorf("error creating firewall DB: %v", err)
6464
}
65-
stores.firewallBolt = firewallDB
6665
stores.firewall = firewalldb.NewDB(firewallDB)
6766
stores.closeFns["firewall"] = firewallDB.Close
6867

db/interfaces.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,13 @@ type BatchedQuerier interface {
8787
// create a batched version of the normal methods they need.
8888
sqlc.Querier
8989

90+
// CustomQueries is the set of custom queries that we have manually
91+
// defined in addition to the ones generated by sqlc.
92+
sqlc.CustomQueries
93+
9094
// BeginTx creates a new database transaction given the set of
9195
// transaction options.
9296
BeginTx(ctx context.Context, options TxOptions) (*sql.Tx, error)
93-
94-
// Backend returns the type of the database backend used.
95-
Backend() sqlc.BackendType
9697
}
9798

9899
// txExecutorOptions is a struct that holds the options for the transaction

db/migrations.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ const (
2222
// daemon.
2323
//
2424
// NOTE: This MUST be updated when a new migration is added.
25-
LatestMigrationVersion = 4
25+
LatestMigrationVersion = 5
2626
)
2727

2828
// MigrationTarget is a functional option that can be passed to applyMigrations

db/sqlc/actions.sql.go

Lines changed: 78 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

db/sqlc/actions_custom.go

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
package sqlc
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"strconv"
7+
"strings"
8+
)
9+
10+
// ActionQueryParams defines the parameters for querying actions.
11+
type ActionQueryParams struct {
12+
SessionID sql.NullInt64
13+
AccountID sql.NullInt64
14+
FeatureName sql.NullString
15+
ActorName sql.NullString
16+
RpcMethod sql.NullString
17+
State sql.NullInt16
18+
EndTime sql.NullTime
19+
StartTime sql.NullTime
20+
GroupID sql.NullInt64
21+
}
22+
23+
// ListActionsParams defines the parameters for listing actions, including
24+
// the ActionQueryParams for filtering and a Pagination struct for
25+
// pagination. The Reversed field indicates whether the results should be
26+
// returned in reverse order based on the created_at timestamp.
27+
type ListActionsParams struct {
28+
ActionQueryParams
29+
Reversed bool
30+
*Pagination
31+
}
32+
33+
// Pagination defines the pagination parameters for listing actions.
34+
type Pagination struct {
35+
NumOffset int32
36+
NumLimit int32
37+
}
38+
39+
// ListActions retrieves a list of actions based on the provided
40+
// ListActionsParams.
41+
func (q *Queries) ListActions(ctx context.Context,
42+
arg ListActionsParams) ([]Action, error) {
43+
44+
query, args := buildListActionsQuery(arg)
45+
rows, err := q.db.QueryContext(ctx, fillPlaceHolders(query), args...)
46+
if err != nil {
47+
return nil, err
48+
}
49+
defer rows.Close()
50+
var items []Action
51+
for rows.Next() {
52+
var i Action
53+
if err := rows.Scan(
54+
&i.ID,
55+
&i.SessionID,
56+
&i.AccountID,
57+
&i.MacaroonIdentifier,
58+
&i.ActorName,
59+
&i.FeatureName,
60+
&i.ActionTrigger,
61+
&i.Intent,
62+
&i.StructuredJsonData,
63+
&i.RpcMethod,
64+
&i.RpcParamsJson,
65+
&i.CreatedAt,
66+
&i.ActionState,
67+
&i.ErrorReason,
68+
); err != nil {
69+
return nil, err
70+
}
71+
items = append(items, i)
72+
}
73+
if err := rows.Close(); err != nil {
74+
return nil, err
75+
}
76+
if err := rows.Err(); err != nil {
77+
return nil, err
78+
}
79+
return items, nil
80+
}
81+
82+
// CountActions returns the number of actions that match the provided
83+
// ActionQueryParams.
84+
func (q *Queries) CountActions(ctx context.Context,
85+
arg ActionQueryParams) (int64, error) {
86+
87+
query, args := buildActionsQuery(arg, true)
88+
row := q.db.QueryRowContext(ctx, query, args...)
89+
90+
var count int64
91+
err := row.Scan(&count)
92+
93+
return count, err
94+
}
95+
96+
// buildActionsQuery constructs a SQL query to retrieve actions based on the
97+
// provided parameters. We do this manually so that if, for example, we have
98+
// a sessionID we are filtering by, then this appears in the query as:
99+
// `WHERE a.session_id = ?` which will properly make use of the underlying
100+
// index. If we were instead to use a single SQLC query, it would include many
101+
// WHERE clauses like:
102+
// "WHERE a.session_id = COALESCE(sqlc.narg('session_id'), a.session_id)".
103+
// This would use the index if run against postres but not when run against
104+
// sqlite.
105+
//
106+
// The 'count' param indicates whether the query should return a count of
107+
// actions that match the criteria or the actions themselves.
108+
func buildActionsQuery(params ActionQueryParams, count bool) (string, []any) {
109+
var (
110+
conditions []string
111+
args []any
112+
)
113+
114+
if params.SessionID.Valid {
115+
conditions = append(conditions, "a.session_id = ?")
116+
args = append(args, params.SessionID.Int64)
117+
}
118+
if params.AccountID.Valid {
119+
conditions = append(conditions, "a.account_id = ?")
120+
args = append(args, params.AccountID.Int64)
121+
}
122+
if params.FeatureName.Valid {
123+
conditions = append(conditions, "a.feature_name = ?")
124+
args = append(args, params.FeatureName.String)
125+
}
126+
if params.ActorName.Valid {
127+
conditions = append(conditions, "a.actor_name = ?")
128+
args = append(args, params.ActorName.String)
129+
}
130+
if params.RpcMethod.Valid {
131+
conditions = append(conditions, "a.rpc_method = ?")
132+
args = append(args, params.RpcMethod.String)
133+
}
134+
if params.State.Valid {
135+
conditions = append(conditions, "a.action_state = ?")
136+
args = append(args, params.State.Int16)
137+
}
138+
if params.EndTime.Valid {
139+
conditions = append(conditions, "a.created_at <= ?")
140+
args = append(args, params.EndTime.Time)
141+
}
142+
if params.StartTime.Valid {
143+
conditions = append(conditions, "a.created_at >= ?")
144+
args = append(args, params.StartTime.Time)
145+
}
146+
if params.GroupID.Valid {
147+
conditions = append(conditions, `
148+
EXISTS (
149+
SELECT 1
150+
FROM sessions s
151+
WHERE s.id = a.session_id AND s.group_id = ?
152+
)`)
153+
args = append(args, params.GroupID.Int64)
154+
}
155+
156+
query := "SELECT a.* FROM actions a"
157+
if count {
158+
query = "SELECT COUNT(*) FROM actions a"
159+
}
160+
if len(conditions) > 0 {
161+
query += " WHERE " + strings.Join(conditions, " AND ")
162+
}
163+
164+
return query, args
165+
}
166+
167+
// buildListActionsQuery constructs a SQL query to retrieve a list of actions
168+
// based on the provided parameters. It builds upon the `buildActionsQuery`
169+
// function, adding pagination and ordering based on the reversed parameter.
170+
func buildListActionsQuery(params ListActionsParams) (string, []interface{}) {
171+
query, args := buildActionsQuery(params.ActionQueryParams, false)
172+
173+
// Determine order direction.
174+
order := "ASC"
175+
if params.Reversed {
176+
order = "DESC"
177+
}
178+
query += " ORDER BY a.created_at " + order
179+
180+
// Maybe paginate.
181+
if params.Pagination != nil {
182+
query += " LIMIT ? OFFSET ?"
183+
args = append(args, params.NumLimit, params.NumOffset)
184+
}
185+
186+
return query, args
187+
}
188+
189+
// fillPlaceHolders replaces all '?' placeholders in the SQL query with
190+
// positional placeholders like $1, $2, etc. This is necessary for
191+
// compatibility with Postgres.
192+
func fillPlaceHolders(query string) string {
193+
var (
194+
sb strings.Builder
195+
argNum = 1
196+
)
197+
198+
for i := range len(query) {
199+
if query[i] != '?' {
200+
sb.WriteByte(query[i])
201+
continue
202+
}
203+
204+
sb.WriteString("$")
205+
sb.WriteString(strconv.Itoa(argNum))
206+
argNum++
207+
}
208+
209+
return sb.String()
210+
}

db/sqlc/db_custom.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
package sqlc
22

3+
import (
4+
"context"
5+
)
6+
37
// BackendType is an enum that represents the type of database backend we're
48
// using.
59
type BackendType uint8
@@ -44,3 +48,19 @@ func NewSqlite(db DBTX) *Queries {
4448
func NewPostgres(db DBTX) *Queries {
4549
return &Queries{db: &wrappedTX{db, BackendTypePostgres}}
4650
}
51+
52+
// CustomQueries defines a set of custom queries that we define in addition
53+
// to the ones generated by sqlc.
54+
type CustomQueries interface {
55+
// CountActions returns the number of actions that match the provided
56+
// ActionQueryParams.
57+
CountActions(ctx context.Context, arg ActionQueryParams) (int64, error)
58+
59+
// ListActions retrieves a list of actions based on the provided
60+
// ListActionsParams.
61+
ListActions(ctx context.Context,
62+
arg ListActionsParams) ([]Action, error)
63+
64+
// Backend returns the type of the database backend used.
65+
Backend() BackendType
66+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
DROP INDEX IF NOT EXISTS actions_state_idx;
2+
DROP INDEX IF NOT EXISTS actions_session_id_idx;
3+
DROP INDEX IF NOT EXISTS actions_feature_name_idx;
4+
DROP INDEX IF NOT EXISTS actions_created_at_idx;
5+
DROP TABLE IF EXISTS actions;

0 commit comments

Comments
 (0)