From 034e5b68e054a32d0c300f88c0ebe7cc92cc69c1 Mon Sep 17 00:00:00 2001 From: Roger Ferdinan Date: Fri, 26 Sep 2025 16:07:20 +0700 Subject: [PATCH] fix: client reconnection data race fix --- v1/client/client.go | 81 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 70 insertions(+), 11 deletions(-) diff --git a/v1/client/client.go b/v1/client/client.go index 655e540..048228c 100644 --- a/v1/client/client.go +++ b/v1/client/client.go @@ -6,6 +6,7 @@ import ( "log" "net/url" "strings" + "sync" "time" custom_rwmutex "git.neurocipta.com/rogerferdinan/custom-rwmutex" @@ -17,6 +18,52 @@ const ( pingPeriod = 10 * time.Second ) +type SafeMap[K comparable, V any] struct { + m sync.Map +} + +func NewSafeMap[K comparable, V any]() *SafeMap[K, V] { + return &SafeMap[K, V]{ + m: sync.Map{}, + } +} + +func (sm *SafeMap[K, V]) Store(key K, value V) { + sm.m.Store(key, value) +} + +func (sm *SafeMap[K, V]) Load(key K) (value V, ok bool) { + val, loaded := sm.m.Load(key) + if !loaded { + return *new(V), false + } + return val.(V), true +} + +func (sm *SafeMap[K, V]) Delete(key K) { + sm.m.Delete(key) +} + +func (sm *SafeMap[K, V]) Range(f func(K, V) bool) { + sm.m.Range(func(key, value any) bool { + k, ok1 := key.(K) + v, ok2 := value.(V) + if !ok1 || !ok2 { + return true + } + return f(k, v) + }) +} + +func (sm *SafeMap[K, V]) Len() int { + count := 0 + sm.Range(func(_ K, _ V) bool { + count++ + return true + }) + return count +} + type SafeWebsocketClientBuilder struct { baseHost *string `nil_checker:"required"` basePort *uint16 `nil_checker:"required"` @@ -74,6 +121,7 @@ func (b *SafeWebsocketClientBuilder) Build() (*SafeWebsocketClient, error) { mu: custom_rwmutex.NewCustomRwMutex(), reconnectCh: make(chan struct{}, 1), isConnected: false, + doneMap: NewSafeMap[string, chan struct{}](), } if err := wsClient.connect(); err != nil { @@ -96,6 +144,8 @@ type SafeWebsocketClient struct { cancel context.CancelFunc reconnectCh chan struct{} isConnected bool + + doneMap *SafeMap[string, chan struct{}] } func (wsClient *SafeWebsocketClient) connect() error { @@ -139,6 +189,7 @@ func (wsClient *SafeWebsocketClient) connect() error { }) wsClient.mu.WriteHandler(func() error { + wsClient.conn.Close() ctx, cancel := context.WithCancel(context.Background()) wsClient.ctx = ctx wsClient.cancel = cancel @@ -158,6 +209,10 @@ func (wsClient *SafeWebsocketClient) connect() error { func (wsClient *SafeWebsocketClient) startPingTicker() { ticker := time.NewTicker(pingPeriod) defer ticker.Stop() + + doneKey := "startPingTicker" + wsClient.doneMap.Store(doneKey, make(chan struct{})) + done, _ := wsClient.doneMap.Load(doneKey) for { select { case <-ticker.C: @@ -168,7 +223,7 @@ func (wsClient *SafeWebsocketClient) startPingTicker() { } return nil }) - case <-wsClient.ctx.Done(): + case <-done: log.Println("Ping ticker stopped") return } @@ -176,32 +231,28 @@ func (wsClient *SafeWebsocketClient) startPingTicker() { } func (wsClient *SafeWebsocketClient) startReceiveHandler() { + doneKey := "startReceiveHandler" + wsClient.doneMap.Store(doneKey, make(chan struct{})) + done, _ := wsClient.doneMap.Load(doneKey) + for { select { - case <-wsClient.ctx.Done(): + case <-done: log.Println("Reconnect handler stopped") return default: - // if err := wsClient.mu.ReadHandler(func() error { conn := wsClient.conn if conn == nil { wsClient.triggerReconnect() return - // return fmt.Errorf("no active connection, waiting for reconnect") } _, message, err := conn.ReadMessage() if err != nil { wsClient.triggerReconnect() return - // return err } wsClient.dataChannel <- message - // return nil - // }); err != nil { - // wsClient.triggerReconnect() - // return - // } } } } @@ -214,12 +265,20 @@ func (wsClient *SafeWebsocketClient) triggerReconnect() { } func (wsClient *SafeWebsocketClient) reconnectHandler() { + doneKey := "reconnectHandler" + wsClient.doneMap.Store(doneKey, make(chan struct{})) + done, _ := wsClient.doneMap.Load(doneKey) for { select { case <-wsClient.reconnectCh: wsClient.cancel() wsClient.connect() - case <-wsClient.ctx.Done(): + + wsClient.doneMap.Range(func(s string, c chan struct{}) bool { + c <- struct{}{} + return true + }) + case <-done: return } }