package client import ( "context" "fmt" "log" "net/url" "strings" "sync" "time" custom_rwmutex "git.neurocipta.com/rogerferdinan/custom-rwmutex" "git.neurocipta.com/rogerferdinan/safe-web-socket/internal" "github.com/gorilla/websocket" ) const ( pingPeriod = 10 * time.Second ) type SafeMap[K comparable, V any] struct { m sync.Map } func NewSafeMap[K comparable, V any]() *SafeMap[K, V] { return &SafeMap[K, V]{ m: sync.Map{}, } } func (sm *SafeMap[K, V]) Store(key K, value V) { sm.m.Store(key, value) } func (sm *SafeMap[K, V]) Load(key K) (value V, ok bool) { val, loaded := sm.m.Load(key) if !loaded { return *new(V), false } return val.(V), true } func (sm *SafeMap[K, V]) Delete(key K) { sm.m.Delete(key) } func (sm *SafeMap[K, V]) Range(f func(K, V) bool) { sm.m.Range(func(key, value any) bool { k, ok1 := key.(K) v, ok2 := value.(V) if !ok1 || !ok2 { return true } return f(k, v) }) } func (sm *SafeMap[K, V]) Len() int { count := 0 sm.Range(func(_ K, _ V) bool { count++ return true }) return count } type SafeWebsocketClientBuilder struct { baseHost *string `nil_checker:"required"` basePort *uint16 `nil_checker:"required"` path *string rawQuery *string useTLS *bool } func NewSafeWebsocketClientBuilder() *SafeWebsocketClientBuilder { return &SafeWebsocketClientBuilder{} } func (b *SafeWebsocketClientBuilder) BaseHost(host string) *SafeWebsocketClientBuilder { b.baseHost = &host return b } func (b *SafeWebsocketClientBuilder) BasePort(port uint16) *SafeWebsocketClientBuilder { b.basePort = &port return b } func (b *SafeWebsocketClientBuilder) UseTLS(useTLS bool) *SafeWebsocketClientBuilder { b.useTLS = &useTLS return b } func (b *SafeWebsocketClientBuilder) Path(path string) *SafeWebsocketClientBuilder { b.path = &path return b } func (b *SafeWebsocketClientBuilder) RawQuery(rawQuery string) *SafeWebsocketClientBuilder { b.rawQuery = &rawQuery return b } func (b *SafeWebsocketClientBuilder) Build() (*SafeWebsocketClient, error) { if err := internal.NilChecker(b); err != nil { return nil, err } var useTLS bool if b.useTLS != nil { useTLS = *b.useTLS } wsClient := SafeWebsocketClient{ baseHost: *b.baseHost, basePort: *b.basePort, useTLS: useTLS, path: b.path, rawQuery: b.rawQuery, dataChannel: make(chan []byte, 1), mu: custom_rwmutex.NewCustomRwMutex(), reconnectCh: make(chan struct{}, 1), isConnected: false, doneMap: NewSafeMap[string, chan struct{}](), } go wsClient.reconnectHandler() if err := wsClient.connect(); err != nil { return nil, fmt.Errorf("failed to establish initial connection: %v", err) } return &wsClient, nil } type SafeWebsocketClient struct { baseHost string basePort uint16 useTLS bool path *string rawQuery *string dataChannel chan []byte mu *custom_rwmutex.CustomRwMutex conn *websocket.Conn ctx context.Context cancel context.CancelFunc reconnectCh chan struct{} isConnected bool doneMap *SafeMap[string, chan struct{}] } func (wsClient *SafeWebsocketClient) connect() error { var scheme string if wsClient.useTLS { scheme = "wss" } else { scheme = "ws" } newURL := url.URL{ Scheme: scheme, Host: fmt.Sprintf("%s:%d", wsClient.baseHost, wsClient.basePort), } if wsClient.path != nil && strings.TrimSpace(*wsClient.path) != "" { newURL.Path = *wsClient.path } if wsClient.rawQuery != nil && strings.TrimSpace(*wsClient.rawQuery) != "" { newURL.RawQuery = *wsClient.rawQuery } conn, _, err := websocket.DefaultDialer.Dial(newURL.String(), nil) if err != nil { return fmt.Errorf("failed to connect to %s: %w", wsClient.baseHost, err) } conn.SetPingHandler(func(pingData string) error { return wsClient.mu.WriteHandler(func() error { if err := conn.WriteMessage(websocket.PongMessage, []byte(pingData)); err != nil { if err == websocket.ErrCloseSent { return nil } if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() { return nil } return err } return nil }) }) wsClient.mu.WriteHandler(func() error { if wsClient.conn != nil { wsClient.conn.Close() } ctx, cancel := context.WithCancel(context.Background()) wsClient.ctx = ctx wsClient.cancel = cancel wsClient.conn = conn wsClient.isConnected = true go wsClient.startPingTicker(ctx) go wsClient.startReceiveHandler(ctx) return nil }) return nil } func (wsClient *SafeWebsocketClient) startPingTicker(ctx context.Context) { ticker := time.NewTicker(pingPeriod) defer ticker.Stop() for { select { case <-ticker.C: wsClient.mu.WriteHandler(func() error { if wsClient.conn == nil { return fmt.Errorf("connecrtion closed") } if err := wsClient.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { log.Printf("Ping failed: %v. Will attempt reconnect.", err) wsClient.triggerReconnect() } return nil }) case <-ctx.Done(): log.Println("Ping ticker stopped due to context cancellation") return } } } func (wsClient *SafeWebsocketClient) startReceiveHandler(ctx context.Context) { for { select { case <-ctx.Done(): log.Println("Reconnect handler stopped") return default: wsClient.mu.ReadHandler(func() error { conn := wsClient.conn if conn == nil { wsClient.triggerReconnect() return fmt.Errorf("connection closed") } _, message, err := conn.ReadMessage() if err != nil { wsClient.triggerReconnect() return fmt.Errorf("failed to read message: %v", err) } select { case wsClient.dataChannel <- message: default: log.Println("Data channel full, dropping message") } return nil }) } } } func (wsClient *SafeWebsocketClient) triggerReconnect() { select { case wsClient.reconnectCh <- struct{}{}: default: } } func (wsClient *SafeWebsocketClient) reconnectHandler() { backoff := 1 * time.Second maxBackoff := 30 * time.Second for { select { case <-wsClient.reconnectCh: log.Println("Reconnect triggered") wsClient.mu.WriteHandler(func() error { if wsClient.cancel != nil { wsClient.cancel() } wsClient.isConnected = false return nil }) time.Sleep(100 * time.Millisecond) for { log.Println("Attempting reconnect in %v...", backoff) select { case <-time.After(backoff): if err := wsClient.connect(); err != nil { log.Println("Reconnect failed: %v", err) if backoff < maxBackoff { backoff *= 2 } continue } log.Println("Reconnected successfully") backoff = 1 * time.Second break } } case <-wsClient.ctx.Done(): log.Println("Reconnect handler stopped due to client shutdown") return } } } func (wsClient *SafeWebsocketClient) DataChannel() <-chan []byte { return wsClient.dataChannel }