1
0
Fork 0
mirror of https://github.com/portainer/portainer.git synced 2025-07-24 15:59:41 +02:00

fix(websocket): abort websocket when logout EE-6058 (#10372)

This commit is contained in:
cmeng 2023-09-29 12:13:09 +13:00 committed by GitHub
parent 9440aa733d
commit 56ab19433a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 227 additions and 49 deletions

View file

@ -1,6 +1,8 @@
package websocket
import (
"github.com/portainer/portainer/api/http/security"
"github.com/rs/zerolog/log"
"net"
"net/http"
"net/http/httputil"
@ -74,6 +76,13 @@ func (handler *Handler) websocketAttach(w http.ResponseWriter, r *http.Request)
}
func (handler *Handler) handleAttachRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error {
tokenData, err := security.RetrieveTokenData(r)
if err != nil {
log.Warn().
Err(err).
Msg("unable to retrieve user details from authentication token")
return err
}
r.Header.Del("Origin")
@ -89,10 +98,15 @@ func (handler *Handler) handleAttachRequest(w http.ResponseWriter, r *http.Reque
}
defer websocketConn.Close()
return hijackAttachStartOperation(websocketConn, params.endpoint, params.ID)
return hijackAttachStartOperation(websocketConn, params.endpoint, params.ID, tokenData.Token)
}
func hijackAttachStartOperation(websocketConn *websocket.Conn, endpoint *portainer.Endpoint, attachID string) error {
func hijackAttachStartOperation(
websocketConn *websocket.Conn,
endpoint *portainer.Endpoint,
attachID string,
token string,
) error {
dial, err := initDial(endpoint)
if err != nil {
return err
@ -116,7 +130,7 @@ func hijackAttachStartOperation(websocketConn *websocket.Conn, endpoint *portain
return err
}
return hijackRequest(websocketConn, httpConn, attachStartRequest)
return hijackRequest(websocketConn, httpConn, attachStartRequest, token)
}
func createAttachStartRequest(attachID string) (*http.Request, error) {

View file

@ -3,6 +3,8 @@ package websocket
import (
"bytes"
"encoding/json"
"github.com/portainer/portainer/api/http/security"
"github.com/rs/zerolog/log"
"net"
"net/http"
"net/http/httputil"
@ -80,6 +82,14 @@ func (handler *Handler) websocketExec(w http.ResponseWriter, r *http.Request) *h
}
func (handler *Handler) handleExecRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error {
tokenData, err := security.RetrieveTokenData(r)
if err != nil {
log.Warn().
Err(err).
Msg("unable to retrieve user details from authentication token")
return err
}
r.Header.Del("Origin")
if params.endpoint.Type == portainer.AgentOnDockerEnvironment {
@ -94,10 +104,15 @@ func (handler *Handler) handleExecRequest(w http.ResponseWriter, r *http.Request
}
defer websocketConn.Close()
return hijackExecStartOperation(websocketConn, params.endpoint, params.ID)
return hijackExecStartOperation(websocketConn, params.endpoint, params.ID, tokenData.Token)
}
func hijackExecStartOperation(websocketConn *websocket.Conn, endpoint *portainer.Endpoint, execID string) error {
func hijackExecStartOperation(
websocketConn *websocket.Conn,
endpoint *portainer.Endpoint,
execID string,
token string,
) error {
dial, err := initDial(endpoint)
if err != nil {
return err
@ -121,7 +136,7 @@ func hijackExecStartOperation(websocketConn *websocket.Conn, endpoint *portainer
return err
}
return hijackRequest(websocketConn, httpConn, execStartRequest)
return hijackRequest(websocketConn, httpConn, execStartRequest, token)
}
func createExecStartRequest(execID string) (*http.Request, error) {

View file

@ -7,9 +7,15 @@ import (
"net/http/httputil"
"github.com/gorilla/websocket"
"github.com/portainer/portainer/api/internal/logoutcontext"
)
func hijackRequest(websocketConn *websocket.Conn, httpConn *httputil.ClientConn, request *http.Request) error {
func hijackRequest(
websocketConn *websocket.Conn,
httpConn *httputil.ClientConn,
request *http.Request,
token string,
) error {
// Server hijacks the connection, error 'connection closed' expected
resp, err := httpConn.Do(request)
if !errors.Is(err, httputil.ErrPersistEOF) {
@ -29,9 +35,15 @@ func hijackRequest(websocketConn *websocket.Conn, httpConn *httputil.ClientConn,
go streamFromReaderToWebsocket(websocketConn, brw, errorChan)
go streamFromWebsocketToWriter(websocketConn, tcpConn, errorChan)
err = <-errorChan
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
return err
logoutCtx := logoutcontext.GetContext(token)
select {
case <-logoutCtx.Done():
return fmt.Errorf("Your session has been logged out.")
case err = <-errorChan:
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
return err
}
}
return nil

View file

@ -1,15 +1,20 @@
package websocket
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/crypto"
"github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/logoutcontext"
"github.com/gorilla/websocket"
"github.com/koding/websocketproxy"
"github.com/portainer/portainer/api/crypto"
"github.com/rs/zerolog/log"
)
func (handler *Handler) proxyEdgeAgentWebsocketRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error {
@ -18,33 +23,12 @@ func (handler *Handler) proxyEdgeAgentWebsocketRequest(w http.ResponseWriter, r
return err
}
endpointURL, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", tunnel.Port))
agentURL, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", tunnel.Port))
if err != nil {
return err
}
endpointURL.Scheme = "ws"
proxy := websocketproxy.NewProxy(endpointURL)
signature, err := handler.SignatureService.CreateSignature(portainer.PortainerAgentSignatureMessage)
if err != nil {
return err
}
proxy.Director = func(incoming *http.Request, out http.Header) {
out.Set(portainer.PortainerAgentPublicKeyHeader, handler.SignatureService.EncodedPublicKey())
out.Set(portainer.PortainerAgentSignatureHeader, signature)
out.Set(portainer.PortainerAgentTargetHeader, params.nodeName)
out.Set(portainer.PortainerAgentKubernetesSATokenHeader, params.token)
}
handler.ReverseTunnelService.SetTunnelStatusToActive(params.endpoint.ID)
handler.ReverseTunnelService.KeepTunnelAlive(params.endpoint.ID, r.Context(), portainer.WebSocketKeepAlive)
proxy.ServeHTTP(w, r)
return nil
return handler.doProxyWebsocketRequest(w, r, params, agentURL, true)
}
func (handler *Handler) proxyAgentWebsocketRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error {
@ -59,17 +43,41 @@ func (handler *Handler) proxyAgentWebsocketRequest(w http.ResponseWriter, r *htt
}
agentURL.Scheme = "ws"
proxy := websocketproxy.NewProxy(agentURL)
return handler.doProxyWebsocketRequest(w, r, params, agentURL, false)
}
if params.endpoint.TLSConfig.TLS || params.endpoint.TLSConfig.TLSSkipVerify {
func (handler *Handler) doProxyWebsocketRequest(
w http.ResponseWriter,
r *http.Request,
params *webSocketRequestParams,
agentURL *url.URL,
isEdge bool,
) error {
tokenData, err := security.RetrieveTokenData(r)
if err != nil {
log.
Warn().
Err(err).
Msg("unable to retrieve user details from authentication token")
return err
}
enableTLS := !isEdge && (params.endpoint.TLSConfig.TLS || params.endpoint.TLSConfig.TLSSkipVerify)
agentURL.Scheme = "ws"
if enableTLS {
agentURL.Scheme = "wss"
}
proxy := websocketproxy.NewProxy(agentURL)
proxyDialer := *websocket.DefaultDialer
proxy.Dialer = &proxyDialer
if enableTLS {
tlsConfig := crypto.CreateTLSConfiguration()
tlsConfig.InsecureSkipVerify = params.endpoint.TLSConfig.TLSSkipVerify
proxy.Dialer = &websocket.Dialer{
TLSClientConfig: tlsConfig,
}
proxyDialer.TLSClientConfig = tlsConfig
}
signature, err := handler.SignatureService.CreateSignature(portainer.PortainerAgentSignatureMessage)
@ -84,7 +92,46 @@ func (handler *Handler) proxyAgentWebsocketRequest(w http.ResponseWriter, r *htt
out.Set(portainer.PortainerAgentKubernetesSATokenHeader, params.token)
}
if isEdge {
handler.ReverseTunnelService.SetTunnelStatusToActive(params.endpoint.ID)
handler.ReverseTunnelService.KeepTunnelAlive(params.endpoint.ID, r.Context(), portainer.WebSocketKeepAlive)
}
abortProxyOnLogout(r.Context(), proxy, tokenData.Token)
proxy.ServeHTTP(w, r)
return nil
}
func abortProxyOnLogout(ctx context.Context, proxy *websocketproxy.WebsocketProxy, token string) {
var wsConn net.Conn
proxy.Dialer.NetDial = func(network, addr string) (net.Conn, error) {
netDialer := &net.Dialer{}
conn, err := netDialer.DialContext(context.Background(), network, addr)
wsConn = conn
return conn, err
}
logoutCtx := logoutcontext.GetContext(token)
go func() {
log.Debug().
Msg("logout watcher for websocket proxy started")
select {
case <-logoutCtx.Done():
log.Debug().
Msg("logout watcher for websocket proxy stopped as user logged out")
if wsConn != nil {
wsConn.Close()
}
case <-ctx.Done():
log.Debug().
Msg("logout watcher for websocket proxy stopped as the ws connection closed")
}
}()
}