diff --git a/v1/client/client.go b/v1/client/client.go index 098cdd6..9531928 100644 --- a/v1/client/client.go +++ b/v1/client/client.go @@ -276,6 +276,7 @@ func (wsClient *SafeWebsocketClient) connect() error { } if _, err := writer.Write(data.Data); err != nil { + wsClient.triggerReconnect() return } } @@ -288,10 +289,7 @@ func (wsClient *SafeWebsocketClient) connect() error { wsClient.cancelFuncs = append(wsClient.cancelFuncs, cancel) return nil }) - _, reader, err := conn.NextReader() - if err != nil { - return - } + for { select { case <-ctx.Done(): @@ -306,6 +304,16 @@ func (wsClient *SafeWebsocketClient) connect() error { return } + mt, reader, err := conn.NextReader() + if err != nil { + wsClient.triggerReconnect() + return + } + + if mt != websocket.TextMessage { + continue + } + readerBytes, err := io.ReadAll(reader) if err != nil { fmt.Printf("io reader failed: %v\n", err) @@ -346,40 +354,6 @@ func (wsClient *SafeWebsocketClient) startPingTicker(ctx context.Context) { } } -// func (wsClient *SafeWebsocketClient) startReceiveHandler(ctx context.Context) { -// for { -// select { -// case <-ctx.Done(): -// log.Println("receive handler stopped") -// return -// default: -// if err := wsClient.mu.ReadHandler(func() error { -// conn := wsClient.conn - -// if conn == nil { -// wsClient.triggerReconnect() -// return fmt.Errorf("connection closed") -// } -// _, message, err := conn.ReadMessage() -// if err != nil { -// wsClient.triggerReconnect() -// return fmt.Errorf("failed to read message: %v", err) -// } -// select { -// case wsClient.dataChannel <- message: -// case <-ctx.Done(): -// log.Println("Reconnect handler stopped") -// default: -// log.Println("") -// } -// return nil -// }); err != nil { -// return -// } -// } -// } -// } - func (wsClient *SafeWebsocketClient) triggerReconnect() { select { case wsClient.reconnectCh <- struct{}{}: diff --git a/v1/examples/client/main.go b/v1/examples/client/main.go index f20907d..711b650 100644 --- a/v1/examples/client/main.go +++ b/v1/examples/client/main.go @@ -29,6 +29,7 @@ func main() { BasePort(8080). Path("/ws/test/data_1"). UseTLS(false). + ChannelSize(30). Build(ctx) if err != nil { log.Fatal(err)