From 73876ecf989272411e6b6d4257087272069d25ac Mon Sep 17 00:00:00 2001 From: Avneesh Hota Date: Mon, 22 Jul 2024 19:54:56 +0530 Subject: [PATCH 1/2] send logs to mongo in batches --- db/logs_dao.go | 95 ++++++++++++++++++++++++++++++++++++++ db/mongo.go | 6 ++- main.go | 123 ++++++++++++++++++++++++++----------------------- 3 files changed, 165 insertions(+), 59 deletions(-) create mode 100644 db/logs_dao.go diff --git a/db/logs_dao.go b/db/logs_dao.go new file mode 100644 index 0000000..7db8b2b --- /dev/null +++ b/db/logs_dao.go @@ -0,0 +1,95 @@ +package db + +import ( + "context" + "fmt" + "log" + "sync" + "time" + + "go.mongodb.org/mongo-driver/mongo" +) + +type LogDocument struct { + Log string `bson:"log"` + Key string `bson:"key"` + Timestamp int64 `bson:"timestamp"` +} + +var ( + logBuffer []LogDocument + bufferLock sync.Mutex + batchSize = 1000 + logCollection *mongo.Collection + insertInterval = time.Second * 60 +) + +func init() { + collection, err := logsInstance() + if err != nil { + log.Fatalf("Error while getting mongo client for logs: %v", err) + } + logCollection = collection + + go periodicInsert() +} + +func logsInstance() (*mongo.Collection, error) { + client, err := GetMongoClient() + if err != nil { + fmt.Println("Error while getting mongo client for logs: " + err.Error()) + return nil, err + } + + return client.Database(AccountID).Collection(LogsCollectionName), nil +} + +func InsertLog(logString string, key string) { + log.Println(logString) + + logString += "MIRRORING: " + logString + + logDoc := LogDocument{ + Log: logString, + Key: key, + Timestamp: time.Now().Unix(), + } + + bufferLock.Lock() + logBuffer = append(logBuffer, logDoc) + if len(logBuffer) >= batchSize { + flushLogs() + } + bufferLock.Unlock() +} + +func flushLogs() { + if len(logBuffer) == 0 { + return + } + + _, err := logCollection.InsertMany(context.Background(), toInterfaceSlice(logBuffer)) + if err != nil { + fmt.Println("Error while inserting logs: " + err.Error()) + } else { + fmt.Println("Logs inserted successfully") + } + logBuffer = logBuffer[:0] // reset the buffer +} + +func toInterfaceSlice(logs []LogDocument) []interface{} { + interfaceSlice := make([]interface{}, len(logs)) + for i, d := range logs { + interfaceSlice[i] = d + } + return interfaceSlice +} + +func periodicInsert() { + for { + time.Sleep(insertInterval) + bufferLock.Lock() + flushLogs() + bufferLock.Unlock() + } +} diff --git a/db/mongo.go b/db/mongo.go index c338ebe..94425cf 100644 --- a/db/mongo.go +++ b/db/mongo.go @@ -2,11 +2,12 @@ package db import ( "context" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" "os" "strconv" "sync" + + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" ) var ( @@ -18,6 +19,7 @@ var ( var AccountID = strconv.Itoa(1_000_000) var TrafficMetricsCollectionName = "traffic_metrics" var AccountSettingsCollectionName = "accounts_settings" +var LogsCollectionName = "logs" func GetMongoClient() (*mongo.Client, error) { once.Do(func() { diff --git a/main.go b/main.go index 04cd15b..b69ff28 100644 --- a/main.go +++ b/main.go @@ -19,7 +19,6 @@ import ( "fmt" "io" "io/ioutil" - "log" "net/http" "os" "runtime" @@ -61,6 +60,9 @@ var ( err error ) +const INFO = "info" +const ERROR = "error" + // key is used to map bidirectional streams to each other. type key struct { net, transport gopacket.Flow @@ -115,12 +117,12 @@ func (f *myFactory) New(netFlow, tcpFlow gopacket.Flow) tcpassembly.Stream { bd := f.bidiMap[k] if bd == nil { bd = &bidi{a: s, key: k, vxlanID: f.vxlanID, source: f.source} - //log.Printf("[%v] created first side of bidirectional stream", bd.key) + //db.InsertLog(fmt.Sprintf("[%v] created first side of bidirectional stream", bd.key), INFO) // Register bidirectional with the reverse key, so the matching stream going // the other direction will find it. f.bidiMap[key{netFlow.Reverse(), tcpFlow.Reverse()}] = bd } else { - //log.Printf("[%v] found second side of bidirectional stream", bd.key) + //db.InsertLog(fmt.Sprintf("[%v] found second side of bidirectional stream", bd.key), INFO) bd.b = s // Clear out the bidi we're using from the map, just in case. delete(f.bidiMap, k) @@ -141,7 +143,7 @@ func (f *myFactory) collectOldStreams() { cutoff := time.Now().Add(-timeout) for k, bd := range f.bidiMap { if bd.lastPacketSeen.Before(cutoff) { - log.Printf("[%v] timing out old stream", bd.key) + db.InsertLog(fmt.Sprintf("[%v] timing out old stream", bd.key), INFO) bd.b = emptyStream // stub out b with an empty stream. delete(f.bidiMap, k) // remove it from our map. bd.maybeFinish() // if b was the last stream we were waiting for, finish up. @@ -362,9 +364,9 @@ func (bd *bidi) maybeFinish() { timeNow := time.Now() switch { case bd.a == nil: - //log.Fatalf("[%v] a should always be non-nil, since it's set when bidis are created", bd.key) + //db.InsertLog(fmt.Sprintf("[%v] a should always be non-nil, since it's set when bidis are created", bd.key), INFO) case bd.b == nil: - //log.Printf("[%v] no second stream yet", bd.key) + //db.InsertLog(fmt.Sprintf("[%v] no second stream yet", bd.key), INFO) default: if bd.a.done && bd.b.done { tryReadFromBD(bd, false) @@ -389,7 +391,7 @@ func createAndGetAssembler(vxlanID int, source string) *tcpassembly.Assembler { _assembler := assemblerMap[vxlanID] if _assembler == nil { - log.Println("creating assembler for vxlanID=", vxlanID) + db.InsertLog("creating assembler for vxlanID="+strconv.Itoa(vxlanID), INFO) // Set up assembly streamFactory := &myFactory{bidiMap: make(map[key]*bidi), vxlanID: vxlanID, source: source} streamPool := tcpassembly.NewStreamPool(streamFactory) @@ -401,7 +403,7 @@ func createAndGetAssembler(vxlanID int, source string) *tcpassembly.Assembler { factoryMap[vxlanID] = streamFactory assemblerMap[vxlanID] = _assembler - log.Println("created assembler for vxlanID=", vxlanID) + db.InsertLog("created assembler for vxlanID="+strconv.Itoa(vxlanID), INFO) } return _assembler @@ -413,17 +415,18 @@ var kafkaWriter *kafka.Writer func flushAll() { for _, v := range assemblerMap { v.FlushOlderThan(time.Now().Add(time.Second * -5)) - //log.Println("num flushed/closed:", r, k) - //log.Println("streams before closing: ", len(factoryMap[k].bidiMap)) + //db.InsertLog("num flushed/closed: "+fmt.Sprintf("%d", r)+", k: "+fmt.Sprintf("%d", k), INFO) + //db.InsertLog("streams before closing: "+fmt.Sprintf("%d", len(factoryMap[k].bidiMap)), INFO) //factoryMap[k].collectOldStreams() - //log.Println("streams after closing: ", len(factoryMap[k].bidiMap)) + //db.InsertLog("streams after closing: "+fmt.Sprintf("%d", len(factoryMap[k].bidiMap)), INFO) } } func run(handle *pcap.Handle, apiCollectionId int, source string) { if err := handle.SetBPFFilter("tcp && not (port 9092 or port 22)"); err != nil { // optional - log.Fatal(err) + db.InsertLog(err.Error(), ERROR) + os.Exit(1) return } @@ -437,9 +440,9 @@ func run(handle *pcap.Handle, apiCollectionId int, source string) { if len(maintainTrafficIpMapInput) > 0 { val, err := strconv.ParseBool(maintainTrafficIpMapInput) if err != nil { - fmt.Println("invalid value set for flag MAINTAIN_TRAFFIC_IP_MAP") + db.InsertLog("invalid value set for flag MAINTAIN_TRAFFIC_IP_MAP", ERROR) } - fmt.Println("setting MAINTAIN_TRAFFIC_IP_MAP = ", val) + db.InsertLog(fmt.Sprintf("setting MAINTAIN_TRAFFIC_IP_MAP = %s", strconv.FormatBool(val)), INFO) maintainTrafficIpMap = val } @@ -447,10 +450,10 @@ func run(handle *pcap.Handle, apiCollectionId int, source string) { if len(aktoMemThresh) > 0 { aktoMemThreshRestart, err = strconv.Atoi(aktoMemThresh) if err != nil { - log.Println("AKTO_MEM_THRESH_RESTART should be valid integer. Found ", aktoMemThresh) + db.InsertLog("AKTO_MEM_THRESH_RESTART should be valid integer. Found "+aktoMemThresh, ERROR) return } else { - log.Println("Setting akto mem threshold threshold at " + strconv.Itoa(aktoMemThreshRestart)) + db.InsertLog("Setting akto mem threshold threshold at "+strconv.Itoa(aktoMemThreshRestart), INFO) } } @@ -461,7 +464,7 @@ func run(handle *pcap.Handle, apiCollectionId int, source string) { for _, i := range ifaces { addrs, err := i.Addrs() if err != nil { - fmt.Print(fmt.Errorf("localAddresses: %+v\n", err.Error())) + db.InsertLog(fmt.Sprintf("localAddresses: %+v\n", err.Error()), ERROR) continue } for _, a := range addrs { @@ -470,7 +473,7 @@ func run(handle *pcap.Handle, apiCollectionId int, source string) { // Check if it's an IPv4 address if ipnet.IP.To4() != nil { // Compare the address with the target address - fmt.Printf("Interface addr %s\n", ipnet.IP.To4().String()) + db.InsertLog(fmt.Sprintf("Interface addr %s\n", ipnet.IP.To4().String()), INFO) interfaceMap[ipnet.IP.To4().String()] = true } } @@ -489,7 +492,7 @@ func run(handle *pcap.Handle, apiCollectionId int, source string) { innerPacket := packet vxlanID := apiCollectionId if innerPacket.NetworkLayer() == nil || innerPacket.TransportLayer() == nil || innerPacket.TransportLayer().LayerType() != layers.LayerTypeTCP { - printLog("not a tcp payload") + db.InsertLog("not a tcp payload", INFO) continue } else { tcp := innerPacket.TransportLayer().(*layers.TCP) @@ -538,22 +541,22 @@ func run(handle *pcap.Handle, apiCollectionId int, source string) { bytesIn += len(tcp.Payload) if bytesIn > bytesInThreshold { - log.Println("exceeded bytesInThreshold: ", bytesInThreshold, " with curr: ", bytesIn) - log.Println("limit reached, sleeping", time.Now()) + db.InsertLog(fmt.Sprintf("exceeded bytesInThreshold: %d with curr: %d", bytesInThreshold, bytesIn), INFO) + db.InsertLog(fmt.Sprintf("limit reached, sleeping at %s", time.Now().String()), INFO) - log.Println("logging memory stats before wipeout", time.Now()) + db.InsertLog("logging memory stats before wipeout", INFO) logMemoryStats() wipeOut() - log.Println("wipeout done", time.Now()) - log.Println("logging memory stats post wipeout", time.Now()) + db.InsertLog("wipeout done", INFO) + db.InsertLog("logging memory stats post wipeout", INFO) logMemoryStats() for k, v := range incomingReqSrcIpCountMap { - log.Printf("srcIp %s, total req %d", k, v) + db.InsertLog(fmt.Sprintf("srcIp %s, total req %d", k, v), INFO) } for k, v := range incomingReqDstIpCountMap { - log.Printf("dstIp %s, total req %d", k, v) + db.InsertLog(fmt.Sprintf("dstIp %s, total req %d", k, v), INFO) } bytesIn = 0 @@ -573,7 +576,7 @@ func run(handle *pcap.Handle, apiCollectionId int, source string) { if time.Now().Sub(kafkaErrMsgEpoch).Seconds() >= 10 { if kafkaErrMsgCount > 1000 { - log.Println("kafka error messages exceeded threshold, sleeping for 10 sec ", time.Now()) + db.InsertLog(fmt.Sprintf("kafka error messages exceeded threshold, sleeping for 10 sec at %s", time.Now().String()), ERROR) time.Sleep(10 * time.Second) } kafkaErrMsgCount = 0 @@ -588,7 +591,7 @@ func kafkaCompletion() func(messages []kafka.Message, err error) { return func(messages []kafka.Message, err error) { if err != nil { kafkaErrMsgCount += len(messages) - log.Printf("kafkaErrMsgCount : %d, messagesCount %d", kafkaErrMsgCount, len(messages)) + db.InsertLog(fmt.Sprintf("kafkaErrMsgCount : %d, messagesCount %d", kafkaErrMsgCount, len(messages)), ERROR) } } } @@ -599,29 +602,29 @@ func initKafka() { if len(kafka_url) == 0 { kafka_url = os.Getenv("AKTO_KAFKA_BROKER_URL") } - printLog("kafka_url: " + kafka_url) + db.InsertLog("kafka_url: "+kafka_url, INFO) bytesInThresholdInput := os.Getenv("AKTO_BYTES_IN_THRESHOLD") if len(bytesInThresholdInput) > 0 { bytesInThreshold, err = strconv.Atoi(bytesInThresholdInput) if err != nil { - printLog("AKTO_BYTES_IN_THRESHOLD should be valid integer. Found " + bytesInThresholdInput) + db.InsertLog("AKTO_BYTES_IN_THRESHOLD should be valid integer. Found "+bytesInThresholdInput, ERROR) return } else { - printLog("Setting bytes in threshold at " + strconv.Itoa(bytesInThreshold)) + db.InsertLog("Setting bytes in threshold at "+strconv.Itoa(bytesInThreshold), INFO) } } kafka_batch_size, e := strconv.Atoi(os.Getenv("AKTO_TRAFFIC_BATCH_SIZE")) if e != nil { - printLog("AKTO_TRAFFIC_BATCH_SIZE should be valid integer") + db.InsertLog("AKTO_TRAFFIC_BATCH_SIZE should be valid integer", ERROR) return } kafka_batch_time_secs, e := strconv.Atoi(os.Getenv("AKTO_TRAFFIC_BATCH_TIME_SECS")) if e != nil { - printLog("AKTO_TRAFFIC_BATCH_TIME_SECS should be valid integer") + db.InsertLog("AKTO_TRAFFIC_BATCH_TIME_SECS should be valid integer", ERROR) return } kafka_batch_time_secs_duration := time.Duration(kafka_batch_time_secs) @@ -629,7 +632,7 @@ func initKafka() { for { kafkaWriter = GetKafkaWriter(kafka_url, "akto.api.logs", kafka_batch_size, kafka_batch_time_secs_duration*time.Second) logMemoryStats() - log.Println("logging kafka stats before pushing message") + db.InsertLog("logging kafka stats before pushing message", INFO) logKafkaStats() value := map[string]string{ "testConnectionString": "kafkaInit", @@ -638,14 +641,14 @@ func initKafka() { out, _ := json.Marshal(value) ctx := context.Background() err := Produce(kafkaWriter, ctx, string(out)) - log.Println("logging kafka stats post pushing message") + db.InsertLog("logging kafka stats post pushing message", INFO) logKafkaStats() if err != nil { - log.Println("error establishing connection with kafka, sending message failed, retrying in 2 seconds", err) + db.InsertLog("error establishing connection with kafka, sending message failed, retrying in 2 seconds "+err.Error(), ERROR) kafkaWriter.Close() time.Sleep(time.Second * 2) } else { - log.Println("connection establishing with kafka successfully") + db.InsertLog("connection establishing with kafka successfully", INFO) kafkaWriter.Completion = kafkaCompletion() break } @@ -673,23 +676,26 @@ func logMemoryStats() { runtime.ReadMemStats(&m) if int(m.Alloc/1024/1024) > aktoMemThreshRestart { - log.Println("current mem usage", m.Alloc/1024/1024) + memUsage := fmt.Sprintf("%d", m.Alloc/1024/1024) + // Concatenate the message + message := "current mem usage: " + memUsage + " MB" + db.InsertLog(message, ERROR) os.Exit(3) } - log.Println("Alloc in MB: ", m.Alloc/1024/1024) - log.Println("Sys in MB: ", m.Sys/1024/1024) + db.InsertLog(fmt.Sprintf("Alloc in MB: %d", m.Alloc/1024/1024), INFO) + db.InsertLog(fmt.Sprintf("Sys in MB: %d", m.Sys/1024/1024), INFO) } func logKafkaStats() { stats := kafkaWriter.Stats() - log.Printf("Stats - Dials %d, Writes %d, Messages %d, Bytes %d, Errors %d, DialTime %v, BatchTime %v, "+ + db.InsertLog(fmt.Sprintf("Stats - Dials %d, Writes %d, Messages %d, Bytes %d, Errors %d, DialTime %v, BatchTime %v, "+ "WriteTime %v, WaitTime %v, Retries %d, BatchSize %d, BatchBytes %d, MaxAttempts %d, MaxBatchSize %d, "+ "BatchTimeout %v, ReadTimeout %v, WriteTimeout %v, RequiredAcks %d, Async %t, Topic %s", stats.Dials, stats.Writes, stats.Messages, stats.Bytes, stats.Errors, stats.DialTime, stats.BatchTime, stats.WriteTime, stats.WaitTime, stats.Retries, stats.BatchSize, stats.BatchBytes, stats.MaxAttempts, stats.MaxBatchSize, - stats.BatchTimeout, stats.ReadTimeout, stats.WriteTimeout, stats.RequiredAcks, stats.Async, stats.Topic) - //log.Println(kafkaWriter.Stats()) + stats.BatchTimeout, stats.ReadTimeout, stats.WriteTimeout, stats.RequiredAcks, stats.Async, stats.Topic), INFO) + //db.InsertLog(kafkaWriter.Stats().String(), INFO) } //export readTcpDumpFile @@ -701,7 +707,8 @@ func readTcpDumpFile(filepath string, kafkaURL string, apiCollectionId int) { initKafka() if handle, err := pcap.OpenOffline(filepath); err != nil { - log.Fatal(err) + db.InsertLog(err.Error(), ERROR) + os.Exit(1) } else { run(handle, apiCollectionId, "PCAP") } @@ -711,14 +718,14 @@ func main() { disableOnDb := os.Getenv("AKTO_DISABLE_ON_DB") disableOnDbFlag := disableOnDb == "true" - log.Printf("Disable flag : %t", disableOnDbFlag) + db.InsertLog(fmt.Sprintf("Disable flag : %t", disableOnDbFlag), INFO) client, err := db.GetMongoClient() mongoPingErr := client.Ping(context.Background(), readpref.Primary()) if err != nil || mongoPingErr != nil { - log.Printf("Failed connecting to mongo %s", err) + db.InsertLog(fmt.Sprintf("Failed connecting to mongo %s", err), ERROR) if disableOnDbFlag { - log.Println("Exiting....") + db.InsertLog("Exiting....", ERROR) time.Sleep(time.Second * 60) panic("Failed connecting to mongo") // this will get restarted by docker } @@ -727,22 +734,23 @@ func main() { defer func() { if err := client.Disconnect(context.Background()); err != nil { // Handle error + db.InsertLog(err.Error(), ERROR) } }() ignoreIpTrafficVar := os.Getenv("AKTO_IGNORE_IP_TRAFFIC") if len(ignoreIpTrafficVar) > 0 { ignoreIpTraffic = strings.ToLower(ignoreIpTrafficVar) == "true" - log.Println("ignoreIpTraffic: ", ignoreIpTraffic) + db.InsertLog(fmt.Sprintf("ignoreIpTraffic: %t", ignoreIpTraffic), INFO) } else { - log.Println("ignoreIpTraffic: missing. defaulting to false") + db.InsertLog("ignoreIpTraffic: missing. defaulting to false", INFO) } ignoreCloudMetadataCallsVar := os.Getenv("AKTO_IGNORE_CLOUD_METADATA_CALLS") if len(ignoreCloudMetadataCallsVar) > 0 { ignoreCloudMetadataCalls = strings.ToLower(ignoreCloudMetadataCallsVar) == "true" - log.Println("ignoreCloudMetadataCalls: ", ignoreCloudMetadataCalls) + db.InsertLog(fmt.Sprintf("ignoreCloudMetadataCalls: %t", ignoreCloudMetadataCalls), INFO) } else { - log.Println("ignoreCloudMetadataCalls: missing. defaulting to false") + db.InsertLog("ignoreCloudMetadataCalls: missing. defaulting to false", INFO) } // Set up a ticker to run every 2 minutes @@ -762,17 +770,18 @@ func main() { initKafka() for { if handle, err := pcap.OpenLive(interfaceName, 128*1024, true, pcap.BlockForever); err != nil { - log.Fatal(err) + db.InsertLog(err.Error(), ERROR) + os.Exit(1) } else { run(handle, -1, "MIRRORING") - log.Println("closing pcap connection....") + db.InsertLog("closing pcap connection....", INFO) handle.Close() - log.Println("sleeping....") + db.InsertLog("sleeping....", INFO) assemblerMap = make(map[int]*tcpassembly.Assembler) incomingCountMap = make(map[string]utils.IncomingCounter) outgoingCountMap = make(map[string]utils.OutgoingCounter) time.Sleep(10 * time.Second) - log.Println("SLEPT") + db.InsertLog("SLEPT", INFO) initKafka() } } @@ -780,7 +789,7 @@ func main() { } func tickerCode() { - log.Println("Running ticker") + db.InsertLog("Running ticker", INFO) db.TrafficMetricsDbUpdates(incomingCountMap, outgoingCountMap) incomingCountMap = make(map[string]utils.IncomingCounter) outgoingCountMap = make(map[string]utils.OutgoingCounter) @@ -789,7 +798,7 @@ func tickerCode() { func printLog(val string) { if printCounter > 0 { - log.Println(val) + db.InsertLog(val, INFO) printCounter-- } } From 44efe709ba3c7d5ddad1e892c2a82c9a66f9e38f Mon Sep 17 00:00:00 2001 From: Avneesh Hota Date: Mon, 22 Jul 2024 20:27:19 +0530 Subject: [PATCH 2/2] fixed collection name and log string --- db/logs_dao.go | 2 +- db/mongo.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/db/logs_dao.go b/db/logs_dao.go index 7db8b2b..eb3e814 100644 --- a/db/logs_dao.go +++ b/db/logs_dao.go @@ -47,7 +47,7 @@ func logsInstance() (*mongo.Collection, error) { func InsertLog(logString string, key string) { log.Println(logString) - logString += "MIRRORING: " + logString + logString = "MIRRORING: " + logString logDoc := LogDocument{ Log: logString, diff --git a/db/mongo.go b/db/mongo.go index 94425cf..e6e467e 100644 --- a/db/mongo.go +++ b/db/mongo.go @@ -19,7 +19,7 @@ var ( var AccountID = strconv.Itoa(1_000_000) var TrafficMetricsCollectionName = "traffic_metrics" var AccountSettingsCollectionName = "accounts_settings" -var LogsCollectionName = "logs" +var LogsCollectionName = "logs_runtime" func GetMongoClient() (*mongo.Client, error) { once.Do(func() {