diff --git a/go.mod b/go.mod index 99b9ca7..ffc8e7e 100644 --- a/go.mod +++ b/go.mod @@ -9,11 +9,13 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.8.0 github.com/spf13/viper v1.18.2 + github.com/stretchr/testify v1.10.0 ) require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -21,6 +23,7 @@ require ( github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/nelkinda/http-go v0.0.1 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.53.0 // indirect github.com/prometheus/procfs v0.14.0 // indirect @@ -30,7 +33,6 @@ require ( github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/stretchr/testify v1.10.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/tidwall/gjson v1.14.4 // indirect github.com/tidwall/match v1.1.1 // indirect diff --git a/main.go b/main.go index 0ad0ff5..53b25b4 100644 --- a/main.go +++ b/main.go @@ -215,7 +215,7 @@ func runExporter() { log.Infof("Metrics set: %v", slices.Sorted(maps.Keys(enabledMetrics))) for _, metric := range enabledMetrics { - prometheus.MustRegister(metric) + metric.MustRegisterWith(prometheus.DefaultRegisterer) } scrapeInterval := time.Duration(viper.GetInt("scrape_interval")) * time.Second diff --git a/metrics.go b/metrics.go index 4c893f1..b73eab1 100644 --- a/metrics.go +++ b/metrics.go @@ -11,6 +11,10 @@ import ( const joinChar = "|" const maxAge = 1 * time.Hour +type trackedMetric interface { + MustRegisterWith(prometheus.Registerer) +} + type reportableMetrics []*metricTracker var registeredMetrics = reportableMetrics{} @@ -29,83 +33,91 @@ func register(m *metricTracker) { type metricTracker struct { sync.Mutex + registered bool + collector prometheus.Collector + del func(...string) bool expirations map[string]time.Time } -func newMetricTracker() *metricTracker { - mt := &metricTracker{expirations: map[string]time.Time{}} +func newMetricTracker(collector prometheus.Collector, del func(...string) bool) *metricTracker { + mt := &metricTracker{ + registered: false, + collector: collector, + del: del, + expirations: map[string]time.Time{}, + } register(mt) return mt } -func (m *metricTracker) touch(labels []string) { - m.Lock() - defer m.Unlock() - m.expirations[strings.Join(labels, joinChar)] = time.Now().Add(maxAge) +func (mt *metricTracker) update(labels []string, fn func()) { + if !mt.registered { + return + } + + mt.Lock() + mt.expirations[strings.Join(labels, joinChar)] = time.Now().Add(maxAge) + defer mt.Unlock() + + fn() +} + +func (mt *metricTracker) MustRegisterWith(registry prometheus.Registerer) { + mt.registered = true + registry.MustRegister(mt) +} + +func (mt *metricTracker) Describe(ch chan<- *prometheus.Desc) { + mt.collector.Describe(ch) } -func (m *metricTracker) collect(del func(...string) bool) { - m.Lock() - defer m.Unlock() +func (mt *metricTracker) Collect(ch chan<- prometheus.Metric) { + mt.Lock() + defer mt.Unlock() now := time.Now() - for k, v := range m.expirations { + for k, v := range mt.expirations { if v.Before(now) { - delete(m.expirations, k) - if del(strings.Split(k, joinChar)...) { + delete(mt.expirations, k) + if mt.del(strings.Split(k, joinChar)...) { expiredMetrics.Inc() } } } + mt.collector.Collect(ch) } type trackedGauge struct { - tracker *metricTracker - gauge *prometheus.GaugeVec + *metricTracker + gauge *prometheus.GaugeVec } func NewTrackedGauge(gauge *prometheus.GaugeVec) *trackedGauge { return &trackedGauge{ - tracker: newMetricTracker(), - gauge: gauge, + metricTracker: newMetricTracker(gauge, gauge.DeleteLabelValues), + gauge: gauge, } } -func (tg *trackedGauge) Describe(ch chan<- *prometheus.Desc) { - tg.gauge.Describe(ch) -} - -func (tg *trackedGauge) Collect(ch chan<- prometheus.Metric) { - tg.tracker.collect(tg.gauge.DeleteLabelValues) - tg.gauge.Collect(ch) -} - func (tg *trackedGauge) Set(val float64, labels ...string) { - tg.gauge.WithLabelValues(labels...).Set(val) - tg.tracker.touch(labels) + tg.update(labels, func() { + tg.gauge.WithLabelValues(labels...).Set(val) + }) } type trackedCounter struct { - tracker *metricTracker + *metricTracker counter *prometheus.CounterVec } func NewTrackedCounter(counter *prometheus.CounterVec) *trackedCounter { return &trackedCounter{ - tracker: newMetricTracker(), - counter: counter, + metricTracker: newMetricTracker(counter, counter.DeleteLabelValues), + counter: counter, } } -func (tc *trackedCounter) Describe(ch chan<- *prometheus.Desc) { - tc.counter.Describe(ch) -} - -func (tc *trackedCounter) Collect(ch chan<- prometheus.Metric) { - tc.tracker.collect(tc.counter.DeleteLabelValues) - tc.counter.Collect(ch) -} - func (tc *trackedCounter) Add(val float64, labels ...string) { - tc.counter.WithLabelValues(labels...).Add(val) - tc.tracker.touch(labels) + tc.update(labels, func() { + tc.counter.WithLabelValues(labels...).Add(val) + }) } diff --git a/metrics_test.go b/metrics_test.go new file mode 100644 index 0000000..4760299 --- /dev/null +++ b/metrics_test.go @@ -0,0 +1,318 @@ +package main + +import ( + "strings" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" +) + +func TestReportableMetricsReport(t *testing.T) { + // Save and restore original registeredMetrics + originalMetrics := registeredMetrics + defer func() { registeredMetrics = originalMetrics }() + + // Create a test gauge to report to + gauge := prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "test_total", + Help: "Test gauge", + }) + + // Create test metric trackers with various expiration counts + tracker1 := &metricTracker{ + expirations: map[string]time.Time{ + "key1": time.Now(), + "key2": time.Now(), + }, + } + tracker2 := &metricTracker{ + expirations: map[string]time.Time{ + "key3": time.Now(), + "key4": time.Now(), + "key5": time.Now(), + }, + } + + metrics := reportableMetrics{tracker1, tracker2} + metrics.Report(gauge) + + // Should report total of 5 expirations + assert.Equal(t, 5.0, testutil.ToFloat64(gauge)) +} + +func TestRegister(t *testing.T) { + // Save and restore original registeredMetrics + originalMetrics := registeredMetrics + defer func() { registeredMetrics = originalMetrics }() + + registeredMetrics = reportableMetrics{} + + tracker := &metricTracker{} + register(tracker) + + assert.Equal(t, 1, len(registeredMetrics)) + assert.Equal(t, tracker, registeredMetrics[0]) +} + +func TestNewMetricTracker(t *testing.T) { + // Save and restore original registeredMetrics + originalMetrics := registeredMetrics + defer func() { registeredMetrics = originalMetrics }() + + registeredMetrics = reportableMetrics{} + + collector := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: "test", + Help: "Test", + }, []string{"label"}) + + delCalled := false + delFunc := func(_ ...string) bool { + delCalled = true + return true + } + + tracker := newMetricTracker(collector, delFunc) + + assert.False(t, tracker.registered, "newMetricTracker() should create unregistered tracker") + assert.Equal(t, collector, tracker.collector) + assert.NotNil(t, tracker.expirations) + + // Test that delete function was set + tracker.del() + assert.True(t, delCalled, "newMetricTracker() did not set delete function correctly") + + // Test that tracker was registered + assert.Equal(t, 1, len(registeredMetrics)) +} + +func TestMetricTrackerUpdate(t *testing.T) { + collector := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: "test", + Help: "Test", + }, []string{"label"}) + + tracker := newMetricTracker(collector, func(...string) bool { return true }) + + t.Run("unregistered tracker does not update", func(t *testing.T) { + fnCalled := false + tracker.update([]string{"label1"}, func() { + fnCalled = true + }) + + assert.False(t, fnCalled, "update() called function on unregistered tracker") + assert.Equal(t, 0, len(tracker.expirations)) + }) + + t.Run("registered tracker updates", func(t *testing.T) { + tracker.registered = true + fnCalled := false + + tracker.update([]string{"label1", "label2"}, func() { + fnCalled = true + }) + + assert.True(t, fnCalled, "update() did not call function on registered tracker") + + key := "label1|label2" + expTime, exists := tracker.expirations[key] + assert.True(t, exists, "update() did not add expiration") + assert.False(t, expTime.Before(time.Now()), "update() set expiration in the past") + assert.False(t, expTime.After(time.Now().Add(maxAge+time.Minute)), "update() set expiration too far in the future") + }) +} + +func TestMetricTrackerMustRegisterWith(t *testing.T) { + collector := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: "test_register", + Help: "Test", + }, []string{"label"}) + + tracker := newMetricTracker(collector, func(...string) bool { return true }) + + assert.False(t, tracker.registered, "tracker should start unregistered") + + registry := prometheus.NewRegistry() + tracker.MustRegisterWith(registry) + + assert.True(t, tracker.registered, "MustRegisterWith() did not set registered to true") +} + +func TestMetricTrackerCollect(t *testing.T) { + // Save and restore expiredMetrics + originalExpired := expiredMetrics + testExpiredMetrics := prometheus.NewCounter(prometheus.CounterOpts{ + Name: "test_expired", + Help: "Test expired metrics", + }) + expiredMetrics = testExpiredMetrics + defer func() { expiredMetrics = originalExpired }() + + gaugeVec := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: "test_collect", + Help: "Test", + }, []string{"label"}) + + tracker := newMetricTracker(gaugeVec, gaugeVec.DeleteLabelValues) + tracker.registered = true + + // Add some expirations: one expired, one not expired + pastTime := time.Now().Add(-2 * time.Hour) + futureTime := time.Now().Add(2 * time.Hour) + + tracker.expirations["expired"] = pastTime + tracker.expirations["valid"] = futureTime + + // Set some values so we can verify collection works + gaugeVec.WithLabelValues("expired").Set(1.0) + gaugeVec.WithLabelValues("valid").Set(2.0) + + metricCh := make(chan prometheus.Metric, 10) + tracker.Collect(metricCh) + close(metricCh) + + // Check that expired key was removed + _, exists := tracker.expirations["expired"] + assert.False(t, exists, "Collect() did not remove expired key") + + // Check that valid key still exists + _, exists = tracker.expirations["valid"] + assert.True(t, exists, "Collect() removed non-expired key") + + // Check that expiredMetrics counter was incremented + assert.Equal(t, 1.0, testutil.ToFloat64(testExpiredMetrics)) + + // Check that metrics were collected + count := 0 + for range metricCh { + count++ + } + + assert.Greater(t, count, 0, "Collect() did not forward metrics from collector") +} + +func TestNewTrackedGauge(t *testing.T) { + gaugeVec := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: "test_gauge", + Help: "Test", + }, []string{"label"}) + + tg := NewTrackedGauge(gaugeVec) + + assert.Equal(t, gaugeVec, tg.gauge) + assert.NotNil(t, tg.metricTracker) +} + +func TestTrackedGaugeSet(t *testing.T) { + gaugeVec := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: "test_gauge_set", + Help: "Test", + }, []string{"label"}) + + tg := NewTrackedGauge(gaugeVec) + + t.Run("unregistered gauge does not set", func(t *testing.T) { + tg.Set(42.0, "label1") + + // Should not add expiration + assert.Equal(t, 0, len(tg.expirations)) + }) + + t.Run("registered gauge sets value", func(t *testing.T) { + tg.registered = true + tg.Set(42.0, "label1") + + // Should add expiration + _, exists := tg.expirations["label1"] + assert.True(t, exists, "Set() did not add expiration") + + // Should set the actual gauge value + assert.Equal(t, 42.0, testutil.ToFloat64(gaugeVec.WithLabelValues("label1"))) + }) + + t.Run("multiple labels", func(t *testing.T) { + gaugeVec2 := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: "test_gauge_multi", + Help: "Test", + }, []string{"label1", "label2"}) + labelVals := []string{"value1", "value2"} + + tg2 := NewTrackedGauge(gaugeVec2) + tg2.registered = true + + tg2.Set(100.0, labelVals...) + + key := strings.Join(labelVals, joinChar) + _, exists := tg2.expirations[key] + assert.True(t, exists, "Set() did not create correct key for multiple labels") + + assert.Equal(t, 100.0, testutil.ToFloat64(gaugeVec2.WithLabelValues(labelVals...))) + }) +} + +func TestNewTrackedCounter(t *testing.T) { + counterVec := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "test_counter", + Help: "Test", + }, []string{"label"}) + + tc := NewTrackedCounter(counterVec) + + assert.Equal(t, counterVec, tc.counter) + assert.NotNil(t, tc.metricTracker) +} + +func TestTrackedCounterAdd(t *testing.T) { + counterVec := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "test_counter_add", + Help: "Test", + }, []string{"label"}) + + tc := NewTrackedCounter(counterVec) + + t.Run("unregistered counter does not add", func(t *testing.T) { + tc.Add(5.0, "label1") + + // Should not add expiration + assert.Equal(t, 0, len(tc.expirations)) + }) + + t.Run("registered counter adds value", func(t *testing.T) { + tc.registered = true + tc.Add(5.0, "label1") + + // Should add expiration + _, exists := tc.expirations["label1"] + assert.True(t, exists, "Add() did not add expiration") + + // Should add to the actual counter value + assert.Equal(t, 5.0, testutil.ToFloat64(counterVec.WithLabelValues("label1"))) + + // Add more to the same labels + tc.Add(3.0, "label1") + assert.Equal(t, 8.0, testutil.ToFloat64(counterVec.WithLabelValues("label1"))) + }) + + t.Run("multiple labels", func(t *testing.T) { + counterVec2 := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "test_counter_multi", + Help: "Test", + }, []string{"label1", "label2"}) + + labelVals := []string{"value1", "value2"} + + tc2 := NewTrackedCounter(counterVec2) + tc2.registered = true + + tc2.Add(10.0, labelVals...) + + key := strings.Join(labelVals, joinChar) + _, exists := tc2.expirations[key] + assert.True(t, exists, "Add() did not create correct key for multiple labels") + + assert.Equal(t, 10.0, testutil.ToFloat64(counterVec2.WithLabelValues(labelVals...))) + }) +} diff --git a/prometheus.go b/prometheus.go index 6109358..bb5ac9c 100644 --- a/prometheus.go +++ b/prometheus.go @@ -61,7 +61,7 @@ const ( r2OperationMetricName MetricName = "cloudflare_r2_operation_count" ) -type MetricsMap map[MetricName]prometheus.Collector +type MetricsMap map[MetricName]trackedMetric func recordError(action string, err error) { exporterErrors.WithLabelValues(action).Inc()