From 34811d4132e99bc02419b450e580b10cbac7df6a Mon Sep 17 00:00:00 2001 From: Roger Ferdinan Date: Wed, 24 Sep 2025 23:06:40 +0700 Subject: [PATCH] feat: safe websocket server implementation --- go.mod | 5 + go.sum | 2 + internal/hub.go | 188 ++++++++++++++++++++++++++++++++++++++ internal/mutex.go | 33 +++++++ internal/nil_checker.go | 40 ++++++++ v1/example/server/main.go | 32 +++++++ v1/server/server.go | 107 ++++++++++++++++++++++ 7 files changed, 407 insertions(+) create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/hub.go create mode 100644 internal/mutex.go create mode 100644 internal/nil_checker.go create mode 100644 v1/example/server/main.go create mode 100644 v1/server/server.go diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..155947b --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module git.neurocipta.com/rogerferdinan/safe-web-socket + +go 1.24.5 + +require github.com/gorilla/websocket v1.5.3 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..25a9fc4 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= diff --git a/internal/hub.go b/internal/hub.go new file mode 100644 index 0000000..74d3294 --- /dev/null +++ b/internal/hub.go @@ -0,0 +1,188 @@ +package internal + +import ( + "fmt" + "log" + "time" + + "github.com/gorilla/websocket" +) + +type Client struct { + Conn *websocket.Conn + Send chan []byte + SubscribedPath string + done chan struct{} + mu *CustomRwMutex +} + +func NewClient(conn *websocket.Conn, subscribedPath string) *Client { + return &Client{ + Conn: conn, + Send: make(chan []byte, 1024), + SubscribedPath: subscribedPath, + done: make(chan struct{}), + mu: NewCustomRwMutex(), + } +} + +type Hub struct { + Clients map[*Client]bool + Broadcast chan []byte + Register chan *Client + Unregister chan *Client + ClientData map[string]chan []byte + writeMu *CustomRwMutex + readMu *CustomRwMutex +} + +func NewHub() *Hub { + return &Hub{ + Broadcast: make(chan []byte), + Register: make(chan *Client), + Unregister: make(chan *Client), + Clients: make(map[*Client]bool), + ClientData: make(map[string]chan []byte), + writeMu: NewCustomRwMutex(), + } +} + +func (h *Hub) AddDataChannel(dataID string) chan []byte { + ch := make(chan []byte, 256) + h.writeMu.WriteHandler(func() error { + if innerCh, ok := h.ClientData[dataID]; ok { + ch = innerCh + return nil + } + + h.ClientData[dataID] = ch + log.Printf("Created data channel for: %s\n", dataID) + return nil + }) + + return ch +} + +func (h *Hub) GetDataChannel(dataID string) (chan []byte, bool) { + var ch chan []byte + var ok bool + h.writeMu.ReadHandler(func() error { + innerCh, innerOk := h.ClientData[dataID] + ch = innerCh + ok = innerOk + return nil + }) + + return ch, ok +} + +func (h *Hub) RemoveDataChannel(dataID string) { + h.writeMu.WriteHandler(func() error { + if ch, ok := h.ClientData[dataID]; ok { + close(ch) + delete(h.ClientData, dataID) + log.Printf("Removed data channel for: %s\n", dataID) + } + return nil + }) +} + +func (h *Hub) Run() { + go func() { + for { + select { + case c := <-h.Register: + h.Clients[c] = true + log.Println("Client registered") + case c := <-h.Unregister: + if _, ok := h.Clients[c]; ok { + delete(h.Clients, c) + close(c.Send) + c.Conn.Close() + log.Println("Client unregistered") + } + case message := <-h.Broadcast: + for c := range h.Clients { + select { + case c.Send <- message: + default: + close(c.Send) + delete(h.Clients, c) + } + } + } + } + }() +} + +func WritePump(c *Client, h *Hub) { + go func() { + defer func() { + h.Unregister <- c + c.Conn.Close() + }() + + c.Conn.SetReadLimit(1024) + c.Conn.SetPongHandler(func(string) error { + if err := c.Conn.WriteMessage(websocket.PongMessage, []byte{}); err != nil { + return fmt.Errorf("failed to send pong: %v", err) + } + return nil + }) + + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case message, ok := <-c.Send: + if err := c.mu.WriteHandler(func() error { + if !ok { + c.Conn.WriteMessage(websocket.CloseMessage, []byte{}) + return fmt.Errorf("message not ok") + } + + w, err := c.Conn.NextWriter(websocket.TextMessage) + if err != nil { + return fmt.Errorf("failed to get writer: %q", err) + } + + w.Write(message) + + if err := w.Close(); err != nil { + return err + } + return nil + }); err == nil { + continue + } + case <-ticker.C: + if err := c.mu.WriteHandler(func() error { + if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return err + } + return nil + }); err != nil { + return + } + } + } + }() +} + +func ReadPump(c *Client) { + go func() { + for { + _, message, err := c.Conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("WebSocket error: %v", err) + } + break + } + + fmt.Println(string(message)) + } + }() + +} diff --git a/internal/mutex.go b/internal/mutex.go new file mode 100644 index 0000000..9c8821e --- /dev/null +++ b/internal/mutex.go @@ -0,0 +1,33 @@ +package internal + +import ( + "sync" +) + +type CustomRwMutex struct { + mu *sync.RWMutex +} + +func NewCustomRwMutex() *CustomRwMutex { + return &CustomRwMutex{ + mu: &sync.RWMutex{}, + } +} + +func (rwMu *CustomRwMutex) WriteHandler(fn func() error) error { + rwMu.mu.Lock() + defer rwMu.mu.Unlock() + if err := fn(); err != nil { + return err + } + return nil +} + +func (rwMu *CustomRwMutex) ReadHandler(fn func() error) error { + rwMu.mu.RLock() + defer rwMu.mu.RUnlock() + if err := fn(); err != nil { + return err + } + return nil +} diff --git a/internal/nil_checker.go b/internal/nil_checker.go new file mode 100644 index 0000000..affcb59 --- /dev/null +++ b/internal/nil_checker.go @@ -0,0 +1,40 @@ +package internal + +import ( + "fmt" + "reflect" + "strings" +) + +func NilChecker(data any) error { + val := reflect.ValueOf(data) + + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + + valType := val.Type() + + if val.Kind() != reflect.Struct { + return fmt.Errorf("data is not a struct") + } + + nilFields := []string{} + for i := range val.NumField() { + field := val.Field(i) + fieldType := valType.Field(i) + tagValue := fieldType.Tag.Get("nil_checker") + + if tagValue == "required" { + if field.Kind() == reflect.Ptr { + if field.IsNil() { + nilFields = append(nilFields, fieldType.Name) + } + } + } + } + if len(nilFields) > 0 { + return fmt.Errorf("%s is empty", strings.Join(nilFields, ",")) + } + return nil +} diff --git a/v1/example/server/main.go b/v1/example/server/main.go new file mode 100644 index 0000000..875ffcb --- /dev/null +++ b/v1/example/server/main.go @@ -0,0 +1,32 @@ +package main + +import ( + "log" + "time" + + "git.neurocipta.com/rogerferdinan/safe-web-socket/v1/server" +) + +func main() { + s, err := server.NewSafeWebsocketServerBuilder(). + BaseHost("localhost"). + BasePort(8080). + HandleFuncWebsocket("/ws/test/", "data_1", func(c chan []byte) { + ticker := time.NewTicker(100 * time.Millisecond) + for range ticker.C { + c <- []byte(time.Now().Format("2006-01-02 15:04:05") + "_data_1") + } + }). + HandleFuncWebsocket("/ws/test/", "data_2", func(c chan []byte) { + ticker := time.NewTicker(100 * time.Millisecond) + for range ticker.C { + c <- []byte(time.Now().Format("2006-01-02 15:04:05") + "_data_2") + } + }). + Build() + + if err != nil { + log.Fatal(err) + } + s.ListenAndServe() +} diff --git a/v1/server/server.go b/v1/server/server.go new file mode 100644 index 0000000..d8b70f1 --- /dev/null +++ b/v1/server/server.go @@ -0,0 +1,107 @@ +package server + +import ( + "fmt" + "log" + "net/http" + "strings" + + "git.neurocipta.com/rogerferdinan/safe-web-socket/internal" + "github.com/gorilla/websocket" +) + +type SafeWebsocketServerBuilder struct { + baseHost *string `nil_checker:"required"` + basePort *uint16 `nil_checker:"required"` + upgrader *websocket.Upgrader `nil_checker:"required"` + mux *http.ServeMux `nil_checker:"required"` +} + +func NewSafeWebsocketServerBuilder() *SafeWebsocketServerBuilder { + h := internal.NewHub() + h.Run() + + return &SafeWebsocketServerBuilder{ + upgrader: &websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true + }, + }, + mux: http.NewServeMux(), + } +} + +func (b *SafeWebsocketServerBuilder) BaseHost(baseHost string) *SafeWebsocketServerBuilder { + b.baseHost = &baseHost + return b +} + +func (b *SafeWebsocketServerBuilder) BasePort(basePort uint16) *SafeWebsocketServerBuilder { + b.basePort = &basePort + return b +} + +func (b *SafeWebsocketServerBuilder) HandleFunc( + pattern string, fn func(http.ResponseWriter, *http.Request), +) *SafeWebsocketServerBuilder { + b.mux.HandleFunc(pattern, fn) + return b +} + +func (b *SafeWebsocketServerBuilder) HandleFuncWebsocket( + pattern string, subscribedPath string, writeFunc func(chan []byte), +) *SafeWebsocketServerBuilder { + h := internal.NewHub() + h.Run() + + b.mux.HandleFunc(pattern+subscribedPath, func(w http.ResponseWriter, r *http.Request) { + conn, err := b.upgrader.Upgrade(w, r, nil) + if err != nil { + http.Error(w, "upgrade failed", http.StatusBadRequest) + return + } + + subscribedPath := strings.TrimPrefix(r.URL.Path, pattern) + fmt.Println(subscribedPath) + if subscribedPath == "" { + http.Error(w, "invalid path", http.StatusBadRequest) + return + } + c := internal.NewClient(conn, subscribedPath) + h.Register <- c + internal.WritePump(c, h) + internal.ReadPump(c) + writeFunc(h.Broadcast) + }) + return b +} + +func (b *SafeWebsocketServerBuilder) Build() (*SafeWebsocketServer, error) { + if err := internal.NilChecker(b); err != nil { + return nil, err + } + + safeServer := SafeWebsocketServer{ + url: fmt.Sprintf("%s:%d", *b.baseHost, *b.basePort), + mux: b.mux, + mu: internal.NewCustomRwMutex(), + } + return &safeServer, nil +} + +type SafeWebsocketServer struct { + hub *internal.Hub + mux *http.ServeMux + url string + mu *internal.CustomRwMutex +} + +func (s *SafeWebsocketServer) ListenAndServe() error { + log.Printf("HTTP serve on %s\n", s.url) + if err := http.ListenAndServe(s.url, s.mux); err != nil { + return fmt.Errorf("failed to serve websocket: %w", err) + } + return nil +}