From 9816426780379b3409b302442a44a4dc905aaedf Mon Sep 17 00:00:00 2001 From: Roger Ferdinan Date: Fri, 21 Nov 2025 20:01:49 +0700 Subject: [PATCH] feat: adding header support for client --- v1/client/client.go | 114 ++++++++++++++++--------------------- v1/examples/client/main.go | 8 ++- v1/examples/server/main.go | 2 +- 3 files changed, 55 insertions(+), 69 deletions(-) diff --git a/v1/client/client.go b/v1/client/client.go index a6c663b..da5c818 100644 --- a/v1/client/client.go +++ b/v1/client/client.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "net/http" "net/url" "strings" "time" @@ -36,6 +37,7 @@ type Message struct { type SafeWebsocketClientBuilder struct { baseHost *string `nil_checker:"required"` basePort *uint16 `nil_checker:"required"` + headers *map[string]string path *string rawQuery *string isDrop *bool @@ -58,6 +60,11 @@ func (b *SafeWebsocketClientBuilder) BasePort(port uint16) *SafeWebsocketClientB return b } +func (b *SafeWebsocketClientBuilder) Headers(headers map[string]string) *SafeWebsocketClientBuilder { + b.headers = &headers + return b +} + func (b *SafeWebsocketClientBuilder) UseTLS(useTLS bool) *SafeWebsocketClientBuilder { b.useTLS = &useTLS return b @@ -117,6 +124,7 @@ func (b *SafeWebsocketClientBuilder) Build(ctx context.Context) (*SafeWebsocketC wsClient := SafeWebsocketClient{ baseHost: *b.baseHost, basePort: *b.basePort, + headers: b.headers, useTLS: *b.useTLS, isDrop: *b.isDrop, path: b.path, @@ -142,6 +150,8 @@ func (b *SafeWebsocketClientBuilder) Build(ctx context.Context) (*SafeWebsocketC type SafeWebsocketClient struct { baseHost string basePort uint16 + headers *map[string]string + isDrop bool useTLS bool path *string @@ -182,7 +192,16 @@ func (wsClient *SafeWebsocketClient) connect() error { newURL.RawQuery = *wsClient.rawQuery } - conn, _, err := websocket.DefaultDialer.Dial(newURL.String(), nil) + header := make(http.Header) + + if wsClient.headers != nil { + for k, v := range *wsClient.headers { + fmt.Println(k, v) + header.Set(k, v) + } + } + + conn, _, err := websocket.DefaultDialer.Dial(newURL.String(), header) if err != nil { return fmt.Errorf("failed to connect to %s: %w", wsClient.baseHost, err) } @@ -268,78 +287,43 @@ func (wsClient *SafeWebsocketClient) readPump() { return nil }) - if wsClient.isDrop { - for { + for { + select { + case <-ctx.Done(): + log.Println("Reader canceled by context") + return + default: + if c == nil { + return + } + + if err := c.SetReadDeadline(time.Now().Add(readDeadline)); err != nil { + log.Printf("error on read deadline: %v\n", err) + return + } + + messageType, data, err := c.ReadMessage() + if err != nil { + log.Printf("error on read message: %v\n", err) + wsClient.triggerReconnect() + return + } + + if messageType != websocket.TextMessage { + continue + } + select { + case wsClient.dataChannel <- data: case <-ctx.Done(): - log.Println("Reader canceled by context") return default: - if c == nil { - return - } - - if err := c.SetReadDeadline(time.Now().Add(readDeadline)); err != nil { - log.Printf("error on read deadline: %v\n", err) - return - } - - messageType, data, err := c.ReadMessage() - if err != nil { - log.Printf("error on read message: %v\n", err) - wsClient.triggerReconnect() - return - } - - if messageType != websocket.TextMessage { - continue - } - - select { - case wsClient.dataChannel <- data: - case <-ctx.Done(): - return - default: + if wsClient.isDrop { log.Println("Data channel full, dropping message") } } } - } else { - for { - select { - case <-ctx.Done(): - log.Println("Reader canceled by context") - return - default: - if c == nil { - return - } - - if err := c.SetReadDeadline(time.Now().Add(readDeadline)); err != nil { - log.Printf("error on read deadline: %v\n", err) - return - } - - messageType, data, err := c.ReadMessage() - if err != nil { - log.Printf("error on read message: %v\n", err) - wsClient.triggerReconnect() - return - } - - if messageType != websocket.TextMessage { - continue - } - - select { - case wsClient.dataChannel <- data: - case <-ctx.Done(): - return - } - } - } } - } func (wsClient *SafeWebsocketClient) startPingTicker(ctx context.Context) { @@ -466,7 +450,7 @@ func (wsClient *SafeWebsocketClient) Close() error { wsClient.conn.Close() } wsClient.isConnected = false - close(wsClient.dataChannel) + // close(wsClient.dataChannel) return nil } diff --git a/v1/examples/client/main.go b/v1/examples/client/main.go index 8ae14c0..79a7808 100644 --- a/v1/examples/client/main.go +++ b/v1/examples/client/main.go @@ -27,7 +27,9 @@ func main() { wsClient, err := client.NewSafeWebsocketClientBuilder(). BaseHost("localhost"). BasePort(8080). - Path("/ws/test/data_1"). + Headers(map[string]string{ + "X-MBX-APIKEY": "abcd", + }).Path("/ws/test/data_1"). UseTLS(false). ChannelSize(1). Build(ctx) @@ -44,7 +46,7 @@ func main() { dataChannel := wsClient.DataChannel() for data := range dataChannel { - // _ = data - fmt.Println(string(data)) + _ = data + // fmt.Println(string(data)) } } diff --git a/v1/examples/server/main.go b/v1/examples/server/main.go index 3841ea7..3b57234 100644 --- a/v1/examples/server/main.go +++ b/v1/examples/server/main.go @@ -17,7 +17,7 @@ func main() { s, err := server.NewSafeWebsocketServerBuilder(). BaseHost("localhost"). BasePort(8080). - ApiKey(""). + ApiKey("abcd"). HandleFuncWebsocket("/ws/test/", "data_1", func(c chan []byte) { ticker := time.NewTicker(10 * time.Millisecond) for range ticker.C {