From eac2ed2bf1826c32dae1c6b45895d07b3bd7f2b9 Mon Sep 17 00:00:00 2001 From: Roger Ferdinan Date: Thu, 25 Sep 2025 10:42:40 +0700 Subject: [PATCH] fix: readPump and writePump hotfix --- internal/hub.go | 191 +++++++++++++------------------------- v1/client/client.go | 82 ++++++++++++++++ v1/example/server/main.go | 4 +- v1/server/server.go | 8 +- 4 files changed, 152 insertions(+), 133 deletions(-) create mode 100644 v1/client/client.go diff --git a/internal/hub.go b/internal/hub.go index 597a30b..0e0c934 100644 --- a/internal/hub.go +++ b/internal/hub.go @@ -8,6 +8,12 @@ import ( "github.com/gorilla/websocket" ) +const ( + writeWait = 10 * time.Second + pongWait = 60 * time.Second + pingPeriod = 55 * time.Second +) + type Client struct { Conn *websocket.Conn Send chan []byte @@ -19,7 +25,7 @@ type Client struct { func NewClient(conn *websocket.Conn, subscribedPath string) *Client { return &Client{ Conn: conn, - Send: make(chan []byte, 1024), + Send: make(chan []byte, 64), SubscribedPath: subscribedPath, done: make(chan struct{}), mu: NewCustomRwMutex(), @@ -31,7 +37,6 @@ type Hub struct { Broadcast chan []byte Register chan *Client Unregister chan *Client - ClientData map[string]chan []byte writeMu *CustomRwMutex readMu *CustomRwMutex } @@ -42,72 +47,29 @@ func NewHub() *Hub { Register: make(chan *Client), Unregister: make(chan *Client), Clients: make(map[*Client]bool), - ClientData: make(map[string]chan []byte), writeMu: NewCustomRwMutex(), } } -func (h *Hub) AddDataChannel(dataID string) chan []byte { - ch := make(chan []byte, 256) - h.writeMu.WriteHandler(func() error { - if innerCh, ok := h.ClientData[dataID]; ok { - ch = innerCh - return nil - } - - h.ClientData[dataID] = ch - log.Printf("Created data channel for: %s\n", dataID) - return nil - }) - - return ch -} - -func (h *Hub) GetDataChannel(dataID string) (chan []byte, bool) { - var ch chan []byte - var ok bool - h.writeMu.ReadHandler(func() error { - innerCh, innerOk := h.ClientData[dataID] - ch = innerCh - ok = innerOk - return nil - }) - - return ch, ok -} - -func (h *Hub) RemoveDataChannel(dataID string) { - h.writeMu.WriteHandler(func() error { - if ch, ok := h.ClientData[dataID]; ok { - close(ch) - delete(h.ClientData, dataID) - log.Printf("Removed data channel for: %s\n", dataID) - } - return nil - }) -} - func (h *Hub) Run() { go func() { for { select { - case c := <-h.Register: - h.Clients[c] = true + case client := <-h.Register: + h.Clients[client] = true log.Println("Client registered") case c := <-h.Unregister: if _, ok := h.Clients[c]; ok { delete(h.Clients, c) close(c.Send) - c.Conn.Close() - log.Println("Client unregistered") } case message := <-h.Broadcast: - for c := range h.Clients { + for client := range h.Clients { select { - case c.Send <- message: + case client.Send <- message: default: - close(c.Send) - delete(h.Clients, c) + close(client.Send) + delete(h.Clients, client) } } } @@ -116,81 +78,58 @@ func (h *Hub) Run() { } func WritePump(c *Client, h *Hub) { - go func() { - defer func() { - h.Unregister <- c - c.Conn.Close() - }() - - ticker := time.NewTicker(30 * time.Second) - defer ticker.Stop() - - for { - select { - case message, ok := <-c.Send: - if err := c.mu.WriteHandler(func() error { - if !ok { - c.Conn.WriteMessage(websocket.CloseMessage, []byte{}) - return fmt.Errorf("message not ok") - } - - if err := c.Conn.WriteMessage(websocket.TextMessage, message); err != nil { - return err - } - - return nil - }); err != nil { - return - } - case <-ticker.C: - if err := c.mu.WriteHandler(func() error { - if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { - return err - } - return nil - }); err != nil { - return - } - } - } - }() -} - -func ReadPump(c *Client) { - go func() { - defer func() { - c.Conn.Close() - }() - - c.Conn.SetReadLimit(1024) - c.Conn.SetPongHandler(func(string) error { - return c.mu.WriteHandler(func() error { - if err := c.Conn.WriteMessage(websocket.PongMessage, []byte{}); err != nil { - return fmt.Errorf("failed to send pong: %v", err) - } - return nil - }) - }) - for { - var messageType int - var message []byte - var err error - // c.mu.ReadHandler(func() error { - messageType, message, err = c.Conn.ReadMessage() - // return nil - // }) - - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - log.Printf("WebSocket error: %v", err) - } - break - } - - if messageType == websocket.TextMessage { - fmt.Printf("Received: %s\n", message) - } - } + pingTicker := time.NewTicker(pingPeriod) + defer func() { + h.Unregister <- c + pingTicker.Stop() + c.Conn.Close() }() + for { + select { + case message, ok := <-c.Send: + c.Conn.SetWriteDeadline(time.Now().Add(writeWait)) + if !ok { + c.Conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } + + if err := c.Conn.WriteMessage(websocket.TextMessage, message); err != nil { + return + } + case <-pingTicker.C: + c.Conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + } + } +} + +func ReadPump(c *Client, h *Hub) { + defer func() { + h.Unregister <- c + c.Conn.Close() + }() + + c.Conn.SetReadLimit(512) + c.Conn.SetReadDeadline(time.Now().Add(pongWait)) + c.Conn.SetPongHandler(func(string) error { + c.Conn.SetReadDeadline(time.Now().Add(pongWait)) + return nil + }) + + for { + messageType, message, err := c.Conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("WebSocket error: %v", err) + } + break + } + + if messageType == websocket.TextMessage { + fmt.Printf("Received: %s\n", message) + } + } } diff --git a/v1/client/client.go b/v1/client/client.go new file mode 100644 index 0000000..8c18f23 --- /dev/null +++ b/v1/client/client.go @@ -0,0 +1,82 @@ +package client + +// import ( +// "context" +// "fmt" +// "time" + +// "git.neurocipta.com/rogerferdinan/safe-web-socket/internal" +// "github.com/gorilla/websocket" +// ) + +// const ( +// pingPeriod = 30 * time.Second +// ) + +// type SafeWebsocketClientBuilder struct { +// baseHost *string `nil_checker:"required"` +// basePort *uint16 `nil_checker:"required"` +// } + +// func NewSafeWebsocketClientBuilder() *SafeWebsocketClientBuilder { +// return &SafeWebsocketClientBuilder{} +// } + +// func (b *SafeWebsocketClientBuilder) BaseHost(host string) *SafeWebsocketClientBuilder { +// b.baseHost = &host +// return b +// } + +// func (b *SafeWebsocketClientBuilder) BasePort(port uint16) *SafeWebsocketClientBuilder { +// b.basePort = &port +// return b +// } + +// func (b *SafeWebsocketClientBuilder) Build() (*SafeWebsocketClient, error) { +// if err := internal.NilChecker(b); err != nil { +// return nil, err +// } + +// ctx, cancel := context.WithCancel(context.Background()) + +// wsClient := SafeWebsocketClient{ +// baseHost: b.baseHost, +// basePort: b.basePort, +// ctx: ctx, +// cancel: cancel, +// reconnectCh: make(chan struct{}, 1), +// isConnected: false, +// } + +// if err := wsClient.connect(); err != nil { +// cancel() +// return nil, fmt.Errorf("failed to establish initial connection: %v", err) +// } + +// wsClient.startPingTicker() +// wsClient.startReceiveHandler() + +// return &wsClient, nil +// } + +// type SafeWebsocketClient struct { +// baseHost *string +// basePort *uint16 +// mu *internal.CustomRwMutex +// ctx context.Context +// cancel context.CancelFunc +// reconnectCh chan struct{} +// isConnected bool +// } + +// func (wsClient *SafeWebsocketClient) connect() error { +// url := fmt.Sprintf("%s:%d", *wsClient.baseHost, *wsClient.basePort) +// conn, _, err := websocket.DefaultDialer.Dial(url, nil) +// if err != nil { +// return fmt.Errorf("failed to connect to %s: %w", *wsClient.baseHost, err) +// } + +// conn.SetPingHandler(func(pingData string) error { +// conn.WriteMessage(websocket.PongMessage, []byte(pingData)) +// }) +// } diff --git a/v1/example/server/main.go b/v1/example/server/main.go index 875ffcb..31550ca 100644 --- a/v1/example/server/main.go +++ b/v1/example/server/main.go @@ -12,13 +12,13 @@ func main() { BaseHost("localhost"). BasePort(8080). HandleFuncWebsocket("/ws/test/", "data_1", func(c chan []byte) { - ticker := time.NewTicker(100 * time.Millisecond) + ticker := time.NewTicker(10 * time.Millisecond) for range ticker.C { c <- []byte(time.Now().Format("2006-01-02 15:04:05") + "_data_1") } }). HandleFuncWebsocket("/ws/test/", "data_2", func(c chan []byte) { - ticker := time.NewTicker(100 * time.Millisecond) + ticker := time.NewTicker(10 * time.Millisecond) for range ticker.C { c <- []byte(time.Now().Format("2006-01-02 15:04:05") + "_data_2") } diff --git a/v1/server/server.go b/v1/server/server.go index c83aa78..fd39d40 100644 --- a/v1/server/server.go +++ b/v1/server/server.go @@ -70,9 +70,9 @@ func (b *SafeWebsocketServerBuilder) HandleFuncWebsocket( } c := internal.NewClient(conn, subscribedPath) h.Register <- c - internal.WritePump(c, h) - internal.ReadPump(c) - writeFunc(h.Broadcast) + go internal.WritePump(c, h) + go internal.ReadPump(c, h) + go writeFunc(h.Broadcast) }) return b } @@ -85,7 +85,6 @@ func (b *SafeWebsocketServerBuilder) Build() (*SafeWebsocketServer, error) { safeServer := SafeWebsocketServer{ url: fmt.Sprintf("%s:%d", *b.baseHost, *b.basePort), mux: b.mux, - mu: internal.NewCustomRwMutex(), } return &safeServer, nil } @@ -94,7 +93,6 @@ type SafeWebsocketServer struct { hub *internal.Hub mux *http.ServeMux url string - mu *internal.CustomRwMutex } func (s *SafeWebsocketServer) ListenAndServe() error {