diff --git a/internal/dhcp6/dhcp6.go b/internal/dhcp6/dhcp6.go index 0795aa1..7a52537 100644 --- a/internal/dhcp6/dhcp6.go +++ b/internal/dhcp6/dhcp6.go @@ -62,7 +62,7 @@ type Client struct { raddr *net.UDPAddr timeNow func() time.Time duid *dhcpv6.Duid - advertise dhcpv6.DHCPv6 + advertise *dhcpv6.Message cfg Config err error @@ -159,7 +159,7 @@ func (c *Client) Close() error { const maxUDPReceivedPacketSize = 8192 // arbitrary size. Theoretically could be up to 65kb -func (c *Client) sendReceive(packet dhcpv6.DHCPv6, expectedType dhcpv6.MessageType) (dhcpv6.DHCPv6, error) { +func (c *Client) sendReceive(packet *dhcpv6.Message, expectedType dhcpv6.MessageType) (*dhcpv6.Message, error) { if packet == nil { return nil, fmt.Errorf("Packet to send cannot be nil") } @@ -185,46 +185,36 @@ func (c *Client) sendReceive(packet dhcpv6.DHCPv6, expectedType dhcpv6.MessageTy // wait for a reply c.Conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) var ( - adv dhcpv6.DHCPv6 - isMessage bool + adv *dhcpv6.Message ) - msg, ok := packet.(*dhcpv6.DHCPv6Message) - if ok { - isMessage = true - } for { buf := make([]byte, maxUDPReceivedPacketSize) n, _, err := c.Conn.ReadFrom(buf) if err != nil { return nil, err } - adv, err = dhcpv6.FromBytes(buf[:n]) + adv, err = dhcpv6.MessageFromBytes(buf[:n]) if err != nil { log.Printf("non-DHCP: %v", err) // skip non-DHCP packets continue } - if recvMsg, ok := adv.(*dhcpv6.DHCPv6Message); ok && isMessage { - // if a regular message, check the transaction ID first - // XXX should this unpack relay messages and check the XID of the - // inner packet too? - if msg.TransactionID() != recvMsg.TransactionID() { - log.Printf("different XID: got %v, want %v", recvMsg.TransactionID(), msg.TransactionID()) - // different XID, we don't want this packet for sure - continue - } + if packet.TransactionID != adv.TransactionID { + log.Printf("different XID: got %v, want %v", adv.TransactionID, packet.TransactionID) + // different XID, we don't want this packet for sure + continue } if expectedType == dhcpv6.MessageTypeNone { // just take whatever arrived break - } else if adv.Type() == expectedType { + } else if adv.MessageType == expectedType { break } } return adv, nil } -func (c *Client) solicit(solicit dhcpv6.DHCPv6) (dhcpv6.DHCPv6, dhcpv6.DHCPv6, error) { +func (c *Client) solicit(solicit *dhcpv6.Message) (*dhcpv6.Message, *dhcpv6.Message, error) { var err error if solicit == nil { solicit, err = dhcpv6.NewSolicitForInterface(c.interfaceName, dhcpv6.WithClientID(*c.duid)) @@ -235,7 +225,7 @@ func (c *Client) solicit(solicit dhcpv6.DHCPv6) (dhcpv6.DHCPv6, dhcpv6.DHCPv6, e if len(c.transactionIDs) > 0 { id := c.transactionIDs[0] c.transactionIDs = c.transactionIDs[1:] - solicit.(*dhcpv6.DHCPv6Message).SetTransactionID(id) + solicit.TransactionID = id } iapd := []byte{0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} opt, err := dhcpv6.ParseOptIAForPrefixDelegation(iapd) @@ -247,8 +237,7 @@ func (c *Client) solicit(solicit dhcpv6.DHCPv6) (dhcpv6.DHCPv6, dhcpv6.DHCPv6, e return solicit, advertise, err } -func (c *Client) request(advertise dhcpv6.DHCPv6) (dhcpv6.DHCPv6, dhcpv6.DHCPv6, error) { - +func (c *Client) request(advertise *dhcpv6.Message) (*dhcpv6.Message, *dhcpv6.Message, error) { request, err := dhcpv6.NewRequestFromAdvertise(advertise, dhcpv6.WithClientID(*c.duid)) if err != nil { return nil, nil, err @@ -260,7 +249,7 @@ func (c *Client) request(advertise dhcpv6.DHCPv6) (dhcpv6.DHCPv6, dhcpv6.DHCPv6, if len(c.transactionIDs) > 0 { id := c.transactionIDs[0] c.transactionIDs = c.transactionIDs[1:] - request.(*dhcpv6.DHCPv6Message).SetTransactionID(id) + request.TransactionID = id } reply, err := c.sendReceive(request, dhcpv6.MessageTypeNone) return request, reply, err @@ -281,7 +270,7 @@ func (c *Client) ObtainOrRenew() bool { return true } var newCfg Config - for _, opt := range reply.Options() { + for _, opt := range reply.Options { switch o := opt.(type) { case *dhcpv6.OptIAForPrefixDelegation: t1 := c.timeNow().Add(time.Duration(o.T1) * time.Second) @@ -306,17 +295,17 @@ func (c *Client) ObtainOrRenew() bool { return true } -func (c *Client) Release() (release dhcpv6.DHCPv6, reply dhcpv6.DHCPv6, err error) { +func (c *Client) Release() (release *dhcpv6.Message, reply *dhcpv6.Message, err error) { release, err = dhcpv6.NewRequestFromAdvertise(c.advertise, dhcpv6.WithClientID(*c.duid)) if err != nil { return nil, nil, err } - release.(*dhcpv6.DHCPv6Message).SetMessage(dhcpv6.MessageTypeRelease) + release.MessageType = dhcpv6.MessageTypeRelease if len(c.transactionIDs) > 0 { id := c.transactionIDs[0] c.transactionIDs = c.transactionIDs[1:] - release.(*dhcpv6.DHCPv6Message).SetTransactionID(id) + release.TransactionID = id } reply, err = c.sendReceive(release, dhcpv6.MessageTypeNone) return release, reply, err