mirror of
https://github.com/portainer/portainer.git
synced 2025-07-19 13:29:41 +02:00
feat(websocket): improve websocket code sharing BE-11340 (#61)
Some checks failed
Label Conflicts / triage (push) Has been cancelled
Some checks failed
Label Conflicts / triage (push) Has been cancelled
This commit is contained in:
parent
b2d67795b3
commit
1d037f2f1f
6 changed files with 122 additions and 153 deletions
145
api/ws/hijack.go
Normal file
145
api/ws/hijack.go
Normal file
|
@ -0,0 +1,145 @@
|
|||
package ws
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const (
|
||||
// Time allowed to write a message to the peer
|
||||
WriteWait = 10 * time.Second
|
||||
|
||||
// Send pings to peer with this period
|
||||
PingPeriod = 50 * time.Second
|
||||
)
|
||||
|
||||
func HijackRequest(websocketConn *websocket.Conn, conn net.Conn, request *http.Request) error {
|
||||
resp, err := sendHTTPRequest(conn, request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check if the response status code indicates an upgrade (101 Switching Protocols)
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
return fmt.Errorf("unexpected response status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
|
||||
errorChan := make(chan error, 1)
|
||||
go StreamFromWebsocketToWriter(websocketConn, conn, errorChan)
|
||||
go WriteReaderToWebSocket(websocketConn, &mu, conn, errorChan)
|
||||
|
||||
err = <-errorChan
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
|
||||
log.Debug().Err(err).Msg("unexpected close error")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Msg("session ended")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendHTTPRequest sends an HTTP request over the provided net.Conn and parses the response.
|
||||
func sendHTTPRequest(conn net.Conn, req *http.Request) (*http.Response, error) {
|
||||
// Send the HTTP request to the server
|
||||
if err := req.Write(conn); err != nil {
|
||||
return nil, fmt.Errorf("error writing request: %w", err)
|
||||
}
|
||||
|
||||
// Read the response from the server
|
||||
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading response: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func WriteReaderToWebSocket(websocketConn *websocket.Conn, mu *sync.Mutex, reader io.Reader, errorChan chan error) {
|
||||
out := make([]byte, ReaderBufferSize)
|
||||
input := make(chan string)
|
||||
pingTicker := time.NewTicker(PingPeriod)
|
||||
defer pingTicker.Stop()
|
||||
defer websocketConn.Close()
|
||||
|
||||
mu.Lock()
|
||||
websocketConn.SetReadLimit(ReaderBufferSize)
|
||||
websocketConn.SetPongHandler(func(string) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
websocketConn.SetPingHandler(func(data string) error {
|
||||
websocketConn.SetWriteDeadline(time.Now().Add(WriteWait))
|
||||
|
||||
return websocketConn.WriteMessage(websocket.PongMessage, []byte(data))
|
||||
})
|
||||
mu.Unlock()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
n, err := reader.Read(out)
|
||||
if err != nil {
|
||||
errorChan <- err
|
||||
|
||||
if !errors.Is(err, io.EOF) {
|
||||
log.Debug().Err(err).Msg("error reading from server")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
processedOutput := ValidString(string(out[:n]))
|
||||
input <- processedOutput
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-input:
|
||||
if err := wsWrite(websocketConn, mu, msg); err != nil {
|
||||
log.Debug().Err(err).Msg("error writing to websocket")
|
||||
errorChan <- err
|
||||
|
||||
return
|
||||
}
|
||||
case <-pingTicker.C:
|
||||
if err := wsPing(websocketConn, mu); err != nil {
|
||||
log.Debug().Err(err).Msg("error writing to websocket during pong response")
|
||||
errorChan <- err
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func wsWrite(websocketConn *websocket.Conn, mu *sync.Mutex, msg string) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
websocketConn.SetWriteDeadline(time.Now().Add(WriteWait))
|
||||
|
||||
return websocketConn.WriteMessage(websocket.TextMessage, []byte(msg))
|
||||
}
|
||||
|
||||
func wsPing(websocketConn *websocket.Conn, mu *sync.Mutex) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
websocketConn.SetWriteDeadline(time.Now().Add(WriteWait))
|
||||
|
||||
return websocketConn.WriteMessage(websocket.PingMessage, nil)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue