diff --git a/pkg/frontend/v2/frontend_test.go b/pkg/frontend/v2/frontend_test.go index e9316ae93b6..3e1ceb78023 100644 --- a/pkg/frontend/v2/frontend_test.go +++ b/pkg/frontend/v2/frontend_test.go @@ -292,12 +292,16 @@ func TestFrontend_Protobuf_HappyPath(t *testing.T) { msg, err := resp.Next(ctx) require.NoError(t, err) - msg.FreeBuffer() // We don't care about the contents of the buffer in the assertion below. + buf := msg.Buffer() + defer buf.Free() + msg.SetBuffer(nil) // We don't care about the contents of the buffer in the assertion below. require.Equal(t, expectedMessages[0], msg) msg, err = resp.Next(ctx) require.NoError(t, err) - msg.FreeBuffer() // We don't care about the contents of the buffer in the assertion below. + buf = msg.Buffer() + defer buf.Free() + msg.SetBuffer(nil) // We don't care about the contents of the buffer in the assertion below. require.Equal(t, expectedMessages[1], msg) // Response stream exhausted. @@ -397,12 +401,16 @@ func TestFrontend_Protobuf_QuerierResponseReceivedBeforeSchedulerResponse(t *tes msg, err := resp.Next(ctx) require.NoError(t, err) - msg.FreeBuffer() // We don't care about the contents of the buffer in the assertion below. + buf := msg.Buffer() + defer buf.Free() + msg.SetBuffer(nil) // We don't care about the contents of the buffer in the assertion below. require.Equal(t, expectedMessages[0], msg) msg, err = resp.Next(ctx) require.NoError(t, err) - msg.FreeBuffer() // We don't care about the contents of the buffer in the assertion below. + buf = msg.Buffer() + defer buf.Free() + msg.SetBuffer(nil) // We don't care about the contents of the buffer in the assertion below. require.Equal(t, expectedMessages[1], msg) // Response stream exhausted. @@ -443,7 +451,9 @@ func TestFrontend_Protobuf_ResponseClosedBeforeStreamExhausted(t *testing.T) { msg, err := resp.Next(ctx) require.NoError(t, err) - msg.FreeBuffer() // We don't care about the contents of the buffer in the assertion below. + buf := msg.Buffer() + defer buf.Free() + msg.SetBuffer(nil) // We don't care about the contents of the buffer in the assertion below. require.Equal(t, expectedMessages[0], msg) resp.Close() // We expect all goroutines to be cleaned up after this (verified by the VerifyNoLeakTestMain call in TestMain above) } @@ -680,7 +690,9 @@ func TestFrontend_Protobuf_RetryEnqueue(t *testing.T) { msg, err := resp.Next(ctx) require.NoError(t, err) - msg.FreeBuffer() // We don't care about the contents of the buffer in the assertion below. + buf := msg.Buffer() + defer buf.Free() + msg.SetBuffer(nil) // We don't care about the contents of the buffer in the assertion below. require.Equal(t, expectedMessages[0], msg) } @@ -729,12 +741,16 @@ func TestFrontend_Protobuf_ReadingResponseAfterAllMessagesReceived(t *testing.T) msg, err := resp.Next(ctx) require.NoError(t, err) - msg.FreeBuffer() // We don't care about the contents of the buffer in the assertion below. + buf := msg.Buffer() + defer buf.Free() + msg.SetBuffer(nil) // We don't care about the contents of the buffer in the assertion below. require.Equal(t, expectedMessages[0], msg) msg, err = resp.Next(ctx) require.NoError(t, err) - msg.FreeBuffer() // We don't care about the contents of the buffer in the assertion below. + buf = msg.Buffer() + defer buf.Free() + msg.SetBuffer(nil) // We don't care about the contents of the buffer in the assertion below. require.Equal(t, expectedMessages[1], msg) // Wait until the last message has been buffered into the stream channel and the stream's context has been cancelled by DoProtobufRequest. @@ -747,7 +763,9 @@ func TestFrontend_Protobuf_ReadingResponseAfterAllMessagesReceived(t *testing.T) msg, err = resp.Next(ctx) require.NoError(t, err) - msg.FreeBuffer() // We don't care about the contents of the buffer in the assertion below. + buf = msg.Buffer() + defer buf.Free() + msg.SetBuffer(nil) // We don't care about the contents of the buffer in the assertion below. require.Equal(t, expectedMessages[2], msg, "should still be able to read last message after stream has been completely read") msg, err = resp.Next(ctx) @@ -1302,12 +1320,16 @@ func TestFrontend_Protobuf_ResponseSentTwice(t *testing.T) { msg, err := resp.Next(ctx) require.NoError(t, err) - msg.FreeBuffer() // We don't care about the contents of the buffer in the assertion below. + buf := msg.Buffer() + defer buf.Free() + msg.SetBuffer(nil) // We don't care about the contents of the buffer in the assertion below. require.Equal(t, firstMessage, msg) msg, err = resp.Next(ctx) require.NoError(t, err) - msg.FreeBuffer() // We don't care about the contents of the buffer in the assertion below. + buf = msg.Buffer() + defer buf.Free() + msg.SetBuffer(nil) // We don't care about the contents of the buffer in the assertion below. require.Equal(t, secondMessage, msg) // Response stream exhausted. diff --git a/pkg/mimirpb/custom.go b/pkg/mimirpb/custom.go index e042da838bf..091c9548069 100644 --- a/pkg/mimirpb/custom.go +++ b/pkg/mimirpb/custom.go @@ -158,6 +158,10 @@ type BufferHolder struct { buffer mem.Buffer } +func (m *BufferHolder) Buffer() mem.Buffer { + return m.buffer +} + func (m *BufferHolder) SetBuffer(buf mem.Buffer) { m.buffer = buf } @@ -165,7 +169,6 @@ func (m *BufferHolder) SetBuffer(buf mem.Buffer) { func (m *BufferHolder) FreeBuffer() { if m.buffer != nil { m.buffer.Free() - m.buffer = nil } } @@ -534,3 +537,30 @@ type orderAwareMetricMetadata struct { // order is the 0-based index of this metadata object in a wider metadata array. order int } + +func (m *WriteRequest) FreeBuffer() { + m.BufferHolder.FreeBuffer() + for p := range m.sourceBufferHolders { + p.FreeBuffer() + } +} + +// AddSourceBufferHolder adds a source BufferHolder to the WriteRequest, +// retaining a strong reference to the source buffer. See +// [WriteRequest.SourceBufferHolders]. +func (m *WriteRequest) AddSourceBufferHolder(bufh *BufferHolder) { + buf := bufh.Buffer() + if buf == nil { + return + } + if _, ok := m.sourceBufferHolders[bufh]; ok { + return + } + + buf.Ref() + + if m.sourceBufferHolders == nil { + m.sourceBufferHolders = map[*BufferHolder]struct{}{} + } + m.sourceBufferHolders[bufh] = struct{}{} +} diff --git a/pkg/mimirpb/mimir.pb.go b/pkg/mimirpb/mimir.pb.go index 3a886f50499..986c4e18677 100644 --- a/pkg/mimirpb/mimir.pb.go +++ b/pkg/mimirpb/mimir.pb.go @@ -290,6 +290,11 @@ func (MetadataRW2_MetricType) EnumDescriptor() ([]byte, []int) { type WriteRequest struct { // Keep reference to buffer for unsafe references. BufferHolder + // sourceBufferHolders is populated when the WriteRequest is synthesized + // from other WriteRequests, e. g. when batching, and thus holds references + // to those source buffers. The WriteRequest must hold a strong reference to + // each of these buffers. + sourceBufferHolders map[*BufferHolder]struct{} Timeseries []PreallocTimeseries `protobuf:"bytes,1,rep,name=timeseries,proto3,customtype=PreallocTimeseries" json:"timeseries"` Source WriteRequest_SourceEnum `protobuf:"varint,2,opt,name=Source,proto3,enum=cortexpb.WriteRequest_SourceEnum" json:"Source,omitempty"` diff --git a/pkg/mimirpb/mimir.pb.go.expdiff b/pkg/mimirpb/mimir.pb.go.expdiff index d858944c833..eeb3418e926 100644 --- a/pkg/mimirpb/mimir.pb.go.expdiff +++ b/pkg/mimirpb/mimir.pb.go.expdiff @@ -1,5 +1,5 @@ diff --git a/pkg/mimirpb/mimir.pb.go b/pkg/mimirpb/mimir.pb.go -index 3a886f5049..dda8609298 100644 +index 986c4e1867..dda8609298 100644 --- a/pkg/mimirpb/mimir.pb.go +++ b/pkg/mimirpb/mimir.pb.go @@ -14,7 +14,6 @@ import ( @@ -10,17 +10,22 @@ index 3a886f5049..dda8609298 100644 strconv "strconv" strings "strings" ) -@@ -288,9 +287,6 @@ func (MetadataRW2_MetricType) EnumDescriptor() ([]byte, []int) { +@@ -288,14 +287,6 @@ func (MetadataRW2_MetricType) EnumDescriptor() ([]byte, []int) { } type WriteRequest struct { - // Keep reference to buffer for unsafe references. - BufferHolder +- // sourceBufferHolders is populated when the WriteRequest is synthesized +- // from other WriteRequests, e. g. when batching, and thus holds references +- // to those source buffers. The WriteRequest must hold a strong reference to +- // each of these buffers. +- sourceBufferHolders map[*BufferHolder]struct{} - Timeseries []PreallocTimeseries `protobuf:"bytes,1,rep,name=timeseries,proto3,customtype=PreallocTimeseries" json:"timeseries"` Source WriteRequest_SourceEnum `protobuf:"varint,2,opt,name=Source,proto3,enum=cortexpb.WriteRequest_SourceEnum" json:"Source,omitempty"` Metadata []*MetricMetadata `protobuf:"bytes,3,rep,name=metadata,proto3" json:"metadata,omitempty"` -@@ -302,16 +298,6 @@ type WriteRequest struct { +@@ -307,16 +298,6 @@ type WriteRequest struct { SkipLabelValidation bool `protobuf:"varint,1000,opt,name=skip_label_validation,json=skipLabelValidation,proto3" json:"skip_label_validation,omitempty"` // Skip label count validation. SkipLabelCountValidation bool `protobuf:"varint,1001,opt,name=skip_label_count_validation,json=skipLabelCountValidation,proto3" json:"skip_label_count_validation,omitempty"` @@ -37,7 +42,7 @@ index 3a886f5049..dda8609298 100644 } func (m *WriteRequest) Reset() { *m = WriteRequest{} } -@@ -474,11 +460,6 @@ func (m *ErrorDetails) GetSoft() bool { +@@ -479,11 +460,6 @@ func (m *ErrorDetails) GetSoft() bool { return false } @@ -49,7 +54,7 @@ index 3a886f5049..dda8609298 100644 type TimeSeries struct { Labels []UnsafeMutableLabel `protobuf:"bytes,1,rep,name=labels,proto3,customtype=UnsafeMutableLabel" json:"labels"` // Sorted by time, oldest sample first. -@@ -491,9 +472,6 @@ type TimeSeries struct { +@@ -496,9 +472,6 @@ type TimeSeries struct { // Zero value means value not set. If you need to use exactly zero value for // the timestamp, use 1 millisecond before or after. CreatedTimestamp int64 `protobuf:"varint,6,opt,name=created_timestamp,json=createdTimestamp,proto3" json:"created_timestamp,omitempty"` @@ -59,7 +64,7 @@ index 3a886f5049..dda8609298 100644 } func (m *TimeSeries) Reset() { *m = TimeSeries{} } -@@ -5963,25 +5941,19 @@ func (m *TimeSeriesRW2) MarshalToSizedBuffer(dAtA []byte) (int, error) { +@@ -5968,25 +5941,19 @@ func (m *TimeSeriesRW2) MarshalToSizedBuffer(dAtA []byte) (int, error) { } } if len(m.LabelsRefs) > 0 { @@ -90,7 +95,7 @@ index 3a886f5049..dda8609298 100644 i = encodeVarintMimir(dAtA, i, uint64(j21)) i-- dAtA[i] = 0xa -@@ -6021,25 +5993,19 @@ func (m *ExemplarRW2) MarshalToSizedBuffer(dAtA []byte) (int, error) { +@@ -6026,25 +5993,19 @@ func (m *ExemplarRW2) MarshalToSizedBuffer(dAtA []byte) (int, error) { dAtA[i] = 0x11 } if len(m.LabelsRefs) > 0 { @@ -121,7 +126,7 @@ index 3a886f5049..dda8609298 100644 i = encodeVarintMimir(dAtA, i, uint64(j23)) i-- dAtA[i] = 0xa -@@ -7387,9 +7353,6 @@ func valueToStringMimir(v interface{}) string { +@@ -7392,9 +7353,6 @@ func valueToStringMimir(v interface{}) string { return fmt.Sprintf("*%v", pv) } func (m *WriteRequest) Unmarshal(dAtA []byte) error { @@ -131,7 +136,7 @@ index 3a886f5049..dda8609298 100644 l := len(dAtA) iNdEx := 0 for iNdEx < l { -@@ -7419,9 +7382,6 @@ func (m *WriteRequest) Unmarshal(dAtA []byte) error { +@@ -7424,9 +7382,6 @@ func (m *WriteRequest) Unmarshal(dAtA []byte) error { } switch fieldNum { case 1: @@ -141,7 +146,7 @@ index 3a886f5049..dda8609298 100644 if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field Timeseries", wireType) } -@@ -7451,8 +7411,7 @@ func (m *WriteRequest) Unmarshal(dAtA []byte) error { +@@ -7456,8 +7411,7 @@ func (m *WriteRequest) Unmarshal(dAtA []byte) error { return io.ErrUnexpectedEOF } m.Timeseries = append(m.Timeseries, PreallocTimeseries{}) @@ -151,7 +156,7 @@ index 3a886f5049..dda8609298 100644 return err } iNdEx = postIndex -@@ -7476,9 +7435,6 @@ func (m *WriteRequest) Unmarshal(dAtA []byte) error { +@@ -7481,9 +7435,6 @@ func (m *WriteRequest) Unmarshal(dAtA []byte) error { } } case 3: @@ -161,7 +166,7 @@ index 3a886f5049..dda8609298 100644 if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field Metadata", wireType) } -@@ -7513,9 +7469,6 @@ func (m *WriteRequest) Unmarshal(dAtA []byte) error { +@@ -7518,9 +7469,6 @@ func (m *WriteRequest) Unmarshal(dAtA []byte) error { } iNdEx = postIndex case 4: @@ -171,7 +176,7 @@ index 3a886f5049..dda8609298 100644 if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field SymbolsRW2", wireType) } -@@ -7545,16 +7498,9 @@ func (m *WriteRequest) Unmarshal(dAtA []byte) error { +@@ -7550,16 +7498,9 @@ func (m *WriteRequest) Unmarshal(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } @@ -189,7 +194,7 @@ index 3a886f5049..dda8609298 100644 if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field TimeseriesRW2", wireType) } -@@ -7583,12 +7529,8 @@ func (m *WriteRequest) Unmarshal(dAtA []byte) error { +@@ -7588,12 +7529,8 @@ func (m *WriteRequest) Unmarshal(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } @@ -204,7 +209,7 @@ index 3a886f5049..dda8609298 100644 return err } iNdEx = postIndex -@@ -7651,12 +7593,6 @@ func (m *WriteRequest) Unmarshal(dAtA []byte) error { +@@ -7656,12 +7593,6 @@ func (m *WriteRequest) Unmarshal(dAtA []byte) error { if iNdEx > l { return io.ErrUnexpectedEOF } @@ -217,7 +222,7 @@ index 3a886f5049..dda8609298 100644 return nil } func (m *WriteResponse) Unmarshal(dAtA []byte) error { -@@ -7924,11 +7860,9 @@ func (m *TimeSeries) Unmarshal(dAtA []byte) error { +@@ -7929,11 +7860,9 @@ func (m *TimeSeries) Unmarshal(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } @@ -232,7 +237,7 @@ index 3a886f5049..dda8609298 100644 } iNdEx = postIndex case 4: -@@ -11237,10 +11171,6 @@ func (m *WriteRequestRW2) Unmarshal(dAtA []byte) error { +@@ -11242,10 +11171,6 @@ func (m *WriteRequestRW2) Unmarshal(dAtA []byte) error { return nil } func (m *TimeSeriesRW2) Unmarshal(dAtA []byte) error { @@ -243,7 +248,7 @@ index 3a886f5049..dda8609298 100644 l := len(dAtA) iNdEx := 0 for iNdEx < l { -@@ -11271,7 +11201,22 @@ func (m *TimeSeries) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadat +@@ -11276,7 +11201,22 @@ func (m *TimeSeries) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadat switch fieldNum { case 1: if wireType == 0 { @@ -267,7 +272,7 @@ index 3a886f5049..dda8609298 100644 } else if wireType == 2 { var packedLen int for shift := uint(0); ; shift += 7 { -@@ -11306,14 +11251,9 @@ func (m *TimeSeries) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadat +@@ -11311,14 +11251,9 @@ func (m *TimeSeries) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadat } } elementCount = count @@ -284,7 +289,7 @@ index 3a886f5049..dda8609298 100644 for iNdEx < postIndex { var v uint32 for shift := uint(0); ; shift += 7 { -@@ -11330,27 +11270,7 @@ func (m *TimeSeries) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadat +@@ -11335,27 +11270,7 @@ func (m *TimeSeries) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadat break } } @@ -313,7 +318,7 @@ index 3a886f5049..dda8609298 100644 } } else { return fmt.Errorf("proto: wrong wireType = %d for field LabelsRefs", wireType) -@@ -11452,11 +11372,9 @@ func (m *TimeSeries) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadat +@@ -11457,11 +11372,9 @@ func (m *TimeSeries) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadat if postIndex > l { return io.ErrUnexpectedEOF } @@ -328,7 +333,7 @@ index 3a886f5049..dda8609298 100644 } iNdEx = postIndex case 5: -@@ -11488,7 +11406,7 @@ func (m *TimeSeries) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadat +@@ -11493,7 +11406,7 @@ func (m *TimeSeries) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadat if postIndex > l { return io.ErrUnexpectedEOF } @@ -337,7 +342,7 @@ index 3a886f5049..dda8609298 100644 return err } iNdEx = postIndex -@@ -11533,10 +11451,6 @@ func (m *TimeSeries) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadat +@@ -11538,10 +11451,6 @@ func (m *TimeSeries) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadat return nil } func (m *ExemplarRW2) Unmarshal(dAtA []byte) error { @@ -348,7 +353,7 @@ index 3a886f5049..dda8609298 100644 l := len(dAtA) iNdEx := 0 for iNdEx < l { -@@ -11567,7 +11481,22 @@ func (m *Exemplar) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols) error { +@@ -11572,7 +11481,22 @@ func (m *Exemplar) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols) error { switch fieldNum { case 1: if wireType == 0 { @@ -372,7 +377,7 @@ index 3a886f5049..dda8609298 100644 } else if wireType == 2 { var packedLen int for shift := uint(0); ; shift += 7 { -@@ -11602,13 +11531,9 @@ func (m *Exemplar) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols) error { +@@ -11607,13 +11531,9 @@ func (m *Exemplar) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols) error { } } elementCount = count @@ -388,7 +393,7 @@ index 3a886f5049..dda8609298 100644 for iNdEx < postIndex { var v uint32 for shift := uint(0); ; shift += 7 { -@@ -11625,20 +11550,7 @@ func (m *Exemplar) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols) error { +@@ -11630,20 +11550,7 @@ func (m *Exemplar) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols) error { break } } @@ -410,7 +415,7 @@ index 3a886f5049..dda8609298 100644 } } else { return fmt.Errorf("proto: wrong wireType = %d for field LabelsRefs", wireType) -@@ -11658,7 +11570,7 @@ func (m *Exemplar) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols) error { +@@ -11663,7 +11570,7 @@ func (m *Exemplar) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols) error { if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field Timestamp", wireType) } @@ -419,7 +424,7 @@ index 3a886f5049..dda8609298 100644 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowMimir -@@ -11668,7 +11580,7 @@ func (m *Exemplar) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols) error { +@@ -11673,7 +11580,7 @@ func (m *Exemplar) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols) error { } b := dAtA[iNdEx] iNdEx++ @@ -428,7 +433,7 @@ index 3a886f5049..dda8609298 100644 if b < 0x80 { break } -@@ -11695,16 +11607,6 @@ func (m *Exemplar) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols) error { +@@ -11700,16 +11607,6 @@ func (m *Exemplar) UnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols) error { return nil } func (m *MetadataRW2) Unmarshal(dAtA []byte) error { @@ -445,7 +450,7 @@ index 3a886f5049..dda8609298 100644 l := len(dAtA) iNdEx := 0 for iNdEx < l { -@@ -11737,7 +11639,7 @@ func MetricMetadataUnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadata +@@ -11742,7 +11639,7 @@ func MetricMetadataUnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadata if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field Type", wireType) } @@ -454,7 +459,7 @@ index 3a886f5049..dda8609298 100644 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowMimir -@@ -11747,7 +11649,7 @@ func MetricMetadataUnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadata +@@ -11752,7 +11649,7 @@ func MetricMetadataUnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadata } b := dAtA[iNdEx] iNdEx++ @@ -463,7 +468,7 @@ index 3a886f5049..dda8609298 100644 if b < 0x80 { break } -@@ -11756,7 +11658,7 @@ func MetricMetadataUnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadata +@@ -11761,7 +11658,7 @@ func MetricMetadataUnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadata if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field HelpRef", wireType) } @@ -472,7 +477,7 @@ index 3a886f5049..dda8609298 100644 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowMimir -@@ -11766,20 +11668,16 @@ func MetricMetadataUnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadata +@@ -11771,20 +11668,16 @@ func MetricMetadataUnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadata } b := dAtA[iNdEx] iNdEx++ @@ -495,7 +500,7 @@ index 3a886f5049..dda8609298 100644 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowMimir -@@ -11789,15 +11687,11 @@ func MetricMetadataUnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadata +@@ -11794,15 +11687,11 @@ func MetricMetadataUnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadata } b := dAtA[iNdEx] iNdEx++ @@ -512,7 +517,7 @@ index 3a886f5049..dda8609298 100644 default: iNdEx = preIndex skippy, err := skipMimir(dAtA[iNdEx:]) -@@ -11817,23 +11711,6 @@ func MetricMetadataUnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadata +@@ -11822,23 +11711,6 @@ func MetricMetadataUnmarshalRW2(dAtA []byte, symbols *rw2PagedSymbols, metadata if iNdEx > l { return io.ErrUnexpectedEOF } diff --git a/pkg/mimirpb/prealloc_rw2_test.go b/pkg/mimirpb/prealloc_rw2_test.go index 96503905a74..112968a7906 100644 --- a/pkg/mimirpb/prealloc_rw2_test.go +++ b/pkg/mimirpb/prealloc_rw2_test.go @@ -684,7 +684,7 @@ func TestWriteRequestRW2Conversion_WriteRequestHasChanged(t *testing.T) { } // If the fields of WriteRequest have changed, then you will probably need to modify - // the FromWriteRequestToRW2Request() implementation accordingly! + // the [FromWriteRequestToRW2Request] implementation accordingly! assert.ElementsMatch(t, []string{ "Timeseries", "Source", @@ -699,5 +699,6 @@ func TestWriteRequestRW2Conversion_WriteRequestHasChanged(t *testing.T) { "unmarshalFromRW2", "rw2symbols", "BufferHolder", + "sourceBufferHolders", }, fieldNames) } diff --git a/pkg/mimirpb/split_test.go b/pkg/mimirpb/split_test.go index b761f4f824d..7b9e5f7a4b1 100644 --- a/pkg/mimirpb/split_test.go +++ b/pkg/mimirpb/split_test.go @@ -434,7 +434,7 @@ func TestSplitWriteRequestByMaxMarshalSize_WriteRequestHasChanged(t *testing.T) } // If the fields of WriteRequest have changed, then you will probably need to modify - // the SplitWriteRequestByMaxMarshalSize() and SplitWriteRequestByMaxMarshalSize() implementations accordingly! + // the [SplitWriteRequestByMaxMarshalSize] and [SplitWriteRequestByMaxMarshalSizeRW2] implementations accordingly! assert.ElementsMatch(t, []string{ "Timeseries", "Source", @@ -449,6 +449,7 @@ func TestSplitWriteRequestByMaxMarshalSize_WriteRequestHasChanged(t *testing.T) "unmarshalFromRW2", "rw2symbols", "BufferHolder", + "sourceBufferHolders", }, fieldNames) } diff --git a/pkg/mimirpb/testutil/buffer.go b/pkg/mimirpb/testutil/buffer.go new file mode 100644 index 00000000000..9a6be8a3155 --- /dev/null +++ b/pkg/mimirpb/testutil/buffer.go @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: AGPL-3.0-only + +package testutil + +import ( + "fmt" + + "go.uber.org/atomic" + "google.golang.org/grpc/mem" + + "github.com/grafana/mimir/pkg/mimirpb" +) + +// TrackBufferRefCount instruments the WriteRequest's underlying buffer's +// reference count. It can be inspected with [BufferRefCount]. +func TrackBufferRefCount(wr *mimirpb.WriteRequest) { + buf := wr.Buffer() + if buf == nil { + // Just set some fake buffer. We care about the reference count, not the + // data itself. + buf = mem.SliceBuffer([]byte("fake data")) + } + ibuf := &memBufferWithInstrumentedRefCount{Buffer: buf} + ibuf.refCount.Add(1) // Match the refCount of buf.Buffer + wr.SetBuffer(ibuf) +} + +// BufferRefCount returns the current reference of a WriteRequest's buffer. +// [TrackBufferRefCount] must have been called on the WriteRequest beforehand. +func BufferRefCount(wr *mimirpb.WriteRequest) int { + switch b := wr.Buffer().(type) { + case *memBufferWithInstrumentedRefCount: + return int(b.refCount.Load()) + default: + panic(fmt.Errorf("expected WriteRequest.Buffer to be a *memBufferWithInstrumentedRefCount from asDeserializedWriteRequest, got %T", wr.Buffer())) + } +} + +type memBufferWithInstrumentedRefCount struct { + mem.Buffer + refCount atomic.Int64 +} + +func (b *memBufferWithInstrumentedRefCount) Ref() { + b.Buffer.Ref() + + b.refCount.Add(1) +} + +func (b *memBufferWithInstrumentedRefCount) Free() { + b.Buffer.Free() + + refCount := b.refCount.Sub(1) + if refCount < 0 { + panic("memBufferWithInstrumentedRefCount reference count below zero") + } +} diff --git a/pkg/storage/ingest/pusher.go b/pkg/storage/ingest/pusher.go index 3bd76116697..d34423744cb 100644 --- a/pkg/storage/ingest/pusher.go +++ b/pkg/storage/ingest/pusher.go @@ -432,7 +432,15 @@ func LabelAdaptersHash(b []byte, ls []mimirpb.LabelAdapter) ([]byte, uint64) { // PushToStorageAndReleaseRequest hashes each time series in the write requests and sends them to the appropriate shard which is then handled by the current batchingQueue in that shard. // PushToStorageAndReleaseRequest ignores SkipLabelNameValidation because that field is only used in the distributor and not in the ingester. // PushToStorageAndReleaseRequest aborts the request if it encounters an error. +// +// Even though it's called "...AndReleaseRequest", while it does release the WriteRequest itself, the underlying buffer's +// ownership is transferred to one or more shards's currentBatch. This is required since those currentBatches will still +// reference the underlying buffer. As a result, the WriteRequest:buffer ownership changes from 1:1 to M:N, as one buffer +// will possibly be referenced from multiple shards, end each shard will possibly reference multiple buffers. The buffer's +// ownership is tracked via mem.Buffer's reference counting, and finally freed once all shards's currentBatches are pushed. func (p *parallelStorageShards) PushToStorageAndReleaseRequest(ctx context.Context, request *mimirpb.WriteRequest) error { + defer request.FreeBuffer() + // Shard series by the hash of their labels. Skip sharding and always append series to // the first shard when there's only one shard. var hashBuf []byte @@ -446,7 +454,7 @@ func (p *parallelStorageShards) PushToStorageAndReleaseRequest(ctx context.Conte shard = shard % uint64(p.numShards) } - if err := p.shards[shard].AddToBatch(ctx, request.Source, request.Timeseries[i]); err != nil { + if err := p.shards[shard].AddToBatch(ctx, request.Source, request.Timeseries[i], &request.BufferHolder); err != nil { return fmt.Errorf("encountered a non-client error when ingesting; this error was for a previous write request for the same tenant: %w", err) } // We're transferring ownership of the timeseries to the batch, clear the slice as we go so we can reuse it. @@ -464,7 +472,7 @@ func (p *parallelStorageShards) PushToStorageAndReleaseRequest(ctx context.Conte shard = rand.IntN(p.numShards) } for mdIdx := range request.Metadata { - if err := p.shards[shard].AddMetadataToBatch(ctx, request.Source, request.Metadata[mdIdx]); err != nil { + if err := p.shards[shard].AddMetadataToBatch(ctx, request.Source, request.Metadata[mdIdx], &request.BufferHolder); err != nil { return fmt.Errorf("encountered a non-client error when ingesting; this error was for a previous write request for the same tenant: %w", err) } if p.numShards > 1 { @@ -646,25 +654,33 @@ func newBatchingQueue(capacity int, batchSize int, metrics *batchingQueueMetrics // AddToBatch adds a time series to the current batch. If the batch size is reached, the batch is pushed to the Channel(). // If an error occurs while pushing the batch, it returns the error and ensures the batch is pushed. -func (q *batchingQueue) AddToBatch(ctx context.Context, source mimirpb.WriteRequest_SourceEnum, ts mimirpb.PreallocTimeseries) error { +func (q *batchingQueue) AddToBatch(ctx context.Context, source mimirpb.WriteRequest_SourceEnum, ts mimirpb.PreallocTimeseries, bufh *mimirpb.BufferHolder) error { if q.currentBatch.startedAt.IsZero() { q.currentBatch.startedAt = time.Now() } q.currentBatch.Timeseries = append(q.currentBatch.Timeseries, ts) q.currentBatch.Context = ctx q.currentBatch.Source = source + // Because currentBatch now contains references to the original buffer, we + // need to add it as a source buffer (which adds a strong reference) to + // avoid use-after-free. + q.currentBatch.AddSourceBufferHolder(bufh) return q.pushIfFull() } // AddMetadataToBatch adds metadata to the current batch. -func (q *batchingQueue) AddMetadataToBatch(ctx context.Context, source mimirpb.WriteRequest_SourceEnum, metadata *mimirpb.MetricMetadata) error { +func (q *batchingQueue) AddMetadataToBatch(ctx context.Context, source mimirpb.WriteRequest_SourceEnum, metadata *mimirpb.MetricMetadata, bufh *mimirpb.BufferHolder) error { if q.currentBatch.startedAt.IsZero() { q.currentBatch.startedAt = time.Now() } q.currentBatch.Metadata = append(q.currentBatch.Metadata, metadata) q.currentBatch.Context = ctx q.currentBatch.Source = source + // Because currentBatch now contains references to the original buffer, we + // need to add it as a source buffer (which adds a strong reference) to + // avoid use-after-free. + q.currentBatch.AddSourceBufferHolder(bufh) return q.pushIfFull() } diff --git a/pkg/storage/ingest/pusher_test.go b/pkg/storage/ingest/pusher_test.go index 95ad2dde7ac..4c8a6665410 100644 --- a/pkg/storage/ingest/pusher_test.go +++ b/pkg/storage/ingest/pusher_test.go @@ -32,6 +32,7 @@ import ( "google.golang.org/grpc/codes" "github.com/grafana/mimir/pkg/mimirpb" + mimirpb_testutil "github.com/grafana/mimir/pkg/mimirpb/testutil" util_log "github.com/grafana/mimir/pkg/util/log" "github.com/grafana/mimir/pkg/util/test" ) @@ -609,10 +610,37 @@ type mockPusher struct { } func (m *mockPusher) PushToStorageAndReleaseRequest(ctx context.Context, request *mimirpb.WriteRequest) error { - args := m.Called(ctx, request) + args := m.Called(ctx, comparableWriteRequest(request)) + request.FreeBuffer() return args.Error(0) } +func comparableWriteRequest(wr *mimirpb.WriteRequest) *mimirpb.WriteRequest { + b, err := wr.Marshal() + if err != nil { + panic(err) + } + var clone mimirpb.WriteRequest + err = clone.Unmarshal(b) + if err != nil { + panic(err) + } + return &clone +} + +func asDeserializedWriteRequest(wr *mimirpb.WriteRequest) *mimirpb.WriteRequest { + b, err := wr.Marshal() + if err != nil { + panic(err) + } + var preq mimirpb.PreallocWriteRequest + err = DeserializeRecordContent(b, &preq, 1) + if err != nil { + panic(err) + } + return &preq.WriteRequest +} + func (m *mockPusher) NotifyPreCommit(_ context.Context) error { args := m.Called() return args.Error(0) @@ -950,10 +978,21 @@ func TestParallelStorageShards_ShardWriteRequest(t *testing.T) { errorHandler := newPushErrorHandler(metrics, nil, log.NewNopLogger()) shardingP := newParallelStorageShards(metrics, errorHandler, tc.shardCount, tc.batchSize, buffer, pusher) + for i, req := range tc.requests { + req = asDeserializedWriteRequest(req) + mimirpb_testutil.TrackBufferRefCount(req) + tc.requests[i] = req + + // When everything's done, check that there are no buffer leaks. + defer func() { + require.Equal(t, 0, mimirpb_testutil.BufferRefCount(req)) + }() + } + upstreamPushErrsCount := 0 for i, req := range tc.expectedUpstreamPushes { err := tc.upstreamPushErrs[i] - pusher.On("PushToStorageAndReleaseRequest", mock.Anything, req).Return(err) + pusher.On("PushToStorageAndReleaseRequest", mock.Anything, comparableWriteRequest(req)).Return(err) if err != nil { upstreamPushErrsCount++ } @@ -1379,7 +1418,7 @@ func TestBatchingQueue_NoDeadlock(t *testing.T) { // Add items to the queue for i := 0; i < batchSize*(capacity+1); i++ { - require.NoError(t, queue.AddToBatch(ctx, mimirpb.API, series)) + require.NoError(t, queue.AddToBatch(ctx, mimirpb.API, series, &mimirpb.BufferHolder{})) } // Close the queue to signal no more items will be added @@ -1426,7 +1465,7 @@ func TestBatchingQueue(t *testing.T) { queue := setupQueue(t, capacity, batchSize, series) series3 := mockPreallocTimeseries("series_3") - require.NoError(t, queue.AddToBatch(context.Background(), mimirpb.API, series3)) + require.NoError(t, queue.AddToBatch(context.Background(), mimirpb.API, series3, &mimirpb.BufferHolder{})) select { case batch := <-queue.Channel(): @@ -1480,7 +1519,7 @@ func TestBatchingQueue(t *testing.T) { // Add items to the queue until it's full. for i := 0; i < capacity*batchSize; i++ { s := mockPreallocTimeseries(fmt.Sprintf("series_%d", i)) - require.NoError(t, queue.AddToBatch(context.Background(), mimirpb.API, s)) + require.NoError(t, queue.AddToBatch(context.Background(), mimirpb.API, s, &mimirpb.BufferHolder{})) } // We should have 5 items in the queue channel and 0 items in the currentBatch. @@ -1497,9 +1536,9 @@ func TestBatchingQueue(t *testing.T) { // Add three more items to fill up the queue again, this shouldn't block. s := mockPreallocTimeseries("series_100") - require.NoError(t, queue.AddToBatch(context.Background(), mimirpb.API, s)) - require.NoError(t, queue.AddToBatch(context.Background(), mimirpb.API, s)) - require.NoError(t, queue.AddToBatch(context.Background(), mimirpb.API, s)) + require.NoError(t, queue.AddToBatch(context.Background(), mimirpb.API, s, &mimirpb.BufferHolder{})) + require.NoError(t, queue.AddToBatch(context.Background(), mimirpb.API, s, &mimirpb.BufferHolder{})) + require.NoError(t, queue.AddToBatch(context.Background(), mimirpb.API, s, &mimirpb.BufferHolder{})) require.Len(t, queue.ch, 5) require.Len(t, queue.currentBatch.Timeseries, 0) @@ -1530,10 +1569,10 @@ func TestBatchingQueue(t *testing.T) { } // Add timeseries with exemplars to the queue - require.NoError(t, queue.AddToBatch(context.Background(), mimirpb.API, timeSeries)) + require.NoError(t, queue.AddToBatch(context.Background(), mimirpb.API, timeSeries, &mimirpb.BufferHolder{})) // Add metadata to the queue - require.NoError(t, queue.AddMetadataToBatch(context.Background(), mimirpb.API, md)) + require.NoError(t, queue.AddMetadataToBatch(context.Background(), mimirpb.API, md, &mimirpb.BufferHolder{})) // Read the batch from the queue select { @@ -1572,14 +1611,14 @@ func TestBatchingQueue_ErrorHandling(t *testing.T) { ctx := context.Background() // Push 1 series so that the next push will complete the batch. - require.NoError(t, queue.AddToBatch(ctx, mimirpb.API, series2)) + require.NoError(t, queue.AddToBatch(ctx, mimirpb.API, series2, &mimirpb.BufferHolder{})) // Push an error to fill the error channel. queue.ReportError(fmt.Errorf("mock error 1")) queue.ReportError(fmt.Errorf("mock error 2")) // AddToBatch should return an error now. - err := queue.AddToBatch(ctx, mimirpb.API, series2) + err := queue.AddToBatch(ctx, mimirpb.API, series2, &mimirpb.BufferHolder{}) assert.Equal(t, "2 errors: mock error 1; mock error 2", err.Error()) // Also the batch was pushed. select { @@ -1591,8 +1630,8 @@ func TestBatchingQueue_ErrorHandling(t *testing.T) { } // AddToBatch should work again. - require.NoError(t, queue.AddToBatch(ctx, mimirpb.API, series2)) - require.NoError(t, queue.AddToBatch(ctx, mimirpb.API, series2)) + require.NoError(t, queue.AddToBatch(ctx, mimirpb.API, series2, &mimirpb.BufferHolder{})) + require.NoError(t, queue.AddToBatch(ctx, mimirpb.API, series2, &mimirpb.BufferHolder{})) }) t.Run("Any errors pushed after last AddToBatch call are received on Close", func(t *testing.T) { @@ -1600,7 +1639,7 @@ func TestBatchingQueue_ErrorHandling(t *testing.T) { ctx := context.Background() // Add a batch to a batch but make sure nothing is pushed., - require.NoError(t, queue.AddToBatch(ctx, mimirpb.API, series1)) + require.NoError(t, queue.AddToBatch(ctx, mimirpb.API, series1, &mimirpb.BufferHolder{})) select { case <-queue.Channel(): @@ -1636,7 +1675,7 @@ func setupQueue(t *testing.T, capacity, batchSize int, series []mimirpb.Prealloc queue := newBatchingQueue(capacity, batchSize, m) for _, s := range series { - require.NoError(t, queue.AddToBatch(context.Background(), mimirpb.API, s)) + require.NoError(t, queue.AddToBatch(context.Background(), mimirpb.API, s, &mimirpb.BufferHolder{})) } return queue