fix: fixing memory leak on client

This commit is contained in:
2026-02-05 07:30:49 +07:00
parent 07f7893a26
commit 2200657ba7

View File

@@ -16,8 +16,9 @@ import (
)
const (
pingPeriod = 10 * time.Second
readDeadline = 30 * time.Second
pingPeriod = 10 * time.Second
readDeadline = 30 * time.Second
writeDeadline = 10 * time.Second
)
type MessageType uint
@@ -182,17 +183,14 @@ func (wsClient *SafeWebsocketClient) connect() error {
Scheme: scheme,
Host: fmt.Sprintf("%s:%d", wsClient.baseHost, wsClient.basePort),
}
if wsClient.path != nil && strings.TrimSpace(*wsClient.path) != "" {
newURL.Path = *wsClient.path
}
if wsClient.rawQuery != nil && strings.TrimSpace(*wsClient.rawQuery) != "" {
newURL.RawQuery = *wsClient.rawQuery
}
header := make(http.Header)
if wsClient.headers != nil {
for k, v := range *wsClient.headers {
header.Set(k, v)
@@ -205,18 +203,22 @@ func (wsClient *SafeWebsocketClient) connect() error {
}
pingCtx, pingCancel := context.WithCancel(context.Background())
pumpCtx, pumpCancel := context.WithCancel(context.Background())
wsClient.mu.WriteHandler(func() error {
wsClient.cancelFuncs = append(wsClient.cancelFuncs, pingCancel)
if wsClient.conn != nil {
wsClient.conn.Close()
}
wsClient.conn = conn
wsClient.cancelFuncs = append(wsClient.cancelFuncs, pingCancel, pumpCancel)
return nil
})
go wsClient.startPingTicker(pingCtx)
if wsClient.conn != nil {
wsClient.conn.Close()
}
wsClient.conn = conn
wsClient.isConnected = true
go wsClient.writePump(pumpCtx, conn)
go wsClient.readPump(pumpCtx, conn)
conn.SetPingHandler(func(pingData string) error {
if err := conn.SetReadDeadline(time.Now().Add(readDeadline)); err != nil {
@@ -239,24 +241,72 @@ func (wsClient *SafeWebsocketClient) connect() error {
return nil
})
go wsClient.writePump()
go wsClient.readPump()
wsClient.isConnected = true
return nil
}
func (wsClient *SafeWebsocketClient) writePump() {
ctx, cancel := context.WithCancel(context.Background())
wsClient.mu.WriteHandler(func() error {
wsClient.cancelFuncs = append(wsClient.cancelFuncs, cancel)
return nil
})
func (wsClient *SafeWebsocketClient) reconnectHandler() {
backoff := 1 * time.Second
maxBackoff := 15 * time.Second
var c *websocket.Conn
wsClient.mu.ReadHandler(func() error {
c = wsClient.conn
return nil
})
for {
select {
case <-wsClient.reconnectCh:
log.Println("Reconnect triggered")
wsClient.mu.WriteHandler(func() error {
if wsClient.cancelFuncs != nil {
for _, cancel := range wsClient.cancelFuncs {
cancel()
}
wsClient.cancelFuncs = nil
}
return nil
})
wsClient.isConnected = false
isInnerLoop := true
for isInnerLoop {
log.Printf("Attempting reconnect in %v...", backoff)
select {
case <-time.After(backoff):
if err := wsClient.connect(); err != nil {
log.Printf("Reconnect failed: %v", err)
if backoff < maxBackoff {
backoff *= 2
}
continue
}
log.Println("Reconnected successfully")
backoff = 1 * time.Second
isInnerLoop = false
continue
case <-wsClient.ctx.Done():
log.Println("reconnect handler stopped due to client shutdown")
wsClient.Close()
return
}
}
if wsClient.reconnectChans != nil {
for _, reconnectCh := range wsClient.reconnectChans {
select {
case reconnectCh <- struct{}{}:
default: // prevent blocking if chan is full
}
}
}
case <-wsClient.ctx.Done():
log.Println("reconnect handler stopped due to client shutdown")
wsClient.Close()
return
}
}
}
func (wsClient *SafeWebsocketClient) writePump(ctx context.Context, c *websocket.Conn) {
defer func() {
c.Close()
}()
for {
select {
@@ -264,30 +314,25 @@ func (wsClient *SafeWebsocketClient) writePump() {
log.Println("Writer canceled by context")
return
case data := <-wsClient.writeChan:
if c == nil {
if err := c.SetWriteDeadline(time.Now().Add(writeDeadline)); err != nil {
log.Printf("error setting write deadline: %v", err)
return
}
if err := c.WriteMessage(int(data.MessageType), data.Data); err != nil {
log.Printf("error on write message: %v\n", err)
wsClient.triggerReconnect() // Trigger reconnect on write failure
return
}
}
}
}
func (wsClient *SafeWebsocketClient) readPump() {
ctx, cancel := context.WithCancel(context.Background())
wsClient.mu.WriteHandler(func() error {
wsClient.cancelFuncs = append(wsClient.cancelFuncs, cancel)
return nil
})
var c *websocket.Conn
wsClient.mu.ReadHandler(func() error {
c = wsClient.conn
return nil
})
func (wsClient *SafeWebsocketClient) readPump(ctx context.Context, c *websocket.Conn) {
defer func() {
wsClient.triggerReconnect()
c.Close()
}()
for {
select {
@@ -295,19 +340,14 @@ func (wsClient *SafeWebsocketClient) readPump() {
log.Println("Reader canceled by context")
return
default:
if c == nil {
return
}
// Set read deadline
if err := c.SetReadDeadline(time.Now().Add(readDeadline)); err != nil {
log.Printf("error on read deadline: %v\n", err)
return
}
messageType, data, err := c.ReadMessage()
if err != nil {
log.Printf("error on read message: %v\n", err)
wsClient.triggerReconnect()
return
}
@@ -353,60 +393,6 @@ func (wsClient *SafeWebsocketClient) triggerReconnect() {
}
}
func (wsClient *SafeWebsocketClient) reconnectHandler() {
backoff := 1 * time.Second
maxBackoff := 30 * time.Second
for {
select {
case <-wsClient.reconnectCh:
log.Println("Reconnect triggered")
wsClient.mu.ReadHandler(func() error {
if wsClient.cancelFuncs != nil {
for _, cancel := range wsClient.cancelFuncs {
cancel()
}
}
return nil
})
wsClient.isConnected = false
isInnerLoop := true
for isInnerLoop {
log.Printf("Attempting reconnect in %v...", backoff)
select {
case <-time.After(backoff):
if err := wsClient.connect(); err != nil {
log.Printf("Reconnect failed: %v", err)
if backoff < maxBackoff {
backoff *= 2
}
continue
}
log.Println("Reconnected successfully")
backoff = 1 * time.Second
isInnerLoop = false
continue
case <-wsClient.ctx.Done():
log.Println("reconnect handler stopped due to client shutdown")
wsClient.Close()
return
}
}
if wsClient.reconnectChans != nil {
for _, reconnectCh := range wsClient.reconnectChans {
reconnectCh <- struct{}{}
}
}
case <-wsClient.ctx.Done():
log.Println("reconnect handler stopped due to client shutdown")
wsClient.Close()
return
}
}
}
func (wsClient *SafeWebsocketClient) ReconnectChannel() <-chan struct{} {
reconnectCh := make(chan struct{}, 1)
wsClient.mu.WriteHandler(func() error {