fix: fixing memory leak on client

This commit is contained in:
2026-02-05 07:30:49 +07:00
parent 07f7893a26
commit 2200657ba7

View File

@@ -16,8 +16,9 @@ import (
) )
const ( const (
pingPeriod = 10 * time.Second pingPeriod = 10 * time.Second
readDeadline = 30 * time.Second readDeadline = 30 * time.Second
writeDeadline = 10 * time.Second
) )
type MessageType uint type MessageType uint
@@ -182,17 +183,14 @@ func (wsClient *SafeWebsocketClient) connect() error {
Scheme: scheme, Scheme: scheme,
Host: fmt.Sprintf("%s:%d", wsClient.baseHost, wsClient.basePort), Host: fmt.Sprintf("%s:%d", wsClient.baseHost, wsClient.basePort),
} }
if wsClient.path != nil && strings.TrimSpace(*wsClient.path) != "" { if wsClient.path != nil && strings.TrimSpace(*wsClient.path) != "" {
newURL.Path = *wsClient.path newURL.Path = *wsClient.path
} }
if wsClient.rawQuery != nil && strings.TrimSpace(*wsClient.rawQuery) != "" { if wsClient.rawQuery != nil && strings.TrimSpace(*wsClient.rawQuery) != "" {
newURL.RawQuery = *wsClient.rawQuery newURL.RawQuery = *wsClient.rawQuery
} }
header := make(http.Header) header := make(http.Header)
if wsClient.headers != nil { if wsClient.headers != nil {
for k, v := range *wsClient.headers { for k, v := range *wsClient.headers {
header.Set(k, v) header.Set(k, v)
@@ -205,18 +203,22 @@ func (wsClient *SafeWebsocketClient) connect() error {
} }
pingCtx, pingCancel := context.WithCancel(context.Background()) pingCtx, pingCancel := context.WithCancel(context.Background())
pumpCtx, pumpCancel := context.WithCancel(context.Background())
wsClient.mu.WriteHandler(func() error { wsClient.mu.WriteHandler(func() error {
wsClient.cancelFuncs = append(wsClient.cancelFuncs, pingCancel) if wsClient.conn != nil {
wsClient.conn.Close()
}
wsClient.conn = conn
wsClient.cancelFuncs = append(wsClient.cancelFuncs, pingCancel, pumpCancel)
return nil return nil
}) })
go wsClient.startPingTicker(pingCtx) go wsClient.startPingTicker(pingCtx)
go wsClient.writePump(pumpCtx, conn)
if wsClient.conn != nil { go wsClient.readPump(pumpCtx, conn)
wsClient.conn.Close()
}
wsClient.conn = conn
wsClient.isConnected = true
conn.SetPingHandler(func(pingData string) error { conn.SetPingHandler(func(pingData string) error {
if err := conn.SetReadDeadline(time.Now().Add(readDeadline)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(readDeadline)); err != nil {
@@ -239,24 +241,72 @@ func (wsClient *SafeWebsocketClient) connect() error {
return nil return nil
}) })
go wsClient.writePump() wsClient.isConnected = true
go wsClient.readPump()
return nil return nil
} }
func (wsClient *SafeWebsocketClient) writePump() { func (wsClient *SafeWebsocketClient) reconnectHandler() {
ctx, cancel := context.WithCancel(context.Background()) backoff := 1 * time.Second
wsClient.mu.WriteHandler(func() error { maxBackoff := 15 * time.Second
wsClient.cancelFuncs = append(wsClient.cancelFuncs, cancel)
return nil
})
var c *websocket.Conn for {
wsClient.mu.ReadHandler(func() error { select {
c = wsClient.conn case <-wsClient.reconnectCh:
return nil log.Println("Reconnect triggered")
})
wsClient.mu.WriteHandler(func() error {
if wsClient.cancelFuncs != nil {
for _, cancel := range wsClient.cancelFuncs {
cancel()
}
wsClient.cancelFuncs = nil
}
return nil
})
wsClient.isConnected = false
isInnerLoop := true
for isInnerLoop {
log.Printf("Attempting reconnect in %v...", backoff)
select {
case <-time.After(backoff):
if err := wsClient.connect(); err != nil {
log.Printf("Reconnect failed: %v", err)
if backoff < maxBackoff {
backoff *= 2
}
continue
}
log.Println("Reconnected successfully")
backoff = 1 * time.Second
isInnerLoop = false
continue
case <-wsClient.ctx.Done():
log.Println("reconnect handler stopped due to client shutdown")
wsClient.Close()
return
}
}
if wsClient.reconnectChans != nil {
for _, reconnectCh := range wsClient.reconnectChans {
select {
case reconnectCh <- struct{}{}:
default: // prevent blocking if chan is full
}
}
}
case <-wsClient.ctx.Done():
log.Println("reconnect handler stopped due to client shutdown")
wsClient.Close()
return
}
}
}
func (wsClient *SafeWebsocketClient) writePump(ctx context.Context, c *websocket.Conn) {
defer func() {
c.Close()
}()
for { for {
select { select {
@@ -264,30 +314,25 @@ func (wsClient *SafeWebsocketClient) writePump() {
log.Println("Writer canceled by context") log.Println("Writer canceled by context")
return return
case data := <-wsClient.writeChan: case data := <-wsClient.writeChan:
if c == nil { if err := c.SetWriteDeadline(time.Now().Add(writeDeadline)); err != nil {
log.Printf("error setting write deadline: %v", err)
return return
} }
if err := c.WriteMessage(int(data.MessageType), data.Data); err != nil { if err := c.WriteMessage(int(data.MessageType), data.Data); err != nil {
log.Printf("error on write message: %v\n", err) log.Printf("error on write message: %v\n", err)
wsClient.triggerReconnect() // Trigger reconnect on write failure
return return
} }
} }
} }
} }
func (wsClient *SafeWebsocketClient) readPump() { func (wsClient *SafeWebsocketClient) readPump(ctx context.Context, c *websocket.Conn) {
ctx, cancel := context.WithCancel(context.Background()) defer func() {
wsClient.mu.WriteHandler(func() error { wsClient.triggerReconnect()
wsClient.cancelFuncs = append(wsClient.cancelFuncs, cancel) c.Close()
return nil }()
})
var c *websocket.Conn
wsClient.mu.ReadHandler(func() error {
c = wsClient.conn
return nil
})
for { for {
select { select {
@@ -295,19 +340,14 @@ func (wsClient *SafeWebsocketClient) readPump() {
log.Println("Reader canceled by context") log.Println("Reader canceled by context")
return return
default: default:
if c == nil { // Set read deadline
return
}
if err := c.SetReadDeadline(time.Now().Add(readDeadline)); err != nil { if err := c.SetReadDeadline(time.Now().Add(readDeadline)); err != nil {
log.Printf("error on read deadline: %v\n", err)
return return
} }
messageType, data, err := c.ReadMessage() messageType, data, err := c.ReadMessage()
if err != nil { if err != nil {
log.Printf("error on read message: %v\n", err) log.Printf("error on read message: %v\n", err)
wsClient.triggerReconnect()
return return
} }
@@ -353,60 +393,6 @@ func (wsClient *SafeWebsocketClient) triggerReconnect() {
} }
} }
func (wsClient *SafeWebsocketClient) reconnectHandler() {
backoff := 1 * time.Second
maxBackoff := 30 * time.Second
for {
select {
case <-wsClient.reconnectCh:
log.Println("Reconnect triggered")
wsClient.mu.ReadHandler(func() error {
if wsClient.cancelFuncs != nil {
for _, cancel := range wsClient.cancelFuncs {
cancel()
}
}
return nil
})
wsClient.isConnected = false
isInnerLoop := true
for isInnerLoop {
log.Printf("Attempting reconnect in %v...", backoff)
select {
case <-time.After(backoff):
if err := wsClient.connect(); err != nil {
log.Printf("Reconnect failed: %v", err)
if backoff < maxBackoff {
backoff *= 2
}
continue
}
log.Println("Reconnected successfully")
backoff = 1 * time.Second
isInnerLoop = false
continue
case <-wsClient.ctx.Done():
log.Println("reconnect handler stopped due to client shutdown")
wsClient.Close()
return
}
}
if wsClient.reconnectChans != nil {
for _, reconnectCh := range wsClient.reconnectChans {
reconnectCh <- struct{}{}
}
}
case <-wsClient.ctx.Done():
log.Println("reconnect handler stopped due to client shutdown")
wsClient.Close()
return
}
}
}
func (wsClient *SafeWebsocketClient) ReconnectChannel() <-chan struct{} { func (wsClient *SafeWebsocketClient) ReconnectChannel() <-chan struct{} {
reconnectCh := make(chan struct{}, 1) reconnectCh := make(chan struct{}, 1)
wsClient.mu.WriteHandler(func() error { wsClient.mu.WriteHandler(func() error {