Improve AARP

This commit is contained in:
Josh Deprez 2024-04-12 15:20:04 +10:00
parent dbc3eaaf54
commit bf5f3f00cf
No known key found for this signature in database

75
aarp.go
View file

@ -53,6 +53,7 @@ type AARPMachine struct {
mu sync.RWMutex mu sync.RWMutex
myAddr aarp.AddrPair myAddr aarp.AddrPair
probes int probes int
assigned bool
assignedCh chan struct{} assignedCh chan struct{}
} }
@ -74,7 +75,7 @@ func NewAARPMachine(cfg *config, pcapHandle *pcap.Handle, myHWAddr ethernet.Addr
func (a *AARPMachine) Address() (aarp.AddrPair, bool) { func (a *AARPMachine) Address() (aarp.AddrPair, bool) {
a.mu.RLock() a.mu.RLock()
defer a.mu.RUnlock() defer a.mu.RUnlock()
return a.myAddr, a.assigned() return a.myAddr, a.assigned
} }
// Assigned returns a channel that is closed when the local address is valid. // Assigned returns a channel that is closed when the local address is valid.
@ -102,9 +103,11 @@ func (a *AARPMachine) Run(ctx context.Context, incomingCh <-chan *ethertalk.Pack
return ctx.Err() return ctx.Err()
case <-ticker.C: case <-ticker.C:
if a.assigned() { if a.probes >= 10 {
a.mu.Lock()
a.assigned = true
a.mu.Unlock()
close(a.assignedCh) close(a.assignedCh)
// No need to keep the ticker running if assigned
ticker.Stop() ticker.Stop()
continue continue
} }
@ -135,7 +138,7 @@ func (a *AARPMachine) Run(ctx context.Context, incomingCh <-chan *ethertalk.Pack
a.addressMappingTable.Learn(aapkt.Src.Proto, aapkt.Src.Hardware) a.addressMappingTable.Learn(aapkt.Src.Proto, aapkt.Src.Hardware)
log.Printf("AARP: Gleaned that %v -> %v", aapkt.Src.Proto, aapkt.Src.Hardware) log.Printf("AARP: Gleaned that %v -> %v", aapkt.Src.Proto, aapkt.Src.Hardware)
if !(aapkt.Dst.Proto == a.myAddr.Proto && a.assigned()) { if !(aapkt.Dst.Proto == a.myAddr.Proto && a.assigned) {
continue continue
} }
@ -152,7 +155,7 @@ func (a *AARPMachine) Run(ctx context.Context, incomingCh <-chan *ethertalk.Pack
if aapkt.Dst.Proto != a.myAddr.Proto { if aapkt.Dst.Proto != a.myAddr.Proto {
continue continue
} }
if !a.assigned() { if !a.assigned {
a.reroll() a.reroll()
} }
@ -163,7 +166,7 @@ func (a *AARPMachine) Run(ctx context.Context, incomingCh <-chan *ethertalk.Pack
if aapkt.Dst.Proto != a.myAddr.Proto { if aapkt.Dst.Proto != a.myAddr.Proto {
continue continue
} }
if !a.assigned() { if !a.assigned {
// Another node is probing for the same address! Unlucky // Another node is probing for the same address! Unlucky
a.reroll() a.reroll()
continue continue
@ -182,11 +185,30 @@ func (a *AARPMachine) Run(ctx context.Context, incomingCh <-chan *ethertalk.Pack
// If the address is in the cache (AMT) and is still valid, that is used. // If the address is in the cache (AMT) and is still valid, that is used.
// Otherwise, the address is resolved using AARP. // Otherwise, the address is resolved using AARP.
func (a *AARPMachine) Resolve(ctx context.Context, ddpAddr ddp.Addr) (ethernet.Addr, error) { func (a *AARPMachine) Resolve(ctx context.Context, ddpAddr ddp.Addr) (ethernet.Addr, error) {
result, waitCh := a.lookupOrWait(ddpAddr) result, waitCh, winner := a.lookupOrWait(ddpAddr)
if waitCh == nil { if waitCh == nil {
return result, nil return result, nil
} }
ctx, cancel := context.WithTimeout(ctx, aarpRequestTimeout)
defer cancel()
if !winner {
// some other goroutine is running the request
for {
select {
case <-ctx.Done():
return ethernet.Addr{}, ctx.Err()
case <-waitCh:
result, waitCh, _ = a.lookupOrWait(ddpAddr)
if waitCh == nil {
return result, nil
}
}
}
}
// I am the winner! I get to send the request packets and run the ticker
if err := a.request(ddpAddr); err != nil { if err := a.request(ddpAddr); err != nil {
return ethernet.Addr{}, err return ethernet.Addr{}, err
} }
@ -194,16 +216,13 @@ func (a *AARPMachine) Resolve(ctx context.Context, ddpAddr ddp.Addr) (ethernet.A
ticker := time.NewTicker(aarpRequestRetransmit) ticker := time.NewTicker(aarpRequestRetransmit)
defer ticker.Stop() defer ticker.Stop()
ctx, cancel := context.WithTimeout(ctx, aarpRequestTimeout)
defer cancel()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ethernet.Addr{}, ctx.Err() return ethernet.Addr{}, ctx.Err()
case <-waitCh: case <-waitCh:
result, waitCh = a.lookupOrWait(ddpAddr) result, waitCh, _ = a.lookupOrWait(ddpAddr)
if waitCh == nil { if waitCh == nil {
return result, nil return result, nil
} }
@ -216,8 +235,6 @@ func (a *AARPMachine) Resolve(ctx context.Context, ddpAddr ddp.Addr) (ethernet.A
} }
} }
func (a *AARPMachine) assigned() bool { return a.probes >= 10 }
// Re-roll a local address // Re-roll a local address
func (a *AARPMachine) reroll() { func (a *AARPMachine) reroll() {
a.mu.Lock() a.mu.Lock()
@ -229,7 +246,7 @@ func (a *AARPMachine) reroll() {
) + a.cfg.EtherTalk.NetStart ) + a.cfg.EtherTalk.NetStart
} }
// Can't use: 0x00, 0xff, 0xfe, or the existing node number // Can't use: 0x00, 0xff, 0xfe, and should avoid the existing node number
newNode := rand.N[ddp.Node](0xfd) + 1 newNode := rand.N[ddp.Node](0xfd) + 1
for newNode != a.myAddr.Proto.Node { for newNode != a.myAddr.Proto.Node {
newNode = rand.N[ddp.Node](0xfd) + 1 newNode = rand.N[ddp.Node](0xfd) + 1
@ -283,6 +300,7 @@ type amtEntry struct {
hwAddr ethernet.Addr hwAddr ethernet.Addr
last time.Time last time.Time
updated chan struct{} updated chan struct{}
requesting bool
} }
// addressMappingTable implements a concurrent-safe Address Mapping Table for // addressMappingTable implements a concurrent-safe Address Mapping Table for
@ -305,35 +323,42 @@ func (t *addressMappingTable) Learn(ddpAddr ddp.Addr, hwAddr ethernet.Addr) {
hwAddr: hwAddr, hwAddr: hwAddr,
last: time.Now(), last: time.Now(),
updated: make(chan struct{}), updated: make(chan struct{}),
requesting: false,
} }
return return
} }
if oldEnt.hwAddr == hwAddr && time.Since(oldEnt.last) < maxAMTEntryAge {
oldEnt.last = time.Now()
return
}
oldEnt.hwAddr = hwAddr oldEnt.hwAddr = hwAddr
oldEnt.last = time.Now() oldEnt.last = time.Now()
oldEnt.requesting = false
close(oldEnt.updated) close(oldEnt.updated)
oldEnt.updated = make(chan struct{}) oldEnt.updated = make(chan struct{})
} }
// lookupOrWait returns either the valid cached Ethernet address for the given // lookupOrWait returns either the valid cached Ethernet address for the given
// DDP address, or a channel that is closed when the entry is updated. // DDP address, or a non-nil channel that is closed when the entry is updated.
func (t *addressMappingTable) lookupOrWait(ddpAddr ddp.Addr) (ethernet.Addr, <-chan struct{}) { // It also reports if this is the first call since the entry became invalid.
func (t *addressMappingTable) lookupOrWait(ddpAddr ddp.Addr) (ethernet.Addr, <-chan struct{}, bool) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
if t.table == nil { if t.table == nil {
t.table = make(map[ddp.Addr]*amtEntry) t.table = make(map[ddp.Addr]*amtEntry)
} }
ent, ok := t.table[ddpAddr] ent := t.table[ddpAddr]
if ok && time.Since(ent.last) < maxAMTEntryAge { if ent == nil {
return ent.hwAddr, nil
}
ch := make(chan struct{}) ch := make(chan struct{})
t.table[ddpAddr] = &amtEntry{ t.table[ddpAddr] = &amtEntry{
updated: ch, updated: ch,
requesting: true,
} }
return ethernet.Addr{}, ch return ethernet.Addr{}, ch, true
}
if time.Since(ent.last) >= maxAMTEntryAge {
if ent.requesting {
return ent.hwAddr, ent.updated, false
}
ent.requesting = true
return ent.hwAddr, ent.updated, true
}
return ent.hwAddr, nil, false
} }