diff --git a/v1/client/client.go b/v1/client/client.go index f252857..aad0e1d 100644 --- a/v1/client/client.go +++ b/v1/client/client.go @@ -59,8 +59,6 @@ func (b *SafeWebsocketClientBuilder) Build() (*SafeWebsocketClient, error) { return nil, err } - ctx, cancel := context.WithCancel(context.Background()) - var useTLS bool if b.useTLS != nil { useTLS = *b.useTLS @@ -72,8 +70,6 @@ func (b *SafeWebsocketClientBuilder) Build() (*SafeWebsocketClient, error) { useTLS: useTLS, path: b.path, rawQuery: b.rawQuery, - ctx: ctx, - cancel: cancel, dataChannel: make(chan []byte, 1), mu: custom_rwmutex.NewCustomRwMutex(), reconnectCh: make(chan struct{}, 1), @@ -81,7 +77,6 @@ func (b *SafeWebsocketClientBuilder) Build() (*SafeWebsocketClient, error) { } if err := wsClient.connect(); err != nil { - cancel() return nil, fmt.Errorf("failed to establish initial connection: %v", err) } @@ -142,6 +137,9 @@ func (wsClient *SafeWebsocketClient) connect() error { }) wsClient.mu.WriteHandler(func() error { + ctx, cancel := context.WithCancel(context.Background()) + wsClient.ctx = ctx + wsClient.cancel = cancel wsClient.conn = conn wsClient.isConnected = true return nil @@ -177,7 +175,6 @@ func (wsClient *SafeWebsocketClient) startPingTicker() { func (wsClient *SafeWebsocketClient) startReceiveHandler() { for { select { - case <-wsClient.reconnectCh: case <-wsClient.ctx.Done(): log.Println("Reconnect handler stopped") return @@ -210,8 +207,14 @@ func (wsClient *SafeWebsocketClient) triggerReconnect() { } func (wsClient *SafeWebsocketClient) reconnectHandler() { - for range wsClient.reconnectCh { - wsClient.connect() + for { + select { + case <-wsClient.reconnectCh: + wsClient.cancel() + wsClient.connect() + case <-wsClient.ctx.Done(): + return + } } }