Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ef98404caf |
20
README.md
20
README.md
@@ -1,2 +1,22 @@
|
||||
# safe-web-socket
|
||||
|
||||
> A secure, production-ready WebSocket wrapper for Go with built-in validation, rate limiting, authentication, and connection management.
|
||||
|
||||
## Overview
|
||||
|
||||
`SafeWebSocket` is a Go library designed to simplify the creation of **secure, scalable, and resilient WebSocket servers**. Built on top of `gorilla/websocket`, it adds essential safety layers — including authentication, input validation, connection limits, rate limiting, and graceful shutdown — so you can focus on your application logic, not security pitfalls.
|
||||
|
||||
Whether you’re building real-time dashboards, chat apps, or live data feeds, `SafeWebSocket` ensures your WebSocket endpoints are protected against common attacks (e.g., DoS, injection, unauthorized access).
|
||||
|
||||
## Features
|
||||
- ✅ **Graceful Shutdown & Cleanup**
|
||||
- ✅ **Automatic Reconnection Handling (Client-side helpers)**
|
||||
- ✅ **WebSocket Ping/Pong Health Checks**
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
go get git.neurocipta.com/rogerferdinan/safe-web-socket
|
||||
```
|
||||
|
||||
> Requires Go 1.24+
|
||||
@@ -63,6 +63,7 @@ func (h *Hub) Run() {
|
||||
delete(h.Clients, c)
|
||||
close(c.Send)
|
||||
}
|
||||
log.Println("Client Unregistered")
|
||||
case message := <-h.Broadcast:
|
||||
for client := range h.Clients {
|
||||
select {
|
||||
@@ -122,7 +123,7 @@ func ReadPump(c *Client, h *Hub) {
|
||||
for {
|
||||
messageType, message, err := c.Conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
if !websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
log.Printf("WebSocket error: %v", err)
|
||||
}
|
||||
break
|
||||
|
||||
@@ -1,82 +1,203 @@
|
||||
package client
|
||||
|
||||
// import (
|
||||
// "context"
|
||||
// "fmt"
|
||||
// "time"
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
// "git.neurocipta.com/rogerferdinan/safe-web-socket/internal"
|
||||
// "github.com/gorilla/websocket"
|
||||
// )
|
||||
"git.neurocipta.com/rogerferdinan/safe-web-socket/internal"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// const (
|
||||
// pingPeriod = 30 * time.Second
|
||||
// )
|
||||
const (
|
||||
pingPeriod = 10 * time.Second
|
||||
)
|
||||
|
||||
// type SafeWebsocketClientBuilder struct {
|
||||
// baseHost *string `nil_checker:"required"`
|
||||
// basePort *uint16 `nil_checker:"required"`
|
||||
// }
|
||||
type SafeWebsocketClientBuilder struct {
|
||||
baseHost *string `nil_checker:"required"`
|
||||
basePort *uint16 `nil_checker:"required"`
|
||||
path *string `nil_checkeer:"required"`
|
||||
useTLS *bool
|
||||
}
|
||||
|
||||
// func NewSafeWebsocketClientBuilder() *SafeWebsocketClientBuilder {
|
||||
// return &SafeWebsocketClientBuilder{}
|
||||
// }
|
||||
func NewSafeWebsocketClientBuilder() *SafeWebsocketClientBuilder {
|
||||
return &SafeWebsocketClientBuilder{}
|
||||
}
|
||||
|
||||
// func (b *SafeWebsocketClientBuilder) BaseHost(host string) *SafeWebsocketClientBuilder {
|
||||
// b.baseHost = &host
|
||||
// return b
|
||||
// }
|
||||
func (b *SafeWebsocketClientBuilder) BaseHost(host string) *SafeWebsocketClientBuilder {
|
||||
b.baseHost = &host
|
||||
return b
|
||||
}
|
||||
|
||||
// func (b *SafeWebsocketClientBuilder) BasePort(port uint16) *SafeWebsocketClientBuilder {
|
||||
// b.basePort = &port
|
||||
// return b
|
||||
// }
|
||||
func (b *SafeWebsocketClientBuilder) BasePort(port uint16) *SafeWebsocketClientBuilder {
|
||||
b.basePort = &port
|
||||
return b
|
||||
}
|
||||
|
||||
// func (b *SafeWebsocketClientBuilder) Build() (*SafeWebsocketClient, error) {
|
||||
// if err := internal.NilChecker(b); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
func (b *SafeWebsocketClientBuilder) UseTLS(useTLS bool) *SafeWebsocketClientBuilder {
|
||||
b.useTLS = &useTLS
|
||||
return b
|
||||
}
|
||||
|
||||
// ctx, cancel := context.WithCancel(context.Background())
|
||||
func (b *SafeWebsocketClientBuilder) Path(path string) *SafeWebsocketClientBuilder {
|
||||
b.path = &path
|
||||
return b
|
||||
}
|
||||
|
||||
// wsClient := SafeWebsocketClient{
|
||||
// baseHost: b.baseHost,
|
||||
// basePort: b.basePort,
|
||||
// ctx: ctx,
|
||||
// cancel: cancel,
|
||||
// reconnectCh: make(chan struct{}, 1),
|
||||
// isConnected: false,
|
||||
// }
|
||||
func (b *SafeWebsocketClientBuilder) Build() (*SafeWebsocketClient, error) {
|
||||
if err := internal.NilChecker(b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if err := wsClient.connect(); err != nil {
|
||||
// cancel()
|
||||
// return nil, fmt.Errorf("failed to establish initial connection: %v", err)
|
||||
// }
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// wsClient.startPingTicker()
|
||||
// wsClient.startReceiveHandler()
|
||||
var useTLS bool
|
||||
if b.useTLS != nil {
|
||||
useTLS = *b.useTLS
|
||||
}
|
||||
|
||||
// return &wsClient, nil
|
||||
// }
|
||||
wsClient := SafeWebsocketClient{
|
||||
baseHost: *b.baseHost,
|
||||
basePort: *b.basePort,
|
||||
useTLS: useTLS,
|
||||
path: *b.path,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
dataChannel: make(chan []byte, 1),
|
||||
mu: internal.NewCustomRwMutex(),
|
||||
reconnectCh: make(chan struct{}, 1),
|
||||
isConnected: false,
|
||||
}
|
||||
|
||||
// type SafeWebsocketClient struct {
|
||||
// baseHost *string
|
||||
// basePort *uint16
|
||||
// mu *internal.CustomRwMutex
|
||||
// ctx context.Context
|
||||
// cancel context.CancelFunc
|
||||
// reconnectCh chan struct{}
|
||||
// isConnected bool
|
||||
// }
|
||||
if err := wsClient.connect(); err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to establish initial connection: %v", err)
|
||||
}
|
||||
|
||||
// func (wsClient *SafeWebsocketClient) connect() error {
|
||||
// url := fmt.Sprintf("%s:%d", *wsClient.baseHost, *wsClient.basePort)
|
||||
// conn, _, err := websocket.DefaultDialer.Dial(url, nil)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("failed to connect to %s: %w", *wsClient.baseHost, err)
|
||||
// }
|
||||
return &wsClient, nil
|
||||
}
|
||||
|
||||
// conn.SetPingHandler(func(pingData string) error {
|
||||
// conn.WriteMessage(websocket.PongMessage, []byte(pingData))
|
||||
// })
|
||||
// }
|
||||
type SafeWebsocketClient struct {
|
||||
baseHost string
|
||||
basePort uint16
|
||||
useTLS bool
|
||||
path string
|
||||
dataChannel chan []byte
|
||||
mu *internal.CustomRwMutex
|
||||
conn *websocket.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
reconnectCh chan struct{}
|
||||
isConnected bool
|
||||
}
|
||||
|
||||
func (wsClient *SafeWebsocketClient) connect() error {
|
||||
var scheme string
|
||||
if wsClient.useTLS {
|
||||
scheme = "wss"
|
||||
} else {
|
||||
scheme = "ws"
|
||||
}
|
||||
newURL := url.URL{
|
||||
Scheme: scheme,
|
||||
Host: fmt.Sprintf("%s:%d", wsClient.baseHost, wsClient.basePort),
|
||||
Path: wsClient.path,
|
||||
}
|
||||
conn, _, err := websocket.DefaultDialer.Dial(newURL.String(), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to %s: %w", wsClient.baseHost, err)
|
||||
}
|
||||
|
||||
conn.SetPingHandler(func(pingData string) error {
|
||||
if err := conn.WriteMessage(websocket.PongMessage, []byte(pingData)); err != nil {
|
||||
if err == websocket.ErrCloseSent {
|
||||
return nil
|
||||
}
|
||||
if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
wsClient.mu.WriteHandler(func() error {
|
||||
wsClient.conn = conn
|
||||
wsClient.isConnected = true
|
||||
return nil
|
||||
})
|
||||
|
||||
go wsClient.startPingTicker()
|
||||
go wsClient.startReceiveHandler()
|
||||
go wsClient.reconnectHandler()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wsClient *SafeWebsocketClient) startPingTicker() {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
wsClient.mu.WriteHandler(func() error {
|
||||
if err := wsClient.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
|
||||
log.Printf("Ping failed: %v. Will attempt reconnect.", err)
|
||||
wsClient.triggerReconnect()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
case <-wsClient.ctx.Done():
|
||||
log.Println("Ping ticker stopped")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (wsClient *SafeWebsocketClient) startReceiveHandler() {
|
||||
for {
|
||||
select {
|
||||
case <-wsClient.reconnectCh:
|
||||
case <-wsClient.ctx.Done():
|
||||
log.Println("Reconnect handler stopped")
|
||||
return
|
||||
default:
|
||||
if err := wsClient.mu.WriteHandler(func() error {
|
||||
conn := wsClient.conn
|
||||
|
||||
if conn == nil {
|
||||
return fmt.Errorf("no active connection, waiting for reconnect")
|
||||
}
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println(string(message))
|
||||
wsClient.dataChannel <- message
|
||||
return nil
|
||||
}); err != nil {
|
||||
wsClient.triggerReconnect()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (wsClient *SafeWebsocketClient) triggerReconnect() {
|
||||
select {
|
||||
case wsClient.reconnectCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (wsClient *SafeWebsocketClient) reconnectHandler() {
|
||||
for range wsClient.reconnectCh {
|
||||
wsClient.connect()
|
||||
}
|
||||
}
|
||||
|
||||
func (wsClient *SafeWebsocketClient) DataChannel() <-chan []byte {
|
||||
return wsClient.dataChannel
|
||||
}
|
||||
|
||||
25
v1/examples/client/main.go
Normal file
25
v1/examples/client/main.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"git.neurocipta.com/rogerferdinan/safe-web-socket/v1/client"
|
||||
)
|
||||
|
||||
func main() {
|
||||
wsClient, err := client.NewSafeWebsocketClientBuilder().
|
||||
BaseHost("localhost").
|
||||
BasePort(8080).
|
||||
Path("/ws/test/data_1").
|
||||
UseTLS(false).
|
||||
Build()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
dataChannel := wsClient.DataChannel()
|
||||
|
||||
for data := range dataChannel {
|
||||
fmt.Println(string(data))
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,18 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"git.neurocipta.com/rogerferdinan/safe-web-socket/v1/server"
|
||||
)
|
||||
|
||||
type ExampleData struct {
|
||||
Time time.Time `json:"time"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
s, err := server.NewSafeWebsocketServerBuilder().
|
||||
BaseHost("localhost").
|
||||
@@ -14,13 +20,27 @@ func main() {
|
||||
HandleFuncWebsocket("/ws/test/", "data_1", func(c chan []byte) {
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
for range ticker.C {
|
||||
c <- []byte(time.Now().Format("2006-01-02 15:04:05") + "_data_1")
|
||||
jsonBytes, err := json.Marshal(ExampleData{
|
||||
Time: time.Now(),
|
||||
Data: "data_1",
|
||||
})
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
c <- jsonBytes
|
||||
}
|
||||
}).
|
||||
HandleFuncWebsocket("/ws/test/", "data_2", func(c chan []byte) {
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
for range ticker.C {
|
||||
c <- []byte(time.Now().Format("2006-01-02 15:04:05") + "_data_2")
|
||||
jsonBytes, err := json.Marshal(ExampleData{
|
||||
Time: time.Now(),
|
||||
Data: "data_2",
|
||||
})
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
c <- jsonBytes
|
||||
}
|
||||
}).
|
||||
Build()
|
||||
Reference in New Issue
Block a user