From 4278628690a3a630947ac07d32877b9d9da3b0a1 Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Fri, 30 Jan 2026 15:39:15 -0800 Subject: [PATCH] Support multiple multicast group joins per connection Co-Authored-By: Claude Opus 4.5 --- multicast.go | 106 ++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 93 insertions(+), 13 deletions(-) diff --git a/multicast.go b/multicast.go index 05c4232..b5b223a 100644 --- a/multicast.go +++ b/multicast.go @@ -2,8 +2,10 @@ package multicast import ( "context" + "fmt" "math/rand" "net" + "sync" "time" "github.com/google/gopacket" @@ -16,12 +18,13 @@ type Conn struct { *ipv4.PacketConn rawConn net.PacketConn iface *net.Interface - groupIP net.IP + groups []net.IP srcIP net.IP srcMAC net.HardwareAddr queryChan chan struct{} ctx context.Context cancel context.CancelFunc + mu sync.Mutex } func ListenMulticastUDP(network string, iface *net.Interface, gaddr *net.UDPAddr) (*Conn, error) { @@ -44,13 +47,18 @@ func ListenMulticastUDP(network string, iface *net.Interface, gaddr *net.UDPAddr ctx, cancel := context.WithCancel(context.Background()) + var srcMAC net.HardwareAddr + if iface != nil { + srcMAC = iface.HardwareAddr + } + conn := &Conn{ PacketConn: p, rawConn: c, iface: iface, - groupIP: gaddr.IP, + groups: []net.IP{gaddr.IP}, srcIP: srcIP, - srcMAC: iface.HardwareAddr, + srcMAC: srcMAC, queryChan: make(chan struct{}, 1), ctx: ctx, cancel: cancel, @@ -62,9 +70,62 @@ func ListenMulticastUDP(network string, iface *net.Interface, gaddr *net.UDPAddr return conn, nil } +func ListenMulticastUDPPort(network string, iface *net.Interface, port int) (*Conn, error) { + srcIP, _ := getInterfaceIPv4(iface) + + c, err := net.ListenPacket(network, fmt.Sprintf(":%d", port)) + if err != nil { + return nil, err + } + + p := ipv4.NewPacketConn(c) + if iface != nil { + p.SetMulticastInterface(iface) + } + + ctx, cancel := context.WithCancel(context.Background()) + + var srcMAC net.HardwareAddr + if iface != nil { + srcMAC = iface.HardwareAddr + } + + conn := &Conn{ + PacketConn: p, + rawConn: c, + iface: iface, + groups: nil, + srcIP: srcIP, + srcMAC: srcMAC, + queryChan: make(chan struct{}, 1), + ctx: ctx, + cancel: cancel, + } + + go conn.runAdvertiser() + go conn.listenForQueries() + + return conn, nil +} + +func (c *Conn) JoinGroup(gaddr *net.UDPAddr) error { + c.mu.Lock() + defer c.mu.Unlock() + + if err := c.PacketConn.JoinGroup(c.iface, gaddr); err != nil { + return err + } + c.groups = append(c.groups, gaddr.IP) + return nil +} + func (c *Conn) Close() error { c.cancel() - c.PacketConn.LeaveGroup(c.iface, &net.UDPAddr{IP: c.groupIP}) + c.mu.Lock() + for _, groupIP := range c.groups { + c.PacketConn.LeaveGroup(c.iface, &net.UDPAddr{IP: groupIP}) + } + c.mu.Unlock() return c.rawConn.Close() } @@ -76,7 +137,7 @@ func (c *Conn) runAdvertiser() { ticker := time.NewTicker(60 * time.Second) defer ticker.Stop() - c.sendReport() + c.sendReports() for { select { @@ -85,13 +146,23 @@ func (c *Conn) runAdvertiser() { case <-c.queryChan: delay := time.Duration(rand.Intn(1000)) * time.Millisecond time.Sleep(delay) - c.sendReport() + c.sendReports() case <-ticker.C: - c.sendReport() + c.sendReports() } } } +func (c *Conn) sendReports() { + c.mu.Lock() + groups := append([]net.IP{}, c.groups...) + c.mu.Unlock() + + for _, groupIP := range groups { + c.sendReport(groupIP) + } +} + func (c *Conn) listenForQueries() { handle, err := pcap.OpenLive(c.iface.Name, 65536, true, 5*time.Second) if err != nil { @@ -131,7 +202,16 @@ func (c *Conn) isQuery(packet gopacket.Packet) bool { switch igmp := igmpLayer.(type) { case *layers.IGMPv1or2: if igmp.Type == layers.IGMPMembershipQuery { - return igmp.GroupAddress.IsUnspecified() || igmp.GroupAddress.Equal(c.groupIP) + if igmp.GroupAddress.IsUnspecified() { + return true + } + c.mu.Lock() + defer c.mu.Unlock() + for _, groupIP := range c.groups { + if igmp.GroupAddress.Equal(groupIP) { + return true + } + } } case *layers.IGMP: if igmp.Type == layers.IGMPMembershipQuery { @@ -141,8 +221,8 @@ func (c *Conn) isQuery(packet gopacket.Packet) bool { return false } -func (c *Conn) sendReport() { - if c.srcIP == nil { +func (c *Conn) sendReport(groupIP net.IP) { + if c.srcIP == nil || c.iface == nil { return } @@ -154,7 +234,7 @@ func (c *Conn) sendReport() { eth := &layers.Ethernet{ SrcMAC: c.srcMAC, - DstMAC: multicastIPToMAC(c.groupIP), + DstMAC: multicastIPToMAC(groupIP), EthernetType: layers.EthernetTypeIPv4, } @@ -164,13 +244,13 @@ func (c *Conn) sendReport() { TTL: 1, Protocol: layers.IPProtocolIGMP, SrcIP: c.srcIP, - DstIP: c.groupIP, + DstIP: groupIP, Options: []layers.IPv4Option{{OptionType: 148, OptionLength: 4, OptionData: []byte{0, 0}}}, } buf := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} - gopacket.SerializeLayers(buf, opts, eth, ip, gopacket.Payload(buildIGMPv2Report(c.groupIP))) + gopacket.SerializeLayers(buf, opts, eth, ip, gopacket.Payload(buildIGMPv2Report(groupIP))) handle.WritePacketData(buf.Bytes()) }