Skip to content

Commit 36657e4

Browse files
committed
CheckTxOptionsCompatibility using Connection.TransactionOptions
1 parent 717019a commit 36657e4

11 files changed

+208
-17
lines changed

config.go

+11-10
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,17 @@ import (
1111
// Config for a connection.
1212
// For tips see https://www.alexedwards.net/blog/configuring-sqldb
1313
type Config struct {
14-
Driver string `json:"driver"`
15-
Host string `json:"host"`
16-
Port uint16 `json:"port,omitempty"`
17-
User string `json:"user,omitempty"`
18-
Password string `json:"password,omitempty"`
19-
Database string `json:"database"`
20-
Extra map[string]string `json:"misc,omitempty"`
21-
MaxOpenConns int `json:"maxOpenConns,omitempty"`
22-
MaxIdleConns int `json:"maxIdleConns,omitempty"`
23-
ConnMaxLifetime time.Duration `json:"connMaxLifetime,omitempty"`
14+
Driver string `json:"driver"`
15+
Host string `json:"host"`
16+
Port uint16 `json:"port,omitempty"`
17+
User string `json:"user,omitempty"`
18+
Password string `json:"password,omitempty"`
19+
Database string `json:"database"`
20+
Extra map[string]string `json:"misc,omitempty"`
21+
MaxOpenConns int `json:"maxOpenConns,omitempty"`
22+
MaxIdleConns int `json:"maxIdleConns,omitempty"`
23+
ConnMaxLifetime time.Duration `json:"connMaxLifetime,omitempty"`
24+
DefaultIsolationLevel sql.IsolationLevel `json:"-"`
2425
}
2526

2627
// ConnectURL for connecting to a database

connection.go

+5
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ type Connection interface {
128128
// IsTransaction returns if the connection is a transaction
129129
IsTransaction() bool
130130

131+
// TransactionOptions returns the sql.TxOptions of the
132+
// current transaction and true as second result value,
133+
// or false if the connection is not a transaction.
134+
TransactionOptions() (*sql.TxOptions, bool)
135+
131136
// Begin a new transaction.
132137
// Returns ErrWithinTransaction if the connection
133138
// is already within a transaction.

errors.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func (e connectionWithError) Stats() sql.DBStats {
7373
}
7474

7575
func (e connectionWithError) Config() *Config {
76-
return nil
76+
return new(Config)
7777
}
7878

7979
func (e connectionWithError) Exec(query string, args ...interface{}) error {
@@ -148,6 +148,10 @@ func (e connectionWithError) IsTransaction() bool {
148148
return false
149149
}
150150

151+
func (ce connectionWithError) TransactionOptions() (*sql.TxOptions, bool) {
152+
return nil, false
153+
}
154+
151155
func (e connectionWithError) Begin(opts *sql.TxOptions) (Connection, error) {
152156
return nil, e.err
153157
}

mockconn/connection.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -153,16 +153,19 @@ func (conn *connection) QueryRows(query string, args ...interface{}) sqldb.RowsS
153153
return conn.rowsProvider.QueryRows(conn.structFieldNamer, query, args...)
154154
}
155155

156-
// IsTransaction returns if the connection is a transaction
157156
func (conn *connection) IsTransaction() bool {
158157
return false
159158
}
160159

160+
func (conn *connection) TransactionOptions() (*sql.TxOptions, bool) {
161+
return nil, false
162+
}
163+
161164
func (conn *connection) Begin(opts *sql.TxOptions) (sqldb.Connection, error) {
162165
if conn.queryWriter != nil {
163166
fmt.Fprint(conn.queryWriter, "BEGIN")
164167
}
165-
return transaction{conn}, nil
168+
return transaction{conn, opts}, nil
166169
}
167170

168171
func (conn *connection) Commit() error {

mockconn/connection_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ type testRow struct {
2424
Str string `db:"str"`
2525
StrPtr *string `db:"str_ptr"`
2626
NilPtr *byte `db:"nil_ptr"`
27+
ReadOnly int `db:"read_only,readonly"`
2728
Ignore int `db:"-"`
2829
UntaggedField int
2930
CreatedAt time.Time `db:"created_at"`

mockconn/transaction.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,24 @@ import (
1010

1111
type transaction struct {
1212
*connection
13+
opts *sql.TxOptions
1314
}
1415

1516
func (conn transaction) WithContext(ctx context.Context) sqldb.Connection {
1617
return transaction{
1718
connection: conn.connection.WithContext(ctx).(*connection), // TODO better way than type cast?
19+
opts: conn.opts,
1820
}
1921
}
2022

21-
// IsTransaction returns if the connection is a transaction
2223
func (conn transaction) IsTransaction() bool {
2324
return true
2425
}
2526

27+
func (conn transaction) TransactionOptions() (*sql.TxOptions, bool) {
28+
return conn.opts, true
29+
}
30+
2631
func (conn transaction) Begin(opts *sql.TxOptions) (sqldb.Connection, error) {
2732
return nil, sqldb.ErrWithinTransaction
2833
}

pqconn/connection.go

+7
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ func New(ctx context.Context, config *sqldb.Config) (sqldb.Connection, error) {
1717
if config.Driver != "postgres" {
1818
return nil, fmt.Errorf(`invalid driver %q, pqconn expects "postgres"`, config.Driver)
1919
}
20+
config.DefaultIsolationLevel = sql.LevelReadCommitted // postgres default
21+
2022
db, err := config.Connect(ctx)
2123
if err != nil {
2224
return nil, err
@@ -166,6 +168,10 @@ func (conn *connection) IsTransaction() bool {
166168
return false
167169
}
168170

171+
func (conn *connection) TransactionOptions() (*sql.TxOptions, bool) {
172+
return nil, false
173+
}
174+
169175
func (conn *connection) Begin(opts *sql.TxOptions) (sqldb.Connection, error) {
170176
tx, err := conn.db.BeginTx(conn.ctx, opts)
171177
if err != nil {
@@ -174,6 +180,7 @@ func (conn *connection) Begin(opts *sql.TxOptions) (sqldb.Connection, error) {
174180
return &transaction{
175181
connection: conn,
176182
tx: tx,
183+
opts: opts,
177184
structFieldNamer: conn.structFieldNamer,
178185
}, nil
179186
}

pqconn/transaction.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@ import (
1111
type transaction struct {
1212
*connection
1313
tx *sql.Tx
14+
opts *sql.TxOptions
1415
structFieldNamer sqldb.StructFieldNamer
1516
}
1617

1718
func (conn *transaction) WithContext(ctx context.Context) sqldb.Connection {
1819
return &transaction{
1920
connection: conn.connection.WithContext(ctx).(*connection), // TODO better way than type cast?
2021
tx: conn.tx,
22+
opts: conn.opts,
2123
structFieldNamer: conn.structFieldNamer,
2224
}
2325
}
@@ -26,6 +28,7 @@ func (conn *transaction) WithStructFieldNamer(namer sqldb.StructFieldNamer) sqld
2628
return &transaction{
2729
connection: conn.connection,
2830
tx: conn.tx,
31+
opts: conn.opts,
2932
structFieldNamer: namer,
3033
}
3134
}
@@ -89,11 +92,14 @@ func (conn *transaction) QueryRows(query string, args ...interface{}) sqldb.Rows
8992
return impl.NewRowsScanner(conn.connection.ctx, rows, conn.structFieldNamer, query, args)
9093
}
9194

92-
// IsTransaction returns if the connection is a transaction
9395
func (conn *transaction) IsTransaction() bool {
9496
return true
9597
}
9698

99+
func (conn *transaction) TransactionOptions() (*sql.TxOptions, bool) {
100+
return conn.opts, true
101+
}
102+
97103
func (conn *transaction) Begin(opts *sql.TxOptions) (sqldb.Connection, error) {
98104
return nil, sqldb.ErrWithinTransaction
99105
}

structfieldnamer_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func TestStructFieldTagNaming_StructFieldName(t *testing.T) {
5858
{name: "read_only", structField: st.Field(2), wantName: "read_only", wantFlags: FieldFlagReadOnly, wantOk: true},
5959
{name: "untagged_field", structField: st.Field(3), wantName: "untagged_field", wantFlags: 0, wantOk: true},
6060
{name: "ignore", structField: st.Field(4), wantName: "", wantFlags: 0, wantOk: false},
61-
{name: "pk_read_only", structField: st.Field(5), wantName: "pk_read_only", wantFlags: FieldFlagPrimaryKey + FieldFlagReadOnly, wantOk: true},
61+
{name: "pk_read_only", structField: st.Field(5), wantName: "pk_read_only", wantFlags: FieldFlagPrimaryKey | FieldFlagReadOnly, wantOk: true},
6262
{name: "no_flag", structField: st.Field(6), wantName: "no_flag", wantFlags: 0, wantOk: true},
6363
{name: "malformed_flags", structField: st.Field(7), wantName: "malformed_flags", wantFlags: FieldFlagReadOnly, wantOk: true},
6464
{name: "Embedded", structField: st.Field(8), wantName: "", wantFlags: 0, wantOk: true},

transaction.go

+38-1
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,16 @@ import (
1010
// Transaction returns all errors from txFunc or transaction commit errors happening after txFunc.
1111
// If parentConn is already a transaction, then it is passed through to txFunc unchanged as tx Connection
1212
// and no parentConn.Begin, Commit, or Rollback calls will occour within this Transaction call.
13+
// An error is returned, if the requested transaction options passed via opts
14+
// are stricter than the options of the parent transaction.
1315
// Errors and panics from txFunc will rollback the transaction if parentConn was not already a transaction.
1416
// Recovered panics are re-paniced and rollback errors after a panic are logged with ErrLogger.
1517
func Transaction(parentConn Connection, opts *sql.TxOptions, txFunc func(tx Connection) error) (err error) {
16-
if parentConn.IsTransaction() {
18+
if parentOpts, parentIsTx := parentConn.TransactionOptions(); parentIsTx {
19+
err = CheckTxOptionsCompatibility(parentOpts, opts, parentConn.Config().DefaultIsolationLevel)
20+
if err != nil {
21+
return err
22+
}
1723
return txFunc(parentConn)
1824
}
1925

@@ -52,3 +58,34 @@ func Transaction(parentConn Connection, opts *sql.TxOptions, txFunc func(tx Conn
5258

5359
return txFunc(tx)
5460
}
61+
62+
// CheckTxOptionsCompatibility returns an error
63+
// if the parent transaction options are less strict than the child options.
64+
func CheckTxOptionsCompatibility(parent, child *sql.TxOptions, defaultIsolation sql.IsolationLevel) error {
65+
var (
66+
parentReadOnly = false
67+
parentIsolation = defaultIsolation
68+
childReadOnly = false
69+
childIsolation = defaultIsolation
70+
)
71+
if parent != nil {
72+
parentReadOnly = parent.ReadOnly
73+
if parent.Isolation != sql.LevelDefault {
74+
parentIsolation = parent.Isolation
75+
}
76+
}
77+
if child != nil {
78+
childReadOnly = child.ReadOnly
79+
if child.Isolation != sql.LevelDefault {
80+
childIsolation = child.Isolation
81+
}
82+
}
83+
84+
if parentReadOnly && !childReadOnly {
85+
return errors.New("parent transaction is read-only but child is not")
86+
}
87+
if parentIsolation < childIsolation {
88+
return fmt.Errorf("parent transaction isolation level '%s' is less strict child level '%s'", parentIsolation, childIsolation)
89+
}
90+
return nil
91+
}

transaction_test.go

+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
package sqldb
2+
3+
import (
4+
"database/sql"
5+
"testing"
6+
)
7+
8+
func TestCheckTxOptionsCompatibility(t *testing.T) {
9+
type args struct {
10+
parent *sql.TxOptions
11+
child *sql.TxOptions
12+
defaultIsolation sql.IsolationLevel
13+
}
14+
tests := []struct {
15+
name string
16+
args args
17+
wantErr bool
18+
}{
19+
{
20+
name: "nil, nil",
21+
args: args{
22+
parent: nil,
23+
child: nil,
24+
},
25+
},
26+
{
27+
name: "nil, default",
28+
args: args{
29+
parent: nil,
30+
child: &sql.TxOptions{},
31+
},
32+
},
33+
{
34+
name: "default, nil",
35+
args: args{
36+
parent: &sql.TxOptions{},
37+
child: nil,
38+
},
39+
},
40+
{
41+
name: "default, default",
42+
args: args{
43+
parent: &sql.TxOptions{},
44+
child: &sql.TxOptions{},
45+
},
46+
},
47+
{
48+
name: "nil, ReadOnly",
49+
args: args{
50+
parent: nil,
51+
child: &sql.TxOptions{ReadOnly: true},
52+
},
53+
},
54+
{
55+
name: "ReadOnly, ReadOnly",
56+
args: args{
57+
parent: nil,
58+
child: &sql.TxOptions{ReadOnly: true},
59+
},
60+
wantErr: false,
61+
},
62+
{
63+
name: "ReadOnly, nil",
64+
args: args{
65+
parent: &sql.TxOptions{ReadOnly: true},
66+
child: nil,
67+
},
68+
wantErr: true,
69+
},
70+
{
71+
name: "ReadCommitted, ReadCommitted",
72+
args: args{
73+
parent: &sql.TxOptions{Isolation: sql.LevelReadCommitted},
74+
child: &sql.TxOptions{Isolation: sql.LevelReadCommitted},
75+
},
76+
},
77+
{
78+
name: "Serializable, ReadCommitted",
79+
args: args{
80+
parent: &sql.TxOptions{Isolation: sql.LevelSerializable},
81+
child: &sql.TxOptions{Isolation: sql.LevelReadCommitted},
82+
},
83+
},
84+
{
85+
name: "ReadCommitted, Serializable",
86+
args: args{
87+
parent: &sql.TxOptions{Isolation: sql.LevelReadCommitted},
88+
child: &sql.TxOptions{Isolation: sql.LevelSerializable},
89+
},
90+
wantErr: true,
91+
},
92+
{
93+
name: "ReadCommitted, Serializable/ReadOnly",
94+
args: args{
95+
parent: &sql.TxOptions{Isolation: sql.LevelReadCommitted},
96+
child: &sql.TxOptions{Isolation: sql.LevelSerializable, ReadOnly: true},
97+
},
98+
wantErr: true,
99+
},
100+
{
101+
name: "ReadCommitted/ReadOnly, ReadCommitted/ReadOnly",
102+
args: args{
103+
parent: &sql.TxOptions{Isolation: sql.LevelReadCommitted, ReadOnly: true},
104+
child: &sql.TxOptions{Isolation: sql.LevelReadCommitted, ReadOnly: true},
105+
},
106+
},
107+
{
108+
name: "Serializable/ReadOnly, ReadCommitted/ReadOnly",
109+
args: args{
110+
parent: &sql.TxOptions{Isolation: sql.LevelSerializable, ReadOnly: true},
111+
child: &sql.TxOptions{Isolation: sql.LevelReadCommitted, ReadOnly: true},
112+
},
113+
},
114+
}
115+
for _, tt := range tests {
116+
t.Run(tt.name, func(t *testing.T) {
117+
if err := CheckTxOptionsCompatibility(tt.args.parent, tt.args.child, tt.args.defaultIsolation); (err != nil) != tt.wantErr {
118+
t.Errorf("CheckTxOptionsCompatibility() error = %v, wantErr %v", err, tt.wantErr)
119+
}
120+
})
121+
}
122+
}

0 commit comments

Comments
 (0)