From bf5f3f00cfe63e69f6bf2d590198960e5405566a Mon Sep 17 00:00:00 2001 From: Josh Deprez Date: Fri, 12 Apr 2024 15:20:04 +1000 Subject: [PATCH] Improve AARP --- aarp.go | 91 ++++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 58 insertions(+), 33 deletions(-) diff --git a/aarp.go b/aarp.go index e71310f..fd861a9 100644 --- a/aarp.go +++ b/aarp.go @@ -53,6 +53,7 @@ type AARPMachine struct { mu sync.RWMutex myAddr aarp.AddrPair probes int + assigned bool assignedCh chan struct{} } @@ -74,7 +75,7 @@ func NewAARPMachine(cfg *config, pcapHandle *pcap.Handle, myHWAddr ethernet.Addr func (a *AARPMachine) Address() (aarp.AddrPair, bool) { a.mu.RLock() defer a.mu.RUnlock() - return a.myAddr, a.assigned() + return a.myAddr, a.assigned } // Assigned returns a channel that is closed when the local address is valid. @@ -102,9 +103,11 @@ func (a *AARPMachine) Run(ctx context.Context, incomingCh <-chan *ethertalk.Pack return ctx.Err() case <-ticker.C: - if a.assigned() { + if a.probes >= 10 { + a.mu.Lock() + a.assigned = true + a.mu.Unlock() close(a.assignedCh) - // No need to keep the ticker running if assigned ticker.Stop() continue } @@ -135,7 +138,7 @@ func (a *AARPMachine) Run(ctx context.Context, incomingCh <-chan *ethertalk.Pack a.addressMappingTable.Learn(aapkt.Src.Proto, aapkt.Src.Hardware) log.Printf("AARP: Gleaned that %v -> %v", aapkt.Src.Proto, aapkt.Src.Hardware) - if !(aapkt.Dst.Proto == a.myAddr.Proto && a.assigned()) { + if !(aapkt.Dst.Proto == a.myAddr.Proto && a.assigned) { continue } @@ -152,7 +155,7 @@ func (a *AARPMachine) Run(ctx context.Context, incomingCh <-chan *ethertalk.Pack if aapkt.Dst.Proto != a.myAddr.Proto { continue } - if !a.assigned() { + if !a.assigned { a.reroll() } @@ -163,7 +166,7 @@ func (a *AARPMachine) Run(ctx context.Context, incomingCh <-chan *ethertalk.Pack if aapkt.Dst.Proto != a.myAddr.Proto { continue } - if !a.assigned() { + if !a.assigned { // Another node is probing for the same address! Unlucky a.reroll() continue @@ -182,11 +185,30 @@ func (a *AARPMachine) Run(ctx context.Context, incomingCh <-chan *ethertalk.Pack // If the address is in the cache (AMT) and is still valid, that is used. // Otherwise, the address is resolved using AARP. func (a *AARPMachine) Resolve(ctx context.Context, ddpAddr ddp.Addr) (ethernet.Addr, error) { - result, waitCh := a.lookupOrWait(ddpAddr) + result, waitCh, winner := a.lookupOrWait(ddpAddr) if waitCh == nil { return result, nil } + ctx, cancel := context.WithTimeout(ctx, aarpRequestTimeout) + defer cancel() + + if !winner { + // some other goroutine is running the request + for { + select { + case <-ctx.Done(): + return ethernet.Addr{}, ctx.Err() + case <-waitCh: + result, waitCh, _ = a.lookupOrWait(ddpAddr) + if waitCh == nil { + return result, nil + } + } + } + } + + // I am the winner! I get to send the request packets and run the ticker if err := a.request(ddpAddr); err != nil { return ethernet.Addr{}, err } @@ -194,16 +216,13 @@ func (a *AARPMachine) Resolve(ctx context.Context, ddpAddr ddp.Addr) (ethernet.A ticker := time.NewTicker(aarpRequestRetransmit) defer ticker.Stop() - ctx, cancel := context.WithTimeout(ctx, aarpRequestTimeout) - defer cancel() - for { select { case <-ctx.Done(): return ethernet.Addr{}, ctx.Err() case <-waitCh: - result, waitCh = a.lookupOrWait(ddpAddr) + result, waitCh, _ = a.lookupOrWait(ddpAddr) if waitCh == nil { return result, nil } @@ -216,8 +235,6 @@ func (a *AARPMachine) Resolve(ctx context.Context, ddpAddr ddp.Addr) (ethernet.A } } -func (a *AARPMachine) assigned() bool { return a.probes >= 10 } - // Re-roll a local address func (a *AARPMachine) reroll() { a.mu.Lock() @@ -229,7 +246,7 @@ func (a *AARPMachine) reroll() { ) + a.cfg.EtherTalk.NetStart } - // Can't use: 0x00, 0xff, 0xfe, or the existing node number + // Can't use: 0x00, 0xff, 0xfe, and should avoid the existing node number newNode := rand.N[ddp.Node](0xfd) + 1 for newNode != a.myAddr.Proto.Node { newNode = rand.N[ddp.Node](0xfd) + 1 @@ -280,9 +297,10 @@ func (a *AARPMachine) request(ddpAddr ddp.Addr) error { } type amtEntry struct { - hwAddr ethernet.Addr - last time.Time - updated chan struct{} + hwAddr ethernet.Addr + last time.Time + updated chan struct{} + requesting bool } // addressMappingTable implements a concurrent-safe Address Mapping Table for @@ -302,38 +320,45 @@ func (t *addressMappingTable) Learn(ddpAddr ddp.Addr, hwAddr ethernet.Addr) { oldEnt := t.table[ddpAddr] if oldEnt == nil { t.table[ddpAddr] = &amtEntry{ - hwAddr: hwAddr, - last: time.Now(), - updated: make(chan struct{}), + hwAddr: hwAddr, + last: time.Now(), + updated: make(chan struct{}), + requesting: false, } return } - if oldEnt.hwAddr == hwAddr && time.Since(oldEnt.last) < maxAMTEntryAge { - oldEnt.last = time.Now() - return - } oldEnt.hwAddr = hwAddr oldEnt.last = time.Now() + oldEnt.requesting = false close(oldEnt.updated) oldEnt.updated = make(chan struct{}) } // lookupOrWait returns either the valid cached Ethernet address for the given -// DDP address, or a channel that is closed when the entry is updated. -func (t *addressMappingTable) lookupOrWait(ddpAddr ddp.Addr) (ethernet.Addr, <-chan struct{}) { +// DDP address, or a non-nil channel that is closed when the entry is updated. +// It also reports if this is the first call since the entry became invalid. +func (t *addressMappingTable) lookupOrWait(ddpAddr ddp.Addr) (ethernet.Addr, <-chan struct{}, bool) { t.mu.Lock() defer t.mu.Unlock() if t.table == nil { t.table = make(map[ddp.Addr]*amtEntry) } - ent, ok := t.table[ddpAddr] - if ok && time.Since(ent.last) < maxAMTEntryAge { - return ent.hwAddr, nil + ent := t.table[ddpAddr] + if ent == nil { + ch := make(chan struct{}) + t.table[ddpAddr] = &amtEntry{ + updated: ch, + requesting: true, + } + return ethernet.Addr{}, ch, true } - ch := make(chan struct{}) - t.table[ddpAddr] = &amtEntry{ - updated: ch, + if time.Since(ent.last) >= maxAMTEntryAge { + if ent.requesting { + return ent.hwAddr, ent.updated, false + } + ent.requesting = true + return ent.hwAddr, ent.updated, true } - return ethernet.Addr{}, ch + return ent.hwAddr, nil, false }