Skip to content

Commit

Permalink
Merge pull request #9 from amirsalarsafaei/fix-opt-ctx
Browse files Browse the repository at this point in the history
fix(logger): use logger from optCtx and write tests
  • Loading branch information
amirsalarsafaei authored Dec 31, 2024
2 parents 3235d3b + f2bce11 commit 4dcfabb
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 12 deletions.
3 changes: 2 additions & 1 deletion dbtracer/dbtracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func NewDBTracer(
unit: "s",
name: "db_query_duration",
},
logger: slog.Default(),
}
for _, opt := range opts {
opt(&optCtx)
Expand All @@ -77,7 +78,7 @@ func NewDBTracer(
}

return &dbTracer{
logger: slog.Default(),
logger: optCtx.logger,
databaseName: databaseName,
shouldLog: optCtx.shouldLog,
logArgs: optCtx.logArgs,
Expand Down
113 changes: 108 additions & 5 deletions dbtracer/dbtracer_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package dbtracer

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"log/slog"
"math/rand"
"sync"
Expand Down Expand Up @@ -95,11 +97,12 @@ func (s *DBTracerSuite) SetupTest() {

func (s *DBTracerSuite) TestNewDBTracer() {
tests := []struct {
name string
databaseName string
opts []Option
setupMocks func(*mockmetric.MockMeterProvider, *mockmetric.MockMeter, *mockmetric.MockFloat64Histogram, *mocktracer.MockTracerProvider, *mocktracer.MockTracer)
wantErr bool
name string
databaseName string
opts []Option
setupMocks func(*mockmetric.MockMeterProvider, *mockmetric.MockMeter, *mockmetric.MockFloat64Histogram, *mocktracer.MockTracerProvider, *mocktracer.MockTracer)
validateTracer func(*DBTracerSuite, Tracer)
wantErr bool
}{
{
name: "successful creation with default options",
Expand All @@ -118,6 +121,40 @@ func (s *DBTracerSuite) TestNewDBTracer() {
Return(h, nil)
},
wantErr: false,
validateTracer: func(s *DBTracerSuite, t Tracer) {
dbTracer, ok := t.(*dbTracer)
s.Require().True(ok)
s.Equal(
slog.Default(),
dbTracer.logger,
"Should use default logger when none specified",
)
},
},
{
name: "successful creation with custom logger",
databaseName: "test_db",
opts: []Option{
WithLogger(slog.New(slog.NewTextHandler(io.Discard, nil))),
},
setupMocks: func(mp *mockmetric.MockMeterProvider, m *mockmetric.MockMeter, h *mockmetric.MockFloat64Histogram, tp *mocktracer.MockTracerProvider, t *mocktracer.MockTracer) {
mp.EXPECT().
Meter("github.com/amirsalarsafaei/sqlc-pgx-monitoring").
Return(m)
m.EXPECT().
Float64Histogram(
"db_query_duration",
metric.WithDescription("The duration of database queries by sqlc function names"),
metric.WithUnit("s"),
).
Return(h, nil)
},
validateTracer: func(s *DBTracerSuite, t Tracer) {
dbTracer, ok := t.(*dbTracer)
s.Require().True(ok)
s.NotEqual(slog.Default(), dbTracer.logger, "Should use custom logger")
},
wantErr: false,
},
{
name: "successful creation with custom histogram config",
Expand Down Expand Up @@ -185,6 +222,9 @@ func (s *DBTracerSuite) TestNewDBTracer() {
} else {
s.NoError(err)
s.NotNil(tracer)
if tt.validateTracer != nil {
tt.validateTracer(s, tracer)
}
}
})
}
Expand Down Expand Up @@ -710,6 +750,69 @@ func (s *DBTracerSuite) TestTraceConcurrent() {
s.NoError(errors.Join(errs...), "Expected no errors in concurrent execution")
}

func (s *DBTracerSuite) TestLoggerBehavior() {
var logBuffer bytes.Buffer
customLogger := slog.New(slog.NewTextHandler(&logBuffer, &slog.HandlerOptions{
Level: slog.LevelDebug,
}))

tracer, err := NewDBTracer(
s.defaultDBName,
WithTraceProvider(s.tracerProvider),
WithMeterProvider(s.meterProvider),
WithLogger(customLogger),
WithShouldLog(func(err error) bool { return true }),
)
s.Require().NoError(err)

// Setup for query execution
s.tracer.EXPECT().
Start(s.ctx, "postgresql.query").
Return(s.ctx, s.span)

s.span.EXPECT().
SetAttributes(
attribute.String("db.name", s.defaultDBName),
attribute.String("db.query_name", "get_users"),
attribute.String("db.query_type", "one"),
attribute.String("db.operation", "query"),
).
Return()

ctx := tracer.TraceQueryStart(s.ctx, s.pgxConn, pgx.TraceQueryStartData{
SQL: s.defaultQuerySQL,
Args: []interface{}{1},
})

expectedErr := errors.New("test error code:9123")

s.span.EXPECT().
End().
Return()

s.span.EXPECT().
SetStatus(codes.Error, expectedErr.Error()).
Return()

s.span.EXPECT().
RecordError(expectedErr).
Return()

s.histogram.EXPECT().
Record(ctx, mock.AnythingOfType("float64"), mock.Anything).
Return()

tracer.TraceQueryEnd(ctx, s.pgxConn, pgx.TraceQueryEndData{
CommandTag: pgconn.CommandTag{},
Err: expectedErr,
})

logOutput := logBuffer.String()
s.Contains(logOutput, "test error code:9123")
s.Contains(logOutput, "get_users")
s.Contains(logOutput, "Query failed")
}

func (s *DBTracerSuite) TestTraceQueryEndOnError() {
s.tracer.EXPECT().
Start(s.ctx, "postgresql.query").
Expand Down
16 changes: 12 additions & 4 deletions dbtracer/traceprepare.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ type tracePrepareData struct {
statementName string // 16 bytes
}

func (dt *dbTracer) TracePrepareStart(ctx context.Context, _ *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
func (dt *dbTracer) TracePrepareStart(
ctx context.Context,
_ *pgx.Conn,
data pgx.TracePrepareStartData,
) context.Context {
queryName, queryType := queryNameFromSQL(data.SQL)
ctx, span := dt.getTracer().Start(ctx, "postgresql.prepare")
span.SetAttributes(
Expand All @@ -41,7 +45,11 @@ func (dt *dbTracer) TracePrepareStart(ctx context.Context, _ *pgx.Conn, data pgx
})
}

func (dt *dbTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
func (dt *dbTracer) TracePrepareEnd(
ctx context.Context,
conn *pgx.Conn,
data pgx.TracePrepareEndData,
) {
prepareData := ctx.Value(dbTracerPrepareCtxKey).(*tracePrepareData)
defer prepareData.span.End()

Expand All @@ -60,7 +68,7 @@ func (dt *dbTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pg
prepareData.span.RecordError(data.Err)
if dt.shouldLog(data.Err) {
dt.logger.LogAttrs(ctx, slog.LevelError,
"Prepare",
"prepare failed",
slog.String("statement_name", prepareData.statementName),
slog.String("sql", prepareData.sql),
slog.Duration("time", interval),
Expand All @@ -71,7 +79,7 @@ func (dt *dbTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pg
} else {
prepareData.span.SetStatus(codes.Ok, "")
dt.logger.LogAttrs(ctx, slog.LevelInfo,
"Prepare",
"prepare",
slog.String("statement_name", prepareData.statementName),
slog.String("sql", prepareData.sql),
slog.Duration("time", interval),
Expand Down
10 changes: 8 additions & 2 deletions dbtracer/tracequery.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ type traceQueryData struct {
startTime time.Time // 8 bytes
}

func (dt *dbTracer) TraceQueryStart(ctx context.Context, _ *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
func (dt *dbTracer) TraceQueryStart(
ctx context.Context,
_ *pgx.Conn,
data pgx.TraceQueryStartData,
) context.Context {
queryName, queryType := queryNameFromSQL(data.SQL)
ctx, span := dt.getTracer().Start(ctx, "postgresql.query")
span.SetAttributes(
Expand Down Expand Up @@ -62,8 +66,9 @@ func (dt *dbTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.

if dt.shouldLog(data.Err) {
dt.logger.LogAttrs(ctx, slog.LevelError,
fmt.Sprintf("Query: %s", queryData.queryName),
fmt.Sprintf("Query failed: %s", queryData.queryName),
slog.String("sql", queryData.sql),
slog.String("query_name", queryData.queryName),
slog.Any("args", dt.logQueryArgs(queryData.args)),
slog.String("query_type", queryData.queryType),
slog.Duration("time", interval),
Expand All @@ -76,6 +81,7 @@ func (dt *dbTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.
dt.logger.LogAttrs(ctx, slog.LevelInfo,
fmt.Sprintf("Query: %s", queryData.queryName),
slog.String("sql", queryData.sql),
slog.String("query_name", queryData.queryName),
slog.String("query_type", queryData.queryType),
slog.Any("args", dt.logQueryArgs(queryData.args)),
slog.Duration("time", interval),
Expand Down

0 comments on commit 4dcfabb

Please sign in to comment.