Support multiple multicast group joins per connection

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Ian Gulliver
2026-01-30 15:39:15 -08:00
parent 12d0b38af9
commit 4278628690

View File

@@ -2,8 +2,10 @@ package multicast
import ( import (
"context" "context"
"fmt"
"math/rand" "math/rand"
"net" "net"
"sync"
"time" "time"
"github.com/google/gopacket" "github.com/google/gopacket"
@@ -16,12 +18,13 @@ type Conn struct {
*ipv4.PacketConn *ipv4.PacketConn
rawConn net.PacketConn rawConn net.PacketConn
iface *net.Interface iface *net.Interface
groupIP net.IP groups []net.IP
srcIP net.IP srcIP net.IP
srcMAC net.HardwareAddr srcMAC net.HardwareAddr
queryChan chan struct{} queryChan chan struct{}
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
mu sync.Mutex
} }
func ListenMulticastUDP(network string, iface *net.Interface, gaddr *net.UDPAddr) (*Conn, error) { func ListenMulticastUDP(network string, iface *net.Interface, gaddr *net.UDPAddr) (*Conn, error) {
@@ -44,13 +47,18 @@ func ListenMulticastUDP(network string, iface *net.Interface, gaddr *net.UDPAddr
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
var srcMAC net.HardwareAddr
if iface != nil {
srcMAC = iface.HardwareAddr
}
conn := &Conn{ conn := &Conn{
PacketConn: p, PacketConn: p,
rawConn: c, rawConn: c,
iface: iface, iface: iface,
groupIP: gaddr.IP, groups: []net.IP{gaddr.IP},
srcIP: srcIP, srcIP: srcIP,
srcMAC: iface.HardwareAddr, srcMAC: srcMAC,
queryChan: make(chan struct{}, 1), queryChan: make(chan struct{}, 1),
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
@@ -62,9 +70,62 @@ func ListenMulticastUDP(network string, iface *net.Interface, gaddr *net.UDPAddr
return conn, nil return conn, nil
} }
func ListenMulticastUDPPort(network string, iface *net.Interface, port int) (*Conn, error) {
srcIP, _ := getInterfaceIPv4(iface)
c, err := net.ListenPacket(network, fmt.Sprintf(":%d", port))
if err != nil {
return nil, err
}
p := ipv4.NewPacketConn(c)
if iface != nil {
p.SetMulticastInterface(iface)
}
ctx, cancel := context.WithCancel(context.Background())
var srcMAC net.HardwareAddr
if iface != nil {
srcMAC = iface.HardwareAddr
}
conn := &Conn{
PacketConn: p,
rawConn: c,
iface: iface,
groups: nil,
srcIP: srcIP,
srcMAC: srcMAC,
queryChan: make(chan struct{}, 1),
ctx: ctx,
cancel: cancel,
}
go conn.runAdvertiser()
go conn.listenForQueries()
return conn, nil
}
func (c *Conn) JoinGroup(gaddr *net.UDPAddr) error {
c.mu.Lock()
defer c.mu.Unlock()
if err := c.PacketConn.JoinGroup(c.iface, gaddr); err != nil {
return err
}
c.groups = append(c.groups, gaddr.IP)
return nil
}
func (c *Conn) Close() error { func (c *Conn) Close() error {
c.cancel() c.cancel()
c.PacketConn.LeaveGroup(c.iface, &net.UDPAddr{IP: c.groupIP}) c.mu.Lock()
for _, groupIP := range c.groups {
c.PacketConn.LeaveGroup(c.iface, &net.UDPAddr{IP: groupIP})
}
c.mu.Unlock()
return c.rawConn.Close() return c.rawConn.Close()
} }
@@ -76,7 +137,7 @@ func (c *Conn) runAdvertiser() {
ticker := time.NewTicker(60 * time.Second) ticker := time.NewTicker(60 * time.Second)
defer ticker.Stop() defer ticker.Stop()
c.sendReport() c.sendReports()
for { for {
select { select {
@@ -85,13 +146,23 @@ func (c *Conn) runAdvertiser() {
case <-c.queryChan: case <-c.queryChan:
delay := time.Duration(rand.Intn(1000)) * time.Millisecond delay := time.Duration(rand.Intn(1000)) * time.Millisecond
time.Sleep(delay) time.Sleep(delay)
c.sendReport() c.sendReports()
case <-ticker.C: case <-ticker.C:
c.sendReport() c.sendReports()
} }
} }
} }
func (c *Conn) sendReports() {
c.mu.Lock()
groups := append([]net.IP{}, c.groups...)
c.mu.Unlock()
for _, groupIP := range groups {
c.sendReport(groupIP)
}
}
func (c *Conn) listenForQueries() { func (c *Conn) listenForQueries() {
handle, err := pcap.OpenLive(c.iface.Name, 65536, true, 5*time.Second) handle, err := pcap.OpenLive(c.iface.Name, 65536, true, 5*time.Second)
if err != nil { if err != nil {
@@ -131,7 +202,16 @@ func (c *Conn) isQuery(packet gopacket.Packet) bool {
switch igmp := igmpLayer.(type) { switch igmp := igmpLayer.(type) {
case *layers.IGMPv1or2: case *layers.IGMPv1or2:
if igmp.Type == layers.IGMPMembershipQuery { if igmp.Type == layers.IGMPMembershipQuery {
return igmp.GroupAddress.IsUnspecified() || igmp.GroupAddress.Equal(c.groupIP) if igmp.GroupAddress.IsUnspecified() {
return true
}
c.mu.Lock()
defer c.mu.Unlock()
for _, groupIP := range c.groups {
if igmp.GroupAddress.Equal(groupIP) {
return true
}
}
} }
case *layers.IGMP: case *layers.IGMP:
if igmp.Type == layers.IGMPMembershipQuery { if igmp.Type == layers.IGMPMembershipQuery {
@@ -141,8 +221,8 @@ func (c *Conn) isQuery(packet gopacket.Packet) bool {
return false return false
} }
func (c *Conn) sendReport() { func (c *Conn) sendReport(groupIP net.IP) {
if c.srcIP == nil { if c.srcIP == nil || c.iface == nil {
return return
} }
@@ -154,7 +234,7 @@ func (c *Conn) sendReport() {
eth := &layers.Ethernet{ eth := &layers.Ethernet{
SrcMAC: c.srcMAC, SrcMAC: c.srcMAC,
DstMAC: multicastIPToMAC(c.groupIP), DstMAC: multicastIPToMAC(groupIP),
EthernetType: layers.EthernetTypeIPv4, EthernetType: layers.EthernetTypeIPv4,
} }
@@ -164,13 +244,13 @@ func (c *Conn) sendReport() {
TTL: 1, TTL: 1,
Protocol: layers.IPProtocolIGMP, Protocol: layers.IPProtocolIGMP,
SrcIP: c.srcIP, SrcIP: c.srcIP,
DstIP: c.groupIP, DstIP: groupIP,
Options: []layers.IPv4Option{{OptionType: 148, OptionLength: 4, OptionData: []byte{0, 0}}}, Options: []layers.IPv4Option{{OptionType: 148, OptionLength: 4, OptionData: []byte{0, 0}}},
} }
buf := gopacket.NewSerializeBuffer() buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
gopacket.SerializeLayers(buf, opts, eth, ip, gopacket.Payload(buildIGMPv2Report(c.groupIP))) gopacket.SerializeLayers(buf, opts, eth, ip, gopacket.Payload(buildIGMPv2Report(groupIP)))
handle.WritePacketData(buf.Bytes()) handle.WritePacketData(buf.Bytes())
} }