Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dhcpv4/bsdp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (c *Client) Exchange(ifname string) ([]*Packet, error) {
conversation = append(conversation, informList)

// ACK[LIST]
ackForList, err := c.Client.SendReceive(sendFd, recvFd, informList.v4(), dhcpv4.MessageTypeAck)
ackForList, err := c.Client.SendReceive(sendFd, recvFd, informList.v4(), dhcpv4.MessageTypeAck, dhcpv4.WithDefault())
if err != nil {
return conversation, err
}
Expand All @@ -67,7 +67,7 @@ func (c *Client) Exchange(ifname string) ([]*Packet, error) {
conversation = append(conversation, informSelect)

// ACK[SELECT]
ackForSelect, err := c.Client.SendReceive(sendFd, recvFd, informSelect.v4(), dhcpv4.MessageTypeAck)
ackForSelect, err := c.Client.SendReceive(sendFd, recvFd, informSelect.v4(), dhcpv4.MessageTypeAck, dhcpv4.WithDefault())
if err != nil {
return conversation, err
}
Expand Down
11 changes: 7 additions & 4 deletions dhcpv4/client4/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (c *Client) getRemoteUDPAddr() (*net.UDPAddr, error) {
// ordered as Discovery, Offer, Request and Acknowledge. In case of errors, an
// error is returned, and the list of DHCPv4 objects will be shorted than 4,
// containing all the sent and received DHCPv4 messages.
func (c *Client) Exchange(ifname string, modifiers ...dhcpv4.Modifier) ([]*dhcpv4.DHCPv4, error) {
func (c *Client) Exchange(ifname string, selector dhcpv4.Selector, modifiers ...dhcpv4.Modifier) ([]*dhcpv4.DHCPv4, error) {
conversation := make([]*dhcpv4.DHCPv4, 0)
raddr, err := c.getRemoteUDPAddr()
if err != nil {
Expand Down Expand Up @@ -229,7 +229,7 @@ func (c *Client) Exchange(ifname string, modifiers ...dhcpv4.Modifier) ([]*dhcpv
conversation = append(conversation, discover)

// Offer
offer, err := c.SendReceive(sfd, rfd, discover, dhcpv4.MessageTypeOffer)
offer, err := c.SendReceive(sfd, rfd, discover, dhcpv4.MessageTypeOffer, selector)
if err != nil {
return conversation, err
}
Expand All @@ -243,7 +243,7 @@ func (c *Client) Exchange(ifname string, modifiers ...dhcpv4.Modifier) ([]*dhcpv
conversation = append(conversation, request)

// Ack
ack, err := c.SendReceive(sfd, rfd, request, dhcpv4.MessageTypeAck)
ack, err := c.SendReceive(sfd, rfd, request, dhcpv4.MessageTypeAck, selector)
if err != nil {
return conversation, err
}
Expand All @@ -255,7 +255,7 @@ func (c *Client) Exchange(ifname string, modifiers ...dhcpv4.Modifier) ([]*dhcpv
// SendReceive sends a packet (with some write timeout) and waits for a
// response up to some read timeout value. If the message type is not
// MessageTypeNone, it will wait for a specific message type
func (c *Client) SendReceive(sendFd, recvFd int, packet *dhcpv4.DHCPv4, messageType dhcpv4.MessageType) (*dhcpv4.DHCPv4, error) {
func (c *Client) SendReceive(sendFd, recvFd int, packet *dhcpv4.DHCPv4, messageType dhcpv4.MessageType, selector dhcpv4.Selector) (*dhcpv4.DHCPv4, error) {
raddr, err := c.getRemoteUDPAddr()
if err != nil {
return nil, err
Expand Down Expand Up @@ -333,6 +333,9 @@ func (c *Client) SendReceive(sendFd, recvFd int, packet *dhcpv4.DHCPv4, messageT
if response.TransactionID != packet.TransactionID {
continue
}
if !selector(response) {
continue
}
// wait for a response message
if response.OpCode != dhcpv4.OpcodeBootReply {
continue
Expand Down
12 changes: 12 additions & 0 deletions dhcpv4/selector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package dhcpv4

// Selector defines the signature for functions that can select DHCPv4
// structures. This is used to drop illegal packets.
type Selector func(d *DHCPv4) bool

// WithDefault returns true by default
func WithDefault() Selector {
return func(d *DHCPv4) bool {
return true
}
}
17 changes: 10 additions & 7 deletions dhcpv6/client6/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ func NewClient() *Client {
// Reply). The modifiers will be applied to the Solicit and Request packets.
// A common use is to make sure that the Solicit packet has the right options,
// see modifiers.go
func (c *Client) Exchange(ifname string, modifiers ...dhcpv6.Modifier) ([]dhcpv6.DHCPv6, error) {
func (c *Client) Exchange(ifname string, selector dhcpv6.Selector, modifiers ...dhcpv6.Modifier) ([]dhcpv6.DHCPv6, error) {
conversation := make([]dhcpv6.DHCPv6, 0)
var err error

// Solicit
solicit, advertise, err := c.Solicit(ifname, modifiers...)
solicit, advertise, err := c.Solicit(ifname, selector, modifiers...)
if solicit != nil {
conversation = append(conversation, solicit)
}
Expand Down Expand Up @@ -72,7 +72,7 @@ func (c *Client) Exchange(ifname string, modifiers ...dhcpv6.Modifier) ([]dhcpv6
return conversation, nil
}

func (c *Client) sendReceive(ifname string, packet dhcpv6.DHCPv6, expectedType dhcpv6.MessageType) (dhcpv6.DHCPv6, error) {
func (c *Client) sendReceive(ifname string, packet dhcpv6.DHCPv6, expectedType dhcpv6.MessageType, selector dhcpv6.Selector) (dhcpv6.DHCPv6, error) {
if packet == nil {
return nil, fmt.Errorf("Packet to send cannot be nil")
}
Expand Down Expand Up @@ -191,6 +191,9 @@ func (c *Client) sendReceive(ifname string, packet dhcpv6.DHCPv6, expectedType d
// different XID, we don't want this packet for sure
continue
}
if !selector(recvMsg) {
continue
}
}
if expectedType == dhcpv6.MessageTypeNone {
// just take whatever arrived
Expand All @@ -205,7 +208,7 @@ func (c *Client) sendReceive(ifname string, packet dhcpv6.DHCPv6, expectedType d
// Solicit sends a Solicit, returns the Solicit, an Advertise (if not nil), and
// an error if any. The modifiers will be applied to the Solicit before sending
// it, see modifiers.go
func (c *Client) Solicit(ifname string, modifiers ...dhcpv6.Modifier) (dhcpv6.DHCPv6, dhcpv6.DHCPv6, error) {
func (c *Client) Solicit(ifname string, selector dhcpv6.Selector, modifiers ...dhcpv6.Modifier) (dhcpv6.DHCPv6, dhcpv6.DHCPv6, error) {
iface, err := net.InterfaceByName(ifname)
if err != nil {
return nil, nil, err
Expand All @@ -217,21 +220,21 @@ func (c *Client) Solicit(ifname string, modifiers ...dhcpv6.Modifier) (dhcpv6.DH
for _, mod := range modifiers {
mod(solicit)
}
advertise, err := c.sendReceive(ifname, solicit, dhcpv6.MessageTypeNone)
advertise, err := c.sendReceive(ifname, solicit, dhcpv6.MessageTypeNone, selector)
return solicit, advertise, err
}

// Request sends a Request built from an Advertise. It returns the Request, a
// Reply (if not nil), and an error if any. The modifiers will be applied to
// the Request before sending it, see modifiers.go
func (c *Client) Request(ifname string, advertise *dhcpv6.Message, modifiers ...dhcpv6.Modifier) (dhcpv6.DHCPv6, dhcpv6.DHCPv6, error) {
func (c *Client) Request(ifname string, advertise *dhcpv6.Message, selector dhcpv6.Selector, modifiers ...dhcpv6.Modifier) (dhcpv6.DHCPv6, dhcpv6.DHCPv6, error) {
request, err := dhcpv6.NewRequestFromAdvertise(advertise)
if err != nil {
return nil, nil, err
}
for _, mod := range modifiers {
mod(request)
}
reply, err := c.sendReceive(ifname, request, dhcpv6.MessageTypeNone)
reply, err := c.sendReceive(ifname, request, dhcpv6.MessageTypeNone, selector)
return request, reply, err
}
12 changes: 12 additions & 0 deletions dhcpv6/selector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package dhcpv6

// Selector defines the signature for functions that can select DHCPv6
// structures. This is used to drop illegal packets.
type Selector func(d *Message) bool

// WithDefault returns true by default
func WithDefault() Selector {
return func(d *Message) bool {
return true
}
}
3 changes: 2 additions & 1 deletion examples/client6/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"flag"
"github.com/insomniacslk/dhcp/dhcpv6"
"log"

"github.com/insomniacslk/dhcp/dhcpv6/client6"
Expand All @@ -28,7 +29,7 @@ func main() {
// still want to know what packets were exchanged until then.
// A default Solicit packet will be used during the "conversation",
// which can be manipulated by using modifiers.
conversation, err := client.Exchange(*iface)
conversation, err := client.Exchange(*iface, dhcpv6.WithDefault())

// Summary() prints a verbose representation of the exchanged packets.
for _, packet := range conversation {
Expand Down
4 changes: 2 additions & 2 deletions netboot/netboot.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func RequestNetbootv6(ifname string, timeout time.Duration, retries int, modifie

client := client6.NewClient()
client.ReadTimeout = timeout
conversation, err = client.Exchange(ifname, modifiers...)
conversation, err = client.Exchange(ifname, dhcpv6.WithDefault(), modifiers...)
if err != nil {
log.Printf("Client.Exchange failed: %v", err)
if i >= retries {
Expand Down Expand Up @@ -75,7 +75,7 @@ func RequestNetbootv4(ifname string, timeout time.Duration, retries int, modifie
log.Printf("sending request, attempt #%d", i+1)
client := client4.NewClient()
client.ReadTimeout = timeout
conversation, err = client.Exchange(ifname, modifiers...)
conversation, err = client.Exchange(ifname, dhcpv4.WithDefault(), modifiers...)
if err != nil {
log.Printf("Client.Exchange failed: %v", err)
log.Printf("sleeping %v before retrying", delay)
Expand Down