package server import ( "fmt" "log" "net/http" "strings" "syscall" "git.neurocipta.com/rogerferdinan/safe-web-socket/internal" "github.com/gorilla/websocket" ) func setMaxRLimit() { var rLimit syscall.Rlimit if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit); err != nil { panic(err) } rLimit.Cur = rLimit.Max if err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLimit); err != nil { panic(err) } } type SafeWebsocketServerBuilder struct { baseHost *string `nil_checker:"required"` basePort *uint16 `nil_checker:"required"` apiKey *string `nil_checker:"required"` upgrader *websocket.Upgrader `nil_checker:"required"` mux *http.ServeMux `nil_checker:"required"` } func NewSafeWebsocketServerBuilder() *SafeWebsocketServerBuilder { return &SafeWebsocketServerBuilder{ upgrader: &websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { return true }, }, mux: http.NewServeMux(), } } func (b *SafeWebsocketServerBuilder) BaseHost(baseHost string) *SafeWebsocketServerBuilder { b.baseHost = &baseHost return b } func (b *SafeWebsocketServerBuilder) BasePort(basePort uint16) *SafeWebsocketServerBuilder { b.basePort = &basePort return b } func (b *SafeWebsocketServerBuilder) ApiKey(apiKey string) *SafeWebsocketServerBuilder { b.apiKey = &apiKey return b } func (b *SafeWebsocketServerBuilder) HandleFunc(pattern string, fn func(http.ResponseWriter, *http.Request)) *SafeWebsocketServerBuilder { b.mux.HandleFunc(pattern, fn) return b } func (b *SafeWebsocketServerBuilder) HandleFuncWebsocket(pattern string, subscribedPath string, writeFunc func(chan []byte)) *SafeWebsocketServerBuilder { h := internal.NewHub() h.Run() b.mux.HandleFunc(pattern+subscribedPath, func(w http.ResponseWriter, r *http.Request) { conn, err := b.upgrader.Upgrade(w, r, nil) if err != nil { http.Error(w, "upgrade failed", http.StatusBadRequest) return } subscribedPath := strings.TrimPrefix(r.URL.Path, pattern) if subscribedPath == "" { http.Error(w, "invalid path", http.StatusBadRequest) return } c := internal.NewClient(conn, subscribedPath) h.Register <- c go internal.WritePump(c, h) go internal.ReadPump(c, h) go writeFunc(h.Broadcast) }) return b } func (b *SafeWebsocketServerBuilder) Build() (*SafeWebsocketServer, error) { if err := internal.NilChecker(b); err != nil { return nil, err } setMaxRLimit() safeServer := SafeWebsocketServer{ mux: b.mux, url: fmt.Sprintf("%s:%d", *b.baseHost, *b.basePort), apiKey: *b.apiKey, } return &safeServer, nil } type SafeWebsocketServer struct { mux *http.ServeMux url string apiKey string } func (s *SafeWebsocketServer) AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("X-MBX-APIKEY") != s.apiKey { internal.ErrorResponse(w, internal.NewStatusMessage(). StatusCode(http.StatusForbidden). Message("X-MBX-APIKEY is missing"). Build()) } next.ServeHTTP(w, r) }) } func (s *SafeWebsocketServer) ListenAndServe() error { log.Printf("HTTP serve on %s\n", s.url) if err := http.ListenAndServe(s.url, s.AuthMiddleware(s.mux)); err != nil { return fmt.Errorf("failed to serve websocket: %w", err) } return nil }