From ce110ca46dfff542d27b92527622a3cd164e9258 Mon Sep 17 00:00:00 2001 From: Amirsalar Safaei Date: Wed, 25 Dec 2024 22:41:01 +0330 Subject: [PATCH] a bit of refactor --- .github/workflows/go.yml | 10 +- README.md | 18 ++- dbtracer/dbtracer.go | 325 -------------------------------------- dbtracer/dbtracer_test.go | 4 +- dbtracer/tracebatch.go | 100 ++++++++++++ dbtracer/traceconnect.go | 68 ++++++++ dbtracer/tracecopyfrom.go | 74 +++++++++ dbtracer/traceprepare.go | 80 ++++++++++ dbtracer/tracequery.go | 86 ++++++++++ 9 files changed, 432 insertions(+), 333 deletions(-) create mode 100644 dbtracer/tracebatch.go create mode 100644 dbtracer/traceconnect.go create mode 100644 dbtracer/tracecopyfrom.go create mode 100644 dbtracer/traceprepare.go create mode 100644 dbtracer/tracequery.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index fa3782a..cee2c04 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -21,5 +21,11 @@ jobs: with: go-version: '1.23' - - name: Test - run: go test ./... + - name: Test with Coverage + run: go test -race -coverprofile=coverage.txt -covermode=atomic ./... + + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + slug: amirsalarsafaei/sqlc-pgx-monitoring diff --git a/README.md b/README.md index 8b6f78d..4cf1c1c 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,21 @@ # sqlc-pgx-monitoring ![GitHub license](https://img.shields.io/badge/license-MIT-blue.svg) +[![codecov](https://codecov.io/gh/vgarvardt/sqlc-pgx-monitoring/branch/main/graph/badge.svg)](https://codecov.io/gh/vgarvardt/sqlc-pgx-monitoring) `sqlc-pgx-monitoring` is a Go package that offers powerful query time monitoring and logging capabilities for applications using the popular `pgx` and `sqlc` libraries in Golang. If you want to gain insights into the performance of your PostgreSQL database queries and ensure the reliability of your application, this package is a valuable addition to your toolset. ## Features +- **Complete OpenTelemetry Support**: Built-in integration with OpenTelemetry for comprehensive observability, including metrics, traces, and spans for all database operations. Traces every database interaction including: + - Individual queries + - Batch operations + - Prepared statements + - Connection lifecycle + - COPY FROM operations + +- **Modern Structured Logging**: Native support for Go's `slog` package, providing structured, leveled logging that's easy to parse and analyze. + - **Query Time Monitoring**: Keep a close eye on the execution times of your SQL queries to identify and optimize slow or resource-intensive database operations. It uses name declared in sqlc queries in the label for distinguishing queries from each other. - **Detailed Logging**: Record detailed logs of executed queries, including name, parameters, timings, and outcomes, which can be invaluable for debugging and performance analysis. @@ -17,7 +27,7 @@ To get started with `sqlc-pgx-monitoring`, you can simply use `go get`: ```shell -go get github.com/amirsalarsafaei/sqlc-pgx-monitoring@latest +go get github.com/amirsalarsafaei/sqlc-pgx-monitoring@v1.4.0 ``` ## Usage @@ -34,15 +44,13 @@ To begin using `sqlc-pgx-monitoring` in your Go project, follow these basic step ### pgx.Conn ```go connConfig.Tracer = dbtracer.NewDBTracer( - logrus.New(), - prometheus.DefaultRegisterer, + "database_name", ) ``` ### pgxpool.Pool ```go poolConfig.ConnConfig.Tracer = dbtracer.NewDBTracer( - logrus.New(), - prometheus.DefaultRegisterer, + "database_name", ) ``` diff --git a/dbtracer/dbtracer.go b/dbtracer/dbtracer.go index a1d5ac0..789a0c1 100644 --- a/dbtracer/dbtracer.go +++ b/dbtracer/dbtracer.go @@ -1,17 +1,13 @@ package dbtracer import ( - "context" "encoding/hex" "fmt" "log/slog" - "time" "unicode/utf8" "github.com/jackc/pgx/v5" "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/trace" ) @@ -93,327 +89,6 @@ const ( dbTracerPrepareCtxKey ) -type traceQueryData struct { - args []any // 24 bytes - span trace.Span // 16 bytes - sql string // 16 bytes - queryName string // 16 bytes - startTime time.Time // 8 bytes -} - -func (dt *dbTracer) TraceQueryStart(ctx context.Context, _ *pgx.Conn, data pgx.TraceQueryStartData) context.Context { - queryName, queryType := queryNameFromSQL(data.SQL) - ctx, span := dt.tracer.Start(ctx, "postgresql.query") - span.SetAttributes( - attribute.String("db.name", dt.databaseName), - attribute.String("db.query_name", queryName), - attribute.String("db.query_type", queryType), - attribute.String("db.operation", "query"), - ) - return context.WithValue(ctx, dbTracerQueryCtxKey, &traceQueryData{ - startTime: time.Now(), - sql: data.SQL, - args: data.Args, - queryName: queryName, - span: span, - }) -} - -func (dt *dbTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { - queryData := ctx.Value(dbTracerQueryCtxKey).(*traceQueryData) - - endTime := time.Now() - interval := endTime.Sub(queryData.startTime) - dt.histogram.Record(ctx, interval.Seconds(), metric.WithAttributes( - attribute.String("operation", "query"), - attribute.String("query_name", queryData.queryName), - )) - - defer queryData.span.End() - - if data.Err != nil { - queryData.span.SetStatus(codes.Error, data.Err.Error()) - queryData.span.RecordError(data.Err) - - if dt.shouldLog(data.Err) { - dt.logger.LogAttrs(ctx, slog.LevelError, - fmt.Sprintf("Query: %s", queryData.queryName), - slog.String("sql", queryData.sql), - slog.Any("args", dt.logQueryArgs(queryData.args)), - slog.Duration("time", interval), - slog.Uint64("pid", uint64(extractConnectionID(conn))), - slog.String("error", data.Err.Error()), - ) - } - } else { - queryData.span.SetStatus(codes.Ok, "") - dt.logger.LogAttrs(ctx, slog.LevelInfo, - fmt.Sprintf("Query: %s", queryData.queryName), - slog.String("sql", queryData.sql), - slog.Any("args", dt.logQueryArgs(queryData.args)), - slog.Duration("time", interval), - slog.Uint64("pid", uint64(extractConnectionID(conn))), - slog.String("commandTag", data.CommandTag.String()), - ) - } -} - -type traceBatchData struct { - span trace.Span // 16 bytes - startTime time.Time // 16 bytes - queryName string // 16 bytes -} - -func (dt *dbTracer) TraceBatchStart(ctx context.Context, _ *pgx.Conn, _ pgx.TraceBatchStartData) context.Context { - ctx, span := dt.tracer.Start(ctx, "postgresql.batch") - span.SetAttributes( - attribute.String("db.name", dt.databaseName), - attribute.String("db.operation", "batch"), - ) - return context.WithValue(ctx, dbTracerBatchCtxKey, &traceBatchData{ - startTime: time.Now(), - span: span, - }) -} - -func (dt *dbTracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { - queryData := ctx.Value(dbTracerBatchCtxKey).(*traceBatchData) - queryName, queryType := queryNameFromSQL(data.SQL) - queryData.queryName = queryName - queryData.span.SetAttributes( - attribute.String("db.query_name", queryName), - attribute.String("db.query_type", queryType), - ) - - if data.Err != nil { - queryData.span.SetStatus(codes.Error, data.Err.Error()) - queryData.span.RecordError(data.Err) - - if dt.shouldLog(data.Err) { - dt.logger.LogAttrs(ctx, slog.LevelError, - "Query", - slog.String("sql", data.SQL), - slog.Any("args", dt.logQueryArgs(data.Args)), - slog.Uint64("pid", uint64(extractConnectionID(conn))), - slog.String("error", data.Err.Error()), - ) - } - } else { - queryData.span.SetStatus(codes.Ok, "") - dt.logger.LogAttrs(ctx, slog.LevelInfo, - "Query", - slog.String("sql", data.SQL), - slog.Any("args", dt.logQueryArgs(data.Args)), - slog.Uint64("pid", uint64(extractConnectionID(conn))), - slog.String("commandTag", data.CommandTag.String()), - ) - } -} - -func (dt *dbTracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { - queryData := ctx.Value(dbTracerBatchCtxKey).(*traceBatchData) - defer queryData.span.End() - - endTime := time.Now() - interval := endTime.Sub(queryData.startTime) - dt.histogram.Record(ctx, interval.Seconds(), metric.WithAttributes( - attribute.String("operation", "batch"), - attribute.String("query_name", queryData.queryName), - )) - - if data.Err != nil { - queryData.span.SetStatus(codes.Error, data.Err.Error()) - queryData.span.RecordError(data.Err) - - if dt.shouldLog(data.Err) { - dt.logger.LogAttrs(ctx, slog.LevelError, - fmt.Sprintf("Query: %s", queryData.queryName), - slog.Duration("interval", interval), - slog.Uint64("pid", uint64(extractConnectionID(conn))), - slog.String("error", data.Err.Error()), - ) - } - } else { - queryData.span.SetStatus(codes.Ok, "") - dt.logger.LogAttrs(ctx, slog.LevelInfo, - "Query", - slog.Duration("interval", interval), - slog.Uint64("pid", uint64(extractConnectionID(conn))), - ) - } -} - -type traceCopyFromData struct { - ColumnNames []string // 24 bytes - span trace.Span // 16 bytes - startTime time.Time // 16 bytes - TableName pgx.Identifier // slice - 24 bytes -} - -func (dt *dbTracer) TraceCopyFromStart(ctx context.Context, _ *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context { - ctx, span := dt.tracer.Start(ctx, "postgresql.copy_from") - span.SetAttributes( - attribute.String("db.name", dt.databaseName), - attribute.String("db.operation", "copy"), - attribute.String("db.table", data.TableName.Sanitize()), - ) - return context.WithValue(ctx, dbTracerCopyFromCtxKey, &traceCopyFromData{ - startTime: time.Now(), - TableName: data.TableName, - ColumnNames: data.ColumnNames, - span: span, - }) -} - -func (dt *dbTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { - copyFromData := ctx.Value(dbTracerCopyFromCtxKey).(*traceCopyFromData) - defer copyFromData.span.End() - - endTime := time.Now() - interval := endTime.Sub(copyFromData.startTime) - dt.histogram.Record(ctx, interval.Seconds(), metric.WithAttributes( - attribute.String("operation", "copy"), - attribute.String("table", copyFromData.TableName.Sanitize()), - )) - - if data.Err != nil { - copyFromData.span.SetStatus(codes.Error, data.Err.Error()) - copyFromData.span.RecordError(data.Err) - - if dt.shouldLog(data.Err) { - dt.logger.LogAttrs(ctx, slog.LevelError, - "CopyFrom", - slog.Any("tableName", copyFromData.TableName), - slog.Any("columnNames", copyFromData.ColumnNames), - slog.Duration("time", interval), - slog.Uint64("pid", uint64(extractConnectionID(conn))), - slog.String("error", data.Err.Error()), - ) - } - } else { - copyFromData.span.SetStatus(codes.Ok, "") - dt.logger.LogAttrs(ctx, slog.LevelInfo, - "CopyFrom", - slog.Any("tableName", copyFromData.TableName), - slog.Any("columnNames", copyFromData.ColumnNames), - slog.Duration("time", interval), - slog.Uint64("pid", uint64(extractConnectionID(conn))), - slog.Int64("rowCount", data.CommandTag.RowsAffected()), - ) - } -} - -type traceConnectData struct { - startTime time.Time - connConfig *pgx.ConnConfig -} - -func (dt *dbTracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context { - return context.WithValue(ctx, dbTracerConnectCtxKey, &traceConnectData{ - startTime: time.Now(), - connConfig: data.ConnConfig, - }) -} - -func (dt *dbTracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) { - connectData := ctx.Value(dbTracerConnectCtxKey).(*traceConnectData) - - endTime := time.Now() - interval := endTime.Sub(connectData.startTime) - - if data.Err != nil { - if dt.shouldLog(data.Err) { - dt.logger.LogAttrs(ctx, slog.LevelError, - "database connect", - slog.String("host", connectData.connConfig.Host), - slog.Uint64("port", uint64(connectData.connConfig.Port)), - slog.String("database", connectData.connConfig.Database), - slog.Duration("time", interval), - slog.String("error", data.Err.Error()), - ) - } - return - - } - - dt.logger.LogAttrs(ctx, slog.LevelInfo, - "database connect", - slog.String("host", connectData.connConfig.Host), - slog.Uint64("port", uint64(connectData.connConfig.Port)), - slog.String("database", connectData.connConfig.Database), - slog.Duration("time", interval), - ) - - if data.Conn != nil { - dt.logger.LogAttrs(ctx, slog.LevelInfo, - "database connect", - slog.String("host", connectData.connConfig.Host), - slog.Uint64("port", uint64(connectData.connConfig.Port)), - slog.String("database", connectData.connConfig.Database), - slog.Duration("time", interval), - ) - } -} - -type tracePrepareData struct { - span trace.Span // 16 bytes - startTime time.Time // 16 bytes - name string // 16 bytes - sql string // 16 bytes -} - -func (dt *dbTracer) TracePrepareStart(ctx context.Context, _ *pgx.Conn, data pgx.TracePrepareStartData) context.Context { - ctx, span := dt.tracer.Start(ctx, "postgresql.prepare") - span.SetAttributes( - attribute.String("db.name", dt.databaseName), - attribute.String("db.operation", "prepare"), - attribute.String("db.prepared_statement_name", data.Name), - ) - return context.WithValue(ctx, dbTracerPrepareCtxKey, &tracePrepareData{ - startTime: time.Now(), - name: data.Name, - span: span, - sql: data.SQL, - }) -} - -func (dt *dbTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { - prepareData := ctx.Value(dbTracerPrepareCtxKey).(*tracePrepareData) - defer prepareData.span.End() - - endTime := time.Now() - interval := endTime.Sub(prepareData.startTime) - dt.histogram.Record(ctx, interval.Seconds(), metric.WithAttributes( - attribute.String("operation", "prepare"), - attribute.String("statement_name", prepareData.name), - )) - - if data.Err != nil { - prepareData.span.SetStatus(codes.Error, data.Err.Error()) - prepareData.span.RecordError(data.Err) - if dt.shouldLog(data.Err) { - dt.logger.LogAttrs(ctx, slog.LevelError, - "Prepare", - slog.String("name", prepareData.name), - slog.String("sql", prepareData.sql), - slog.Duration("time", interval), - slog.Uint64("pid", uint64(extractConnectionID(conn))), - slog.String("error", data.Err.Error()), - ) - } - } else { - prepareData.span.SetStatus(codes.Ok, "") - dt.logger.LogAttrs(ctx, slog.LevelInfo, - "Prepare", - slog.String("name", prepareData.name), - slog.String("sql", prepareData.sql), - slog.Duration("time", interval), - slog.Uint64("pid", uint64(extractConnectionID(conn))), - slog.Bool("alreadyPrepared", data.AlreadyPrepared), - ) - } -} - func (dt *dbTracer) logQueryArgs(args []any) []any { if !dt.logArgs { return nil diff --git a/dbtracer/dbtracer_test.go b/dbtracer/dbtracer_test.go index 7159731..f531e65 100644 --- a/dbtracer/dbtracer_test.go +++ b/dbtracer/dbtracer_test.go @@ -292,7 +292,7 @@ func (s *DBTracerSuite) TestTraceBatchDuration() { } func (s *DBTracerSuite) TestTracePrepareWithDuration() { - prepareSQL := "SELECT * FROM users WHERE id = $1" + prepareSQL := s.defaultQuerySQL stmtName := "get_user_by_id" s.tracer.EXPECT(). @@ -304,6 +304,8 @@ func (s *DBTracerSuite) TestTracePrepareWithDuration() { attribute.String("db.name", s.defaultDBName), attribute.String("db.operation", "prepare"), attribute.String("db.prepared_statement_name", stmtName), + attribute.String("db.query_name", "get_users"), + attribute.String("db.query_type", "one"), ). Return() diff --git a/dbtracer/tracebatch.go b/dbtracer/tracebatch.go new file mode 100644 index 0000000..6e8718e --- /dev/null +++ b/dbtracer/tracebatch.go @@ -0,0 +1,100 @@ +package dbtracer + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/jackc/pgx/v5" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/trace" +) + +type traceBatchData struct { + span trace.Span // 16 bytes + startTime time.Time // 16 bytes + queryName string // 16 bytes +} + +func (dt *dbTracer) TraceBatchStart(ctx context.Context, _ *pgx.Conn, _ pgx.TraceBatchStartData) context.Context { + ctx, span := dt.tracer.Start(ctx, "postgresql.batch") + span.SetAttributes( + attribute.String("db.name", dt.databaseName), + attribute.String("db.operation", "batch"), + ) + return context.WithValue(ctx, dbTracerBatchCtxKey, &traceBatchData{ + startTime: time.Now(), + span: span, + }) +} + +func (dt *dbTracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + queryData := ctx.Value(dbTracerBatchCtxKey).(*traceBatchData) + queryName, queryType := queryNameFromSQL(data.SQL) + queryData.queryName = queryName + queryData.span.SetAttributes( + attribute.String("db.query_name", queryName), + attribute.String("db.query_type", queryType), + ) + + if data.Err != nil { + queryData.span.SetStatus(codes.Error, data.Err.Error()) + queryData.span.RecordError(data.Err) + + if dt.shouldLog(data.Err) { + dt.logger.LogAttrs(ctx, slog.LevelError, + "Query", + slog.String("sql", data.SQL), + slog.Any("args", dt.logQueryArgs(data.Args)), + slog.Uint64("pid", uint64(extractConnectionID(conn))), + slog.String("error", data.Err.Error()), + ) + } + } else { + queryData.span.SetStatus(codes.Ok, "") + dt.logger.LogAttrs(ctx, slog.LevelInfo, + "Query", + slog.String("sql", data.SQL), + slog.Any("args", dt.logQueryArgs(data.Args)), + slog.Uint64("pid", uint64(extractConnectionID(conn))), + slog.String("commandTag", data.CommandTag.String()), + ) + } +} + +func (dt *dbTracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + queryData := ctx.Value(dbTracerBatchCtxKey).(*traceBatchData) + defer queryData.span.End() + + endTime := time.Now() + interval := endTime.Sub(queryData.startTime) + dt.histogram.Record(ctx, interval.Seconds(), metric.WithAttributes( + attribute.String("operation", "batch"), + attribute.String("query_name", queryData.queryName), + attribute.Bool("error", data.Err != nil), + )) + + if data.Err != nil { + queryData.span.SetStatus(codes.Error, data.Err.Error()) + queryData.span.RecordError(data.Err) + + if dt.shouldLog(data.Err) { + dt.logger.LogAttrs(ctx, slog.LevelError, + fmt.Sprintf("Query: %s", queryData.queryName), + slog.Duration("interval", interval), + slog.Uint64("pid", uint64(extractConnectionID(conn))), + slog.String("error", data.Err.Error()), + ) + } + } else { + queryData.span.SetStatus(codes.Ok, "") + dt.logger.LogAttrs(ctx, slog.LevelInfo, + "Query", + slog.Duration("interval", interval), + slog.Uint64("pid", uint64(extractConnectionID(conn))), + ) + } +} diff --git a/dbtracer/traceconnect.go b/dbtracer/traceconnect.go new file mode 100644 index 0000000..d59d11e --- /dev/null +++ b/dbtracer/traceconnect.go @@ -0,0 +1,68 @@ +package dbtracer + +import ( + "context" + "log/slog" + "time" + + "github.com/jackc/pgx/v5" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +type traceConnectData struct { + startTime time.Time + connConfig *pgx.ConnConfig +} + +func (dt *dbTracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context { + return context.WithValue(ctx, dbTracerConnectCtxKey, &traceConnectData{ + startTime: time.Now(), + connConfig: data.ConnConfig, + }) +} + +func (dt *dbTracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) { + connectData := ctx.Value(dbTracerConnectCtxKey).(*traceConnectData) + + endTime := time.Now() + interval := endTime.Sub(connectData.startTime) + + dt.histogram.Record(ctx, interval.Seconds(), metric.WithAttributes( + attribute.String("operation", "connect"), + attribute.Bool("error", data.Err != nil), + )) + + if data.Err != nil { + if dt.shouldLog(data.Err) { + dt.logger.LogAttrs(ctx, slog.LevelError, + "database connect", + slog.String("host", connectData.connConfig.Host), + slog.Uint64("port", uint64(connectData.connConfig.Port)), + slog.String("database", connectData.connConfig.Database), + slog.Duration("time", interval), + slog.String("error", data.Err.Error()), + ) + } + return + + } + + dt.logger.LogAttrs(ctx, slog.LevelInfo, + "database connect", + slog.String("host", connectData.connConfig.Host), + slog.Uint64("port", uint64(connectData.connConfig.Port)), + slog.String("database", connectData.connConfig.Database), + slog.Duration("time", interval), + ) + + if data.Conn != nil { + dt.logger.LogAttrs(ctx, slog.LevelInfo, + "database connect", + slog.String("host", connectData.connConfig.Host), + slog.Uint64("port", uint64(connectData.connConfig.Port)), + slog.String("database", connectData.connConfig.Database), + slog.Duration("time", interval), + ) + } +} diff --git a/dbtracer/tracecopyfrom.go b/dbtracer/tracecopyfrom.go new file mode 100644 index 0000000..6500d49 --- /dev/null +++ b/dbtracer/tracecopyfrom.go @@ -0,0 +1,74 @@ +package dbtracer + +import ( + "context" + "log/slog" + "time" + + "github.com/jackc/pgx/v5" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/trace" +) + +type traceCopyFromData struct { + ColumnNames []string // 24 bytes + span trace.Span // 16 bytes + startTime time.Time // 16 bytes + TableName pgx.Identifier // slice - 24 bytes +} + +func (dt *dbTracer) TraceCopyFromStart(ctx context.Context, _ *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context { + ctx, span := dt.tracer.Start(ctx, "postgresql.copy_from") + span.SetAttributes( + attribute.String("db.name", dt.databaseName), + attribute.String("db.operation", "copy"), + attribute.String("db.table", data.TableName.Sanitize()), + ) + return context.WithValue(ctx, dbTracerCopyFromCtxKey, &traceCopyFromData{ + startTime: time.Now(), + TableName: data.TableName, + ColumnNames: data.ColumnNames, + span: span, + }) +} + +func (dt *dbTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { + copyFromData := ctx.Value(dbTracerCopyFromCtxKey).(*traceCopyFromData) + defer copyFromData.span.End() + + endTime := time.Now() + interval := endTime.Sub(copyFromData.startTime) + dt.histogram.Record(ctx, interval.Seconds(), metric.WithAttributes( + attribute.String("operation", "copy"), + attribute.String("table", copyFromData.TableName.Sanitize()), + attribute.Bool("error", data.Err != nil), + )) + + if data.Err != nil { + copyFromData.span.SetStatus(codes.Error, data.Err.Error()) + copyFromData.span.RecordError(data.Err) + + if dt.shouldLog(data.Err) { + dt.logger.LogAttrs(ctx, slog.LevelError, + "CopyFrom", + slog.Any("tableName", copyFromData.TableName), + slog.Any("columnNames", copyFromData.ColumnNames), + slog.Duration("time", interval), + slog.Uint64("pid", uint64(extractConnectionID(conn))), + slog.String("error", data.Err.Error()), + ) + } + } else { + copyFromData.span.SetStatus(codes.Ok, "") + dt.logger.LogAttrs(ctx, slog.LevelInfo, + "CopyFrom", + slog.Any("tableName", copyFromData.TableName), + slog.Any("columnNames", copyFromData.ColumnNames), + slog.Duration("time", interval), + slog.Uint64("pid", uint64(extractConnectionID(conn))), + slog.Int64("rowCount", data.CommandTag.RowsAffected()), + ) + } +} diff --git a/dbtracer/traceprepare.go b/dbtracer/traceprepare.go new file mode 100644 index 0000000..d108c0a --- /dev/null +++ b/dbtracer/traceprepare.go @@ -0,0 +1,80 @@ +package dbtracer + +import ( + "context" + "log/slog" + "time" + + "github.com/jackc/pgx/v5" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/trace" +) + +type tracePrepareData struct { + span trace.Span // 16 bytes + startTime time.Time // 16 bytes + queryName string // 16 bytes + queryType string + sql string // 16 bytes + statementName string // 16 bytes +} + +func (dt *dbTracer) TracePrepareStart(ctx context.Context, _ *pgx.Conn, data pgx.TracePrepareStartData) context.Context { + queryName, queryType := queryNameFromSQL(data.SQL) + ctx, span := dt.tracer.Start(ctx, "postgresql.prepare") + span.SetAttributes( + attribute.String("db.name", dt.databaseName), + attribute.String("db.operation", "prepare"), + attribute.String("db.prepared_statement_name", data.Name), + attribute.String("db.query_name", queryName), + attribute.String("db.query_type", queryType), + ) + return context.WithValue(ctx, dbTracerPrepareCtxKey, &tracePrepareData{ + startTime: time.Now(), + statementName: data.Name, + span: span, + sql: data.SQL, + }) +} + +func (dt *dbTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { + prepareData := ctx.Value(dbTracerPrepareCtxKey).(*tracePrepareData) + defer prepareData.span.End() + + endTime := time.Now() + interval := endTime.Sub(prepareData.startTime) + dt.histogram.Record(ctx, interval.Seconds(), metric.WithAttributes( + attribute.String("operation", "prepare"), + attribute.String("statement_name", prepareData.statementName), + attribute.String("query_name", prepareData.queryName), + attribute.String("query_type", prepareData.queryType), + attribute.Bool("error", data.Err != nil), + )) + + if data.Err != nil { + prepareData.span.SetStatus(codes.Error, data.Err.Error()) + prepareData.span.RecordError(data.Err) + if dt.shouldLog(data.Err) { + dt.logger.LogAttrs(ctx, slog.LevelError, + "Prepare", + slog.String("statement_name", prepareData.statementName), + slog.String("sql", prepareData.sql), + slog.Duration("time", interval), + slog.Uint64("pid", uint64(extractConnectionID(conn))), + slog.String("error", data.Err.Error()), + ) + } + } else { + prepareData.span.SetStatus(codes.Ok, "") + dt.logger.LogAttrs(ctx, slog.LevelInfo, + "Prepare", + slog.String("statement_name", prepareData.statementName), + slog.String("sql", prepareData.sql), + slog.Duration("time", interval), + slog.Uint64("pid", uint64(extractConnectionID(conn))), + slog.Bool("alreadyPrepared", data.AlreadyPrepared), + ) + } +} diff --git a/dbtracer/tracequery.go b/dbtracer/tracequery.go new file mode 100644 index 0000000..be4d159 --- /dev/null +++ b/dbtracer/tracequery.go @@ -0,0 +1,86 @@ +package dbtracer + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/jackc/pgx/v5" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/trace" +) + +type traceQueryData struct { + args []any // 24 bytes + span trace.Span // 16 bytes + sql string // 16 bytes + queryName string // 16 bytes + queryType string // 16 bytes + startTime time.Time // 8 bytes +} + +func (dt *dbTracer) TraceQueryStart(ctx context.Context, _ *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + queryName, queryType := queryNameFromSQL(data.SQL) + ctx, span := dt.tracer.Start(ctx, "postgresql.query") + span.SetAttributes( + attribute.String("db.name", dt.databaseName), + attribute.String("db.query_name", queryName), + attribute.String("db.query_type", queryType), + attribute.String("db.operation", "query"), + ) + return context.WithValue(ctx, dbTracerQueryCtxKey, &traceQueryData{ + startTime: time.Now(), + sql: data.SQL, + args: data.Args, + queryName: queryName, + queryType: queryType, + span: span, + }) +} + +func (dt *dbTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + queryData := ctx.Value(dbTracerQueryCtxKey).(*traceQueryData) + + endTime := time.Now() + interval := endTime.Sub(queryData.startTime) + histogramAttrs := []attribute.KeyValue{ + attribute.String("operation", "query"), + attribute.String("query_name", queryData.queryName), + attribute.String("query_type", queryData.queryType), + attribute.Bool("error", data.Err != nil), + } + dt.histogram.Record(ctx, interval.Seconds(), metric.WithAttributes(histogramAttrs...)) + + defer queryData.span.End() + + if data.Err != nil { + queryData.span.SetStatus(codes.Error, data.Err.Error()) + queryData.span.RecordError(data.Err) + + if dt.shouldLog(data.Err) { + dt.logger.LogAttrs(ctx, slog.LevelError, + fmt.Sprintf("Query: %s", queryData.queryName), + slog.String("sql", queryData.sql), + slog.Any("args", dt.logQueryArgs(queryData.args)), + slog.String("query_type", queryData.queryType), + slog.Duration("time", interval), + slog.Uint64("pid", uint64(extractConnectionID(conn))), + slog.String("error", data.Err.Error()), + ) + } + } else { + queryData.span.SetStatus(codes.Ok, "") + dt.logger.LogAttrs(ctx, slog.LevelInfo, + fmt.Sprintf("Query: %s", queryData.queryName), + slog.String("sql", queryData.sql), + slog.String("query_type", queryData.queryType), + slog.Any("args", dt.logQueryArgs(queryData.args)), + slog.Duration("time", interval), + slog.Uint64("pid", uint64(extractConnectionID(conn))), + slog.String("commandTag", data.CommandTag.String()), + ) + } +}