package server import ( "context" "crypto/subtle" "fmt" "log" "net/http" "strings" "syscall" "time" "git.neurocipta.com/rogerferdinan/safe-web-socket/internal" "github.com/gorilla/websocket" ) func setMaxRLimit() { var rLimit syscall.Rlimit if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit); err != nil { panic(err) } rLimit.Cur = rLimit.Max if err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLimit); err != nil { panic(err) } } type SafeWebsocketServerBuilder struct { baseHost *string `nil_checker:"required"` basePort *uint16 `nil_checker:"required"` apiKey *string `nil_checker:"required"` upgrader *websocket.Upgrader `nil_checker:"required"` mux *http.ServeMux `nil_checker:"required"` ctx context.Context } func NewSafeWebsocketServerBuilder() *SafeWebsocketServerBuilder { return &SafeWebsocketServerBuilder{ upgrader: &websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { return true }, }, mux: http.NewServeMux(), ctx: context.Background(), } } func (b *SafeWebsocketServerBuilder) BaseHost(baseHost string) *SafeWebsocketServerBuilder { b.baseHost = &baseHost return b } func (b *SafeWebsocketServerBuilder) BasePort(basePort uint16) *SafeWebsocketServerBuilder { b.basePort = &basePort return b } func (b *SafeWebsocketServerBuilder) ApiKey(apiKey string) *SafeWebsocketServerBuilder { b.apiKey = &apiKey return b } // Context sets the lifecycle context for all hub and writeFunc goroutines. // When ctx is cancelled, hubs stop dispatching and all active connections // are unblocked so their goroutines can exit cleanly. // Call this before HandleFuncWebsocket. Defaults to context.Background(). func (b *SafeWebsocketServerBuilder) Context(ctx context.Context) *SafeWebsocketServerBuilder { b.ctx = ctx return b } func (b *SafeWebsocketServerBuilder) HandleFunc(pattern string, fn func(http.ResponseWriter, *http.Request)) *SafeWebsocketServerBuilder { b.mux.HandleFunc(pattern, fn) return b } // HandleFuncWebsocket registers a WebSocket endpoint. func (b *SafeWebsocketServerBuilder) HandleFuncWebsocket(pattern string, subscribedPath string, maxClients int, writeFunc func(ctx context.Context, writeChannel chan []byte)) *SafeWebsocketServerBuilder { h := internal.NewHub(pattern+subscribedPath, maxClients) h.Run(b.ctx) go writeFunc(b.ctx, h.Broadcast) b.mux.HandleFunc(pattern+subscribedPath, func(w http.ResponseWriter, r *http.Request) { conn, err := b.upgrader.Upgrade(w, r, nil) if err != nil { http.Error(w, "upgrade failed", http.StatusBadRequest) return } subscribedPath := strings.TrimPrefix(r.URL.Path, pattern) if subscribedPath == "" { http.Error(w, "invalid path", http.StatusBadRequest) return } c := internal.NewClient(conn, subscribedPath) h.Register <- c go internal.WritePump(c, h) go internal.ReadPump(c, h) }) return b } func (b *SafeWebsocketServerBuilder) Build() (*SafeWebsocketServer, error) { if err := internal.NilChecker(b); err != nil { return nil, err } setMaxRLimit() safeServer := SafeWebsocketServer{ mux: b.mux, url: fmt.Sprintf("%s:%d", *b.baseHost, *b.basePort), apiKey: *b.apiKey, ctx: b.ctx, } return &safeServer, nil } type SafeWebsocketServer struct { mux *http.ServeMux url string apiKey string ctx context.Context } func (s *SafeWebsocketServer) AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { providedKey := r.Header.Get("X-MBX-APIKEY") expectedKey := s.apiKey if subtle.ConstantTimeCompare([]byte(providedKey), []byte(expectedKey)) != 1 { internal.ErrorResponse(w, internal.NewStatusMessage(). StatusCode(http.StatusForbidden). Message("X-MBX-APIKEY is missing"). Build()) return } next.ServeHTTP(w, r) }) } func (s *SafeWebsocketServer) ListenAndServe() error { srv := &http.Server{ Addr: s.url, Handler: s.AuthMiddleware(s.mux), } errCh := make(chan error, 1) go func() { log.Printf("HTTP serve on %s\n", s.url) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { errCh <- fmt.Errorf("failed to serve websocket: %w", err) } close(errCh) }() select { case err := <-errCh: return err case <-s.ctx.Done(): log.Println("Server shutting down...") shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := srv.Shutdown(shutdownCtx); err != nil { return fmt.Errorf("server shutdown: %w", err) } return <-errCh } }