diff --git a/go.mod b/go.mod index d4971f3..b4fdcaa 100644 --- a/go.mod +++ b/go.mod @@ -5,3 +5,5 @@ go 1.24.5 require github.com/gorilla/websocket v1.5.3 require git.neurocipta.com/rogerferdinan/custom-rwmutex v1.0.0 + +require git.neurocipta.com/rogerferdinan/safe-map v0.0.0-20251011004629-ab0b119a7c48 // indirect diff --git a/go.sum b/go.sum index 9bfc916..d4d557e 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ git.neurocipta.com/rogerferdinan/custom-rwmutex v1.0.0 h1:KnNc40SrYsg0cksIIcQy/ca6bunkGADQOs1u7O/E+iY= git.neurocipta.com/rogerferdinan/custom-rwmutex v1.0.0/go.mod h1:9DvvHc2UZhBwEs63NgO4IhiuHnBNtTuBkTJgiMnnCss= +git.neurocipta.com/rogerferdinan/safe-map v0.0.0-20251011004629-ab0b119a7c48 h1:4wXSbEuwFd2gycaaGP35bjUkKEEO6WcVfJ6cetEyT5s= +git.neurocipta.com/rogerferdinan/safe-map v0.0.0-20251011004629-ab0b119a7c48/go.mod h1:QtIxG0BYCCq8a5qyklpSHA8qWUvKr+mfl42qF9QxTc0= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= diff --git a/v1/client/client.go b/v1/client/client.go index 5dbf903..6774d48 100644 --- a/v1/client/client.go +++ b/v1/client/client.go @@ -6,10 +6,10 @@ import ( "log" "net/url" "strings" - "sync" "time" custom_rwmutex "git.neurocipta.com/rogerferdinan/custom-rwmutex" + safemap "git.neurocipta.com/rogerferdinan/safe-map" "git.neurocipta.com/rogerferdinan/safe-web-socket/internal" "github.com/gorilla/websocket" ) @@ -19,52 +19,6 @@ const ( readDeadline = 30 * time.Second ) -type SafeMap[K comparable, V any] struct { - m sync.Map -} - -func NewSafeMap[K comparable, V any]() *SafeMap[K, V] { - return &SafeMap[K, V]{ - m: sync.Map{}, - } -} - -func (sm *SafeMap[K, V]) Store(key K, value V) { - sm.m.Store(key, value) -} - -func (sm *SafeMap[K, V]) Load(key K) (value V, ok bool) { - val, loaded := sm.m.Load(key) - if !loaded { - return *new(V), false - } - return val.(V), true -} - -func (sm *SafeMap[K, V]) Delete(key K) { - sm.m.Delete(key) -} - -func (sm *SafeMap[K, V]) Range(f func(K, V) bool) { - sm.m.Range(func(key, value any) bool { - k, ok1 := key.(K) - v, ok2 := value.(V) - if !ok1 || !ok2 { - return true - } - return f(k, v) - }) -} - -func (sm *SafeMap[K, V]) Len() int { - count := 0 - sm.Range(func(_ K, _ V) bool { - count++ - return true - }) - return count -} - type MessageType uint const ( @@ -80,13 +34,12 @@ type Message struct { } type SafeWebsocketClientBuilder struct { - baseHost *string `nil_checker:"required"` - basePort *uint16 `nil_checker:"required"` - path *string - rawQuery *string - useTLS *bool - channelSize *int64 - authenticateFn func(*SafeWebsocketClient) error + baseHost *string `nil_checker:"required"` + basePort *uint16 `nil_checker:"required"` + path *string + rawQuery *string + useTLS *bool + channelSize *int64 } func NewSafeWebsocketClientBuilder() *SafeWebsocketClientBuilder { @@ -108,11 +61,6 @@ func (b *SafeWebsocketClientBuilder) UseTLS(useTLS bool) *SafeWebsocketClientBui return b } -func (b *SafeWebsocketClientBuilder) AuthenticateFn(authenticateFn func(*SafeWebsocketClient) error) *SafeWebsocketClientBuilder { - b.authenticateFn = authenticateFn - return b -} - func (b *SafeWebsocketClientBuilder) Path(path string) *SafeWebsocketClientBuilder { b.path = &path return b @@ -154,14 +102,10 @@ func (b *SafeWebsocketClientBuilder) Build(ctx context.Context) (*SafeWebsocketC ctx: ctx, reconnectCh: make(chan struct{}, 1), isConnected: false, - doneMap: NewSafeMap[string, chan struct{}](), + doneMap: safemap.NewSafeMap[string, chan struct{}](), writeChan: make(chan Message, 1), } - if b.authenticateFn != nil { - wsClient.authenticateFn = b.authenticateFn - } - go wsClient.reconnectHandler() if err := wsClient.connect(); err != nil { @@ -189,8 +133,7 @@ type SafeWebsocketClient struct { reconnectCh chan struct{} reconnectChans []chan struct{} isConnected bool - doneMap *SafeMap[string, chan struct{}] - authenticateFn func(*SafeWebsocketClient) error + doneMap *safemap.SafeMap[string, chan struct{}] writeChan chan Message pongChan chan error @@ -235,97 +178,8 @@ func (wsClient *SafeWebsocketClient) connect() error { wsClient.conn = conn wsClient.isConnected = true - if wsClient.authenticateFn != nil { - if err := wsClient.authenticateFn(wsClient); err != nil { - return err - } - } - - go func() { - ctx, cancel := context.WithCancel(context.Background()) - wsClient.mu.WriteHandler(func() error { - wsClient.cancelFuncs = append(wsClient.cancelFuncs, cancel) - return nil - }) - - var c *websocket.Conn - wsClient.mu.ReadHandler(func() error { - c = conn - return nil - }) - - for { - select { - case <-ctx.Done(): - log.Println("Writer stopped due to client shutdown") - return - case data := <-wsClient.writeChan: - if c == nil { - wsClient.triggerReconnect() - return - } - - if err := c.WriteMessage(int(data.MessageType), data.Data); err != nil { - log.Printf("error on write message: %v\n", err) - wsClient.triggerReconnect() - if data.MessageType == MessageTypePong { - wsClient.pongChan <- err - } - return - } - } - } - }() - - go func() { - ctx, cancel := context.WithCancel(context.Background()) - wsClient.mu.WriteHandler(func() error { - wsClient.cancelFuncs = append(wsClient.cancelFuncs, cancel) - return nil - }) - - var c *websocket.Conn - - wsClient.mu.ReadHandler(func() error { - c = conn - return nil - }) - if c == nil { - wsClient.triggerReconnect() - return - } - for { - select { - case <-ctx.Done(): - log.Println("Reader stopped due to client shutdown") - return - default: - if err := c.SetReadDeadline(time.Now().Add(readDeadline)); err != nil { - log.Printf("error on read deadline: %v\n", err) - return - } - - messageType, data, err := c.ReadMessage() - if err != nil { - log.Printf("error on read message: %v\n", err) - wsClient.triggerReconnect() - return - } - - if messageType != websocket.TextMessage { - continue - } - - select { - case wsClient.dataChannel <- data: - case <-ctx.Done(): - return - default: - log.Println("Data channel full, dropping message") - } - } - } - }() + go wsClient.writePump() + go wsClient.readPump() conn.SetPingHandler(func(pingData string) error { wsClient.writeChan <- Message{ @@ -344,6 +198,94 @@ func (wsClient *SafeWebsocketClient) connect() error { return nil } +func (wsClient *SafeWebsocketClient) writePump() { + ctx, cancel := context.WithCancel(context.Background()) + wsClient.mu.WriteHandler(func() error { + wsClient.cancelFuncs = append(wsClient.cancelFuncs, cancel) + return nil + }) + + var c *websocket.Conn + wsClient.mu.ReadHandler(func() error { + c = wsClient.conn + return nil + }) + + for { + select { + case <-ctx.Done(): + log.Println("Writer stopped due to client shutdown") + return + case data := <-wsClient.writeChan: + if c == nil { + wsClient.triggerReconnect() + return + } + + if err := c.WriteMessage(int(data.MessageType), data.Data); err != nil { + log.Printf("error on write message: %v\n", err) + c.Close() + wsClient.triggerReconnect() + if data.MessageType == MessageTypePong { + wsClient.pongChan <- err + } + return + } + } + } +} + +func (wsClient *SafeWebsocketClient) readPump() { + ctx, cancel := context.WithCancel(context.Background()) + wsClient.mu.WriteHandler(func() error { + wsClient.cancelFuncs = append(wsClient.cancelFuncs, cancel) + return nil + }) + + var c *websocket.Conn + + wsClient.mu.ReadHandler(func() error { + c = wsClient.conn + return nil + }) + if c == nil { + wsClient.triggerReconnect() + return + } + + for { + select { + case <-ctx.Done(): + log.Println("Reader stopped due to client shutdown") + return + default: + if err := c.SetReadDeadline(time.Now().Add(readDeadline)); err != nil { + log.Printf("error on read deadline: %v\n", err) + return + } + + messageType, data, err := c.ReadMessage() + if err != nil { + log.Printf("error on read message: %v\n", err) + wsClient.triggerReconnect() + return + } + + if messageType != websocket.TextMessage { + continue + } + + select { + case wsClient.dataChannel <- data: + case <-ctx.Done(): + return + default: + log.Println("Data channel full, dropping message") + } + } + } +} + func (wsClient *SafeWebsocketClient) startPingTicker(ctx context.Context) { ticker := time.NewTicker(pingPeriod) defer ticker.Stop() @@ -387,7 +329,7 @@ func (wsClient *SafeWebsocketClient) reconnectHandler() { }) wsClient.isConnected = false - time.Sleep(100 * time.Millisecond) + // time.Sleep(100 * time.Millisecond) isInnerLoop := true for isInnerLoop {