diff --git a/lambda-template.yaml b/lambda-template.yaml index 6da697d..c8e8e0a 100644 --- a/lambda-template.yaml +++ b/lambda-template.yaml @@ -26,6 +26,8 @@ Metadata: default: 'Log Groups name & filter (if applicable)' S3BucketNames: default: 'S3 bucket Names & Prefix (if applicable)' + SNSTopicArn: + default: 'SNS Topic ARN (if applicable)' CommonAttributes: default: 'Common Attributes to be added to the log events' StoreNRLicenseKeyInSecretManager: @@ -55,6 +57,10 @@ Parameters: Type: String Description: "JSON array of objects representing your S3Bucketname and prefixes (if applicable) For example: [{\"bucket\":\"bucket1\",\"prefix\":\"prefix/\"}]" Default: "" + SNSTopicArn: + Type: String + Description: "SNS Topic Arn that will trigger the lambda function (if applicable)" + Default: "" LogGroupConfig: Description: "JSON array of objects representing your LogGroup and Filters (if applicable). For example: [{\"LogGroupName\":\"logGroup1\",\"FilterPattern\":\"filter1\"}]" Type: String @@ -73,6 +79,7 @@ Parameters: Conditions: ShouldCreateSecret: !Equals [ !Ref StoreNRLicenseKeyInSecretManager, "true" ] AddS3Trigger: !Not [ !Equals [!Ref S3BucketNames , ""]] + AddSNSTrigger: !Not [ !Equals [!Ref SNSTopicArn , ""]] AddCloudwatchTrigger: !Not [ !Equals [!Ref LogGroupConfig , ""]] IsCommonAttributesNotBlank: !Not [!Equals [!Ref CommonAttributes, ""]] @@ -442,6 +449,15 @@ Resources: S3BucketArns: !GetAtt NewRelicLogsResourceForS3ARNConstruction.BucketARNs LambdaFunctionArn: !GetAtt NewRelicLogsServerlessLogForwarder.Arn + NewRelicLogsSNSTrigger: + Type: 'AWS::CloudFormation::Stack' + Condition: AddSNSTrigger + Properties: + TemplateURL: sns-trigger-stack.yaml + Parameters: + SNSTopicArn: !Ref SNSTopicArn + LambdaFunctionArn: !GetAtt NewRelicLogsServerlessLogForwarder.Arn + NewRelicLogsCloudWatchTrigger: Type: 'AWS::CloudFormation::Stack' Condition: AddCloudwatchTrigger diff --git a/sns-trigger-stack.yaml b/sns-trigger-stack.yaml new file mode 100644 index 0000000..346cbef --- /dev/null +++ b/sns-trigger-stack.yaml @@ -0,0 +1,27 @@ +AWSTemplateFormatVersion: '2010-09-09' + +Parameters: + SNSTopicArn: + Type: String + Description: ARN of the SNS topic which will act as trigger to lambda + LambdaFunctionArn: + Type: String + Description: Lambda arn to add event trigger + + +Resources: + NewRelicSnsSubscription: + Type: AWS::SNS::Subscription + Properties: + TopicArn: !Ref SNSTopicArn + Protocol: lambda + Endpoint: !Ref LambdaFunctionArn + + NewRelicSnsInvokePermission: + Type: AWS::Lambda::Permission + Properties: + Action: "lambda:InvokeFunction" + FunctionName: !Ref LambdaFunctionArn + Principal: "sns.amazonaws.com" + SourceArn: !Ref SNSTopicArn + diff --git a/src/main.go b/src/main.go index 2db7275..244ee94 100644 --- a/src/main.go +++ b/src/main.go @@ -45,6 +45,14 @@ func handlerWithArgs(ctx context.Context, event unmarshal.Event, nrClient util.N log.Fatalf("error creating s3 client: %v", err) } err = s3.GetLogsFromS3Event(ctx, event.S3Event, awsConfiguration, channel, s3Client, s3.DefaultReaderFactory) + case unmarshal.SNS: + log.Debugf("processing sns event: %v", event.SNSEvent) + var s3Client s3.ObjectClient + s3Client, err = s3.NewS3Client(ctx) + if err != nil { + log.Fatalf("error creating s3 client: %v", err) + } + err = s3.GetLogsFromSNSEvent(ctx, event.SNSEvent, awsConfiguration, channel, s3Client, s3.DefaultReaderFactory) default: log.Error("unable to process unknown event type. Supported event types are cloudwatch and s3") return nil diff --git a/src/s3/s3.go b/src/s3/s3.go index 454aaa6..6727b47 100644 --- a/src/s3/s3.go +++ b/src/s3/s3.go @@ -6,6 +6,7 @@ import ( "compress/bzip2" "compress/gzip" "context" + "encoding/json" "io" "os" "regexp" @@ -62,6 +63,65 @@ func GetLogsFromS3Event(ctx context.Context, s3Event events.S3Event, awsConfigur return nil } +// GetLogsFromSNSEvent batches logs from SNS into DetailedJson format and sends them to the specified channel. +// It returns an error if there is a problem retrieving or sending the logs. +func GetLogsFromSNSEvent(ctx context.Context, snsEvent events.SNSEvent, awsConfiguration util.AWSConfiguration, channel chan common.DetailedLogsBatch, s3Client ObjectClient, readerFactory ReaderFactory) error { + for _, record := range snsEvent.Records { + + // Unmarshal the Message field into a json array + var messageData struct { + Records []struct { + S3 struct { + Bucket struct { + Name string `json:"name"` + } + Object struct { + Key string `json:"key"` + } + } + } + } + + err := json.Unmarshal([]byte(record.SNS.Message), &messageData) + if err != nil { + log.Errorf("failed to unmarshal SNS message: %v", err) + continue + } + + if len(messageData.Records) != 0 { + for _, msg := range messageData.Records { + log.Debugf("processing sns event message: %v", msg) + + // The Following are the common attributes for all log messages. + // New Relic uses these common attributes to generate Unique Entity ID. + attributes := common.LogAttributes{ + "aws.accountId": awsConfiguration.AccountID, + "logBucketName": msg.S3.Bucket.Name, + "logObjectKey": msg.S3.Object.Key, + "aws.realm": awsConfiguration.Realm, + "aws.region": awsConfiguration.Region, + "instrumentation.provider": common.InstrumentationProvider, + "instrumentation.name": common.InstrumentationName, + "instrumentation.version": common.InstrumentationVersion, + } + + if err := util.AddCustomMetaData(os.Getenv(common.CustomMetaData), attributes); err != nil { + log.Errorf("failed to add custom metadata %v", err) + return err + } + + if err := buildMeltLogsFromS3Bucket(ctx, msg.S3.Bucket.Name, msg.S3.Object.Key, channel, attributes, s3Client, readerFactory); err != nil { + return err + } + } + } else { + log.Debugf("SNS event Message field contains no records") + } + } + + return nil +} + // fetchS3Reader fetches an S3 object from the specified bucket and returns an io.ReadCloser for reading its contents. // It returns the io.ReadCloser and any error encountered during the operation. func fetchS3Reader(ctx context.Context, bucketName string, objectName string, s3Client ObjectClient) (io.ReadCloser, error) { diff --git a/src/s3/s3_test.go b/src/s3/s3_test.go index 38e152a..30ec416 100644 --- a/src/s3/s3_test.go +++ b/src/s3/s3_test.go @@ -7,9 +7,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/dsnet/compress/bzip2" - "github.com/newrelic/aws-unified-lambda-logging/common" - "github.com/newrelic/aws-unified-lambda-logging/util" "io" "strings" "testing" @@ -17,6 +14,9 @@ import ( "github.com/aws/aws-lambda-go/events" "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/dsnet/compress/bzip2" + "github.com/newrelic/aws-unified-lambda-logging/common" + "github.com/newrelic/aws-unified-lambda-logging/util" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -243,6 +243,144 @@ func TestGetLogsFromS3Event(t *testing.T) { } } +// TestGetLogsFromSNSEvent is a unit test function that tests the GetLogsFromSNSEvent function. +// It tests different scenarios of SNS event processing and verifies the expected results. +func TestGetLogsFromSNSEvent(t *testing.T) { + tests := []struct { + name string // Name of the test case + setupSNSMock func(*MockAPI) // Function to set up the SNS mock + setupRFMock func(*MockReaderFactory) // Function to set up the ReaderFactory mock + expectedError error // Expected error from the function + batchSize int // Expected number of batches + Key string // Key of the SNS Message object + }{ + { + name: "Successful SNS event processing", + setupSNSMock: func(m *MockAPI) { + m.On("GetObject", mock.Anything, mock.Anything).Return(&s3.GetObjectOutput{ + Body: io.NopCloser(bytes.NewReader([]byte("log content"))), + }, nil) + }, + setupRFMock: func(m *MockReaderFactory) { + m.On("Create", mock.Anything, "test-key.gz").Return(strings.NewReader("log content"), nil) + }, + expectedError: nil, + batchSize: 1, + Key: "test-key.gz", + }, + { + name: "Error fetching S3 object from SNS event", + setupSNSMock: func(m *MockAPI) { + m.On("GetObject", mock.Anything, mock.Anything).Return(&s3.GetObjectOutput{}, errors.New("s3 error")) + }, + setupRFMock: func(m *MockReaderFactory) {}, + expectedError: errors.New("s3 error"), + Key: "test-key.gz", + }, + { + name: "Successful SNS S3 event processing. Used to test the maximum number of messages in a batch.", + setupSNSMock: func(m *MockAPI) { + m.On("GetObject", mock.Anything, mock.Anything).Return(&s3.GetObjectOutput{ + Body: io.NopCloser(bytes.NewReader([]byte(generateLogsOnCount(common.MaxPayloadMessages + 10)))), + }, nil) + }, + setupRFMock: func(m *MockReaderFactory) { + m.On("Create", mock.Anything, "test-key.gz").Return(strings.NewReader(generateLogsOnCount(common.MaxPayloadMessages+10)), nil) + }, + expectedError: nil, + batchSize: 2, + Key: "test-key.gz", + }, + { + name: "Successful SNS S3 event processing. Used to test the maximum payload size in a batch.", + setupSNSMock: func(m *MockAPI) { + m.On("GetObject", mock.Anything, mock.Anything).Return(&s3.GetObjectOutput{ + Body: io.NopCloser(bytes.NewReader([]byte(generateLogOnSize(1024*1024*1 + 10)))), + }, nil) + }, + setupRFMock: func(m *MockReaderFactory) { + m.On("Create", mock.Anything, "test-key.gz").Return(strings.NewReader(generateLogOnSize(1024*1024*1+10)), nil) + }, + expectedError: nil, + batchSize: 2, + Key: "test-key.gz", + }, + { + name: "CloudTrail Digest Ignore Scenario.", + setupSNSMock: func(m *MockAPI) {}, + setupRFMock: func(m *MockReaderFactory) {}, + expectedError: nil, + batchSize: 0, + Key: "test-key_CloudTrail-Digest_2021-09-01T00-00-00Z.json.gz", + }, + { + name: "Reading CloudTrail logs from SNS event", + setupSNSMock: func(m *MockAPI) { + m.On("GetObject", mock.Anything, mock.Anything).Return(&s3.GetObjectOutput{ + Body: io.NopCloser(bytes.NewReader([]byte(generateCloudTrailTestLogs(4)))), + }, nil) + }, + setupRFMock: func(m *MockReaderFactory) { + m.On("Create", mock.Anything, "test-key_CloudTrail_2021-09-01T00-00-00Z.json.gz").Return(strings.NewReader(generateCloudTrailTestLogs(4)), nil) + }, + expectedError: nil, + batchSize: 1, + Key: "test-key_CloudTrail_2021-09-01T00-00-00Z.json.gz", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + awsConfiguration := util.AWSConfiguration{ + AccountID: "123456789012", + Realm: "aws", + Region: "us-west-2", + } + + channel := make(chan common.DetailedLogsBatch, 2) + messageBytes := []byte(fmt.Sprintf("{\"Records\":[{\"s3\":{\"bucket\":{\"name\":\"nr-log-test-bucket\"},\"object\":{\"key\":\"%s\"}}}]}", tc.Key)) + snsEvent := events.SNSEvent{ + Records: []events.SNSEventRecord{ + { + SNS: events.SNSEntity{ + Message: string(messageBytes), + }, + }, + }, + } + + mockSNSClient := new(MockAPI) + tc.setupSNSMock(mockSNSClient) + + mockReaderFactory := new(MockReaderFactory) + tc.setupRFMock(mockReaderFactory) + + // Call the GetLogsFromS3Event function + err := GetLogsFromSNSEvent(ctx, snsEvent, awsConfiguration, channel, mockSNSClient, mockReaderFactory.Create) + close(channel) + + // Check for expected errors + if tc.expectedError != nil { + assert.Error(t, err) + assert.EqualError(t, err, tc.expectedError.Error()) + } else { + assert.NoError(t, err) + batchCount := 0 + for batch := range channel { + assert.NotEmpty(t, batch) + batchCount++ + } + assert.Equal(t, tc.batchSize, batchCount) + } + + // Assert that all expectations were met + mockSNSClient.AssertExpectations(t) + mockReaderFactory.AssertExpectations(t) + }) + } +} + // compressBzip2 compresses the given data using the bzip2 algorithm and returns the compressed data. // It uses the bzip2.WriterConfig with the BestCompression level for optimal compression. func compressBzip2(data []byte) []byte { diff --git a/src/unmarshal/unmarshal.go b/src/unmarshal/unmarshal.go index 17c3886..a05b323 100644 --- a/src/unmarshal/unmarshal.go +++ b/src/unmarshal/unmarshal.go @@ -3,6 +3,7 @@ package unmarshal import ( "encoding/json" + "github.com/newrelic/aws-unified-lambda-logging/logger" "github.com/aws/aws-lambda-go/events" @@ -12,6 +13,7 @@ import ( const ( CLOUDWATCH = "cloudwatch" // CLOUDWATCH represents the event type for CloudWatch logs. S3 = "s3" // S3 represents the event type for S3 events. + SNS = "sns" // SNS represents the event type for SNS events. ) var log = logger.NewLogrusLogger(logger.WithDebugLevel()) @@ -21,6 +23,7 @@ type Event struct { EventType string // EventType represents the type of the event. CloudwatchLogsData events.CloudwatchLogsData // CloudwatchLogsData represents the CloudWatch logs data. S3Event events.S3Event // S3Event represents the S3 event data. + SNSEvent events.SNSEvent // SNSEvent represents the SNS event data. } // UnmarshalJSON unmarshals the JSON data into the Event struct. @@ -50,5 +53,15 @@ func (event *Event) UnmarshalJSON(data []byte) error { return err } + //Try to unmarshal the event as SNSEvent + var snsEvent events.SNSEvent + err = json.Unmarshal(data, &snsEvent) + if err == nil && len(snsEvent.Records) != 0 && snsEvent.Records[0].EventSource == "aws:sns" { + event.EventType = SNS + event.SNSEvent = snsEvent + + return err + } + return nil } diff --git a/src/unmarshal/unmarshal_test.go b/src/unmarshal/unmarshal_test.go index 8a7c9c1..69658a1 100644 --- a/src/unmarshal/unmarshal_test.go +++ b/src/unmarshal/unmarshal_test.go @@ -63,3 +63,38 @@ func TestUnmarshalJSONCloudWatchLogsData(t *testing.T) { assert.NotEqual(t, expected.EventType, event.EventType) assert.NotEqual(t, expected.CloudwatchLogsData, event.CloudwatchLogsData) } + +// TestUnmarshalJSONSNSEvent is a unit test function that tests the unmarshaling of a JSON SNS event. +// It verifies that the unmarshaled event matches the expected event. +func TestUnmarshalJSONSNSEvent(t *testing.T) { + input := []byte(`{ + "Records": [ + { + "EventSource": "aws:sns", + "Sns": { + "Message": "test message" + } + } + ] + }`) + expected := Event{ + EventType: SNS, + SNSEvent: events.SNSEvent{ + Records: []events.SNSEventRecord{ + { + EventSource: "aws:sns", + SNS: events.SNSEntity{ + Message: "test message", + }, + }, + }, + }, + } + + var event Event + err := json.Unmarshal(input, &event) + + assert.NoError(t, err) + assert.Equal(t, expected.EventType, event.EventType) + assert.Equal(t, expected.SNSEvent, event.SNSEvent) +}