@@ -2,6 +2,7 @@ package aws
2
2
3
3
import (
4
4
"context"
5
+ "fmt"
5
6
"strconv"
6
7
"testing"
7
8
"time"
@@ -11,6 +12,8 @@ import (
11
12
"github.com/aws/aws-sdk-go-v2/credentials"
12
13
"github.com/aws/aws-sdk-go-v2/service/sqs"
13
14
"github.com/aws/aws-sdk-go-v2/service/sqs/types"
15
+ "github.com/benthosdev/benthos/v4/public/service"
16
+ "github.com/stretchr/testify/assert"
14
17
"github.com/stretchr/testify/require"
15
18
)
16
19
@@ -81,6 +84,8 @@ func (m *mockSqsInput) ChangeMessageVisibilityBatch(ctx context.Context, input *
81
84
for _ , entry := range input .Entries {
82
85
if _ , found := m .mesTimeouts [* entry .Id ]; found {
83
86
m .mesTimeouts [* entry .Id ] = entry .VisibilityTimeout
87
+ } else {
88
+ panic ("nope" )
84
89
}
85
90
}
86
91
@@ -200,3 +205,82 @@ func TestSQSInput(t *testing.T) {
200
205
return msgsLen == 0
201
206
}, 5 * time .Second , time .Second )
202
207
}
208
+
209
+ func TestSQSInputBatchAck (t * testing.T ) {
210
+ tCtx := context .Background ()
211
+ defer tCtx .Done ()
212
+
213
+ messages := []types.Message {}
214
+ for i := 0 ; i < 101 ; i ++ {
215
+ messages = append (messages , types.Message {
216
+ Body : aws .String (fmt .Sprintf ("message-%v" , i )),
217
+ MessageId : aws .String (fmt .Sprintf ("id-%v" , i )),
218
+ ReceiptHandle : aws .String (fmt .Sprintf ("h-%v" , i )),
219
+ })
220
+ }
221
+ expectedMessages := len (messages )
222
+
223
+ conf , err := config .LoadDefaultConfig (context .Background (),
224
+ config .WithCredentialsProvider (credentials .NewStaticCredentialsProvider ("xxxxx" , "xxxxx" , "xxxxx" )),
225
+ )
226
+ require .NoError (t , err )
227
+
228
+ r , err := newAWSSQSReader (
229
+ sqsiConfig {
230
+ URL : "http://foo.example.com" ,
231
+ WaitTimeSeconds : 0 ,
232
+ DeleteMessage : true ,
233
+ ResetVisibility : true ,
234
+ MaxNumberOfMessages : 10 ,
235
+ },
236
+ conf ,
237
+ nil ,
238
+ )
239
+ require .NoError (t , err )
240
+
241
+ mockInput := & mockSqsInput {
242
+ mtx : make (chan struct {}, 1 ),
243
+ queueTimeout : 10 ,
244
+ messages : messages ,
245
+ mesTimeouts : make (map [string ]int32 , expectedMessages ),
246
+ }
247
+ mockInput .mtx <- struct {}{}
248
+ r .sqs = mockInput
249
+ go mockInput .TimeoutLoop (tCtx )
250
+
251
+ defer r .closeSignal .TriggerHardStop ()
252
+ err = r .Connect (tCtx )
253
+ require .NoError (t , err )
254
+
255
+ receivedMessageAcks := map [string ]service.AckFunc {}
256
+
257
+ for _ , eMsg := range messages {
258
+ m , aFn , err := r .Read (tCtx )
259
+ require .NoError (t , err )
260
+
261
+ mBytes , err := m .AsBytes ()
262
+ require .NoError (t , err )
263
+
264
+ assert .Equal (t , * eMsg .Body , string (mBytes ))
265
+ receivedMessageAcks [string (mBytes )] = aFn
266
+ }
267
+
268
+ // Check that messages haven't been deleted from the queue
269
+ mockInput .do (func () {
270
+ require .Len (t , mockInput .messages , expectedMessages )
271
+ require .Len (t , mockInput .mesTimeouts , expectedMessages )
272
+ })
273
+
274
+ // Ack all messages as a batch
275
+ for _ , aFn := range receivedMessageAcks {
276
+ require .NoError (t , aFn (tCtx , err ))
277
+ }
278
+
279
+ require .Eventually (t , func () bool {
280
+ msgsLen := 0
281
+ mockInput .do (func () {
282
+ msgsLen = len (mockInput .messages )
283
+ })
284
+ return msgsLen == 0
285
+ }, 5 * time .Second , time .Second )
286
+ }
0 commit comments