package client import ( "context" "fmt" "log" "net/url" "strings" "sync" "time" custom_rwmutex "git.neurocipta.com/rogerferdinan/custom-rwmutex" "git.neurocipta.com/rogerferdinan/safe-web-socket/internal" "github.com/gorilla/websocket" ) const ( pingPeriod = 10 * time.Second readDeadline = 120 * time.Second ) type SafeMap[K comparable, V any] struct { m sync.Map } func NewSafeMap[K comparable, V any]() *SafeMap[K, V] { return &SafeMap[K, V]{ m: sync.Map{}, } } func (sm *SafeMap[K, V]) Store(key K, value V) { sm.m.Store(key, value) } func (sm *SafeMap[K, V]) Load(key K) (value V, ok bool) { val, loaded := sm.m.Load(key) if !loaded { return *new(V), false } return val.(V), true } func (sm *SafeMap[K, V]) Delete(key K) { sm.m.Delete(key) } func (sm *SafeMap[K, V]) Range(f func(K, V) bool) { sm.m.Range(func(key, value any) bool { k, ok1 := key.(K) v, ok2 := value.(V) if !ok1 || !ok2 { return true } return f(k, v) }) } func (sm *SafeMap[K, V]) Len() int { count := 0 sm.Range(func(_ K, _ V) bool { count++ return true }) return count } type MessageType uint const ( MessageTypeText MessageType = iota + 1 MessageTypePing MessageType = iota MessageTypePong MessageType = iota MessageTypeClose MessageType = iota ) type Message struct { MessageType MessageType Data []byte } type SafeWebsocketClientBuilder struct { baseHost *string `nil_checker:"required"` basePort *uint16 `nil_checker:"required"` path *string rawQuery *string useTLS *bool channelSize *int64 authenticateFn func(*SafeWebsocketClient) error } func NewSafeWebsocketClientBuilder() *SafeWebsocketClientBuilder { return &SafeWebsocketClientBuilder{} } func (b *SafeWebsocketClientBuilder) BaseHost(host string) *SafeWebsocketClientBuilder { b.baseHost = &host return b } func (b *SafeWebsocketClientBuilder) BasePort(port uint16) *SafeWebsocketClientBuilder { b.basePort = &port return b } func (b *SafeWebsocketClientBuilder) UseTLS(useTLS bool) *SafeWebsocketClientBuilder { b.useTLS = &useTLS return b } func (b *SafeWebsocketClientBuilder) AuthenticateFn(authenticateFn func(*SafeWebsocketClient) error) *SafeWebsocketClientBuilder { b.authenticateFn = authenticateFn return b } func (b *SafeWebsocketClientBuilder) Path(path string) *SafeWebsocketClientBuilder { b.path = &path return b } func (b *SafeWebsocketClientBuilder) RawQuery(rawQuery string) *SafeWebsocketClientBuilder { b.rawQuery = &rawQuery return b } func (b *SafeWebsocketClientBuilder) ChannelSize(channelSize int64) *SafeWebsocketClientBuilder { b.channelSize = &channelSize return b } func (b *SafeWebsocketClientBuilder) Build(ctx context.Context) (*SafeWebsocketClient, error) { if err := internal.NilChecker(b); err != nil { return nil, err } var useTLS bool if b.useTLS != nil { useTLS = *b.useTLS } if b.channelSize == nil { channelSize := int64(1) b.channelSize = &channelSize } wsClient := SafeWebsocketClient{ baseHost: *b.baseHost, basePort: *b.basePort, useTLS: useTLS, path: b.path, rawQuery: b.rawQuery, dataChannel: make(chan []byte, *b.channelSize), mu: custom_rwmutex.NewCustomRwMutex(), ctx: ctx, reconnectCh: make(chan struct{}, 1), isConnected: false, doneMap: NewSafeMap[string, chan struct{}](), writeChan: make(chan Message), } if b.authenticateFn != nil { wsClient.authenticateFn = b.authenticateFn } go wsClient.reconnectHandler() if err := wsClient.connect(); err != nil { return nil, fmt.Errorf("failed to establish initial connection: %v", err) } return &wsClient, nil } type SafeWebsocketClient struct { baseHost string basePort uint16 useTLS bool path *string rawQuery *string dataChannel chan []byte mu *custom_rwmutex.CustomRwMutex conn *websocket.Conn cancelFuncs []context.CancelFunc ctx context.Context Cancel context.CancelFunc reconnectCh chan struct{} reconnectChans []chan struct{} isConnected bool doneMap *SafeMap[string, chan struct{}] authenticateFn func(*SafeWebsocketClient) error writeChan chan Message } func (wsClient *SafeWebsocketClient) connect() error { var scheme string if wsClient.useTLS { scheme = "wss" } else { scheme = "ws" } newURL := url.URL{ Scheme: scheme, Host: fmt.Sprintf("%s:%d", wsClient.baseHost, wsClient.basePort), } if wsClient.path != nil && strings.TrimSpace(*wsClient.path) != "" { newURL.Path = *wsClient.path } if wsClient.rawQuery != nil && strings.TrimSpace(*wsClient.rawQuery) != "" { newURL.RawQuery = *wsClient.rawQuery } conn, _, err := websocket.DefaultDialer.Dial(newURL.String(), nil) if err != nil { return fmt.Errorf("failed to connect to %s: %w", wsClient.baseHost, err) } pingCtx, pingCancel := context.WithCancel(context.Background()) wsClient.mu.WriteHandler(func() error { wsClient.cancelFuncs = append(wsClient.cancelFuncs, pingCancel) return nil }) go wsClient.startPingTicker(pingCtx) if wsClient.conn != nil { wsClient.conn.Close() } wsClient.conn = conn wsClient.isConnected = true if wsClient.authenticateFn != nil { if err := wsClient.authenticateFn(wsClient); err != nil { return err } } go func() { ctx, cancel := context.WithCancel(context.Background()) wsClient.mu.WriteHandler(func() error { wsClient.cancelFuncs = append(wsClient.cancelFuncs, cancel) return nil }) var c *websocket.Conn wsClient.mu.ReadHandler(func() error { c = conn return nil }) for { select { case <-ctx.Done(): log.Println("Writer stopped due to client shutdown") return case data := <-wsClient.writeChan: if c == nil { wsClient.triggerReconnect() return } if err := c.WriteMessage(int(data.MessageType), data.Data); err != nil { log.Printf("error on write message: %v\n", err) wsClient.triggerReconnect() return } } } }() go func() { ctx, cancel := context.WithCancel(context.Background()) wsClient.mu.WriteHandler(func() error { wsClient.cancelFuncs = append(wsClient.cancelFuncs, cancel) return nil }) var c *websocket.Conn wsClient.mu.ReadHandler(func() error { c = conn return nil }) if c == nil { wsClient.triggerReconnect() return } for { select { case <-ctx.Done(): log.Println("Reader stopped due to client shutdown") return default: if err := c.SetReadDeadline(time.Now().Add(readDeadline)); err != nil { log.Printf("error on read deadline: %v\n", err) return } messageType, data, err := c.ReadMessage() if err != nil { log.Printf("error on read message: %v\n", err) wsClient.triggerReconnect() return } if messageType != websocket.TextMessage { continue } select { case wsClient.dataChannel <- data: case <-ctx.Done(): return default: log.Println("Data channel full, dropping message") } } } }() conn.SetPingHandler(func(pingData string) error { wsClient.writeChan <- Message{ MessageType: MessageTypePong, Data: []byte(pingData), } // if err := conn.WriteMessage(websocket.PongMessage, []byte(pingData)); err != nil { // if err == websocket.ErrCloseSent { // return nil // } // if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() { // return nil // } // return err // } return nil }) return nil } func (wsClient *SafeWebsocketClient) startPingTicker(ctx context.Context) { ticker := time.NewTicker(pingPeriod) defer ticker.Stop() for { select { case <-ctx.Done(): log.Println("ping ticker stopped") return case <-ticker.C: wsClient.writeChan <- Message{ MessageType: websocket.PingMessage, Data: nil, } } } } func (wsClient *SafeWebsocketClient) triggerReconnect() { select { case wsClient.reconnectCh <- struct{}{}: default: } } func (wsClient *SafeWebsocketClient) reconnectHandler() { backoff := 1 * time.Second maxBackoff := 30 * time.Second for { select { case <-wsClient.reconnectCh: log.Println("Reconnect triggered") wsClient.mu.ReadHandler(func() error { if wsClient.cancelFuncs != nil { for _, cancel := range wsClient.cancelFuncs { cancel() } } return nil }) wsClient.isConnected = false time.Sleep(100 * time.Millisecond) isInnerLoop := true for isInnerLoop { log.Printf("Attempting reconnect in %v...", backoff) select { case <-time.After(backoff): if err := wsClient.connect(); err != nil { log.Printf("Reconnect failed: %v", err) if backoff < maxBackoff { backoff *= 2 } continue } log.Println("Reconnected successfully") backoff = 1 * time.Second isInnerLoop = false continue case <-wsClient.ctx.Done(): log.Println("reconnect handler stopped due to client shutdown") wsClient.Close() return } } if wsClient.reconnectChans != nil { for _, reconnectCh := range wsClient.reconnectChans { reconnectCh <- struct{}{} } } case <-wsClient.ctx.Done(): log.Println("reconnect handler stopped due to client shutdown") wsClient.Close() return } } } func (wsClient *SafeWebsocketClient) ReconnectChannel() <-chan struct{} { reconnectCh := make(chan struct{}) wsClient.mu.WriteHandler(func() error { wsClient.reconnectChans = append(wsClient.reconnectChans, reconnectCh) return nil }) return reconnectCh } func (wsClient *SafeWebsocketClient) DataChannel() <-chan []byte { return wsClient.dataChannel } func (wsClient *SafeWebsocketClient) Write(data []byte) error { wsClient.writeChan <- Message{ MessageType: MessageTypeText, Data: data, } return nil } func (wsClient *SafeWebsocketClient) Close() error { wsClient.mu.ReadHandler(func() error { if wsClient.cancelFuncs != nil { for _, cancel := range wsClient.cancelFuncs { cancel() } } return nil }) if wsClient.reconnectChans != nil { for _, reconnectChan := range wsClient.reconnectChans { close(reconnectChan) } } if wsClient.conn != nil { wsClient.conn.Close() } wsClient.isConnected = false // close(wsClient.dataChannel) return nil }