diff --git a/api/chisel/service.go b/api/chisel/service.go index c7e1cb662..9ef172589 100644 --- a/api/chisel/service.go +++ b/api/chisel/service.go @@ -75,10 +75,11 @@ func (service *Service) KeepTunnelAlive(endpointID portainer.EndpointID, ctx con log.Debug(). Int("endpoint_id", int(endpointID)). Float64("max_alive_minutes", maxAlive.Minutes()). - Msg("start") + Msg("KeepTunnelAlive: start") maxAliveTicker := time.NewTicker(maxAlive) defer maxAliveTicker.Stop() + pingTicker := time.NewTicker(tunnelCleanupInterval) defer pingTicker.Stop() @@ -91,13 +92,13 @@ func (service *Service) KeepTunnelAlive(endpointID portainer.EndpointID, ctx con log.Debug(). Int("endpoint_id", int(endpointID)). Err(err). - Msg("ping agent") + Msg("KeepTunnelAlive: ping agent") } case <-maxAliveTicker.C: log.Debug(). Int("endpoint_id", int(endpointID)). Float64("timeout_minutes", maxAlive.Minutes()). - Msg("tunnel keep alive timeout") + Msg("KeepTunnelAlive: tunnel keep alive timeout") return case <-ctx.Done(): @@ -105,7 +106,7 @@ func (service *Service) KeepTunnelAlive(endpointID portainer.EndpointID, ctx con log.Debug(). Int("endpoint_id", int(endpointID)). Err(err). - Msg("tunnel stop") + Msg("KeepTunnelAlive: tunnel stop") return } diff --git a/api/http/handler/auth/logout.go b/api/http/handler/auth/logout.go index 95508b1ed..dc4088259 100644 --- a/api/http/handler/auth/logout.go +++ b/api/http/handler/auth/logout.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/portainer/portainer/api/http/security" + "github.com/portainer/portainer/api/internal/logoutcontext" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/response" ) @@ -25,5 +26,7 @@ func (handler *Handler) logout(w http.ResponseWriter, r *http.Request) *httperro handler.KubernetesTokenCacheManager.RemoveUserFromCache(tokenData.ID) + logoutcontext.Cancel(tokenData.Token) + return response.Empty(w) } diff --git a/api/http/handler/websocket/attach.go b/api/http/handler/websocket/attach.go index a4b3e19c0..b023f7bb5 100644 --- a/api/http/handler/websocket/attach.go +++ b/api/http/handler/websocket/attach.go @@ -1,6 +1,8 @@ package websocket import ( + "github.com/portainer/portainer/api/http/security" + "github.com/rs/zerolog/log" "net" "net/http" "net/http/httputil" @@ -74,6 +76,13 @@ func (handler *Handler) websocketAttach(w http.ResponseWriter, r *http.Request) } func (handler *Handler) handleAttachRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error { + tokenData, err := security.RetrieveTokenData(r) + if err != nil { + log.Warn(). + Err(err). + Msg("unable to retrieve user details from authentication token") + return err + } r.Header.Del("Origin") @@ -89,10 +98,15 @@ func (handler *Handler) handleAttachRequest(w http.ResponseWriter, r *http.Reque } defer websocketConn.Close() - return hijackAttachStartOperation(websocketConn, params.endpoint, params.ID) + return hijackAttachStartOperation(websocketConn, params.endpoint, params.ID, tokenData.Token) } -func hijackAttachStartOperation(websocketConn *websocket.Conn, endpoint *portainer.Endpoint, attachID string) error { +func hijackAttachStartOperation( + websocketConn *websocket.Conn, + endpoint *portainer.Endpoint, + attachID string, + token string, +) error { dial, err := initDial(endpoint) if err != nil { return err @@ -116,7 +130,7 @@ func hijackAttachStartOperation(websocketConn *websocket.Conn, endpoint *portain return err } - return hijackRequest(websocketConn, httpConn, attachStartRequest) + return hijackRequest(websocketConn, httpConn, attachStartRequest, token) } func createAttachStartRequest(attachID string) (*http.Request, error) { diff --git a/api/http/handler/websocket/exec.go b/api/http/handler/websocket/exec.go index 46de75777..6a1df05e2 100644 --- a/api/http/handler/websocket/exec.go +++ b/api/http/handler/websocket/exec.go @@ -3,6 +3,8 @@ package websocket import ( "bytes" "encoding/json" + "github.com/portainer/portainer/api/http/security" + "github.com/rs/zerolog/log" "net" "net/http" "net/http/httputil" @@ -80,6 +82,14 @@ func (handler *Handler) websocketExec(w http.ResponseWriter, r *http.Request) *h } func (handler *Handler) handleExecRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error { + tokenData, err := security.RetrieveTokenData(r) + if err != nil { + log.Warn(). + Err(err). + Msg("unable to retrieve user details from authentication token") + return err + } + r.Header.Del("Origin") if params.endpoint.Type == portainer.AgentOnDockerEnvironment { @@ -94,10 +104,15 @@ func (handler *Handler) handleExecRequest(w http.ResponseWriter, r *http.Request } defer websocketConn.Close() - return hijackExecStartOperation(websocketConn, params.endpoint, params.ID) + return hijackExecStartOperation(websocketConn, params.endpoint, params.ID, tokenData.Token) } -func hijackExecStartOperation(websocketConn *websocket.Conn, endpoint *portainer.Endpoint, execID string) error { +func hijackExecStartOperation( + websocketConn *websocket.Conn, + endpoint *portainer.Endpoint, + execID string, + token string, +) error { dial, err := initDial(endpoint) if err != nil { return err @@ -121,7 +136,7 @@ func hijackExecStartOperation(websocketConn *websocket.Conn, endpoint *portainer return err } - return hijackRequest(websocketConn, httpConn, execStartRequest) + return hijackRequest(websocketConn, httpConn, execStartRequest, token) } func createExecStartRequest(execID string) (*http.Request, error) { diff --git a/api/http/handler/websocket/hijack.go b/api/http/handler/websocket/hijack.go index ca9ae26cc..ccda8ac8f 100644 --- a/api/http/handler/websocket/hijack.go +++ b/api/http/handler/websocket/hijack.go @@ -7,9 +7,15 @@ import ( "net/http/httputil" "github.com/gorilla/websocket" + "github.com/portainer/portainer/api/internal/logoutcontext" ) -func hijackRequest(websocketConn *websocket.Conn, httpConn *httputil.ClientConn, request *http.Request) error { +func hijackRequest( + websocketConn *websocket.Conn, + httpConn *httputil.ClientConn, + request *http.Request, + token string, +) error { // Server hijacks the connection, error 'connection closed' expected resp, err := httpConn.Do(request) if !errors.Is(err, httputil.ErrPersistEOF) { @@ -29,9 +35,15 @@ func hijackRequest(websocketConn *websocket.Conn, httpConn *httputil.ClientConn, go streamFromReaderToWebsocket(websocketConn, brw, errorChan) go streamFromWebsocketToWriter(websocketConn, tcpConn, errorChan) - err = <-errorChan - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { - return err + logoutCtx := logoutcontext.GetContext(token) + + select { + case <-logoutCtx.Done(): + return fmt.Errorf("Your session has been logged out.") + case err = <-errorChan: + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { + return err + } } return nil diff --git a/api/http/handler/websocket/proxy.go b/api/http/handler/websocket/proxy.go index d03d5a113..f6ebf0888 100644 --- a/api/http/handler/websocket/proxy.go +++ b/api/http/handler/websocket/proxy.go @@ -1,15 +1,20 @@ package websocket import ( + "context" "fmt" + "net" "net/http" "net/url" portainer "github.com/portainer/portainer/api" - "github.com/portainer/portainer/api/crypto" + "github.com/portainer/portainer/api/http/security" + "github.com/portainer/portainer/api/internal/logoutcontext" "github.com/gorilla/websocket" "github.com/koding/websocketproxy" + "github.com/portainer/portainer/api/crypto" + "github.com/rs/zerolog/log" ) func (handler *Handler) proxyEdgeAgentWebsocketRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error { @@ -18,33 +23,12 @@ func (handler *Handler) proxyEdgeAgentWebsocketRequest(w http.ResponseWriter, r return err } - endpointURL, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", tunnel.Port)) + agentURL, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", tunnel.Port)) if err != nil { return err } - endpointURL.Scheme = "ws" - proxy := websocketproxy.NewProxy(endpointURL) - - signature, err := handler.SignatureService.CreateSignature(portainer.PortainerAgentSignatureMessage) - if err != nil { - return err - } - - proxy.Director = func(incoming *http.Request, out http.Header) { - out.Set(portainer.PortainerAgentPublicKeyHeader, handler.SignatureService.EncodedPublicKey()) - out.Set(portainer.PortainerAgentSignatureHeader, signature) - out.Set(portainer.PortainerAgentTargetHeader, params.nodeName) - out.Set(portainer.PortainerAgentKubernetesSATokenHeader, params.token) - } - - handler.ReverseTunnelService.SetTunnelStatusToActive(params.endpoint.ID) - - handler.ReverseTunnelService.KeepTunnelAlive(params.endpoint.ID, r.Context(), portainer.WebSocketKeepAlive) - - proxy.ServeHTTP(w, r) - - return nil + return handler.doProxyWebsocketRequest(w, r, params, agentURL, true) } func (handler *Handler) proxyAgentWebsocketRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error { @@ -59,17 +43,41 @@ func (handler *Handler) proxyAgentWebsocketRequest(w http.ResponseWriter, r *htt } agentURL.Scheme = "ws" - proxy := websocketproxy.NewProxy(agentURL) + return handler.doProxyWebsocketRequest(w, r, params, agentURL, false) +} - if params.endpoint.TLSConfig.TLS || params.endpoint.TLSConfig.TLSSkipVerify { +func (handler *Handler) doProxyWebsocketRequest( + w http.ResponseWriter, + r *http.Request, + params *webSocketRequestParams, + agentURL *url.URL, + isEdge bool, +) error { + tokenData, err := security.RetrieveTokenData(r) + if err != nil { + log. + Warn(). + Err(err). + Msg("unable to retrieve user details from authentication token") + return err + } + + enableTLS := !isEdge && (params.endpoint.TLSConfig.TLS || params.endpoint.TLSConfig.TLSSkipVerify) + + agentURL.Scheme = "ws" + if enableTLS { agentURL.Scheme = "wss" + } + proxy := websocketproxy.NewProxy(agentURL) + proxyDialer := *websocket.DefaultDialer + proxy.Dialer = &proxyDialer + + if enableTLS { tlsConfig := crypto.CreateTLSConfiguration() tlsConfig.InsecureSkipVerify = params.endpoint.TLSConfig.TLSSkipVerify - proxy.Dialer = &websocket.Dialer{ - TLSClientConfig: tlsConfig, - } + proxyDialer.TLSClientConfig = tlsConfig } signature, err := handler.SignatureService.CreateSignature(portainer.PortainerAgentSignatureMessage) @@ -84,7 +92,46 @@ func (handler *Handler) proxyAgentWebsocketRequest(w http.ResponseWriter, r *htt out.Set(portainer.PortainerAgentKubernetesSATokenHeader, params.token) } + if isEdge { + handler.ReverseTunnelService.SetTunnelStatusToActive(params.endpoint.ID) + handler.ReverseTunnelService.KeepTunnelAlive(params.endpoint.ID, r.Context(), portainer.WebSocketKeepAlive) + } + + abortProxyOnLogout(r.Context(), proxy, tokenData.Token) + proxy.ServeHTTP(w, r) return nil } + +func abortProxyOnLogout(ctx context.Context, proxy *websocketproxy.WebsocketProxy, token string) { + var wsConn net.Conn + + proxy.Dialer.NetDial = func(network, addr string) (net.Conn, error) { + netDialer := &net.Dialer{} + + conn, err := netDialer.DialContext(context.Background(), network, addr) + wsConn = conn + + return conn, err + } + + logoutCtx := logoutcontext.GetContext(token) + + go func() { + log.Debug(). + Msg("logout watcher for websocket proxy started") + + select { + case <-logoutCtx.Done(): + log.Debug(). + Msg("logout watcher for websocket proxy stopped as user logged out") + if wsConn != nil { + wsConn.Close() + } + case <-ctx.Done(): + log.Debug(). + Msg("logout watcher for websocket proxy stopped as the ws connection closed") + } + }() +} diff --git a/api/internal/logoutcontext/logout_context.go b/api/internal/logoutcontext/logout_context.go new file mode 100644 index 000000000..ea172bafe --- /dev/null +++ b/api/internal/logoutcontext/logout_context.go @@ -0,0 +1,20 @@ +package logoutcontext + +import ( + "context" +) + +const LogoutPrefix = "logout-" + +func GetContext(token string) context.Context { + return GetService(logoutToken(token)).GetLogoutCtx() +} + +func Cancel(token string) { + GetService(logoutToken(token)).Cancel() + RemoveService(logoutToken(token)) +} + +func logoutToken(token string) string { + return LogoutPrefix + token +} diff --git a/api/internal/logoutcontext/service.go b/api/internal/logoutcontext/service.go new file mode 100644 index 000000000..d608b0e3a --- /dev/null +++ b/api/internal/logoutcontext/service.go @@ -0,0 +1,28 @@ +package logoutcontext + +import ( + "context" +) + +type ( + Service struct { + ctx context.Context + cancel context.CancelFunc + } +) + +func NewService() *Service { + ctx, cancel := context.WithCancel(context.Background()) + return &Service{ + ctx: ctx, + cancel: cancel, + } +} + +func (s *Service) Cancel() { + s.cancel() +} + +func (s *Service) GetLogoutCtx() context.Context { + return s.ctx +} diff --git a/api/internal/logoutcontext/service_factory.go b/api/internal/logoutcontext/service_factory.go new file mode 100644 index 000000000..c01e4aeda --- /dev/null +++ b/api/internal/logoutcontext/service_factory.go @@ -0,0 +1,34 @@ +package logoutcontext + +import "sync" + +type ( + ServiceFactory struct { + mu sync.Mutex + services map[string]*Service + } +) + +var serviceFactory = ServiceFactory{ + services: make(map[string]*Service), +} + +func GetService(token string) *Service { + serviceFactory.mu.Lock() + defer serviceFactory.mu.Unlock() + + service, ok := serviceFactory.services[token] + if !ok { + service = NewService() + serviceFactory.services[token] = service + } + + return service +} + +func RemoveService(token string) { + serviceFactory.mu.Lock() + defer serviceFactory.mu.Unlock() + + delete(serviceFactory.services, token) +} diff --git a/api/jwt/jwt.go b/api/jwt/jwt.go index 4ae0f1820..5fbcb02f7 100644 --- a/api/jwt/jwt.go +++ b/api/jwt/jwt.go @@ -137,6 +137,7 @@ func (service *Service) ParseAndVerifyToken(token string) (*portainer.TokenData, ID: portainer.UserID(cl.UserID), Username: cl.Username, Role: portainer.UserRole(cl.Role), + Token: token, }, nil } } diff --git a/api/portainer.go b/api/portainer.go index 703bcfe82..e998976c2 100644 --- a/api/portainer.go +++ b/api/portainer.go @@ -1278,6 +1278,7 @@ type ( Username string Role UserRole ForceChangePassword bool + Token string } // TunnelDetails represents information associated to a tunnel diff --git a/app/docker/views/containers/console/containerConsoleController.js b/app/docker/views/containers/console/containerConsoleController.js index 575c2fd18..bc4bf14d7 100644 --- a/app/docker/views/containers/console/containerConsoleController.js +++ b/app/docker/views/containers/console/containerConsoleController.js @@ -67,7 +67,6 @@ angular.module('portainer.docker').controller('ContainerConsoleController', [ } const params = { - token: LocalStorage.getJWT(), endpointId: $state.params.endpointId, id: attachId, }; @@ -108,7 +107,6 @@ angular.module('portainer.docker').controller('ContainerConsoleController', [ ContainerService.createExec(execConfig) .then(function success(data) { const params = { - token: LocalStorage.getJWT(), endpointId: $state.params.endpointId, id: data.Id, }; @@ -167,6 +165,9 @@ angular.module('portainer.docker').controller('ContainerConsoleController', [ if ($transition$.params().nodeName) { url += '&nodeName=' + $transition$.params().nodeName; } + + url += '&token=' + LocalStorage.getJWT(); + if (url.indexOf('https') > -1) { url = url.replace('https://', 'wss://'); } else { diff --git a/app/portainer/__module.js b/app/portainer/__module.js index 81d77fcc5..6367dbdbe 100644 --- a/app/portainer/__module.js +++ b/app/portainer/__module.js @@ -154,7 +154,7 @@ angular url: '/logout', params: { error: '', - performApiLogout: false, + performApiLogout: true, }, views: { 'content@': { diff --git a/app/portainer/services/authentication.js b/app/portainer/services/authentication.js index 79abfcd19..9a182f84d 100644 --- a/app/portainer/services/authentication.js +++ b/app/portainer/services/authentication.js @@ -41,7 +41,7 @@ angular.module('portainer.app').factory('Authentication', [ } async function logoutAsync(performApiLogout) { - if (performApiLogout) { + if (performApiLogout && isAuthenticated()) { await Auth.logout().$promise; } diff --git a/app/react/kubernetes/applications/ConsoleView/ConsoleView.tsx b/app/react/kubernetes/applications/ConsoleView/ConsoleView.tsx index b61f0f027..5a2d833cc 100644 --- a/app/react/kubernetes/applications/ConsoleView/ConsoleView.tsx +++ b/app/react/kubernetes/applications/ConsoleView/ConsoleView.tsx @@ -3,7 +3,7 @@ import { useCurrentStateAndParams } from '@uirouter/react'; import { Terminal as TerminalIcon } from 'lucide-react'; import { Terminal } from 'xterm'; -import { useLocalStorage } from '@/react/hooks/useLocalStorage'; +import { get } from '@/react/hooks/useLocalStorage'; import { baseHref } from '@/portainer/helpers/pathHelper'; import { notifyError } from '@/portainer/services/notifications'; @@ -27,7 +27,6 @@ export function ConsoleView() { }, } = useCurrentStateAndParams(); - const [jwtToken] = useLocalStorage('JWT', ''); const [command, setCommand] = useState('/bin/sh'); const [connectionStatus, setConnectionStatus] = useState('closed'); const [terminal, setTerminal] = useState(null as Terminal | null); @@ -170,6 +169,8 @@ export function ConsoleView() { ); function connectConsole() { + const jwtToken = get('JWT', ''); + const params: StringDictionary = { token: jwtToken, endpointId: environmentId,