diff --git a/v1/client/client.go b/v1/client/client.go index 2a86580..b6c0429 100644 --- a/v1/client/client.go +++ b/v1/client/client.go @@ -16,8 +16,9 @@ import ( ) const ( - pingPeriod = 10 * time.Second - readDeadline = 30 * time.Second + pingPeriod = 10 * time.Second + readDeadline = 30 * time.Second + writeDeadline = 10 * time.Second ) type MessageType uint @@ -182,17 +183,14 @@ func (wsClient *SafeWebsocketClient) connect() error { Scheme: scheme, Host: fmt.Sprintf("%s:%d", wsClient.baseHost, wsClient.basePort), } - if wsClient.path != nil && strings.TrimSpace(*wsClient.path) != "" { newURL.Path = *wsClient.path } - if wsClient.rawQuery != nil && strings.TrimSpace(*wsClient.rawQuery) != "" { newURL.RawQuery = *wsClient.rawQuery } header := make(http.Header) - if wsClient.headers != nil { for k, v := range *wsClient.headers { header.Set(k, v) @@ -205,18 +203,22 @@ func (wsClient *SafeWebsocketClient) connect() error { } pingCtx, pingCancel := context.WithCancel(context.Background()) + pumpCtx, pumpCancel := context.WithCancel(context.Background()) + wsClient.mu.WriteHandler(func() error { - wsClient.cancelFuncs = append(wsClient.cancelFuncs, pingCancel) + if wsClient.conn != nil { + wsClient.conn.Close() + } + + wsClient.conn = conn + + wsClient.cancelFuncs = append(wsClient.cancelFuncs, pingCancel, pumpCancel) return nil }) go wsClient.startPingTicker(pingCtx) - - if wsClient.conn != nil { - wsClient.conn.Close() - } - wsClient.conn = conn - wsClient.isConnected = true + go wsClient.writePump(pumpCtx, conn) + go wsClient.readPump(pumpCtx, conn) conn.SetPingHandler(func(pingData string) error { if err := conn.SetReadDeadline(time.Now().Add(readDeadline)); err != nil { @@ -239,24 +241,72 @@ func (wsClient *SafeWebsocketClient) connect() error { return nil }) - go wsClient.writePump() - go wsClient.readPump() - + wsClient.isConnected = true return nil } -func (wsClient *SafeWebsocketClient) writePump() { - ctx, cancel := context.WithCancel(context.Background()) - wsClient.mu.WriteHandler(func() error { - wsClient.cancelFuncs = append(wsClient.cancelFuncs, cancel) - return nil - }) +func (wsClient *SafeWebsocketClient) reconnectHandler() { + backoff := 1 * time.Second + maxBackoff := 15 * time.Second - var c *websocket.Conn - wsClient.mu.ReadHandler(func() error { - c = wsClient.conn - return nil - }) + for { + select { + case <-wsClient.reconnectCh: + log.Println("Reconnect triggered") + + wsClient.mu.WriteHandler(func() error { + if wsClient.cancelFuncs != nil { + for _, cancel := range wsClient.cancelFuncs { + cancel() + } + wsClient.cancelFuncs = nil + } + return nil + }) + + wsClient.isConnected = false + isInnerLoop := true + for isInnerLoop { + log.Printf("Attempting reconnect in %v...", backoff) + select { + case <-time.After(backoff): + if err := wsClient.connect(); err != nil { + log.Printf("Reconnect failed: %v", err) + if backoff < maxBackoff { + backoff *= 2 + } + continue + } + log.Println("Reconnected successfully") + backoff = 1 * time.Second + isInnerLoop = false + continue + case <-wsClient.ctx.Done(): + log.Println("reconnect handler stopped due to client shutdown") + wsClient.Close() + return + } + } + if wsClient.reconnectChans != nil { + for _, reconnectCh := range wsClient.reconnectChans { + select { + case reconnectCh <- struct{}{}: + default: // prevent blocking if chan is full + } + } + } + case <-wsClient.ctx.Done(): + log.Println("reconnect handler stopped due to client shutdown") + wsClient.Close() + return + } + } +} + +func (wsClient *SafeWebsocketClient) writePump(ctx context.Context, c *websocket.Conn) { + defer func() { + c.Close() + }() for { select { @@ -264,30 +314,25 @@ func (wsClient *SafeWebsocketClient) writePump() { log.Println("Writer canceled by context") return case data := <-wsClient.writeChan: - if c == nil { + if err := c.SetWriteDeadline(time.Now().Add(writeDeadline)); err != nil { + log.Printf("error setting write deadline: %v", err) return } if err := c.WriteMessage(int(data.MessageType), data.Data); err != nil { log.Printf("error on write message: %v\n", err) + wsClient.triggerReconnect() // Trigger reconnect on write failure return } } } } -func (wsClient *SafeWebsocketClient) readPump() { - ctx, cancel := context.WithCancel(context.Background()) - wsClient.mu.WriteHandler(func() error { - wsClient.cancelFuncs = append(wsClient.cancelFuncs, cancel) - return nil - }) - - var c *websocket.Conn - wsClient.mu.ReadHandler(func() error { - c = wsClient.conn - return nil - }) +func (wsClient *SafeWebsocketClient) readPump(ctx context.Context, c *websocket.Conn) { + defer func() { + wsClient.triggerReconnect() + c.Close() + }() for { select { @@ -295,19 +340,14 @@ func (wsClient *SafeWebsocketClient) readPump() { log.Println("Reader canceled by context") return default: - if c == nil { - return - } - + // Set read deadline if err := c.SetReadDeadline(time.Now().Add(readDeadline)); err != nil { - log.Printf("error on read deadline: %v\n", err) return } messageType, data, err := c.ReadMessage() if err != nil { log.Printf("error on read message: %v\n", err) - wsClient.triggerReconnect() return } @@ -353,60 +393,6 @@ func (wsClient *SafeWebsocketClient) triggerReconnect() { } } -func (wsClient *SafeWebsocketClient) reconnectHandler() { - backoff := 1 * time.Second - maxBackoff := 30 * time.Second - for { - select { - case <-wsClient.reconnectCh: - log.Println("Reconnect triggered") - - wsClient.mu.ReadHandler(func() error { - if wsClient.cancelFuncs != nil { - for _, cancel := range wsClient.cancelFuncs { - cancel() - } - } - return nil - }) - - wsClient.isConnected = false - - isInnerLoop := true - for isInnerLoop { - log.Printf("Attempting reconnect in %v...", backoff) - select { - case <-time.After(backoff): - if err := wsClient.connect(); err != nil { - log.Printf("Reconnect failed: %v", err) - if backoff < maxBackoff { - backoff *= 2 - } - continue - } - log.Println("Reconnected successfully") - backoff = 1 * time.Second - isInnerLoop = false - continue - case <-wsClient.ctx.Done(): - log.Println("reconnect handler stopped due to client shutdown") - wsClient.Close() - return - } - } - if wsClient.reconnectChans != nil { - for _, reconnectCh := range wsClient.reconnectChans { - reconnectCh <- struct{}{} - } - } - case <-wsClient.ctx.Done(): - log.Println("reconnect handler stopped due to client shutdown") - wsClient.Close() - return - } - } -} - func (wsClient *SafeWebsocketClient) ReconnectChannel() <-chan struct{} { reconnectCh := make(chan struct{}, 1) wsClient.mu.WriteHandler(func() error {