From ef98404caf422d04551453a1131aee2456fa2e98 Mon Sep 17 00:00:00 2001 From: Roger Ferdinan Date: Thu, 25 Sep 2025 14:14:24 +0700 Subject: [PATCH] feat: safe websocket client implementation --- README.md | 20 ++ internal/hub.go | 3 +- v1/client/client.go | 251 ++++++++++++++++++------ v1/examples/client/main.go | 25 +++ v1/{example => examples}/server/main.go | 24 ++- 5 files changed, 255 insertions(+), 68 deletions(-) create mode 100644 v1/examples/client/main.go rename v1/{example => examples}/server/main.go (58%) diff --git a/README.md b/README.md index 60ec9cb..02f7b14 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,22 @@ # safe-web-socket +> A secure, production-ready WebSocket wrapper for Go with built-in validation, rate limiting, authentication, and connection management. + +## Overview + +`SafeWebSocket` is a Go library designed to simplify the creation of **secure, scalable, and resilient WebSocket servers**. Built on top of `gorilla/websocket`, it adds essential safety layers — including authentication, input validation, connection limits, rate limiting, and graceful shutdown — so you can focus on your application logic, not security pitfalls. + +Whether you’re building real-time dashboards, chat apps, or live data feeds, `SafeWebSocket` ensures your WebSocket endpoints are protected against common attacks (e.g., DoS, injection, unauthorized access). + +## Features +- ✅ **Graceful Shutdown & Cleanup** +- ✅ **Automatic Reconnection Handling (Client-side helpers)** +- ✅ **WebSocket Ping/Pong Health Checks** + +## Installation + +```bash +go get git.neurocipta.com/rogerferdinan/safe-web-socket +``` + +> Requires Go 1.24+ \ No newline at end of file diff --git a/internal/hub.go b/internal/hub.go index 0e0c934..59dcec3 100644 --- a/internal/hub.go +++ b/internal/hub.go @@ -63,6 +63,7 @@ func (h *Hub) Run() { delete(h.Clients, c) close(c.Send) } + log.Println("Client Unregistered") case message := <-h.Broadcast: for client := range h.Clients { select { @@ -122,7 +123,7 @@ func ReadPump(c *Client, h *Hub) { for { messageType, message, err := c.Conn.ReadMessage() if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + if !websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Printf("WebSocket error: %v", err) } break diff --git a/v1/client/client.go b/v1/client/client.go index 8c18f23..97f13c1 100644 --- a/v1/client/client.go +++ b/v1/client/client.go @@ -1,82 +1,203 @@ package client -// import ( -// "context" -// "fmt" -// "time" +import ( + "context" + "fmt" + "log" + "net/url" + "time" -// "git.neurocipta.com/rogerferdinan/safe-web-socket/internal" -// "github.com/gorilla/websocket" -// ) + "git.neurocipta.com/rogerferdinan/safe-web-socket/internal" + "github.com/gorilla/websocket" +) -// const ( -// pingPeriod = 30 * time.Second -// ) +const ( + pingPeriod = 10 * time.Second +) -// type SafeWebsocketClientBuilder struct { -// baseHost *string `nil_checker:"required"` -// basePort *uint16 `nil_checker:"required"` -// } +type SafeWebsocketClientBuilder struct { + baseHost *string `nil_checker:"required"` + basePort *uint16 `nil_checker:"required"` + path *string `nil_checkeer:"required"` + useTLS *bool +} -// func NewSafeWebsocketClientBuilder() *SafeWebsocketClientBuilder { -// return &SafeWebsocketClientBuilder{} -// } +func NewSafeWebsocketClientBuilder() *SafeWebsocketClientBuilder { + return &SafeWebsocketClientBuilder{} +} -// func (b *SafeWebsocketClientBuilder) BaseHost(host string) *SafeWebsocketClientBuilder { -// b.baseHost = &host -// return b -// } +func (b *SafeWebsocketClientBuilder) BaseHost(host string) *SafeWebsocketClientBuilder { + b.baseHost = &host + return b +} -// func (b *SafeWebsocketClientBuilder) BasePort(port uint16) *SafeWebsocketClientBuilder { -// b.basePort = &port -// return b -// } +func (b *SafeWebsocketClientBuilder) BasePort(port uint16) *SafeWebsocketClientBuilder { + b.basePort = &port + return b +} -// func (b *SafeWebsocketClientBuilder) Build() (*SafeWebsocketClient, error) { -// if err := internal.NilChecker(b); err != nil { -// return nil, err -// } +func (b *SafeWebsocketClientBuilder) UseTLS(useTLS bool) *SafeWebsocketClientBuilder { + b.useTLS = &useTLS + return b +} -// ctx, cancel := context.WithCancel(context.Background()) +func (b *SafeWebsocketClientBuilder) Path(path string) *SafeWebsocketClientBuilder { + b.path = &path + return b +} -// wsClient := SafeWebsocketClient{ -// baseHost: b.baseHost, -// basePort: b.basePort, -// ctx: ctx, -// cancel: cancel, -// reconnectCh: make(chan struct{}, 1), -// isConnected: false, -// } +func (b *SafeWebsocketClientBuilder) Build() (*SafeWebsocketClient, error) { + if err := internal.NilChecker(b); err != nil { + return nil, err + } -// if err := wsClient.connect(); err != nil { -// cancel() -// return nil, fmt.Errorf("failed to establish initial connection: %v", err) -// } + ctx, cancel := context.WithCancel(context.Background()) -// wsClient.startPingTicker() -// wsClient.startReceiveHandler() + var useTLS bool + if b.useTLS != nil { + useTLS = *b.useTLS + } -// return &wsClient, nil -// } + wsClient := SafeWebsocketClient{ + baseHost: *b.baseHost, + basePort: *b.basePort, + useTLS: useTLS, + path: *b.path, + ctx: ctx, + cancel: cancel, + dataChannel: make(chan []byte, 1), + mu: internal.NewCustomRwMutex(), + reconnectCh: make(chan struct{}, 1), + isConnected: false, + } -// type SafeWebsocketClient struct { -// baseHost *string -// basePort *uint16 -// mu *internal.CustomRwMutex -// ctx context.Context -// cancel context.CancelFunc -// reconnectCh chan struct{} -// isConnected bool -// } + if err := wsClient.connect(); err != nil { + cancel() + return nil, fmt.Errorf("failed to establish initial connection: %v", err) + } -// func (wsClient *SafeWebsocketClient) connect() error { -// url := fmt.Sprintf("%s:%d", *wsClient.baseHost, *wsClient.basePort) -// conn, _, err := websocket.DefaultDialer.Dial(url, nil) -// if err != nil { -// return fmt.Errorf("failed to connect to %s: %w", *wsClient.baseHost, err) -// } + return &wsClient, nil +} -// conn.SetPingHandler(func(pingData string) error { -// conn.WriteMessage(websocket.PongMessage, []byte(pingData)) -// }) -// } +type SafeWebsocketClient struct { + baseHost string + basePort uint16 + useTLS bool + path string + dataChannel chan []byte + mu *internal.CustomRwMutex + conn *websocket.Conn + ctx context.Context + cancel context.CancelFunc + reconnectCh chan struct{} + isConnected bool +} + +func (wsClient *SafeWebsocketClient) connect() error { + var scheme string + if wsClient.useTLS { + scheme = "wss" + } else { + scheme = "ws" + } + newURL := url.URL{ + Scheme: scheme, + Host: fmt.Sprintf("%s:%d", wsClient.baseHost, wsClient.basePort), + Path: wsClient.path, + } + conn, _, err := websocket.DefaultDialer.Dial(newURL.String(), nil) + if err != nil { + return fmt.Errorf("failed to connect to %s: %w", wsClient.baseHost, err) + } + + conn.SetPingHandler(func(pingData string) error { + if err := conn.WriteMessage(websocket.PongMessage, []byte(pingData)); err != nil { + if err == websocket.ErrCloseSent { + return nil + } + if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() { + return nil + } + return err + } + return nil + }) + + wsClient.mu.WriteHandler(func() error { + wsClient.conn = conn + wsClient.isConnected = true + return nil + }) + + go wsClient.startPingTicker() + go wsClient.startReceiveHandler() + go wsClient.reconnectHandler() + + return nil +} + +func (wsClient *SafeWebsocketClient) startPingTicker() { + ticker := time.NewTicker(pingPeriod) + defer ticker.Stop() + for { + select { + case <-ticker.C: + wsClient.mu.WriteHandler(func() error { + if err := wsClient.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { + log.Printf("Ping failed: %v. Will attempt reconnect.", err) + wsClient.triggerReconnect() + } + return nil + }) + case <-wsClient.ctx.Done(): + log.Println("Ping ticker stopped") + return + } + } +} + +func (wsClient *SafeWebsocketClient) startReceiveHandler() { + for { + select { + case <-wsClient.reconnectCh: + case <-wsClient.ctx.Done(): + log.Println("Reconnect handler stopped") + return + default: + if err := wsClient.mu.WriteHandler(func() error { + conn := wsClient.conn + + if conn == nil { + return fmt.Errorf("no active connection, waiting for reconnect") + } + _, message, err := conn.ReadMessage() + if err != nil { + return err + } + fmt.Println(string(message)) + wsClient.dataChannel <- message + return nil + }); err != nil { + wsClient.triggerReconnect() + return + } + } + } +} + +func (wsClient *SafeWebsocketClient) triggerReconnect() { + select { + case wsClient.reconnectCh <- struct{}{}: + default: + } +} + +func (wsClient *SafeWebsocketClient) reconnectHandler() { + for range wsClient.reconnectCh { + wsClient.connect() + } +} + +func (wsClient *SafeWebsocketClient) DataChannel() <-chan []byte { + return wsClient.dataChannel +} diff --git a/v1/examples/client/main.go b/v1/examples/client/main.go new file mode 100644 index 0000000..bb3d9c3 --- /dev/null +++ b/v1/examples/client/main.go @@ -0,0 +1,25 @@ +package main + +import ( + "fmt" + "log" + + "git.neurocipta.com/rogerferdinan/safe-web-socket/v1/client" +) + +func main() { + wsClient, err := client.NewSafeWebsocketClientBuilder(). + BaseHost("localhost"). + BasePort(8080). + Path("/ws/test/data_1"). + UseTLS(false). + Build() + if err != nil { + log.Fatal(err) + } + dataChannel := wsClient.DataChannel() + + for data := range dataChannel { + fmt.Println(string(data)) + } +} diff --git a/v1/example/server/main.go b/v1/examples/server/main.go similarity index 58% rename from v1/example/server/main.go rename to v1/examples/server/main.go index 31550ca..90510ad 100644 --- a/v1/example/server/main.go +++ b/v1/examples/server/main.go @@ -1,12 +1,18 @@ package main import ( + "encoding/json" "log" "time" "git.neurocipta.com/rogerferdinan/safe-web-socket/v1/server" ) +type ExampleData struct { + Time time.Time `json:"time"` + Data string `json:"data"` +} + func main() { s, err := server.NewSafeWebsocketServerBuilder(). BaseHost("localhost"). @@ -14,13 +20,27 @@ func main() { HandleFuncWebsocket("/ws/test/", "data_1", func(c chan []byte) { ticker := time.NewTicker(10 * time.Millisecond) for range ticker.C { - c <- []byte(time.Now().Format("2006-01-02 15:04:05") + "_data_1") + jsonBytes, err := json.Marshal(ExampleData{ + Time: time.Now(), + Data: "data_1", + }) + if err != nil { + continue + } + c <- jsonBytes } }). HandleFuncWebsocket("/ws/test/", "data_2", func(c chan []byte) { ticker := time.NewTicker(10 * time.Millisecond) for range ticker.C { - c <- []byte(time.Now().Format("2006-01-02 15:04:05") + "_data_2") + jsonBytes, err := json.Marshal(ExampleData{ + Time: time.Now(), + Data: "data_2", + }) + if err != nil { + continue + } + c <- jsonBytes } }). Build()