From f2bce110301fcb8b4ca1c1f783cec09f974d3284 Mon Sep 17 00:00:00 2001 From: Amirsalar Safaei Date: Tue, 31 Dec 2024 23:10:19 +0330 Subject: [PATCH] fix(logger): use logger from optCtx and write tests --- dbtracer/dbtracer.go | 3 +- dbtracer/dbtracer_test.go | 113 ++++++++++++++++++++++++++++++++++++-- dbtracer/traceprepare.go | 16 ++++-- dbtracer/tracequery.go | 10 +++- 4 files changed, 130 insertions(+), 12 deletions(-) diff --git a/dbtracer/dbtracer.go b/dbtracer/dbtracer.go index 880d572..3989de4 100644 --- a/dbtracer/dbtracer.go +++ b/dbtracer/dbtracer.go @@ -61,6 +61,7 @@ func NewDBTracer( unit: "s", name: "db_query_duration", }, + logger: slog.Default(), } for _, opt := range opts { opt(&optCtx) @@ -77,7 +78,7 @@ func NewDBTracer( } return &dbTracer{ - logger: slog.Default(), + logger: optCtx.logger, databaseName: databaseName, shouldLog: optCtx.shouldLog, logArgs: optCtx.logArgs, diff --git a/dbtracer/dbtracer_test.go b/dbtracer/dbtracer_test.go index b3ff038..0cbbaa9 100644 --- a/dbtracer/dbtracer_test.go +++ b/dbtracer/dbtracer_test.go @@ -1,9 +1,11 @@ package dbtracer import ( + "bytes" "context" "errors" "fmt" + "io" "log/slog" "math/rand" "sync" @@ -95,11 +97,12 @@ func (s *DBTracerSuite) SetupTest() { func (s *DBTracerSuite) TestNewDBTracer() { tests := []struct { - name string - databaseName string - opts []Option - setupMocks func(*mockmetric.MockMeterProvider, *mockmetric.MockMeter, *mockmetric.MockFloat64Histogram, *mocktracer.MockTracerProvider, *mocktracer.MockTracer) - wantErr bool + name string + databaseName string + opts []Option + setupMocks func(*mockmetric.MockMeterProvider, *mockmetric.MockMeter, *mockmetric.MockFloat64Histogram, *mocktracer.MockTracerProvider, *mocktracer.MockTracer) + validateTracer func(*DBTracerSuite, Tracer) + wantErr bool }{ { name: "successful creation with default options", @@ -118,6 +121,40 @@ func (s *DBTracerSuite) TestNewDBTracer() { Return(h, nil) }, wantErr: false, + validateTracer: func(s *DBTracerSuite, t Tracer) { + dbTracer, ok := t.(*dbTracer) + s.Require().True(ok) + s.Equal( + slog.Default(), + dbTracer.logger, + "Should use default logger when none specified", + ) + }, + }, + { + name: "successful creation with custom logger", + databaseName: "test_db", + opts: []Option{ + WithLogger(slog.New(slog.NewTextHandler(io.Discard, nil))), + }, + setupMocks: func(mp *mockmetric.MockMeterProvider, m *mockmetric.MockMeter, h *mockmetric.MockFloat64Histogram, tp *mocktracer.MockTracerProvider, t *mocktracer.MockTracer) { + mp.EXPECT(). + Meter("github.com/amirsalarsafaei/sqlc-pgx-monitoring"). + Return(m) + m.EXPECT(). + Float64Histogram( + "db_query_duration", + metric.WithDescription("The duration of database queries by sqlc function names"), + metric.WithUnit("s"), + ). + Return(h, nil) + }, + validateTracer: func(s *DBTracerSuite, t Tracer) { + dbTracer, ok := t.(*dbTracer) + s.Require().True(ok) + s.NotEqual(slog.Default(), dbTracer.logger, "Should use custom logger") + }, + wantErr: false, }, { name: "successful creation with custom histogram config", @@ -185,6 +222,9 @@ func (s *DBTracerSuite) TestNewDBTracer() { } else { s.NoError(err) s.NotNil(tracer) + if tt.validateTracer != nil { + tt.validateTracer(s, tracer) + } } }) } @@ -710,6 +750,69 @@ func (s *DBTracerSuite) TestTraceConcurrent() { s.NoError(errors.Join(errs...), "Expected no errors in concurrent execution") } +func (s *DBTracerSuite) TestLoggerBehavior() { + var logBuffer bytes.Buffer + customLogger := slog.New(slog.NewTextHandler(&logBuffer, &slog.HandlerOptions{ + Level: slog.LevelDebug, + })) + + tracer, err := NewDBTracer( + s.defaultDBName, + WithTraceProvider(s.tracerProvider), + WithMeterProvider(s.meterProvider), + WithLogger(customLogger), + WithShouldLog(func(err error) bool { return true }), + ) + s.Require().NoError(err) + + // Setup for query execution + s.tracer.EXPECT(). + Start(s.ctx, "postgresql.query"). + Return(s.ctx, s.span) + + s.span.EXPECT(). + SetAttributes( + attribute.String("db.name", s.defaultDBName), + attribute.String("db.query_name", "get_users"), + attribute.String("db.query_type", "one"), + attribute.String("db.operation", "query"), + ). + Return() + + ctx := tracer.TraceQueryStart(s.ctx, s.pgxConn, pgx.TraceQueryStartData{ + SQL: s.defaultQuerySQL, + Args: []interface{}{1}, + }) + + expectedErr := errors.New("test error code:9123") + + s.span.EXPECT(). + End(). + Return() + + s.span.EXPECT(). + SetStatus(codes.Error, expectedErr.Error()). + Return() + + s.span.EXPECT(). + RecordError(expectedErr). + Return() + + s.histogram.EXPECT(). + Record(ctx, mock.AnythingOfType("float64"), mock.Anything). + Return() + + tracer.TraceQueryEnd(ctx, s.pgxConn, pgx.TraceQueryEndData{ + CommandTag: pgconn.CommandTag{}, + Err: expectedErr, + }) + + logOutput := logBuffer.String() + s.Contains(logOutput, "test error code:9123") + s.Contains(logOutput, "get_users") + s.Contains(logOutput, "Query failed") +} + func (s *DBTracerSuite) TestTraceQueryEndOnError() { s.tracer.EXPECT(). Start(s.ctx, "postgresql.query"). diff --git a/dbtracer/traceprepare.go b/dbtracer/traceprepare.go index d67b1a9..e04dc38 100644 --- a/dbtracer/traceprepare.go +++ b/dbtracer/traceprepare.go @@ -21,7 +21,11 @@ type tracePrepareData struct { statementName string // 16 bytes } -func (dt *dbTracer) TracePrepareStart(ctx context.Context, _ *pgx.Conn, data pgx.TracePrepareStartData) context.Context { +func (dt *dbTracer) TracePrepareStart( + ctx context.Context, + _ *pgx.Conn, + data pgx.TracePrepareStartData, +) context.Context { queryName, queryType := queryNameFromSQL(data.SQL) ctx, span := dt.getTracer().Start(ctx, "postgresql.prepare") span.SetAttributes( @@ -41,7 +45,11 @@ func (dt *dbTracer) TracePrepareStart(ctx context.Context, _ *pgx.Conn, data pgx }) } -func (dt *dbTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { +func (dt *dbTracer) TracePrepareEnd( + ctx context.Context, + conn *pgx.Conn, + data pgx.TracePrepareEndData, +) { prepareData := ctx.Value(dbTracerPrepareCtxKey).(*tracePrepareData) defer prepareData.span.End() @@ -60,7 +68,7 @@ func (dt *dbTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pg prepareData.span.RecordError(data.Err) if dt.shouldLog(data.Err) { dt.logger.LogAttrs(ctx, slog.LevelError, - "Prepare", + "prepare failed", slog.String("statement_name", prepareData.statementName), slog.String("sql", prepareData.sql), slog.Duration("time", interval), @@ -71,7 +79,7 @@ func (dt *dbTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pg } else { prepareData.span.SetStatus(codes.Ok, "") dt.logger.LogAttrs(ctx, slog.LevelInfo, - "Prepare", + "prepare", slog.String("statement_name", prepareData.statementName), slog.String("sql", prepareData.sql), slog.Duration("time", interval), diff --git a/dbtracer/tracequery.go b/dbtracer/tracequery.go index 1130887..d36c9d0 100644 --- a/dbtracer/tracequery.go +++ b/dbtracer/tracequery.go @@ -22,7 +22,11 @@ type traceQueryData struct { startTime time.Time // 8 bytes } -func (dt *dbTracer) TraceQueryStart(ctx context.Context, _ *pgx.Conn, data pgx.TraceQueryStartData) context.Context { +func (dt *dbTracer) TraceQueryStart( + ctx context.Context, + _ *pgx.Conn, + data pgx.TraceQueryStartData, +) context.Context { queryName, queryType := queryNameFromSQL(data.SQL) ctx, span := dt.getTracer().Start(ctx, "postgresql.query") span.SetAttributes( @@ -62,8 +66,9 @@ func (dt *dbTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx. if dt.shouldLog(data.Err) { dt.logger.LogAttrs(ctx, slog.LevelError, - fmt.Sprintf("Query: %s", queryData.queryName), + fmt.Sprintf("Query failed: %s", queryData.queryName), slog.String("sql", queryData.sql), + slog.String("query_name", queryData.queryName), slog.Any("args", dt.logQueryArgs(queryData.args)), slog.String("query_type", queryData.queryType), slog.Duration("time", interval), @@ -76,6 +81,7 @@ func (dt *dbTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx. dt.logger.LogAttrs(ctx, slog.LevelInfo, fmt.Sprintf("Query: %s", queryData.queryName), slog.String("sql", queryData.sql), + slog.String("query_name", queryData.queryName), slog.String("query_type", queryData.queryType), slog.Any("args", dt.logQueryArgs(queryData.args)), slog.Duration("time", interval),