diff --git a/go.mod b/go.mod index b4fdcaa..65d7e01 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,9 @@ go 1.24.5 require github.com/gorilla/websocket v1.5.3 -require git.neurocipta.com/rogerferdinan/custom-rwmutex v1.0.0 +require ( + git.neurocipta.com/rogerferdinan/custom-rwmutex v1.0.0 + github.com/google/uuid v1.6.0 +) -require git.neurocipta.com/rogerferdinan/safe-map v0.0.0-20251011004629-ab0b119a7c48 // indirect +require git.neurocipta.com/rogerferdinan/safe-map v0.0.0-20251011004629-ab0b119a7c48 diff --git a/go.sum b/go.sum index d4d557e..7595de4 100644 --- a/go.sum +++ b/go.sum @@ -2,5 +2,7 @@ git.neurocipta.com/rogerferdinan/custom-rwmutex v1.0.0 h1:KnNc40SrYsg0cksIIcQy/c git.neurocipta.com/rogerferdinan/custom-rwmutex v1.0.0/go.mod h1:9DvvHc2UZhBwEs63NgO4IhiuHnBNtTuBkTJgiMnnCss= git.neurocipta.com/rogerferdinan/safe-map v0.0.0-20251011004629-ab0b119a7c48 h1:4wXSbEuwFd2gycaaGP35bjUkKEEO6WcVfJ6cetEyT5s= git.neurocipta.com/rogerferdinan/safe-map v0.0.0-20251011004629-ab0b119a7c48/go.mod h1:QtIxG0BYCCq8a5qyklpSHA8qWUvKr+mfl42qF9QxTc0= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= diff --git a/internal/hub.go b/internal/hub.go index dddc418..0485e7d 100644 --- a/internal/hub.go +++ b/internal/hub.go @@ -1,9 +1,11 @@ package internal import ( + "fmt" "log" "time" + "github.com/google/uuid" "github.com/gorilla/websocket" ) @@ -14,6 +16,7 @@ const ( ) type Client struct { + ID string Conn *websocket.Conn Send chan []byte SubscribedPath string @@ -22,8 +25,9 @@ type Client struct { func NewClient(conn *websocket.Conn, subscribedPath string) *Client { return &Client{ + ID: uuid.NewString(), Conn: conn, - Send: make(chan []byte, 2), + Send: make(chan []byte, 1), SubscribedPath: subscribedPath, done: make(chan struct{}, 1), } @@ -41,7 +45,7 @@ func NewHub() *Hub { Broadcast: make(chan []byte, 1), Register: make(chan *Client, 1), Unregister: make(chan *Client, 1), - Clients: make(map[*Client]bool, 0), + Clients: make(map[*Client]bool), } } @@ -53,19 +57,21 @@ func (h *Hub) Run() { h.Clients[client] = true log.Println("Client registered") case c := <-h.Unregister: - if _, ok := h.Clients[c]; ok { + if v, ok := h.Clients[c]; ok { + fmt.Println(v, c) delete(h.Clients, c) close(c.Send) } log.Println("Client Unregistered") case message := <-h.Broadcast: for client := range h.Clients { - select { - case client.Send <- message: - default: - close(client.Send) - delete(h.Clients, client) - } + client.Send <- message + // select { + // case client.Send <- message: + // default: + // close(client.Send) + // delete(h.Clients, client) + // } } } } @@ -86,15 +92,18 @@ func WritePump(c *Client, h *Hub) { c.Conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { c.Conn.WriteMessage(websocket.CloseMessage, []byte{}) + fmt.Println(ok) return } if err := c.Conn.WriteMessage(websocket.TextMessage, message); err != nil { + fmt.Println(err) return } case <-pingTicker.C: c.Conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.Conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { + fmt.Println(err) return } } @@ -107,7 +116,7 @@ func ReadPump(c *Client, h *Hub) { c.Conn.Close() }() - c.Conn.SetReadLimit(512) + c.Conn.SetReadLimit(1024) c.Conn.SetReadDeadline(time.Now().Add(pongWait)) c.Conn.SetPongHandler(func(string) error { c.Conn.SetReadDeadline(time.Now().Add(pongWait)) diff --git a/v1/client/client.go b/v1/client/client.go index 1d9d40d..a6c663b 100644 --- a/v1/client/client.go +++ b/v1/client/client.go @@ -34,13 +34,14 @@ type Message struct { } type SafeWebsocketClientBuilder struct { - baseHost *string `nil_checker:"required"` - basePort *uint16 `nil_checker:"required"` - path *string - rawQuery *string - isDrop *bool - useTLS *bool - channelSize *int64 + baseHost *string `nil_checker:"required"` + basePort *uint16 `nil_checker:"required"` + path *string + rawQuery *string + isDrop *bool + useTLS *bool + channelSize *int64 + writeChannelSize *int64 } func NewSafeWebsocketClientBuilder() *SafeWebsocketClientBuilder { @@ -82,6 +83,11 @@ func (b *SafeWebsocketClientBuilder) ChannelSize(channelSize int64) *SafeWebsock return b } +func (b *SafeWebsocketClientBuilder) WriteChannelSize(writeChannelSize int64) *SafeWebsocketClientBuilder { + b.writeChannelSize = &writeChannelSize + return b +} + func (b *SafeWebsocketClientBuilder) Build(ctx context.Context) (*SafeWebsocketClient, error) { if err := internal.NilChecker(b); err != nil { return nil, err @@ -103,6 +109,11 @@ func (b *SafeWebsocketClientBuilder) Build(ctx context.Context) (*SafeWebsocketC b.channelSize = &channelSize } + if b.writeChannelSize == nil { + writeChannelSize := int64(1) + b.writeChannelSize = &writeChannelSize + } + wsClient := SafeWebsocketClient{ baseHost: *b.baseHost, basePort: *b.basePort, @@ -116,7 +127,7 @@ func (b *SafeWebsocketClientBuilder) Build(ctx context.Context) (*SafeWebsocketC reconnectCh: make(chan struct{}, 1), isConnected: false, doneMap: safemap.NewSafeMap[string, chan struct{}](), - writeChan: make(chan Message, 1), + writeChan: make(chan Message, *b.writeChannelSize), } go wsClient.reconnectHandler() @@ -193,19 +204,19 @@ func (wsClient *SafeWebsocketClient) connect() error { go wsClient.writePump() go wsClient.readPump() - conn.SetPingHandler(func(pingData string) error { - wsClient.writeChan <- Message{ - MessageType: MessageTypePong, - Data: []byte(pingData), - } + // conn.SetPingHandler(func(pingData string) error { + // wsClient.writeChan <- Message{ + // MessageType: MessageTypePong, + // Data: []byte(pingData), + // } - select { - case err := <-wsClient.pongChan: - return err - default: - } - return nil - }) + // select { + // case err := <-wsClient.pongChan: + // return err + // default: + // } + // return nil + // }) return nil } @@ -374,7 +385,6 @@ func (wsClient *SafeWebsocketClient) reconnectHandler() { }) wsClient.isConnected = false - // time.Sleep(100 * time.Millisecond) isInnerLoop := true for isInnerLoop { diff --git a/v1/examples/client/main.go b/v1/examples/client/main.go index 711b650..e469801 100644 --- a/v1/examples/client/main.go +++ b/v1/examples/client/main.go @@ -35,11 +35,11 @@ func main() { log.Fatal(err) } - go func() { - for range wsClient.ReconnectChannel() { - fmt.Println("Reconnection Success") - } - }() + // go func() { + // for range wsClient.ReconnectChannel() { + // fmt.Println("Reconnection Success") + // } + // }() dataChannel := wsClient.DataChannel()