routing table refactor

This commit is contained in:
Josh Deprez 2024-04-14 13:44:35 +10:00
parent 21b59920d8
commit 90229d07b5
Signed by: josh
SSH key fingerprint: SHA256:zZji7w1Ilh2RuUpbQcqkLPrqmRwpiCSycbF2EfKm6Kw
4 changed files with 71 additions and 43 deletions

11
main.go
View file

@ -129,6 +129,12 @@ func main() {
}() }()
} }
// ----------------------------- Routing table ----------------------------
routing := &routingTable{
table: make(map[ddp.Network][]*route),
allRoutes: make(map[*route]struct{}),
}
// ------------------------- Configured peer setup ------------------------ // ------------------------- Configured peer setup ------------------------
for _, peerStr := range cfg.Peers { for _, peerStr := range cfg.Peers {
if !hasPortRE.MatchString(peerStr) { if !hasPortRE.MatchString(peerStr) {
@ -151,6 +157,7 @@ func main() {
conn: ln, conn: ln,
raddr: raddr, raddr: raddr,
recv: make(chan aurp.Packet, 1024), recv: make(chan aurp.Packet, 1024),
routingTable: routing,
} }
aurp.Inc(&nextConnID) aurp.Inc(&nextConnID)
peers[udpAddrFromNet(raddr)] = peer peers[udpAddrFromNet(raddr)] = peer
@ -167,6 +174,7 @@ func main() {
aarp: aarpMachine, aarp: aarpMachine,
cfg: cfg, cfg: cfg,
pcapHandle: pcapHandle, pcapHandle: pcapHandle,
routingTable: routing,
} }
rtmpCh := make(chan *ddp.ExtPacket, 1024) rtmpCh := make(chan *ddp.ExtPacket, 1024)
go rtmpMachine.Run(ctx, rtmpCh) go rtmpMachine.Run(ctx, rtmpCh)
@ -241,7 +249,7 @@ func main() {
// addressed to a node on the local network." // addressed to a node on the local network."
if ddpkt.DstNet != 0 && (ddpkt.DstNet < cfg.EtherTalk.NetStart || ddpkt.DstNet > cfg.EtherTalk.NetEnd) { if ddpkt.DstNet != 0 && (ddpkt.DstNet < cfg.EtherTalk.NetStart || ddpkt.DstNet > cfg.EtherTalk.NetEnd) {
// Is it for a network in the routing table? // Is it for a network in the routing table?
rt := lookupRoute(ddpkt.DstNet) rt := routing.lookupRoute(ddpkt.DstNet)
if rt == nil { if rt == nil {
log.Printf("DDP: no route for network %d", ddpkt.DstNet) log.Printf("DDP: no route for network %d", ddpkt.DstNet)
continue continue
@ -484,6 +492,7 @@ func main() {
conn: ln, conn: ln,
raddr: raddr, raddr: raddr,
recv: make(chan aurp.Packet, 1024), recv: make(chan aurp.Packet, 1024),
routingTable: routing,
} }
aurp.Inc(&nextConnID) aurp.Inc(&nextConnID)
peers[ra] = pr peers[ra] = pr

View file

@ -81,6 +81,8 @@ type peer struct {
conn *net.UDPConn conn *net.UDPConn
raddr *net.UDPAddr raddr *net.UDPAddr
recv chan aurp.Packet recv chan aurp.Packet
routingTable *routingTable
} }
// send encodes and sends pkt to the remote host. // send encodes and sends pkt to the remote host.
@ -277,7 +279,7 @@ func (p *peer) handle(ctx context.Context) error {
log.Printf("Learned about these networks: %v", pkt.Networks) log.Printf("Learned about these networks: %v", pkt.Networks)
for _, nt := range pkt.Networks { for _, nt := range pkt.Networks {
upsertRoutes( p.routingTable.upsertRoutes(
nt.Extended, nt.Extended,
ddp.Network(nt.RangeStart), ddp.Network(nt.RangeStart),
ddp.Network(nt.RangeEnd), ddp.Network(nt.RangeEnd),

View file

@ -10,6 +10,8 @@ import (
"github.com/sfiera/multitalk/pkg/ddp" "github.com/sfiera/multitalk/pkg/ddp"
) )
const maxRouteAge = 10 * time.Minute // TODO: confirm
type route struct { type route struct {
extended bool extended bool
netStart ddp.Network netStart ddp.Network
@ -19,29 +21,32 @@ type route struct {
last time.Time last time.Time
} }
var ( type routingTable struct {
routingTableMu sync.Mutex tableMu sync.Mutex
routingTable = make(map[ddp.Network][]*route) table map[ddp.Network][]*route
allRoutesMu sync.Mutex allRoutesMu sync.Mutex
allRoutes = make(map[*route]struct{}) allRoutes map[*route]struct{}
) }
func lookupRoute(network ddp.Network) *route { func (rt *routingTable) lookupRoute(network ddp.Network) *route {
routingTableMu.Lock() rt.tableMu.Lock()
defer routingTableMu.Unlock() defer rt.tableMu.Unlock()
rs := routingTable[network] for _, rs := range rt.table[network] {
if len(rs) == 0 { if time.Since(rs.last) > maxRouteAge {
continue
}
return rs
}
return nil return nil
} }
return rs[0]
}
func upsertRoutes(extended bool, netStart, netEnd ddp.Network, peer *peer, metric uint8) error { func (rt *routingTable) upsertRoutes(extended bool, netStart, netEnd ddp.Network, peer *peer, metric uint8) error {
if netStart > netEnd { if netStart > netEnd {
return fmt.Errorf("invalid network range [%d, %d]", netStart, netEnd) return fmt.Errorf("invalid network range [%d, %d]", netStart, netEnd)
} }
r := &route{ r := &route{
extended: extended, extended: extended,
netStart: netStart, netStart: netStart,
@ -51,17 +56,33 @@ func upsertRoutes(extended bool, netStart, netEnd ddp.Network, peer *peer, metri
last: time.Now(), last: time.Now(),
} }
allRoutesMu.Lock() rt.allRoutesMu.Lock()
allRoutes[r] = struct{}{} rt.allRoutes[r] = struct{}{}
allRoutesMu.Unlock() rt.allRoutesMu.Unlock()
routingTableMu.Lock() rt.tableMu.Lock()
defer routingTableMu.Unlock() defer rt.tableMu.Unlock()
for n := netStart; n <= netEnd; n++ { for n := netStart; n <= netEnd; n++ {
routingTable[n] = append(routingTable[n], r) rt.table[n] = append(rt.table[n], r)
slices.SortFunc(routingTable[n], func(r, s *route) int { slices.SortFunc(rt.table[n], func(r, s *route) int {
return cmp.Compare(r.metric, s.metric) return cmp.Compare(r.metric, s.metric)
}) })
} }
return nil return nil
} }
func (rt *routingTable) validRoutes() []*route {
rt.allRoutesMu.Lock()
defer rt.allRoutesMu.Unlock()
valid := make([]*route, 0, len(rt.allRoutes))
for r := range rt.allRoutes {
if r.peer == nil {
continue
}
if time.Since(r.last) > maxRouteAge {
continue
}
valid = append(valid, r)
}
return r
}

View file

@ -35,6 +35,7 @@ type RTMPMachine struct {
aarp *AARPMachine aarp *AARPMachine
cfg *config cfg *config
pcapHandle *pcap.Handle pcapHandle *pcap.Handle
routingTable *routingTable
} }
// Run executes the machine. // Run executes the machine.
@ -213,12 +214,7 @@ func (m *RTMPMachine) dataPacket(myAddr ddp.Addr) *rtmp.DataPacket {
}, },
}, },
} }
allRoutesMu.Lock() for _, rt := range m.routingTable.validRoutes() {
defer allRoutesMu.Unlock()
for rt := range allRoutes {
if rt.peer == nil {
continue
}
p.NetworkTuples = append(p.NetworkTuples, rtmp.NetworkTuple{ p.NetworkTuples = append(p.NetworkTuples, rtmp.NetworkTuple{
Extended: rt.extended, Extended: rt.extended,
RangeStart: rt.netStart, RangeStart: rt.netStart,