diff --git a/v1/client/client.go b/v1/client/client.go index 09ac0e3..201a52e 100644 --- a/v1/client/client.go +++ b/v1/client/client.go @@ -3,6 +3,7 @@ package client import ( "context" "fmt" + "io" "log" "net/url" "strings" @@ -15,7 +16,8 @@ import ( ) const ( - pingPeriod = 10 * time.Second + pingPeriod = 10 * time.Second + readDeadline = 10 * time.Second ) type SafeMap[K comparable, V any] struct { @@ -64,6 +66,20 @@ func (sm *SafeMap[K, V]) Len() int { return count } +type MessageType uint + +const ( + MessageTypeText MessageType = iota + 1 + MessageTypePing MessageType = iota + MessageTypePong MessageType = iota + MessageTypeClose MessageType = iota +) + +type Message struct { + MessageType MessageType + Data []byte +} + type SafeWebsocketClientBuilder struct { baseHost *string `nil_checker:"required"` basePort *uint16 `nil_checker:"required"` @@ -156,21 +172,27 @@ func (b *SafeWebsocketClientBuilder) Build(ctx context.Context) (*SafeWebsocketC } type SafeWebsocketClient struct { - baseHost string - basePort uint16 - useTLS bool - path *string - rawQuery *string - dataChannel chan []byte - mu *custom_rwmutex.CustomRwMutex - conn *websocket.Conn - cancelFuncs []context.CancelFunc - ctx context.Context + baseHost string + basePort uint16 + useTLS bool + path *string + rawQuery *string + + dataChannel chan []byte + mu *custom_rwmutex.CustomRwMutex + conn *websocket.Conn + cancelFuncs []context.CancelFunc + + ctx context.Context + Cancel context.CancelFunc + reconnectCh chan struct{} reconnectChans []chan struct{} isConnected bool doneMap *SafeMap[string, chan struct{}] authenticateFn func(*SafeWebsocketClient) error + + writeChan chan Message } func (wsClient *SafeWebsocketClient) connect() error { @@ -211,6 +233,10 @@ func (wsClient *SafeWebsocketClient) connect() error { return nil }) + pingCtx, pingCancel := context.WithCancel(context.Background()) + go wsClient.startPingTicker(pingCtx) + wsClient.cancelFuncs = append(wsClient.cancelFuncs, pingCancel) + if wsClient.conn != nil { wsClient.conn.Close() } @@ -223,17 +249,82 @@ func (wsClient *SafeWebsocketClient) connect() error { } } - wsClient.mu.WriteHandler(func() error { - pingCtx, pingCancel := context.WithCancel(context.Background()) - go wsClient.startPingTicker(pingCtx) - wsClient.cancelFuncs = append(wsClient.cancelFuncs, pingCancel) + go func() { + ctx, cancel := context.WithCancel(context.Background()) + wsClient.cancelFuncs = append(wsClient.cancelFuncs, cancel) + for { + select { + case <-ctx.Done(): + fmt.Println("Writer stopped due to client shutdown") + return + case data := <-wsClient.writeChan: + if conn == nil { + wsClient.triggerReconnect() + return + } + messageType := websocket.TextMessage + switch data.MessageType { + case MessageTypePing: + messageType = websocket.PingMessage + case MessageTypePong: + messageType = websocket.PongMessage + case MessageTypeClose: + messageType = websocket.CloseMessage + } + writer, err := conn.NextWriter(messageType) + if err != nil { + return + } - receiverCtx, receiverCancel := context.WithCancel(context.Background()) - go wsClient.startReceiveHandler(receiverCtx) - wsClient.cancelFuncs = append(wsClient.cancelFuncs, receiverCancel) + if _, err := writer.Write(data.Data); err != nil { + return + } + } + } + }() + + go func() { + ctx, cancel := context.WithCancel(context.Background()) + wsClient.cancelFuncs = append(wsClient.cancelFuncs, cancel) + for { + select { + case <-ctx.Done(): + fmt.Println("Reader stopped due to client shutdown") + return + default: + if conn == nil { + return + } + if err := conn.SetReadDeadline(time.Now().Add(readDeadline)); err != nil { + fmt.Printf("error on read deadline: %v\n", err) + return + } + _, reader, err := conn.NextReader() + if err != nil { + fmt.Printf("Next Reader Closed: %v\n", err) + wsClient.triggerReconnect() + return + } + + readerBytes, err := io.ReadAll(reader) + if err != nil { + fmt.Printf("io reader failed: %v\n", err) + wsClient.triggerReconnect() + return + } + select { + case wsClient.dataChannel <- readerBytes: + case <-ctx.Done(): + return + default: + fmt.Println("Data channel full, dropping message") + } + + } + + } + }() - return nil - }) return nil } @@ -247,54 +338,48 @@ func (wsClient *SafeWebsocketClient) startPingTicker(ctx context.Context) { log.Println("ping ticker stopped") return case <-ticker.C: - wsClient.mu.WriteHandler(func() error { - if wsClient.conn == nil { - return fmt.Errorf("connection closed") - } - if err := wsClient.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { - log.Printf("Ping failed: %v. Will attempt reconnect.", err) - wsClient.triggerReconnect() - } - return nil - }) - } - } -} - -func (wsClient *SafeWebsocketClient) startReceiveHandler(ctx context.Context) { - for { - select { - case <-ctx.Done(): - log.Println("receive handler stopped") - return - default: - if err := wsClient.mu.ReadHandler(func() error { - conn := wsClient.conn - - if conn == nil { - wsClient.triggerReconnect() - return fmt.Errorf("connection closed") - } - _, message, err := conn.ReadMessage() - if err != nil { - wsClient.triggerReconnect() - return fmt.Errorf("failed to read message: %v", err) - } - select { - case wsClient.dataChannel <- message: - case <-ctx.Done(): - log.Println("Reconnect handler stopped") - default: - log.Println("Data channel full, dropping message") - } - return nil - }); err != nil { - return + wsClient.writeChan <- Message{ + MessageType: websocket.PingMessage, + Data: nil, } } } } +// func (wsClient *SafeWebsocketClient) startReceiveHandler(ctx context.Context) { +// for { +// select { +// case <-ctx.Done(): +// log.Println("receive handler stopped") +// return +// default: +// if err := wsClient.mu.ReadHandler(func() error { +// conn := wsClient.conn + +// if conn == nil { +// wsClient.triggerReconnect() +// return fmt.Errorf("connection closed") +// } +// _, message, err := conn.ReadMessage() +// if err != nil { +// wsClient.triggerReconnect() +// return fmt.Errorf("failed to read message: %v", err) +// } +// select { +// case wsClient.dataChannel <- message: +// case <-ctx.Done(): +// log.Println("Reconnect handler stopped") +// default: +// log.Println("") +// } +// return nil +// }); err != nil { +// return +// } +// } +// } +// } + func (wsClient *SafeWebsocketClient) triggerReconnect() { select { case wsClient.reconnectCh <- struct{}{}: @@ -309,6 +394,7 @@ func (wsClient *SafeWebsocketClient) reconnectHandler() { select { case <-wsClient.reconnectCh: log.Println("Reconnect triggered") + if wsClient.cancelFuncs != nil { for _, cancel := range wsClient.cancelFuncs { cancel() @@ -334,6 +420,8 @@ func (wsClient *SafeWebsocketClient) reconnectHandler() { isInnerLoop = false continue case <-wsClient.ctx.Done(): + log.Println("reconnect handler stopped due to client shutdown") + wsClient.Close() return } } @@ -364,31 +452,30 @@ func (wsClient *SafeWebsocketClient) DataChannel() <-chan []byte { return wsClient.dataChannel } -func (wsClient *SafeWebsocketClient) WriteJSON(message any) error { - return wsClient.mu.WriteHandler(func() error { - return wsClient.conn.WriteJSON(message) - }) +func (wsClient *SafeWebsocketClient) Write(data []byte) error { + wsClient.writeChan <- Message{ + MessageType: MessageTypeText, + Data: data, + } + return nil } func (wsClient *SafeWebsocketClient) Close() error { - wsClient.mu.WriteHandler(func() error { - if wsClient.cancelFuncs != nil { - for _, cancel := range wsClient.cancelFuncs { - cancel() - } + if wsClient.cancelFuncs != nil { + for _, cancel := range wsClient.cancelFuncs { + cancel() } + } - if wsClient.reconnectChans != nil { - for _, reconnectChan := range wsClient.reconnectChans { - close(reconnectChan) - } + if wsClient.reconnectChans != nil { + for _, reconnectChan := range wsClient.reconnectChans { + close(reconnectChan) } - if wsClient.conn != nil { - wsClient.conn.Close() - } - wsClient.isConnected = false - return nil - }) + } + if wsClient.conn != nil { + wsClient.conn.Close() + } + wsClient.isConnected = false close(wsClient.dataChannel) return nil diff --git a/v1/examples/client/main.go b/v1/examples/client/main.go index 5103196..f20907d 100644 --- a/v1/examples/client/main.go +++ b/v1/examples/client/main.go @@ -43,6 +43,7 @@ func main() { dataChannel := wsClient.DataChannel() for data := range dataChannel { + // _ = data fmt.Println(string(data)) } }