From dfbe2f2808310717bc2caf8e97eff29fe61d78f3 Mon Sep 17 00:00:00 2001 From: Roger Ferdinan Date: Mon, 29 Sep 2025 09:06:31 +0700 Subject: [PATCH] fix: fixing client deadlock condition & reconnection issues --- v1/client/client.go | 105 +++++++++++++++++++++---------------- v1/examples/client/main.go | 19 ++++++- 2 files changed, 79 insertions(+), 45 deletions(-) diff --git a/v1/client/client.go b/v1/client/client.go index 42f9869..8cecd38 100644 --- a/v1/client/client.go +++ b/v1/client/client.go @@ -65,11 +65,12 @@ func (sm *SafeMap[K, V]) Len() int { } type SafeWebsocketClientBuilder struct { - baseHost *string `nil_checker:"required"` - basePort *uint16 `nil_checker:"required"` - path *string - rawQuery *string - useTLS *bool + baseHost *string `nil_checker:"required"` + basePort *uint16 `nil_checker:"required"` + path *string + rawQuery *string + useTLS *bool + channelSize *int64 } func NewSafeWebsocketClientBuilder() *SafeWebsocketClientBuilder { @@ -101,7 +102,12 @@ func (b *SafeWebsocketClientBuilder) RawQuery(rawQuery string) *SafeWebsocketCli return b } -func (b *SafeWebsocketClientBuilder) Build(channelSize int) (*SafeWebsocketClient, error) { +func (b *SafeWebsocketClientBuilder) ChannelSize(channelSize int64) *SafeWebsocketClientBuilder { + b.channelSize = &channelSize + return b +} + +func (b *SafeWebsocketClientBuilder) Build(ctx context.Context) (*SafeWebsocketClient, error) { if err := internal.NilChecker(b); err != nil { return nil, err } @@ -111,14 +117,20 @@ func (b *SafeWebsocketClientBuilder) Build(channelSize int) (*SafeWebsocketClien useTLS = *b.useTLS } + if b.channelSize == nil { + channelSize := int64(1) + b.channelSize = &channelSize + } + wsClient := SafeWebsocketClient{ baseHost: *b.baseHost, basePort: *b.basePort, useTLS: useTLS, path: b.path, rawQuery: b.rawQuery, - dataChannel: make(chan []byte, channelSize), + dataChannel: make(chan []byte, *b.channelSize), mu: custom_rwmutex.NewCustomRwMutex(), + ctx: ctx, reconnectCh: make(chan struct{}, 1), isConnected: false, doneMap: NewSafeMap[string, chan struct{}](), @@ -142,8 +154,8 @@ type SafeWebsocketClient struct { dataChannel chan []byte mu *custom_rwmutex.CustomRwMutex conn *websocket.Conn + cancelFuncs []context.CancelFunc ctx context.Context - cancel context.CancelFunc reconnectCh chan struct{} isConnected bool @@ -176,18 +188,16 @@ func (wsClient *SafeWebsocketClient) connect() error { } conn.SetPingHandler(func(pingData string) error { - return wsClient.mu.WriteHandler(func() error { - if err := conn.WriteMessage(websocket.PongMessage, []byte(pingData)); err != nil { - if err == websocket.ErrCloseSent { - return nil - } - if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() { - return nil - } - return err + if err := conn.WriteMessage(websocket.PongMessage, []byte(pingData)); err != nil { + if err == websocket.ErrCloseSent { + return nil } - return nil - }) + if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() { + return nil + } + return err + } + return nil }) wsClient.mu.WriteHandler(func() error { @@ -195,14 +205,17 @@ func (wsClient *SafeWebsocketClient) connect() error { wsClient.conn.Close() } - ctx, cancel := context.WithCancel(context.Background()) - wsClient.ctx = ctx - wsClient.cancel = cancel wsClient.conn = conn + wsClient.isConnected = true - go wsClient.startPingTicker(ctx) - go wsClient.startReceiveHandler(ctx) + pingCtx, pingCancel := context.WithCancel(context.Background()) + go wsClient.startPingTicker(pingCtx) + wsClient.cancelFuncs = append(wsClient.cancelFuncs, pingCancel) + + receiverCtx, receiverCancel := context.WithCancel(context.Background()) + go wsClient.startReceiveHandler(receiverCtx) + wsClient.cancelFuncs = append(wsClient.cancelFuncs, receiverCancel) return nil }) @@ -215,10 +228,13 @@ func (wsClient *SafeWebsocketClient) startPingTicker(ctx context.Context) { for { select { + case <-ctx.Done(): + log.Println("ping ticker stopped") + return case <-ticker.C: wsClient.mu.WriteHandler(func() error { if wsClient.conn == nil { - return fmt.Errorf("connecrtion closed") + return fmt.Errorf("connection closed") } if err := wsClient.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { log.Printf("Ping failed: %v. Will attempt reconnect.", err) @@ -226,9 +242,6 @@ func (wsClient *SafeWebsocketClient) startPingTicker(ctx context.Context) { } return nil }) - case <-ctx.Done(): - log.Println("Ping ticker stopped due to context cancellation") - return } } } @@ -237,10 +250,10 @@ func (wsClient *SafeWebsocketClient) startReceiveHandler(ctx context.Context) { for { select { case <-ctx.Done(): - log.Println("Reconnect handler stopped") + log.Println("receive handler stopped") return default: - wsClient.mu.ReadHandler(func() error { + if err := wsClient.mu.ReadHandler(func() error { conn := wsClient.conn if conn == nil { @@ -254,12 +267,16 @@ func (wsClient *SafeWebsocketClient) startReceiveHandler(ctx context.Context) { } select { case wsClient.dataChannel <- message: + case <-ctx.Done(): + log.Println("Reconnect handler stopped") + return nil default: log.Println("Data channel full, dropping message") } return nil - }) - + }); err != nil { + return + } } } } @@ -275,19 +292,15 @@ func (wsClient *SafeWebsocketClient) reconnectHandler() { backoff := 1 * time.Second maxBackoff := 30 * time.Second for { - if wsClient.ctx == nil { - continue - } select { case <-wsClient.reconnectCh: log.Println("Reconnect triggered") - wsClient.mu.WriteHandler(func() error { - if wsClient.cancel != nil { - wsClient.cancel() + if wsClient.cancelFuncs != nil { + for _, cancel := range wsClient.cancelFuncs { + cancel() } - wsClient.isConnected = false - return nil - }) + } + wsClient.isConnected = false time.Sleep(100 * time.Millisecond) @@ -311,8 +324,10 @@ func (wsClient *SafeWebsocketClient) reconnectHandler() { return } } + case <-wsClient.ctx.Done(): - log.Println("Reconnect handler stopped due to client shutdown") + log.Println("reconnect handler stopped due to client shutdown") + wsClient.Close() return } } @@ -324,8 +339,10 @@ func (wsClient *SafeWebsocketClient) DataChannel() <-chan []byte { func (wsClient *SafeWebsocketClient) Close() error { wsClient.mu.WriteHandler(func() error { - if wsClient.cancel != nil { - wsClient.cancel() + if wsClient.cancelFuncs != nil { + for _, cancel := range wsClient.cancelFuncs { + cancel() + } } if wsClient.conn != nil { wsClient.conn.Close() diff --git a/v1/examples/client/main.go b/v1/examples/client/main.go index bb3d9c3..deaac21 100644 --- a/v1/examples/client/main.go +++ b/v1/examples/client/main.go @@ -1,22 +1,39 @@ package main import ( + "context" "fmt" "log" + "os" + "os/signal" + "syscall" "git.neurocipta.com/rogerferdinan/safe-web-socket/v1/client" ) func main() { + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + <-sigChan + fmt.Println("\nReceived interrupt signal. Shutting down gracefully...") + cancel() + }() + wsClient, err := client.NewSafeWebsocketClientBuilder(). BaseHost("localhost"). BasePort(8080). Path("/ws/test/data_1"). UseTLS(false). - Build() + Build(ctx) if err != nil { log.Fatal(err) } + dataChannel := wsClient.DataChannel() for data := range dataChannel {