diff --git a/main.go b/main.go index 931f502..21501ed 100644 --- a/main.go +++ b/main.go @@ -129,6 +129,12 @@ func main() { }() } + // ----------------------------- Routing table ---------------------------- + routing := &routingTable{ + table: make(map[ddp.Network][]*route), + allRoutes: make(map[*route]struct{}), + } + // ------------------------- Configured peer setup ------------------------ for _, peerStr := range cfg.Peers { if !hasPortRE.MatchString(peerStr) { @@ -148,9 +154,10 @@ func main() { RemoteDI: aurp.IPDomainIdentifier(raddr.IP), LocalConnID: nextConnID, }, - conn: ln, - raddr: raddr, - recv: make(chan aurp.Packet, 1024), + conn: ln, + raddr: raddr, + recv: make(chan aurp.Packet, 1024), + routingTable: routing, } aurp.Inc(&nextConnID) peers[udpAddrFromNet(raddr)] = peer @@ -164,9 +171,10 @@ func main() { // --------------------------------- RTMP --------------------------------- rtmpMachine := &RTMPMachine{ - aarp: aarpMachine, - cfg: cfg, - pcapHandle: pcapHandle, + aarp: aarpMachine, + cfg: cfg, + pcapHandle: pcapHandle, + routingTable: routing, } rtmpCh := make(chan *ddp.ExtPacket, 1024) go rtmpMachine.Run(ctx, rtmpCh) @@ -241,7 +249,7 @@ func main() { // addressed to a node on the local network." if ddpkt.DstNet != 0 && (ddpkt.DstNet < cfg.EtherTalk.NetStart || ddpkt.DstNet > cfg.EtherTalk.NetEnd) { // Is it for a network in the routing table? - rt := lookupRoute(ddpkt.DstNet) + rt := routing.lookupRoute(ddpkt.DstNet) if rt == nil { log.Printf("DDP: no route for network %d", ddpkt.DstNet) continue @@ -481,9 +489,10 @@ func main() { RemoteDI: dh.SourceDI, // platinum rule LocalConnID: nextConnID, }, - conn: ln, - raddr: raddr, - recv: make(chan aurp.Packet, 1024), + conn: ln, + raddr: raddr, + recv: make(chan aurp.Packet, 1024), + routingTable: routing, } aurp.Inc(&nextConnID) peers[ra] = pr diff --git a/peer.go b/peer.go index b995015..684e901 100644 --- a/peer.go +++ b/peer.go @@ -81,6 +81,8 @@ type peer struct { conn *net.UDPConn raddr *net.UDPAddr recv chan aurp.Packet + + routingTable *routingTable } // send encodes and sends pkt to the remote host. @@ -277,7 +279,7 @@ func (p *peer) handle(ctx context.Context) error { log.Printf("Learned about these networks: %v", pkt.Networks) for _, nt := range pkt.Networks { - upsertRoutes( + p.routingTable.upsertRoutes( nt.Extended, ddp.Network(nt.RangeStart), ddp.Network(nt.RangeEnd), diff --git a/route.go b/route.go index 081d767..6dbafd8 100644 --- a/route.go +++ b/route.go @@ -10,6 +10,8 @@ import ( "github.com/sfiera/multitalk/pkg/ddp" ) +const maxRouteAge = 10 * time.Minute // TODO: confirm + type route struct { extended bool netStart ddp.Network @@ -19,29 +21,32 @@ type route struct { last time.Time } -var ( - routingTableMu sync.Mutex - routingTable = make(map[ddp.Network][]*route) +type routingTable struct { + tableMu sync.Mutex + table map[ddp.Network][]*route allRoutesMu sync.Mutex - allRoutes = make(map[*route]struct{}) -) - -func lookupRoute(network ddp.Network) *route { - routingTableMu.Lock() - defer routingTableMu.Unlock() - - rs := routingTable[network] - if len(rs) == 0 { - return nil - } - return rs[0] + allRoutes map[*route]struct{} } -func upsertRoutes(extended bool, netStart, netEnd ddp.Network, peer *peer, metric uint8) error { +func (rt *routingTable) lookupRoute(network ddp.Network) *route { + rt.tableMu.Lock() + defer rt.tableMu.Unlock() + + for _, rs := range rt.table[network] { + if time.Since(rs.last) > maxRouteAge { + continue + } + return rs + } + return nil +} + +func (rt *routingTable) upsertRoutes(extended bool, netStart, netEnd ddp.Network, peer *peer, metric uint8) error { if netStart > netEnd { return fmt.Errorf("invalid network range [%d, %d]", netStart, netEnd) } + r := &route{ extended: extended, netStart: netStart, @@ -51,17 +56,33 @@ func upsertRoutes(extended bool, netStart, netEnd ddp.Network, peer *peer, metri last: time.Now(), } - allRoutesMu.Lock() - allRoutes[r] = struct{}{} - allRoutesMu.Unlock() + rt.allRoutesMu.Lock() + rt.allRoutes[r] = struct{}{} + rt.allRoutesMu.Unlock() - routingTableMu.Lock() - defer routingTableMu.Unlock() + rt.tableMu.Lock() + defer rt.tableMu.Unlock() for n := netStart; n <= netEnd; n++ { - routingTable[n] = append(routingTable[n], r) - slices.SortFunc(routingTable[n], func(r, s *route) int { + rt.table[n] = append(rt.table[n], r) + slices.SortFunc(rt.table[n], func(r, s *route) int { return cmp.Compare(r.metric, s.metric) }) } return nil } + +func (rt *routingTable) validRoutes() []*route { + rt.allRoutesMu.Lock() + defer rt.allRoutesMu.Unlock() + valid := make([]*route, 0, len(rt.allRoutes)) + for r := range rt.allRoutes { + if r.peer == nil { + continue + } + if time.Since(r.last) > maxRouteAge { + continue + } + valid = append(valid, r) + } + return r +} diff --git a/rtmp.go b/rtmp.go index 33a4d38..61bc0b9 100644 --- a/rtmp.go +++ b/rtmp.go @@ -32,9 +32,10 @@ import ( // RTMPMachine implements RTMP on an AppleTalk network attached to the router. type RTMPMachine struct { - aarp *AARPMachine - cfg *config - pcapHandle *pcap.Handle + aarp *AARPMachine + cfg *config + pcapHandle *pcap.Handle + routingTable *routingTable } // Run executes the machine. @@ -213,12 +214,7 @@ func (m *RTMPMachine) dataPacket(myAddr ddp.Addr) *rtmp.DataPacket { }, }, } - allRoutesMu.Lock() - defer allRoutesMu.Unlock() - for rt := range allRoutes { - if rt.peer == nil { - continue - } + for _, rt := range m.routingTable.validRoutes() { p.NetworkTuples = append(p.NetworkTuples, rtmp.NetworkTuple{ Extended: rt.extended, RangeStart: rt.netStart,