mirror of
https://github.com/portainer/portainer.git
synced 2025-07-24 15:59:41 +02:00
fix(websocket): abort websocket when logout EE-6058 (#10372)
This commit is contained in:
parent
9440aa733d
commit
56ab19433a
15 changed files with 227 additions and 49 deletions
|
@ -1,6 +1,8 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/rs/zerolog/log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
|
@ -74,6 +76,13 @@ func (handler *Handler) websocketAttach(w http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
|
||||
func (handler *Handler) handleAttachRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error {
|
||||
tokenData, err := security.RetrieveTokenData(r)
|
||||
if err != nil {
|
||||
log.Warn().
|
||||
Err(err).
|
||||
Msg("unable to retrieve user details from authentication token")
|
||||
return err
|
||||
}
|
||||
|
||||
r.Header.Del("Origin")
|
||||
|
||||
|
@ -89,10 +98,15 @@ func (handler *Handler) handleAttachRequest(w http.ResponseWriter, r *http.Reque
|
|||
}
|
||||
defer websocketConn.Close()
|
||||
|
||||
return hijackAttachStartOperation(websocketConn, params.endpoint, params.ID)
|
||||
return hijackAttachStartOperation(websocketConn, params.endpoint, params.ID, tokenData.Token)
|
||||
}
|
||||
|
||||
func hijackAttachStartOperation(websocketConn *websocket.Conn, endpoint *portainer.Endpoint, attachID string) error {
|
||||
func hijackAttachStartOperation(
|
||||
websocketConn *websocket.Conn,
|
||||
endpoint *portainer.Endpoint,
|
||||
attachID string,
|
||||
token string,
|
||||
) error {
|
||||
dial, err := initDial(endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -116,7 +130,7 @@ func hijackAttachStartOperation(websocketConn *websocket.Conn, endpoint *portain
|
|||
return err
|
||||
}
|
||||
|
||||
return hijackRequest(websocketConn, httpConn, attachStartRequest)
|
||||
return hijackRequest(websocketConn, httpConn, attachStartRequest, token)
|
||||
}
|
||||
|
||||
func createAttachStartRequest(attachID string) (*http.Request, error) {
|
||||
|
|
|
@ -3,6 +3,8 @@ package websocket
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/rs/zerolog/log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
|
@ -80,6 +82,14 @@ func (handler *Handler) websocketExec(w http.ResponseWriter, r *http.Request) *h
|
|||
}
|
||||
|
||||
func (handler *Handler) handleExecRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error {
|
||||
tokenData, err := security.RetrieveTokenData(r)
|
||||
if err != nil {
|
||||
log.Warn().
|
||||
Err(err).
|
||||
Msg("unable to retrieve user details from authentication token")
|
||||
return err
|
||||
}
|
||||
|
||||
r.Header.Del("Origin")
|
||||
|
||||
if params.endpoint.Type == portainer.AgentOnDockerEnvironment {
|
||||
|
@ -94,10 +104,15 @@ func (handler *Handler) handleExecRequest(w http.ResponseWriter, r *http.Request
|
|||
}
|
||||
defer websocketConn.Close()
|
||||
|
||||
return hijackExecStartOperation(websocketConn, params.endpoint, params.ID)
|
||||
return hijackExecStartOperation(websocketConn, params.endpoint, params.ID, tokenData.Token)
|
||||
}
|
||||
|
||||
func hijackExecStartOperation(websocketConn *websocket.Conn, endpoint *portainer.Endpoint, execID string) error {
|
||||
func hijackExecStartOperation(
|
||||
websocketConn *websocket.Conn,
|
||||
endpoint *portainer.Endpoint,
|
||||
execID string,
|
||||
token string,
|
||||
) error {
|
||||
dial, err := initDial(endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -121,7 +136,7 @@ func hijackExecStartOperation(websocketConn *websocket.Conn, endpoint *portainer
|
|||
return err
|
||||
}
|
||||
|
||||
return hijackRequest(websocketConn, httpConn, execStartRequest)
|
||||
return hijackRequest(websocketConn, httpConn, execStartRequest, token)
|
||||
}
|
||||
|
||||
func createExecStartRequest(execID string) (*http.Request, error) {
|
||||
|
|
|
@ -7,9 +7,15 @@ import (
|
|||
"net/http/httputil"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/portainer/portainer/api/internal/logoutcontext"
|
||||
)
|
||||
|
||||
func hijackRequest(websocketConn *websocket.Conn, httpConn *httputil.ClientConn, request *http.Request) error {
|
||||
func hijackRequest(
|
||||
websocketConn *websocket.Conn,
|
||||
httpConn *httputil.ClientConn,
|
||||
request *http.Request,
|
||||
token string,
|
||||
) error {
|
||||
// Server hijacks the connection, error 'connection closed' expected
|
||||
resp, err := httpConn.Do(request)
|
||||
if !errors.Is(err, httputil.ErrPersistEOF) {
|
||||
|
@ -29,9 +35,15 @@ func hijackRequest(websocketConn *websocket.Conn, httpConn *httputil.ClientConn,
|
|||
go streamFromReaderToWebsocket(websocketConn, brw, errorChan)
|
||||
go streamFromWebsocketToWriter(websocketConn, tcpConn, errorChan)
|
||||
|
||||
err = <-errorChan
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
|
||||
return err
|
||||
logoutCtx := logoutcontext.GetContext(token)
|
||||
|
||||
select {
|
||||
case <-logoutCtx.Done():
|
||||
return fmt.Errorf("Your session has been logged out.")
|
||||
case err = <-errorChan:
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -1,15 +1,20 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/crypto"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/logoutcontext"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/koding/websocketproxy"
|
||||
"github.com/portainer/portainer/api/crypto"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func (handler *Handler) proxyEdgeAgentWebsocketRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error {
|
||||
|
@ -18,33 +23,12 @@ func (handler *Handler) proxyEdgeAgentWebsocketRequest(w http.ResponseWriter, r
|
|||
return err
|
||||
}
|
||||
|
||||
endpointURL, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", tunnel.Port))
|
||||
agentURL, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", tunnel.Port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
endpointURL.Scheme = "ws"
|
||||
proxy := websocketproxy.NewProxy(endpointURL)
|
||||
|
||||
signature, err := handler.SignatureService.CreateSignature(portainer.PortainerAgentSignatureMessage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
proxy.Director = func(incoming *http.Request, out http.Header) {
|
||||
out.Set(portainer.PortainerAgentPublicKeyHeader, handler.SignatureService.EncodedPublicKey())
|
||||
out.Set(portainer.PortainerAgentSignatureHeader, signature)
|
||||
out.Set(portainer.PortainerAgentTargetHeader, params.nodeName)
|
||||
out.Set(portainer.PortainerAgentKubernetesSATokenHeader, params.token)
|
||||
}
|
||||
|
||||
handler.ReverseTunnelService.SetTunnelStatusToActive(params.endpoint.ID)
|
||||
|
||||
handler.ReverseTunnelService.KeepTunnelAlive(params.endpoint.ID, r.Context(), portainer.WebSocketKeepAlive)
|
||||
|
||||
proxy.ServeHTTP(w, r)
|
||||
|
||||
return nil
|
||||
return handler.doProxyWebsocketRequest(w, r, params, agentURL, true)
|
||||
}
|
||||
|
||||
func (handler *Handler) proxyAgentWebsocketRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error {
|
||||
|
@ -59,17 +43,41 @@ func (handler *Handler) proxyAgentWebsocketRequest(w http.ResponseWriter, r *htt
|
|||
}
|
||||
|
||||
agentURL.Scheme = "ws"
|
||||
proxy := websocketproxy.NewProxy(agentURL)
|
||||
return handler.doProxyWebsocketRequest(w, r, params, agentURL, false)
|
||||
}
|
||||
|
||||
if params.endpoint.TLSConfig.TLS || params.endpoint.TLSConfig.TLSSkipVerify {
|
||||
func (handler *Handler) doProxyWebsocketRequest(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
params *webSocketRequestParams,
|
||||
agentURL *url.URL,
|
||||
isEdge bool,
|
||||
) error {
|
||||
tokenData, err := security.RetrieveTokenData(r)
|
||||
if err != nil {
|
||||
log.
|
||||
Warn().
|
||||
Err(err).
|
||||
Msg("unable to retrieve user details from authentication token")
|
||||
return err
|
||||
}
|
||||
|
||||
enableTLS := !isEdge && (params.endpoint.TLSConfig.TLS || params.endpoint.TLSConfig.TLSSkipVerify)
|
||||
|
||||
agentURL.Scheme = "ws"
|
||||
if enableTLS {
|
||||
agentURL.Scheme = "wss"
|
||||
}
|
||||
|
||||
proxy := websocketproxy.NewProxy(agentURL)
|
||||
proxyDialer := *websocket.DefaultDialer
|
||||
proxy.Dialer = &proxyDialer
|
||||
|
||||
if enableTLS {
|
||||
tlsConfig := crypto.CreateTLSConfiguration()
|
||||
tlsConfig.InsecureSkipVerify = params.endpoint.TLSConfig.TLSSkipVerify
|
||||
|
||||
proxy.Dialer = &websocket.Dialer{
|
||||
TLSClientConfig: tlsConfig,
|
||||
}
|
||||
proxyDialer.TLSClientConfig = tlsConfig
|
||||
}
|
||||
|
||||
signature, err := handler.SignatureService.CreateSignature(portainer.PortainerAgentSignatureMessage)
|
||||
|
@ -84,7 +92,46 @@ func (handler *Handler) proxyAgentWebsocketRequest(w http.ResponseWriter, r *htt
|
|||
out.Set(portainer.PortainerAgentKubernetesSATokenHeader, params.token)
|
||||
}
|
||||
|
||||
if isEdge {
|
||||
handler.ReverseTunnelService.SetTunnelStatusToActive(params.endpoint.ID)
|
||||
handler.ReverseTunnelService.KeepTunnelAlive(params.endpoint.ID, r.Context(), portainer.WebSocketKeepAlive)
|
||||
}
|
||||
|
||||
abortProxyOnLogout(r.Context(), proxy, tokenData.Token)
|
||||
|
||||
proxy.ServeHTTP(w, r)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func abortProxyOnLogout(ctx context.Context, proxy *websocketproxy.WebsocketProxy, token string) {
|
||||
var wsConn net.Conn
|
||||
|
||||
proxy.Dialer.NetDial = func(network, addr string) (net.Conn, error) {
|
||||
netDialer := &net.Dialer{}
|
||||
|
||||
conn, err := netDialer.DialContext(context.Background(), network, addr)
|
||||
wsConn = conn
|
||||
|
||||
return conn, err
|
||||
}
|
||||
|
||||
logoutCtx := logoutcontext.GetContext(token)
|
||||
|
||||
go func() {
|
||||
log.Debug().
|
||||
Msg("logout watcher for websocket proxy started")
|
||||
|
||||
select {
|
||||
case <-logoutCtx.Done():
|
||||
log.Debug().
|
||||
Msg("logout watcher for websocket proxy stopped as user logged out")
|
||||
if wsConn != nil {
|
||||
wsConn.Close()
|
||||
}
|
||||
case <-ctx.Done():
|
||||
log.Debug().
|
||||
Msg("logout watcher for websocket proxy stopped as the ws connection closed")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue