-
-
Notifications
You must be signed in to change notification settings - Fork 246
/
Copy pathconnection_instrumented_test.go
101 lines (85 loc) · 2.37 KB
/
connection_instrumented_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
package pop
import (
"context"
"fmt"
"os"
"strings"
"sync"
"time"
"github.com/luna-duclos/instrumentedsql"
"github.com/stretchr/testify/suite"
)
func testInstrumentedDriver(p *suite.Suite) {
r := p.Require()
deets := *Connections[os.Getenv("SODA_DIALECT")].Dialect.Details()
ctx, cancel := context.WithTimeout(context.TODO(), time.Second*5)
defer cancel()
// The WaitGroup and channel ensures that the logger is properly called. This can only happen
// when the instrumented driver is working as expected and returns the expected query.
var (
queryMySQL = "SELECT 1 FROM DUAL WHERE 1=?"
queryOther = "SELECT 1 WHERE 1=?"
mc = make(chan string)
wg sync.WaitGroup
expected = []string{
"SELECT 1 FROM DUAL WHERE 1=?",
"SELECT 1 FROM DUAL WHERE 1=$1",
"SELECT 1 WHERE 1=?",
"SELECT 1 WHERE 1=$1",
}
)
query := queryOther
if os.Getenv("SODA_DIALECT") == "mysql" {
query = queryMySQL
}
wg.Add(1)
go func() {
defer wg.Done()
var messages []string
var found bool
for {
select {
case m := <-mc:
p.T().Logf("Received message: %s", m)
messages = append(messages, m)
for _, e := range expected {
if strings.Contains(m, e) {
p.T().Logf("Found part %s in %s", e, m)
found = true
break
}
}
case <-ctx.Done():
if !found {
r.FailNow(fmt.Sprintf("Expected tracer to return the \"%s\" query but only the following messages have been received:\n\n\t%s", query, strings.Join(messages, "\n\t")))
return
}
return
}
}
}()
var checker = instrumentedsql.LoggerFunc(func(ctx context.Context, msg string, keyvals ...interface{}) {
p.T().Logf("Instrumentation received message: %s - %+v", msg, keyvals)
mc <- fmt.Sprintf("%s - %+v", msg, keyvals)
})
deets.UseInstrumentedDriver = true
deets.InstrumentedDriverOptions = []instrumentedsql.Opt{instrumentedsql.WithLogger(checker)}
c, err := NewConnection(&deets)
r.NoError(err)
r.NoError(c.Open())
err = c.WithContext(context.TODO()).RawQuery(query, 1).Exec()
r.NoError(err)
wg.Wait()
}
func (s *PostgreSQLSuite) Test_Instrumentation() {
testInstrumentedDriver(&s.Suite)
}
func (s *MySQLSuite) Test_Instrumentation() {
testInstrumentedDriver(&s.Suite)
}
func (s *SQLiteSuite) Test_Instrumentation() {
testInstrumentedDriver(&s.Suite)
}
func (s *CockroachSuite) Test_Instrumentation() {
testInstrumentedDriver(&s.Suite)
}