feat: safe websocket client implementation
This commit is contained in:
20
README.md
20
README.md
@@ -1,2 +1,22 @@
|
|||||||
# safe-web-socket
|
# safe-web-socket
|
||||||
|
|
||||||
|
> A secure, production-ready WebSocket wrapper for Go with built-in validation, rate limiting, authentication, and connection management.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
`SafeWebSocket` is a Go library designed to simplify the creation of **secure, scalable, and resilient WebSocket servers**. Built on top of `gorilla/websocket`, it adds essential safety layers — including authentication, input validation, connection limits, rate limiting, and graceful shutdown — so you can focus on your application logic, not security pitfalls.
|
||||||
|
|
||||||
|
Whether you’re building real-time dashboards, chat apps, or live data feeds, `SafeWebSocket` ensures your WebSocket endpoints are protected against common attacks (e.g., DoS, injection, unauthorized access).
|
||||||
|
|
||||||
|
## Features
|
||||||
|
- ✅ **Graceful Shutdown & Cleanup**
|
||||||
|
- ✅ **Automatic Reconnection Handling (Client-side helpers)**
|
||||||
|
- ✅ **WebSocket Ping/Pong Health Checks**
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go get git.neurocipta.com/rogerferdinan/safe-web-socket
|
||||||
|
```
|
||||||
|
|
||||||
|
> Requires Go 1.24+
|
||||||
@@ -63,6 +63,7 @@ func (h *Hub) Run() {
|
|||||||
delete(h.Clients, c)
|
delete(h.Clients, c)
|
||||||
close(c.Send)
|
close(c.Send)
|
||||||
}
|
}
|
||||||
|
log.Println("Client Unregistered")
|
||||||
case message := <-h.Broadcast:
|
case message := <-h.Broadcast:
|
||||||
for client := range h.Clients {
|
for client := range h.Clients {
|
||||||
select {
|
select {
|
||||||
@@ -122,7 +123,7 @@ func ReadPump(c *Client, h *Hub) {
|
|||||||
for {
|
for {
|
||||||
messageType, message, err := c.Conn.ReadMessage()
|
messageType, message, err := c.Conn.ReadMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
if !websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||||
log.Printf("WebSocket error: %v", err)
|
log.Printf("WebSocket error: %v", err)
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -1,82 +1,203 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
// import (
|
import (
|
||||||
// "context"
|
"context"
|
||||||
// "fmt"
|
"fmt"
|
||||||
// "time"
|
"log"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
// "git.neurocipta.com/rogerferdinan/safe-web-socket/internal"
|
"git.neurocipta.com/rogerferdinan/safe-web-socket/internal"
|
||||||
// "github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
// )
|
)
|
||||||
|
|
||||||
// const (
|
const (
|
||||||
// pingPeriod = 30 * time.Second
|
pingPeriod = 10 * time.Second
|
||||||
// )
|
)
|
||||||
|
|
||||||
// type SafeWebsocketClientBuilder struct {
|
type SafeWebsocketClientBuilder struct {
|
||||||
// baseHost *string `nil_checker:"required"`
|
baseHost *string `nil_checker:"required"`
|
||||||
// basePort *uint16 `nil_checker:"required"`
|
basePort *uint16 `nil_checker:"required"`
|
||||||
// }
|
path *string `nil_checkeer:"required"`
|
||||||
|
useTLS *bool
|
||||||
|
}
|
||||||
|
|
||||||
// func NewSafeWebsocketClientBuilder() *SafeWebsocketClientBuilder {
|
func NewSafeWebsocketClientBuilder() *SafeWebsocketClientBuilder {
|
||||||
// return &SafeWebsocketClientBuilder{}
|
return &SafeWebsocketClientBuilder{}
|
||||||
// }
|
}
|
||||||
|
|
||||||
// func (b *SafeWebsocketClientBuilder) BaseHost(host string) *SafeWebsocketClientBuilder {
|
func (b *SafeWebsocketClientBuilder) BaseHost(host string) *SafeWebsocketClientBuilder {
|
||||||
// b.baseHost = &host
|
b.baseHost = &host
|
||||||
// return b
|
return b
|
||||||
// }
|
}
|
||||||
|
|
||||||
// func (b *SafeWebsocketClientBuilder) BasePort(port uint16) *SafeWebsocketClientBuilder {
|
func (b *SafeWebsocketClientBuilder) BasePort(port uint16) *SafeWebsocketClientBuilder {
|
||||||
// b.basePort = &port
|
b.basePort = &port
|
||||||
// return b
|
return b
|
||||||
// }
|
}
|
||||||
|
|
||||||
// func (b *SafeWebsocketClientBuilder) Build() (*SafeWebsocketClient, error) {
|
func (b *SafeWebsocketClientBuilder) UseTLS(useTLS bool) *SafeWebsocketClientBuilder {
|
||||||
// if err := internal.NilChecker(b); err != nil {
|
b.useTLS = &useTLS
|
||||||
// return nil, err
|
return b
|
||||||
// }
|
}
|
||||||
|
|
||||||
// ctx, cancel := context.WithCancel(context.Background())
|
func (b *SafeWebsocketClientBuilder) Path(path string) *SafeWebsocketClientBuilder {
|
||||||
|
b.path = &path
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
// wsClient := SafeWebsocketClient{
|
func (b *SafeWebsocketClientBuilder) Build() (*SafeWebsocketClient, error) {
|
||||||
// baseHost: b.baseHost,
|
if err := internal.NilChecker(b); err != nil {
|
||||||
// basePort: b.basePort,
|
return nil, err
|
||||||
// ctx: ctx,
|
}
|
||||||
// cancel: cancel,
|
|
||||||
// reconnectCh: make(chan struct{}, 1),
|
|
||||||
// isConnected: false,
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if err := wsClient.connect(); err != nil {
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
// cancel()
|
|
||||||
// return nil, fmt.Errorf("failed to establish initial connection: %v", err)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// wsClient.startPingTicker()
|
var useTLS bool
|
||||||
// wsClient.startReceiveHandler()
|
if b.useTLS != nil {
|
||||||
|
useTLS = *b.useTLS
|
||||||
|
}
|
||||||
|
|
||||||
// return &wsClient, nil
|
wsClient := SafeWebsocketClient{
|
||||||
// }
|
baseHost: *b.baseHost,
|
||||||
|
basePort: *b.basePort,
|
||||||
|
useTLS: useTLS,
|
||||||
|
path: *b.path,
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
dataChannel: make(chan []byte, 1),
|
||||||
|
mu: internal.NewCustomRwMutex(),
|
||||||
|
reconnectCh: make(chan struct{}, 1),
|
||||||
|
isConnected: false,
|
||||||
|
}
|
||||||
|
|
||||||
// type SafeWebsocketClient struct {
|
if err := wsClient.connect(); err != nil {
|
||||||
// baseHost *string
|
cancel()
|
||||||
// basePort *uint16
|
return nil, fmt.Errorf("failed to establish initial connection: %v", err)
|
||||||
// mu *internal.CustomRwMutex
|
}
|
||||||
// ctx context.Context
|
|
||||||
// cancel context.CancelFunc
|
|
||||||
// reconnectCh chan struct{}
|
|
||||||
// isConnected bool
|
|
||||||
// }
|
|
||||||
|
|
||||||
// func (wsClient *SafeWebsocketClient) connect() error {
|
return &wsClient, nil
|
||||||
// url := fmt.Sprintf("%s:%d", *wsClient.baseHost, *wsClient.basePort)
|
}
|
||||||
// conn, _, err := websocket.DefaultDialer.Dial(url, nil)
|
|
||||||
// if err != nil {
|
|
||||||
// return fmt.Errorf("failed to connect to %s: %w", *wsClient.baseHost, err)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// conn.SetPingHandler(func(pingData string) error {
|
type SafeWebsocketClient struct {
|
||||||
// conn.WriteMessage(websocket.PongMessage, []byte(pingData))
|
baseHost string
|
||||||
// })
|
basePort uint16
|
||||||
// }
|
useTLS bool
|
||||||
|
path string
|
||||||
|
dataChannel chan []byte
|
||||||
|
mu *internal.CustomRwMutex
|
||||||
|
conn *websocket.Conn
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
reconnectCh chan struct{}
|
||||||
|
isConnected bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wsClient *SafeWebsocketClient) connect() error {
|
||||||
|
var scheme string
|
||||||
|
if wsClient.useTLS {
|
||||||
|
scheme = "wss"
|
||||||
|
} else {
|
||||||
|
scheme = "ws"
|
||||||
|
}
|
||||||
|
newURL := url.URL{
|
||||||
|
Scheme: scheme,
|
||||||
|
Host: fmt.Sprintf("%s:%d", wsClient.baseHost, wsClient.basePort),
|
||||||
|
Path: wsClient.path,
|
||||||
|
}
|
||||||
|
conn, _, err := websocket.DefaultDialer.Dial(newURL.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to %s: %w", wsClient.baseHost, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.SetPingHandler(func(pingData string) error {
|
||||||
|
if err := conn.WriteMessage(websocket.PongMessage, []byte(pingData)); err != nil {
|
||||||
|
if err == websocket.ErrCloseSent {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
wsClient.mu.WriteHandler(func() error {
|
||||||
|
wsClient.conn = conn
|
||||||
|
wsClient.isConnected = true
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
go wsClient.startPingTicker()
|
||||||
|
go wsClient.startReceiveHandler()
|
||||||
|
go wsClient.reconnectHandler()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wsClient *SafeWebsocketClient) startPingTicker() {
|
||||||
|
ticker := time.NewTicker(pingPeriod)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
wsClient.mu.WriteHandler(func() error {
|
||||||
|
if err := wsClient.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
|
||||||
|
log.Printf("Ping failed: %v. Will attempt reconnect.", err)
|
||||||
|
wsClient.triggerReconnect()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
case <-wsClient.ctx.Done():
|
||||||
|
log.Println("Ping ticker stopped")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wsClient *SafeWebsocketClient) startReceiveHandler() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-wsClient.reconnectCh:
|
||||||
|
case <-wsClient.ctx.Done():
|
||||||
|
log.Println("Reconnect handler stopped")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
if err := wsClient.mu.WriteHandler(func() error {
|
||||||
|
conn := wsClient.conn
|
||||||
|
|
||||||
|
if conn == nil {
|
||||||
|
return fmt.Errorf("no active connection, waiting for reconnect")
|
||||||
|
}
|
||||||
|
_, message, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Println(string(message))
|
||||||
|
wsClient.dataChannel <- message
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
wsClient.triggerReconnect()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wsClient *SafeWebsocketClient) triggerReconnect() {
|
||||||
|
select {
|
||||||
|
case wsClient.reconnectCh <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wsClient *SafeWebsocketClient) reconnectHandler() {
|
||||||
|
for range wsClient.reconnectCh {
|
||||||
|
wsClient.connect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wsClient *SafeWebsocketClient) DataChannel() <-chan []byte {
|
||||||
|
return wsClient.dataChannel
|
||||||
|
}
|
||||||
|
|||||||
25
v1/examples/client/main.go
Normal file
25
v1/examples/client/main.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"git.neurocipta.com/rogerferdinan/safe-web-socket/v1/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
wsClient, err := client.NewSafeWebsocketClientBuilder().
|
||||||
|
BaseHost("localhost").
|
||||||
|
BasePort(8080).
|
||||||
|
Path("/ws/test/data_1").
|
||||||
|
UseTLS(false).
|
||||||
|
Build()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
dataChannel := wsClient.DataChannel()
|
||||||
|
|
||||||
|
for data := range dataChannel {
|
||||||
|
fmt.Println(string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,12 +1,18 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"log"
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.neurocipta.com/rogerferdinan/safe-web-socket/v1/server"
|
"git.neurocipta.com/rogerferdinan/safe-web-socket/v1/server"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ExampleData struct {
|
||||||
|
Time time.Time `json:"time"`
|
||||||
|
Data string `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
s, err := server.NewSafeWebsocketServerBuilder().
|
s, err := server.NewSafeWebsocketServerBuilder().
|
||||||
BaseHost("localhost").
|
BaseHost("localhost").
|
||||||
@@ -14,13 +20,27 @@ func main() {
|
|||||||
HandleFuncWebsocket("/ws/test/", "data_1", func(c chan []byte) {
|
HandleFuncWebsocket("/ws/test/", "data_1", func(c chan []byte) {
|
||||||
ticker := time.NewTicker(10 * time.Millisecond)
|
ticker := time.NewTicker(10 * time.Millisecond)
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
c <- []byte(time.Now().Format("2006-01-02 15:04:05") + "_data_1")
|
jsonBytes, err := json.Marshal(ExampleData{
|
||||||
|
Time: time.Now(),
|
||||||
|
Data: "data_1",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c <- jsonBytes
|
||||||
}
|
}
|
||||||
}).
|
}).
|
||||||
HandleFuncWebsocket("/ws/test/", "data_2", func(c chan []byte) {
|
HandleFuncWebsocket("/ws/test/", "data_2", func(c chan []byte) {
|
||||||
ticker := time.NewTicker(10 * time.Millisecond)
|
ticker := time.NewTicker(10 * time.Millisecond)
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
c <- []byte(time.Now().Format("2006-01-02 15:04:05") + "_data_2")
|
jsonBytes, err := json.Marshal(ExampleData{
|
||||||
|
Time: time.Now(),
|
||||||
|
Data: "data_2",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c <- jsonBytes
|
||||||
}
|
}
|
||||||
}).
|
}).
|
||||||
Build()
|
Build()
|
||||||
Reference in New Issue
Block a user