From 5e5df9090f18b5802a5a7a2674d808e91147acf5 Mon Sep 17 00:00:00 2001 From: Roger Ferdinan Date: Mon, 29 Sep 2025 17:55:24 +0700 Subject: [PATCH] fix: adding back read mutex for data race --- v1/client/client.go | 79 +++++++++++++++++++++++++++------------------ 1 file changed, 47 insertions(+), 32 deletions(-) diff --git a/v1/client/client.go b/v1/client/client.go index 2206b21..3f4ec2a 100644 --- a/v1/client/client.go +++ b/v1/client/client.go @@ -65,12 +65,13 @@ func (sm *SafeMap[K, V]) Len() int { } type SafeWebsocketClientBuilder struct { - baseHost *string `nil_checker:"required"` - basePort *uint16 `nil_checker:"required"` - path *string - rawQuery *string - useTLS *bool - channelSize *int64 + baseHost *string `nil_checker:"required"` + basePort *uint16 `nil_checker:"required"` + path *string + rawQuery *string + useTLS *bool + channelSize *int64 + authenticateFn func() error } func NewSafeWebsocketClientBuilder() *SafeWebsocketClientBuilder { @@ -92,6 +93,11 @@ func (b *SafeWebsocketClientBuilder) UseTLS(useTLS bool) *SafeWebsocketClientBui return b } +func (b *SafeWebsocketClientBuilder) AuthenticateFn(authenticateFn func() error) *SafeWebsocketClientBuilder { + b.authenticateFn = authenticateFn + return b +} + func (b *SafeWebsocketClientBuilder) Path(path string) *SafeWebsocketClientBuilder { b.path = &path return b @@ -136,6 +142,10 @@ func (b *SafeWebsocketClientBuilder) Build(ctx context.Context) (*SafeWebsocketC doneMap: NewSafeMap[string, chan struct{}](), } + if b.authenticateFn != nil { + wsClient.authenticateFn = b.authenticateFn + } + go wsClient.reconnectHandler() if err := wsClient.connect(); err != nil { @@ -160,6 +170,7 @@ type SafeWebsocketClient struct { reconnectChans []chan struct{} isConnected bool doneMap *SafeMap[string, chan struct{}] + authenticateFn func() error } func (wsClient *SafeWebsocketClient) connect() error { @@ -200,14 +211,17 @@ func (wsClient *SafeWebsocketClient) connect() error { return nil }) + if wsClient.conn != nil { + wsClient.conn.Close() + } + wsClient.conn = conn + wsClient.isConnected = true + + if err := wsClient.authenticateFn(); err != nil { + return err + } + wsClient.mu.WriteHandler(func() error { - if wsClient.conn != nil { - wsClient.conn.Close() - } - - wsClient.conn = conn - - wsClient.isConnected = true pingCtx, pingCancel := context.WithCancel(context.Background()) go wsClient.startPingTicker(pingCtx) @@ -253,28 +267,29 @@ func (wsClient *SafeWebsocketClient) startReceiveHandler(ctx context.Context) { log.Println("receive handler stopped") return default: - conn := wsClient.conn + if err := wsClient.mu.ReadHandler(func() error { + conn := wsClient.conn - if conn == nil { - wsClient.triggerReconnect() - return - // return fmt.Errorf("connection closed") - } - _, message, err := conn.ReadMessage() - if err != nil { - wsClient.triggerReconnect() - // return fmt.Errorf("failed to read message: %v", err) + if conn == nil { + wsClient.triggerReconnect() + return fmt.Errorf("connection closed") + } + _, message, err := conn.ReadMessage() + if err != nil { + wsClient.triggerReconnect() + return fmt.Errorf("failed to read message: %v", err) + } + select { + case wsClient.dataChannel <- message: + case <-ctx.Done(): + log.Println("Reconnect handler stopped") + default: + log.Println("Data channel full, dropping message") + } + return nil + }); err != nil { return } - select { - case wsClient.dataChannel <- message: - case <-ctx.Done(): - log.Println("Reconnect handler stopped") - return - default: - log.Println("Data channel full, dropping message") - } - return } } }