diff --git a/v1/client/client.go b/v1/client/client.go index 6774d48..fcfc200 100644 --- a/v1/client/client.go +++ b/v1/client/client.go @@ -38,6 +38,7 @@ type SafeWebsocketClientBuilder struct { basePort *uint16 `nil_checker:"required"` path *string rawQuery *string + isDrop *bool useTLS *bool channelSize *int64 } @@ -71,6 +72,11 @@ func (b *SafeWebsocketClientBuilder) RawQuery(rawQuery string) *SafeWebsocketCli return b } +func (b *SafeWebsocketClientBuilder) IsDrop(isDrop bool) *SafeWebsocketClientBuilder { + b.isDrop = &isDrop + return b +} + func (b *SafeWebsocketClientBuilder) ChannelSize(channelSize int64) *SafeWebsocketClientBuilder { b.channelSize = &channelSize return b @@ -81,9 +87,15 @@ func (b *SafeWebsocketClientBuilder) Build(ctx context.Context) (*SafeWebsocketC return nil, err } - var useTLS bool - if b.useTLS != nil { - useTLS = *b.useTLS + // var useTLS bool + if b.useTLS == nil { + useTLS := true + b.useTLS = &useTLS + } + + if b.isDrop == nil { + isDrop := true + b.isDrop = &isDrop } if b.channelSize == nil { @@ -94,7 +106,8 @@ func (b *SafeWebsocketClientBuilder) Build(ctx context.Context) (*SafeWebsocketC wsClient := SafeWebsocketClient{ baseHost: *b.baseHost, basePort: *b.basePort, - useTLS: useTLS, + useTLS: *b.useTLS, + isDrop: *b.isDrop, path: b.path, rawQuery: b.rawQuery, dataChannel: make(chan []byte, *b.channelSize), @@ -118,17 +131,16 @@ func (b *SafeWebsocketClientBuilder) Build(ctx context.Context) (*SafeWebsocketC type SafeWebsocketClient struct { baseHost string basePort uint16 + isDrop bool useTLS bool path *string rawQuery *string - dataChannel chan []byte mu *custom_rwmutex.CustomRwMutex conn *websocket.Conn + ctx context.Context cancelFuncs []context.CancelFunc - - ctx context.Context - Cancel context.CancelFunc + dataChannel chan []byte reconnectCh chan struct{} reconnectChans []chan struct{} @@ -214,18 +226,15 @@ func (wsClient *SafeWebsocketClient) writePump() { for { select { case <-ctx.Done(): - log.Println("Writer stopped due to client shutdown") + log.Println("Writer canceled by context") return case data := <-wsClient.writeChan: if c == nil { - wsClient.triggerReconnect() return } if err := c.WriteMessage(int(data.MessageType), data.Data); err != nil { log.Printf("error on write message: %v\n", err) - c.Close() - wsClient.triggerReconnect() if data.MessageType == MessageTypePong { wsClient.pongChan <- err } @@ -243,47 +252,83 @@ func (wsClient *SafeWebsocketClient) readPump() { }) var c *websocket.Conn - wsClient.mu.ReadHandler(func() error { c = wsClient.conn return nil }) - if c == nil { - wsClient.triggerReconnect() - return - } - - for { - select { - case <-ctx.Done(): - log.Println("Reader stopped due to client shutdown") - return - default: - if err := c.SetReadDeadline(time.Now().Add(readDeadline)); err != nil { - log.Printf("error on read deadline: %v\n", err) - return - } - - messageType, data, err := c.ReadMessage() - if err != nil { - log.Printf("error on read message: %v\n", err) - wsClient.triggerReconnect() - return - } - - if messageType != websocket.TextMessage { - continue - } + if wsClient.isDrop { + for { select { - case wsClient.dataChannel <- data: case <-ctx.Done(): + log.Println("Reader canceled by context") return default: - log.Println("Data channel full, dropping message") + if c == nil { + return + } + + if err := c.SetReadDeadline(time.Now().Add(readDeadline)); err != nil { + log.Printf("error on read deadline: %v\n", err) + return + } + + messageType, data, err := c.ReadMessage() + if err != nil { + log.Printf("error on read message: %v\n", err) + wsClient.triggerReconnect() + return + } + + if messageType != websocket.TextMessage { + continue + } + + select { + case wsClient.dataChannel <- data: + case <-ctx.Done(): + return + default: + log.Println("Data channel full, dropping message") + } + } + } + } else { + for { + select { + case <-ctx.Done(): + log.Println("Reader canceled by context") + return + default: + if c == nil { + return + } + + if err := c.SetReadDeadline(time.Now().Add(readDeadline)); err != nil { + log.Printf("error on read deadline: %v\n", err) + return + } + + messageType, data, err := c.ReadMessage() + if err != nil { + log.Printf("error on read message: %v\n", err) + wsClient.triggerReconnect() + return + } + + if messageType != websocket.TextMessage { + continue + } + + select { + case wsClient.dataChannel <- data: + case <-ctx.Done(): + return + } } } } + } func (wsClient *SafeWebsocketClient) startPingTicker(ctx context.Context) {