diff --git a/config.go b/config.go index 2e85e76..4883d86 100644 --- a/config.go +++ b/config.go @@ -243,6 +243,10 @@ type UpstreamConfig struct { // The caller should not access this field directly. // Use UpstreamSendClientInfo instead. SendClientInfo *bool `mapstructure:"send_client_info" toml:"send_client_info,omitempty"` + + // How to transmit client info to upstream. (e.g. default:Headers, Subdomain, Path) + ClientIdType string `mapstructure:"client_id_type" toml:"client_id_type,omitempty"` + // The caller should not access this field directly. // Use IsDiscoverable instead. Discoverable *bool `mapstructure:"discoverable" toml:"discoverable"` diff --git a/doh.go b/doh.go index d702995..0ac63b9 100644 --- a/doh.go +++ b/doh.go @@ -2,12 +2,15 @@ package ctrld import ( "context" + "crypto/sha256" "encoding/base64" + "encoding/hex" "errors" "fmt" "io" "net/http" "net/url" + "regexp" "runtime" "strings" "sync" @@ -94,6 +97,17 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro query.Add("dns", enc) endpoint := *r.endpoint + + if ci, ok := ctx.Value(ClientInfoCtxKey{}).(*ClientInfo); ok && ci != nil { + switch r.uc.ClientIdType { + case "subdomain": + endpoint.Host = clientIdFromClientInfo(r.uc, ci) + "." + endpoint.Host + + case "path": + endpoint.Path = strings.TrimRight(endpoint.Path, "/") + "/" + clientIdFromClientInfo(r.uc, ci) + } + } + endpoint.RawQuery = query.Encode() req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) if err != nil { @@ -147,7 +161,7 @@ func addHeader(ctx context.Context, req *http.Request, uc *UpstreamConfig) { if ci, ok := ctx.Value(ClientInfoCtxKey{}).(*ClientInfo); ok && ci != nil { printed = ci.Mac != "" || ci.IP != "" || ci.Hostname != "" switch { - case uc.IsControlD(): + case uc.IsControlD() || uc.ClientIdType == "" || uc.ClientIdType == "headers": dohHeader = newControlDHeaders(ci) case uc.isNextDNS(): dohHeader = newNextDNSHeaders(ci) @@ -202,3 +216,39 @@ func newNextDNSHeaders(ci *ClientInfo) http.Header { } return header } + +func clientIdFromClientInfo(uc *UpstreamConfig, ci *ClientInfo) string { + switch ci.ClientIDPref { + case "mac": + return clientIdFromMac(ci.Mac) + case "host": + return clientIdFromHostname(ci.Hostname) + } + return hashHostnameAndMac(ci.Hostname, ci.Mac) +} + +func hashHostnameAndMac(hostname, mac string) string { + h := sha256.New() + h.Write([]byte(hostname + mac)) + return hex.EncodeToString(h.Sum(nil)) +} + +func clientIdFromMac(mac string) string { + return strings.ReplaceAll(mac, ":", "-") +} + +func clientIdFromHostname(hostname string) string { + // Define a regular expression to match allowed characters + re := regexp.MustCompile(`[^a-zA-Z0-9-]`) + + // Remove chars not allowed in subdomain + subdomain := re.ReplaceAllString(hostname, "") + + // Replace spaces with -- + subdomain = strings.ReplaceAll(subdomain, " ", "--") + + // Trim leading and trailing hyphens + subdomain = strings.Trim(subdomain, "-") + + return subdomain +}