Skip to content

Commit

Permalink
refactor: remove protocol specific timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
natesales committed May 10, 2024
1 parent d646dc1 commit 64457fb
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 102 deletions.
160 changes: 90 additions & 70 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,98 +404,118 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation
}
msgs := createQuery(opts, rrTypesSlice)

var entries []*output.Entry
for _, serverStr := range opts.Server {
// Parse server address and transport type
server, transportType, err := parseServer(serverStr)
if err != nil {
return err
}
log.Debugf("Using server %s with transport %s", server, transportType)
errChan := make(chan error)

// Recursive zone transfer
if opts.RecAXFR {
if opts.Name == "" {
return fmt.Errorf("no name specified for AXFR")
go func() {
var entries []*output.Entry
for _, serverStr := range opts.Server {
// Parse server address and transport type
server, transportType, err := parseServer(serverStr)
if err != nil {
errChan <- fmt.Errorf("parsing server %s: %s", serverStr, err)
}
_ = RecAXFR(opts.Name, server, out)
return nil
}
log.Debugf("Using server %s with transport %s", server, transportType)

// Create transport
txp, err := newTransport(server, transportType, tlsConfig)
if err != nil {
return err
}
// Recursive zone transfer
if opts.RecAXFR {
if opts.Name == "" {
errChan <- fmt.Errorf("no name specified for AXFR")
}
_ = RecAXFR(opts.Name, server, out)
errChan <- nil // exit immediately
}

startTime := time.Now()
var replies []*dns.Msg
for _, msg := range msgs {
reply, err := (*txp).Exchange(&msg)
// Create transport
txp, err := newTransport(server, transportType, tlsConfig)
if err != nil {
return err
errChan <- fmt.Errorf("creating transport: %s", err)
}

if transportType != transport.TypeQUIC && opts.IDCheck && reply.Id != msg.Id {
return fmt.Errorf("ID mismatch: expected %d, got %d", msg.Id, reply.Id)
startTime := time.Now()
var replies []*dns.Msg
for _, msg := range msgs {
if txp == nil {
errChan <- fmt.Errorf("transport is nil")
}
reply, err := (*txp).Exchange(&msg)
if err != nil {
errChan <- fmt.Errorf("exchange: %s", err)
}

if reply == nil {
errChan <- fmt.Errorf("no reply from server")
}

if transportType != transport.TypeQUIC && opts.IDCheck && reply.Id != msg.Id {
errChan <- fmt.Errorf("ID mismatch: expected %d, got %d", msg.Id, reply.Id)
}
replies = append(replies, reply)
}
replies = append(replies, reply)
}

// Process TXT parsing
if opts.TXTConcat {
for _, reply := range replies {
txtConcat(reply)
// Process TXT parsing
if opts.TXTConcat {
for _, reply := range replies {
txtConcat(reply)
}
}
}

// Round TTL
if opts.RoundTTLs {
for _, reply := range replies {
for _, rr := range reply.Answer {
rr.Header().Ttl = rr.Header().Ttl - (rr.Header().Ttl % 60)
// Round TTL
if opts.RoundTTLs {
for _, reply := range replies {
for _, rr := range reply.Answer {
rr.Header().Ttl = rr.Header().Ttl - (rr.Header().Ttl % 60)
}
}
}
}

e := &output.Entry{
Queries: msgs,
Replies: replies,
Server: server,
Time: time.Since(startTime),
}
e := &output.Entry{
Queries: msgs,
Replies: replies,
Server: server,
Time: time.Since(startTime),
}

if opts.ResolveIPs {
e.LoadPTRs(txp)
}

entries = append(entries, e)

if opts.ResolveIPs {
e.LoadPTRs(txp)
if err := (*txp).Close(); err != nil {
errChan <- fmt.Errorf("closing transport: %s", err)
}
}

entries = append(entries, e)
printer := output.Printer{
Out: out,
Opts: &opts,
}

if err := (*txp).Close(); err != nil {
return fmt.Errorf("closing transport: %s", err)
if opts.NSID && opts.Format == "pretty" {
printer.PrettyPrintNSID(entries)
}
}

printer := output.Printer{
Out: out,
Opts: &opts,
}
switch opts.Format {
case "pretty":
printer.PrintPretty(entries)
case "column":
printer.PrintColumn(entries)
case "raw":
printer.PrintRaw(entries)
case "json", "yml", "yaml":
printer.PrintStructured(entries)
default:
errChan <- fmt.Errorf("invalid output format")
}

if opts.NSID && opts.Format == "pretty" {
printer.PrettyPrintNSID(entries)
}
errChan <- nil
}()

switch opts.Format {
case "pretty":
printer.PrintPretty(entries)
case "column":
printer.PrintColumn(entries)
case "raw":
printer.PrintRaw(entries)
case "json", "yml", "yaml":
printer.PrintStructured(entries)
default:
return fmt.Errorf("invalid output format")
select {
case <-time.After(opts.Timeout):
return fmt.Errorf("timeout")
case err := <-errChan:
return err
}

return nil
Expand Down
1 change: 0 additions & 1 deletion resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ func newTransport(server string, transportType transport.Type, tlsConfig *tls.Co
common := transport.Common{
Server: server,
ReuseConn: opts.ReuseConn,
Timeout: opts.Timeout,
}

switch transportType {
Expand Down
2 changes: 0 additions & 2 deletions transport/dnscrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ type DNSCrypt struct {
func (d *DNSCrypt) setup() {
if d.client == nil || d.resolver == nil || !d.ReuseConn {
d.client = &dnscrypt.Client{
Net: "udp",
Timeout: d.Timeout,
UDPSize: d.UDPSize,
}

Expand Down
6 changes: 1 addition & 5 deletions transport/dnscrypt_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
package transport

import "time"

func dnscryptTransport() *DNSCrypt {
d := &DNSCrypt{
Common: Common{Timeout: 1 * time.Second},
return &DNSCrypt{
ServerStamp: "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
}
return d
}
10 changes: 2 additions & 8 deletions transport/http_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package transport

import (
"crypto/tls"
"net/http"
"testing"
"time"
Expand All @@ -13,14 +12,9 @@ import (
func httpTransport() *HTTP {
return &HTTP{
Common: Common{
Server: "https://cloudflare-dns.com/dns-query",
Timeout: 2 * time.Second,
Server: "https://cloudflare-dns.com/dns-query",
},
TLSConfig: &tls.Config{},
UserAgent: "",
Method: http.MethodGet,
HTTP3: false,
NoPMTUd: false,
Method: http.MethodGet,
}
}

Expand Down
4 changes: 2 additions & 2 deletions transport/plain.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ type Plain struct {
}

func (p *Plain) Exchange(m *dns.Msg) (*dns.Msg, error) {
tcpClient := dns.Client{Net: "tcp", Timeout: p.Timeout}
tcpClient := dns.Client{Net: "tcp"}
if p.PreferTCP {
reply, _, tcpErr := tcpClient.Exchange(m, p.Server)
return reply, tcpErr
}

client := dns.Client{Timeout: p.Timeout, UDPSize: p.UDPBuffer}
client := dns.Client{UDPSize: p.UDPBuffer}
reply, _, err := client.Exchange(m, p.Server)

if reply != nil && reply.Truncated {
Expand Down
4 changes: 1 addition & 3 deletions transport/plain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package transport

import (
"testing"
"time"

"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
Expand All @@ -11,8 +10,7 @@ import (
func plainTransport() *Plain {
return &Plain{
Common: Common{
Server: "9.9.9.9:53",
Timeout: 5 * time.Second,
Server: "9.9.9.9:53",
},
PreferTCP: false,
UDPBuffer: 1232,
Expand Down
4 changes: 1 addition & 3 deletions transport/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ func (t *TLS) Exchange(msg *dns.Msg) (*dns.Msg, error) {
if t.conn == nil || !t.ReuseConn {
var err error
t.conn, err = tls.DialWithDialer(
&net.Dialer{
Timeout: t.Timeout,
},
&net.Dialer{},
"tcp",
t.Server,
t.TLSConfig,
Expand Down
5 changes: 0 additions & 5 deletions transport/tls_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
package transport

import (
"time"
)

func tlsTransport() *TLS {
return &TLS{
Common: Common{
Server: "dns.quad9.net:853",
Timeout: 1 * time.Second,
ReuseConn: false,
},
}
Expand Down
3 changes: 0 additions & 3 deletions transport/transport.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package transport

import (
"time"

"github.com/miekg/dns"
)

Expand All @@ -14,7 +12,6 @@ type Transport interface {
type Common struct {
Server string
ReuseConn bool
Timeout time.Duration
}

type Type string
Expand Down

0 comments on commit 64457fb

Please sign in to comment.