Files
chasquid/internal/haproxy/haproxy.go
Alberto Bertogli e79586a014 Implement HAProxy protocol support
This patch implements support for incoming connections wrapped in the
HAProxy protocol v1.

This is useful when running chasquid behind a HAProxy server, as it
needs the original source IP to perform SPF checks.

This patch is a reimplementation of one originally provided by Denys
Vitali in pull request #15, except the logic for the protocol handling
is moved to a new package, and the smtpsrv.Conn handling of the source
IP is simplified.

It is marked as experimental for now, since we want to give it a bit
more exposure just in case the option/api needs adjustment.

Thanks a lot to Denys Vitali (@denysvitali in github) for sending the
original patch for this, and helping test it!
2020-11-13 20:49:42 +00:00

77 lines
2.0 KiB
Go

// Package haproxy implements the handshake for the HAProxy client protocol
// version 1, as described in
// https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt.
package haproxy
import (
"bufio"
"errors"
"net"
"strconv"
"strings"
)
var (
errInvalidProtoID = errors.New("invalid protocol identifier")
errUnkProtocol = errors.New("unknown protocol")
errInvalidFields = errors.New("invalid number of fields")
errInvalidSrcIP = errors.New("invalid src ip")
errInvalidDstIP = errors.New("invalid dst ip")
errInvalidSrcPort = errors.New("invalid src port")
errInvalidDstPort = errors.New("invalid dst port")
)
// Handshake performs the HAProxy protocol v1 handshake on the given reader,
// which is expected to be backed by a network connection.
// It returns the source and destination addresses, or an error if the
// handshake could not complete.
// Note that any timeouts or limits must be set by the caller on the
// underlying connection, this is helper only to perform the handshake.
func Handshake(r *bufio.Reader) (src, dst net.Addr, err error) {
line, err := r.ReadString('\n')
if err != nil {
return nil, nil, err
}
fields := strings.Fields(line)
if len(fields) < 2 || fields[0] != "PROXY" {
return nil, nil, errInvalidProtoID
}
switch fields[1] {
case "TCP4", "TCP6":
// Allowed to continue, nothing to do.
default:
return nil, nil, errUnkProtocol
}
if len(fields) != 6 {
return nil, nil, errInvalidFields
}
srcIP := net.ParseIP(fields[2])
if srcIP == nil {
return nil, nil, errInvalidSrcIP
}
dstIP := net.ParseIP(fields[3])
if dstIP == nil {
return nil, nil, errInvalidDstIP
}
srcPort, err := strconv.ParseUint(fields[4], 10, 16)
if err != nil {
return nil, nil, errInvalidSrcPort
}
dstPort, err := strconv.ParseUint(fields[5], 10, 16)
if err != nil {
return nil, nil, errInvalidDstPort
}
src = &net.TCPAddr{IP: srcIP, Port: int(srcPort)}
dst = &net.TCPAddr{IP: dstIP, Port: int(dstPort)}
return src, dst, nil
}