diff --git a/account.go b/account.go index 56c8335..cbfbf5a 100644 --- a/account.go +++ b/account.go @@ -50,6 +50,7 @@ const ( type AccountInfo struct { Error bool `json:"error"` // error or not ErrMsg string `json:"errmsg,omitempty"` // error string message + TraceId string `json:"trace_id"` FCoin int `json:"fcoin"` // fcoin count FofaPoint int64 `json:"fofa_point"` // fofa point IsVIP bool `json:"isvip"` // is vip @@ -58,6 +59,10 @@ type AccountInfo struct { RemainApiData int `json:"remain_api_data"` // available data amount } +func (ai *AccountInfo) SetTraceId(traceId string) { + ai.TraceId = traceId +} + func (ai AccountInfo) String() string { d, _ := json.MarshalIndent(ai, "", " ") return string(d) diff --git a/account_test.go b/account_test.go index 33c583b..8b89a53 100644 --- a/account_test.go +++ b/account_test.go @@ -45,6 +45,7 @@ func TestAccountInfo_String(t *testing.T) { }, ai) assert.Equal(t, `{ "error": false, + "trace_id": "", "fcoin": 0, "fofa_point": 0, "isvip": true, diff --git a/client.go b/client.go index bd7ae05..7d80410 100644 --- a/client.go +++ b/client.go @@ -1,4 +1,5 @@ -/*Package gofofa fofa client in Go +/* +Package gofofa fofa client in Go env settings: - FOFA_CLIENT_URL full fofa connnection string, format: /?email=&key=&version= @@ -25,7 +26,7 @@ const ( type Client struct { Server string // can set local server for debugging, format: :// APIVersion string // api version - Email string // fofa email + Email string // Deprecated: As of gofofa 1.16, email will no longer be required Key string // fofa key Account AccountInfo // fofa account info @@ -37,6 +38,7 @@ type Client struct { onResults func(results [][]string) // when fetch results callback accountDebug bool // 调试账号明文信息 + traceId bool // 报错信息返回 trace id } // Update merge config from config url @@ -114,6 +116,14 @@ func WithAccountDebug(v bool) ClientOption { } } +// WithTraceId 报错信息中返回 trace id +func WithTraceId(v bool) ClientOption { + return func(c *Client) error { + c.traceId = v + return nil + } +} + // NewClient from fofa connection string to config // and with env config merge func NewClient(options ...ClientOption) (*Client, error) { diff --git a/go.mod b/go.mod index 351b8c2..5ea573a 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/LubyRuffy/gofofa go 1.18 require ( + github.com/avast/retry-go v3.0.0+incompatible github.com/fatih/color v1.13.0 github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 github.com/sirupsen/logrus v1.8.1 diff --git a/host.go b/host.go index ac77ec7..914500f 100644 --- a/host.go +++ b/host.go @@ -4,11 +4,18 @@ import ( "context" "encoding/base64" "errors" + "fmt" + "github.com/avast/retry-go" "math" "strconv" "strings" + "time" ) +type CommonResp interface { + SetTraceId(string) +} + const ( NoHostWithFixURL = "host field must included when fixUrl option set" ) @@ -23,12 +30,18 @@ type HostResults struct { Size int `json:"size"` // 总数 Results interface{} `json:"results"` Next string `json:"next"` + TraceId string `json:"trace_id"` +} + +func (h *HostResults) SetTraceId(traceId string) { + h.TraceId = traceId } // HostStatsData /host api results type HostStatsData struct { Error bool `json:"error"` Errmsg string `json:"errmsg"` + TraceId string `json:"trace_id"` Host string `json:"host"` IP string `json:"ip"` ASN int `json:"asn"` @@ -42,6 +55,10 @@ type HostStatsData struct { UpdateTime string `json:"update_time"` } +func (s *HostStatsData) SetTraceId(traceId string) { + s.TraceId = traceId +} + // SearchOptions options of search, for post processors type SearchOptions struct { FixUrl bool // each host fix as url, like 1.1.1.1,80 will change to http://1.1.1.1, https://1.1.1.1:8443 will no change @@ -225,22 +242,41 @@ func (c *Client) HostSearch(query string, size int, fields []string, options ... } var hr HostResults - err = c.Fetch("search/all", - map[string]string{ - "qbase64": base64.StdEncoding.EncodeToString([]byte(query)), - "size": strconv.Itoa(perPage), - "page": strconv.Itoa(page), - "fields": strings.Join(fields, ","), - "full": strconv.FormatBool(full), // 是否全部数据,非一年内 + err = retry.Do( + func() error { + err = c.Fetch("search/all", + map[string]string{ + "qbase64": base64.StdEncoding.EncodeToString([]byte(query)), + "size": strconv.Itoa(perPage), + "page": strconv.Itoa(page), + "fields": strings.Join(fields, ","), + "full": strconv.FormatBool(full), // 是否全部数据,非一年内 + }, + &hr) + if err != nil { + return err + } + return nil }, - &hr) + retry.Attempts(3), + retry.Delay(3*time.Second), + retry.DelayType(retry.RandomDelay), + retry.LastErrorOnly(true), + ) if err != nil { + if c.traceId { + err = fmt.Errorf("[%s]%s", hr.TraceId, err.Error()) + } return } // 报错,退出 if len(hr.Errmsg) > 0 { - err = errors.New(hr.Errmsg) + if c.traceId { + err = errors.New(hr.Errmsg + " trace id: " + hr.TraceId) + } else { + err = errors.New(hr.Errmsg) + } break } @@ -315,8 +351,12 @@ func (c *Client) HostSize(query string) (count int, err error) { }, &hr) if err != nil { + if c.traceId { + err = fmt.Errorf("[%s]%s", hr.TraceId, err.Error()) + } return } + count = hr.Size return } @@ -369,23 +409,43 @@ func (c *Client) DumpSearch(query string, allSize int, batchSize int, fields []s } } + // 添加默认三次重试,防止大数据量拉取时的报错 var hr HostResults - err = c.Fetch("search/next", - map[string]string{ - "qbase64": base64.StdEncoding.EncodeToString([]byte(query)), - "size": strconv.Itoa(perPage), - "fields": strings.Join(fields, ","), - "full": strconv.FormatBool(full), // 是否全部数据,非一年内 - "next": next, // 偏移 + err = retry.Do( + func() error { + err = c.Fetch("search/next", + map[string]string{ + "qbase64": base64.StdEncoding.EncodeToString([]byte(query)), + "size": strconv.Itoa(perPage), + "fields": strings.Join(fields, ","), + "full": strconv.FormatBool(full), // 是否全部数据,非一年内 + "next": next, // 偏移 + }, + &hr) + if err != nil { + return err + } + return nil }, - &hr) + retry.Attempts(3), + retry.Delay(3*time.Second), + retry.DelayType(retry.RandomDelay), + retry.LastErrorOnly(true), + ) if err != nil { - return + if c.traceId { + err = fmt.Errorf("[%s]%s", hr.TraceId, err.Error()) + } + return err } // 报错,退出 if len(hr.Errmsg) > 0 { - err = errors.New(hr.Errmsg) + if c.traceId { + err = errors.New(hr.Errmsg + " trace id: " + hr.TraceId) + } else { + err = errors.New(hr.Errmsg) + } break } diff --git a/request.go b/request.go index ed393d9..e061b27 100644 --- a/request.go +++ b/request.go @@ -37,7 +37,7 @@ func readAll(reader io.Reader, size int) ([]byte, error) { } // just fetch fofa body, no need to unmarshal -func (c *Client) fetchBody(apiURI string, params map[string]string) (body []byte, err error) { +func (c *Client) fetchBody(apiURI string, params map[string]string) (body []byte, traceId string, err error) { var req *http.Request var resp *http.Response @@ -53,6 +53,13 @@ func (c *Client) fetchBody(apiURI string, params map[string]string) (body []byte resp, err = c.httpClient.Do(req) //responseDump, _ := httputil.DumpResponse(resp, false) //log.Println(string(responseDump)) + + // 获取 traceId + if c.traceId { + // 获取请求头中的 trace id + traceId = resp.Header.Get("Trace-Id") + } + if err != nil { if !c.accountDebug { // 替换账号明文信息 @@ -107,14 +114,15 @@ func (c *Client) fetchBody(apiURI string, params map[string]string) (body []byte } // Fetch http request and parse as json return to v -func (c *Client) Fetch(apiURI string, params map[string]string, v interface{}) (err error) { - content, err := c.fetchBody(apiURI, params) +func (c *Client) Fetch(apiURI string, params map[string]string, v CommonResp) (err error) { + content, traceId, err := c.fetchBody(apiURI, params) if err != nil { return } - if err = json.Unmarshal(content, v); err != nil { - return + if err = json.Unmarshal(content, &v); err != nil { + return fmt.Errorf("fail search fofa content %s error %s", content, err.Error()) } + v.SetTraceId(traceId) return } diff --git a/request_test.go b/request_test.go index 3e19e10..12ecfe6 100644 --- a/request_test.go +++ b/request_test.go @@ -82,6 +82,17 @@ func newTcpTestServer(handler func(conn net.Conn, data []byte) error) *tcpTestSe return ts } +type resp struct { + Error bool `json:"error"` + Errmsg string `json:"errmsg"` + TraceId string `json:"trace_id"` + Text string `json:"text"` +} + +func (r *resp) SetTraceId(traceId string) { + r.TraceId = traceId +} + func TestClient_Fetch(t *testing.T) { _, err := NewClient(WithURL("http://127.0.0.1:55")) assert.Error(t, err) @@ -97,14 +108,14 @@ func TestClient_Fetch(t *testing.T) { } // 解析异常 - var a map[string]interface{} + var a = resp{} err = cli.Fetch("", nil, &a) assert.Error(t, err) // gzip err = cli.Fetch("gzip.json", nil, &a) assert.Nil(t, err) - assert.Equal(t, "hello world", a["text"].(string)) + assert.Equal(t, "hello world", a.Text) // content Length Error err = cli.Fetch("contentLengthError.json", nil, &a) diff --git a/stats.go b/stats.go index af05563..a6690e2 100644 --- a/stats.go +++ b/stats.go @@ -3,6 +3,7 @@ package gofofa import ( "encoding/base64" "errors" + "fmt" "strconv" "strings" ) @@ -11,11 +12,16 @@ import ( type StatsResults struct { Error bool `json:"error"` Errmsg string `json:"errmsg"` + TraceId string `json:"trace_id"` Distinct map[string]interface{} `json:"distinct"` Aggs map[string]interface{} `json:"aggs"` LastUpdateTime string `json:"lastupdatetime"` } +func (s *StatsResults) SetTraceId(traceId string) { + s.TraceId = traceId +} + // StatsItem one stats item type StatsItem struct { Name string @@ -47,10 +53,19 @@ func (c *Client) Stats(query string, size int, fields []string) (res []StatsObje }, &sr) if err != nil { + if c.traceId { + err = fmt.Errorf("[%s]%s", sr.TraceId, err.Error()) + } return } + + // 报错,退出 if len(sr.Errmsg) > 0 { - err = errors.New(sr.Errmsg) + if c.traceId { + err = errors.New(sr.Errmsg + " trace id: " + sr.TraceId) + } else { + err = errors.New(sr.Errmsg) + } return }