package client import ( "context" "fmt" "log" "net/http" "net/url" "strings" "time" custom_rwmutex "git.neurocipta.com/rogerferdinan/custom-rwmutex" safemap "git.neurocipta.com/rogerferdinan/safe-map" "git.neurocipta.com/rogerferdinan/safe-web-socket/internal" "github.com/gorilla/websocket" ) const ( pingPeriod = 10 * time.Second readDeadline = 30 * time.Second ) type MessageType uint const ( MessageTypeText MessageType = websocket.TextMessage MessageTypePing MessageType = websocket.PingMessage MessageTypePong MessageType = websocket.PongMessage MessageTypeClose MessageType = websocket.CloseMessage ) type Message struct { MessageType MessageType Data []byte } type SafeWebsocketClientBuilder struct { baseHost *string `nil_checker:"required"` basePort *uint16 `nil_checker:"required"` headers *map[string]string path *string rawQuery *string isDrop *bool useTLS *bool channelSize *int64 writeChannelSize *int64 } func NewSafeWebsocketClientBuilder() *SafeWebsocketClientBuilder { return &SafeWebsocketClientBuilder{} } func (b *SafeWebsocketClientBuilder) BaseHost(host string) *SafeWebsocketClientBuilder { b.baseHost = &host return b } func (b *SafeWebsocketClientBuilder) BasePort(port uint16) *SafeWebsocketClientBuilder { b.basePort = &port return b } func (b *SafeWebsocketClientBuilder) Headers(headers map[string]string) *SafeWebsocketClientBuilder { b.headers = &headers return b } func (b *SafeWebsocketClientBuilder) UseTLS(useTLS bool) *SafeWebsocketClientBuilder { b.useTLS = &useTLS return b } func (b *SafeWebsocketClientBuilder) Path(path string) *SafeWebsocketClientBuilder { b.path = &path return b } func (b *SafeWebsocketClientBuilder) RawQuery(rawQuery string) *SafeWebsocketClientBuilder { b.rawQuery = &rawQuery return b } func (b *SafeWebsocketClientBuilder) IsDrop(isDrop bool) *SafeWebsocketClientBuilder { b.isDrop = &isDrop return b } func (b *SafeWebsocketClientBuilder) ChannelSize(channelSize int64) *SafeWebsocketClientBuilder { b.channelSize = &channelSize return b } func (b *SafeWebsocketClientBuilder) WriteChannelSize(writeChannelSize int64) *SafeWebsocketClientBuilder { b.writeChannelSize = &writeChannelSize return b } func (b *SafeWebsocketClientBuilder) Build(ctx context.Context) (*SafeWebsocketClient, error) { if err := internal.NilChecker(b); err != nil { return nil, err } // var useTLS bool if b.useTLS == nil { useTLS := true b.useTLS = &useTLS } if b.isDrop == nil { isDrop := true b.isDrop = &isDrop } if b.channelSize == nil { channelSize := int64(1) b.channelSize = &channelSize } if b.writeChannelSize == nil { writeChannelSize := int64(1) b.writeChannelSize = &writeChannelSize } wsClient := SafeWebsocketClient{ baseHost: *b.baseHost, basePort: *b.basePort, headers: b.headers, useTLS: *b.useTLS, isDrop: *b.isDrop, path: b.path, rawQuery: b.rawQuery, dataChannel: make(chan []byte, *b.channelSize), mu: custom_rwmutex.NewCustomRwMutex(), ctx: ctx, reconnectCh: make(chan struct{}, 1), isConnected: false, doneMap: safemap.NewSafeMap[string, chan struct{}](), writeChan: make(chan Message, *b.writeChannelSize), } go wsClient.reconnectHandler() if err := wsClient.connect(); err != nil { return nil, fmt.Errorf("failed to establish initial connection: %v", err) } return &wsClient, nil } type SafeWebsocketClient struct { baseHost string basePort uint16 headers *map[string]string isDrop bool useTLS bool path *string rawQuery *string mu *custom_rwmutex.CustomRwMutex conn *websocket.Conn ctx context.Context cancelFuncs []context.CancelFunc dataChannel chan []byte reconnectCh chan struct{} reconnectChans []chan struct{} isConnected bool doneMap *safemap.SafeMap[string, chan struct{}] writeChan chan Message pongChan chan error } func (wsClient *SafeWebsocketClient) connect() error { var scheme string if wsClient.useTLS { scheme = "wss" } else { scheme = "ws" } newURL := url.URL{ Scheme: scheme, Host: fmt.Sprintf("%s:%d", wsClient.baseHost, wsClient.basePort), } if wsClient.path != nil && strings.TrimSpace(*wsClient.path) != "" { newURL.Path = *wsClient.path } if wsClient.rawQuery != nil && strings.TrimSpace(*wsClient.rawQuery) != "" { newURL.RawQuery = *wsClient.rawQuery } header := make(http.Header) if wsClient.headers != nil { for k, v := range *wsClient.headers { fmt.Println(k, v) header.Set(k, v) } } conn, _, err := websocket.DefaultDialer.Dial(newURL.String(), header) if err != nil { return fmt.Errorf("failed to connect to %s: %w", wsClient.baseHost, err) } pingCtx, pingCancel := context.WithCancel(context.Background()) wsClient.mu.WriteHandler(func() error { wsClient.cancelFuncs = append(wsClient.cancelFuncs, pingCancel) return nil }) go wsClient.startPingTicker(pingCtx) if wsClient.conn != nil { wsClient.conn.Close() } wsClient.conn = conn wsClient.isConnected = true go wsClient.writePump() go wsClient.readPump() // conn.SetPingHandler(func(pingData string) error { // wsClient.writeChan <- Message{ // MessageType: MessageTypePong, // Data: []byte(pingData), // } // select { // case err := <-wsClient.pongChan: // return err // default: // } // return nil // }) return nil } func (wsClient *SafeWebsocketClient) writePump() { ctx, cancel := context.WithCancel(context.Background()) wsClient.mu.WriteHandler(func() error { wsClient.cancelFuncs = append(wsClient.cancelFuncs, cancel) return nil }) var c *websocket.Conn wsClient.mu.ReadHandler(func() error { c = wsClient.conn return nil }) for { select { case <-ctx.Done(): log.Println("Writer canceled by context") return case data := <-wsClient.writeChan: if c == nil { return } if err := c.WriteMessage(int(data.MessageType), data.Data); err != nil { log.Printf("error on write message: %v\n", err) if data.MessageType == MessageTypePong { wsClient.pongChan <- err } return } } } } func (wsClient *SafeWebsocketClient) readPump() { ctx, cancel := context.WithCancel(context.Background()) wsClient.mu.WriteHandler(func() error { wsClient.cancelFuncs = append(wsClient.cancelFuncs, cancel) return nil }) var c *websocket.Conn wsClient.mu.ReadHandler(func() error { c = wsClient.conn return nil }) for { select { case <-ctx.Done(): log.Println("Reader canceled by context") return default: if c == nil { return } if err := c.SetReadDeadline(time.Now().Add(readDeadline)); err != nil { log.Printf("error on read deadline: %v\n", err) return } messageType, data, err := c.ReadMessage() if err != nil { log.Printf("error on read message: %v\n", err) wsClient.triggerReconnect() return } if messageType != websocket.TextMessage { continue } select { case wsClient.dataChannel <- data: case <-ctx.Done(): return default: if wsClient.isDrop { log.Println("Data channel full, dropping message") } } } } } func (wsClient *SafeWebsocketClient) startPingTicker(ctx context.Context) { ticker := time.NewTicker(pingPeriod) defer ticker.Stop() for { select { case <-ctx.Done(): log.Println("ping ticker canceled by context") return case <-ticker.C: wsClient.writeChan <- Message{ MessageType: websocket.PingMessage, Data: []byte{}, } } } } func (wsClient *SafeWebsocketClient) triggerReconnect() { select { case wsClient.reconnectCh <- struct{}{}: default: } } func (wsClient *SafeWebsocketClient) reconnectHandler() { backoff := 1 * time.Second maxBackoff := 30 * time.Second for { select { case <-wsClient.reconnectCh: log.Println("Reconnect triggered") wsClient.mu.ReadHandler(func() error { if wsClient.cancelFuncs != nil { for _, cancel := range wsClient.cancelFuncs { cancel() } } return nil }) wsClient.isConnected = false isInnerLoop := true for isInnerLoop { log.Printf("Attempting reconnect in %v...", backoff) select { case <-time.After(backoff): if err := wsClient.connect(); err != nil { log.Printf("Reconnect failed: %v", err) if backoff < maxBackoff { backoff *= 2 } continue } log.Println("Reconnected successfully") backoff = 1 * time.Second isInnerLoop = false continue case <-wsClient.ctx.Done(): log.Println("reconnect handler stopped due to client shutdown") wsClient.Close() return } } if wsClient.reconnectChans != nil { for _, reconnectCh := range wsClient.reconnectChans { reconnectCh <- struct{}{} } } case <-wsClient.ctx.Done(): log.Println("reconnect handler stopped due to client shutdown") wsClient.Close() return } } } func (wsClient *SafeWebsocketClient) ReconnectChannel() <-chan struct{} { reconnectCh := make(chan struct{}, 1) wsClient.mu.WriteHandler(func() error { wsClient.reconnectChans = append(wsClient.reconnectChans, reconnectCh) return nil }) return reconnectCh } func (wsClient *SafeWebsocketClient) DataChannel() <-chan []byte { return wsClient.dataChannel } func (wsClient *SafeWebsocketClient) CloseDataChannel() { close(wsClient.dataChannel) } func (wsClient *SafeWebsocketClient) Write(data []byte) error { wsClient.writeChan <- Message{ MessageType: MessageTypeText, Data: data, } return nil } func (wsClient *SafeWebsocketClient) Close() error { wsClient.mu.ReadHandler(func() error { if wsClient.cancelFuncs != nil { for _, cancel := range wsClient.cancelFuncs { cancel() } } return nil }) if wsClient.reconnectChans != nil { for _, reconnectChan := range wsClient.reconnectChans { close(reconnectChan) } } if wsClient.conn != nil { wsClient.conn.Close() } wsClient.isConnected = false // close(wsClient.dataChannel) return nil }