package client import ( "context" "fmt" "log" "net/url" "time" "git.neurocipta.com/rogerferdinan/safe-web-socket/internal" "github.com/gorilla/websocket" ) const ( pingPeriod = 10 * time.Second ) type SafeWebsocketClientBuilder struct { baseHost *string `nil_checker:"required"` basePort *uint16 `nil_checker:"required"` path *string `nil_checkeer:"required"` useTLS *bool } func NewSafeWebsocketClientBuilder() *SafeWebsocketClientBuilder { return &SafeWebsocketClientBuilder{} } func (b *SafeWebsocketClientBuilder) BaseHost(host string) *SafeWebsocketClientBuilder { b.baseHost = &host return b } func (b *SafeWebsocketClientBuilder) BasePort(port uint16) *SafeWebsocketClientBuilder { b.basePort = &port return b } func (b *SafeWebsocketClientBuilder) UseTLS(useTLS bool) *SafeWebsocketClientBuilder { b.useTLS = &useTLS return b } func (b *SafeWebsocketClientBuilder) Path(path string) *SafeWebsocketClientBuilder { b.path = &path return b } func (b *SafeWebsocketClientBuilder) Build() (*SafeWebsocketClient, error) { if err := internal.NilChecker(b); err != nil { return nil, err } ctx, cancel := context.WithCancel(context.Background()) var useTLS bool if b.useTLS != nil { useTLS = *b.useTLS } wsClient := SafeWebsocketClient{ baseHost: *b.baseHost, basePort: *b.basePort, useTLS: useTLS, path: *b.path, ctx: ctx, cancel: cancel, dataChannel: make(chan []byte, 1), mu: internal.NewCustomRwMutex(), reconnectCh: make(chan struct{}, 1), isConnected: false, } if err := wsClient.connect(); err != nil { cancel() return nil, fmt.Errorf("failed to establish initial connection: %v", err) } return &wsClient, nil } type SafeWebsocketClient struct { baseHost string basePort uint16 useTLS bool path string dataChannel chan []byte mu *internal.CustomRwMutex conn *websocket.Conn ctx context.Context cancel context.CancelFunc reconnectCh chan struct{} isConnected bool } func (wsClient *SafeWebsocketClient) connect() error { var scheme string if wsClient.useTLS { scheme = "wss" } else { scheme = "ws" } newURL := url.URL{ Scheme: scheme, Host: fmt.Sprintf("%s:%d", wsClient.baseHost, wsClient.basePort), Path: wsClient.path, } conn, _, err := websocket.DefaultDialer.Dial(newURL.String(), nil) if err != nil { return fmt.Errorf("failed to connect to %s: %w", wsClient.baseHost, err) } conn.SetPingHandler(func(pingData string) error { if err := conn.WriteMessage(websocket.PongMessage, []byte(pingData)); err != nil { if err == websocket.ErrCloseSent { return nil } if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() { return nil } return err } return nil }) wsClient.mu.WriteHandler(func() error { wsClient.conn = conn wsClient.isConnected = true return nil }) go wsClient.startPingTicker() go wsClient.startReceiveHandler() go wsClient.reconnectHandler() return nil } func (wsClient *SafeWebsocketClient) startPingTicker() { ticker := time.NewTicker(pingPeriod) defer ticker.Stop() for { select { case <-ticker.C: wsClient.mu.WriteHandler(func() error { if err := wsClient.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { log.Printf("Ping failed: %v. Will attempt reconnect.", err) wsClient.triggerReconnect() } return nil }) case <-wsClient.ctx.Done(): log.Println("Ping ticker stopped") return } } } func (wsClient *SafeWebsocketClient) startReceiveHandler() { for { select { case <-wsClient.reconnectCh: case <-wsClient.ctx.Done(): log.Println("Reconnect handler stopped") return default: if err := wsClient.mu.WriteHandler(func() error { conn := wsClient.conn if conn == nil { return fmt.Errorf("no active connection, waiting for reconnect") } _, message, err := conn.ReadMessage() if err != nil { return err } fmt.Println(string(message)) wsClient.dataChannel <- message return nil }); err != nil { wsClient.triggerReconnect() return } } } } func (wsClient *SafeWebsocketClient) triggerReconnect() { select { case wsClient.reconnectCh <- struct{}{}: default: } } func (wsClient *SafeWebsocketClient) reconnectHandler() { for range wsClient.reconnectCh { wsClient.connect() } } func (wsClient *SafeWebsocketClient) DataChannel() <-chan []byte { return wsClient.dataChannel }