diff --git a/api/chisel/schedules.go b/api/chisel/schedules.go index 6bcba574d..a00562d9a 100644 --- a/api/chisel/schedules.go +++ b/api/chisel/schedules.go @@ -1,14 +1,13 @@ package chisel import ( - "strconv" - portainer "github.com/portainer/portainer/api" ) // AddEdgeJob register an EdgeJob inside the tunnel details associated to an environment(endpoint). func (service *Service) AddEdgeJob(endpointID portainer.EndpointID, edgeJob *portainer.EdgeJob) { - tunnel := service.GetTunnelDetails(endpointID) + service.mu.Lock() + tunnel := service.getTunnelDetails(endpointID) existingJobIndex := -1 for idx, existingJob := range tunnel.Jobs { @@ -24,24 +23,25 @@ func (service *Service) AddEdgeJob(endpointID portainer.EndpointID, edgeJob *por tunnel.Jobs[existingJobIndex] = *edgeJob } - key := strconv.Itoa(int(endpointID)) - service.tunnelDetailsMap.Set(key, tunnel) + service.mu.Unlock() } // RemoveEdgeJob will remove the specified Edge job from each tunnel it was registered with. func (service *Service) RemoveEdgeJob(edgeJobID portainer.EdgeJobID) { - for item := range service.tunnelDetailsMap.IterBuffered() { - tunnelDetails := item.Val.(*portainer.TunnelDetails) + service.mu.Lock() - updatedJobs := make([]portainer.EdgeJob, 0) - for _, edgeJob := range tunnelDetails.Jobs { - if edgeJob.ID == edgeJobID { - continue + for _, tunnel := range service.tunnelDetailsMap { + // Filter in-place + n := 0 + for _, edgeJob := range tunnel.Jobs { + if edgeJob.ID != edgeJobID { + tunnel.Jobs[n] = edgeJob + n++ } - updatedJobs = append(updatedJobs, edgeJob) } - tunnelDetails.Jobs = updatedJobs - service.tunnelDetailsMap.Set(item.Key, tunnelDetails) + tunnel.Jobs = tunnel.Jobs[:n] } + + service.mu.Unlock() } diff --git a/api/chisel/service.go b/api/chisel/service.go index b2813bc7c..d45bc2b06 100644 --- a/api/chisel/service.go +++ b/api/chisel/service.go @@ -3,15 +3,16 @@ package chisel import ( "context" "fmt" - "github.com/portainer/portainer/api/http/proxy" "log" "net/http" "strconv" + "sync" "time" + "github.com/portainer/portainer/api/http/proxy" + "github.com/dchest/uniuri" chserver "github.com/jpillora/chisel/server" - cmap "github.com/orcaman/concurrent-map" portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/dataservices" ) @@ -28,18 +29,19 @@ const ( type Service struct { serverFingerprint string serverPort string - tunnelDetailsMap cmap.ConcurrentMap + tunnelDetailsMap map[string]*portainer.TunnelDetails dataStore dataservices.DataStore snapshotService portainer.SnapshotService chiselServer *chserver.Server shutdownCtx context.Context ProxyManager *proxy.Manager + mu sync.Mutex } // NewService returns a pointer to a new instance of Service func NewService(dataStore dataservices.DataStore, shutdownCtx context.Context) *Service { return &Service{ - tunnelDetailsMap: cmap.New(), + tunnelDetailsMap: make(map[string]*portainer.TunnelDetails), dataStore: dataStore, shutdownCtx: shutdownCtx, } @@ -58,11 +60,7 @@ func (service *Service) pingAgent(endpointID portainer.EndpointID) error { Timeout: 3 * time.Second, } _, err = httpClient.Do(req) - if err != nil { - return err - } - - return nil + return err } // KeepTunnelAlive keeps the tunnel of the given environment for maxAlive duration, or until ctx is done @@ -185,46 +183,48 @@ func (service *Service) startTunnelVerificationLoop() { } func (service *Service) checkTunnels() { - for item := range service.tunnelDetailsMap.IterBuffered() { - tunnel := item.Val.(*portainer.TunnelDetails) + service.mu.Lock() + for key, tunnel := range service.tunnelDetailsMap { if tunnel.LastActivity.IsZero() || tunnel.Status == portainer.EdgeAgentIdle { continue } elapsed := time.Since(tunnel.LastActivity) - log.Printf("[DEBUG] [chisel,monitoring] [endpoint_id: %s] [status: %s] [status_time_seconds: %f] [message: environment tunnel monitoring]", item.Key, tunnel.Status, elapsed.Seconds()) + log.Printf("[DEBUG] [chisel,monitoring] [endpoint_id: %s] [status: %s] [status_time_seconds: %f] [message: environment tunnel monitoring]", key, tunnel.Status, elapsed.Seconds()) if tunnel.Status == portainer.EdgeAgentManagementRequired && elapsed.Seconds() < requiredTimeout.Seconds() { continue } else if tunnel.Status == portainer.EdgeAgentManagementRequired && elapsed.Seconds() > requiredTimeout.Seconds() { - log.Printf("[DEBUG] [chisel,monitoring] [endpoint_id: %s] [status: %s] [status_time_seconds: %f] [timeout_seconds: %f] [message: REQUIRED state timeout exceeded]", item.Key, tunnel.Status, elapsed.Seconds(), requiredTimeout.Seconds()) + log.Printf("[DEBUG] [chisel,monitoring] [endpoint_id: %s] [status: %s] [status_time_seconds: %f] [timeout_seconds: %f] [message: REQUIRED state timeout exceeded]", key, tunnel.Status, elapsed.Seconds(), requiredTimeout.Seconds()) } if tunnel.Status == portainer.EdgeAgentActive && elapsed.Seconds() < activeTimeout.Seconds() { continue } else if tunnel.Status == portainer.EdgeAgentActive && elapsed.Seconds() > activeTimeout.Seconds() { - log.Printf("[DEBUG] [chisel,monitoring] [endpoint_id: %s] [status: %s] [status_time_seconds: %f] [timeout_seconds: %f] [message: ACTIVE state timeout exceeded]", item.Key, tunnel.Status, elapsed.Seconds(), activeTimeout.Seconds()) + log.Printf("[DEBUG] [chisel,monitoring] [endpoint_id: %s] [status: %s] [status_time_seconds: %f] [timeout_seconds: %f] [message: ACTIVE state timeout exceeded]", key, tunnel.Status, elapsed.Seconds(), activeTimeout.Seconds()) - endpointID, err := strconv.Atoi(item.Key) + endpointID, err := strconv.Atoi(key) if err != nil { - log.Printf("[ERROR] [chisel,snapshot,conversion] Invalid environment identifier (id: %s): %s", item.Key, err) + log.Printf("[ERROR] [chisel,snapshot,conversion] Invalid environment identifier (id: %s): %s", key, err) } err = service.snapshotEnvironment(portainer.EndpointID(endpointID), tunnel.Port) if err != nil { - log.Printf("[ERROR] [snapshot] Unable to snapshot Edge environment (id: %s): %s", item.Key, err) + log.Printf("[ERROR] [snapshot] Unable to snapshot Edge environment (id: %s): %s", key, err) } } - endpointID, err := strconv.Atoi(item.Key) + endpointID, err := strconv.Atoi(key) if err != nil { - log.Printf("[ERROR] [chisel,conversion] Invalid environment identifier (id: %s): %s", item.Key, err) + log.Printf("[ERROR] [chisel,conversion] Invalid environment identifier (id: %s): %s", key, err) continue } - service.SetTunnelStatusToIdle(portainer.EndpointID(endpointID)) + service.setTunnelStatusToIdle(portainer.EndpointID(endpointID)) } + + service.mu.Unlock() } func (service *Service) snapshotEnvironment(endpointID portainer.EndpointID, tunnelPort int) error { diff --git a/api/chisel/tunnel.go b/api/chisel/tunnel.go index 884c08baa..177367f2d 100644 --- a/api/chisel/tunnel.go +++ b/api/chisel/tunnel.go @@ -24,8 +24,7 @@ const ( func (service *Service) getUnusedPort() int { port := randomInt(minAvailablePort, maxAvailablePort) - for item := range service.tunnelDetailsMap.IterBuffered() { - tunnel := item.Val.(*portainer.TunnelDetails) + for _, tunnel := range service.tunnelDetailsMap { if tunnel.Port == port { return service.getUnusedPort() } @@ -38,26 +37,33 @@ func randomInt(min, max int) int { return min + rand.Intn(max-min) } -// GetTunnelDetails returns information about the tunnel associated to an environment(endpoint). -func (service *Service) GetTunnelDetails(endpointID portainer.EndpointID) *portainer.TunnelDetails { +// NOTE: it needs to be called with the lock acquired +func (service *Service) getTunnelDetails(endpointID portainer.EndpointID) *portainer.TunnelDetails { key := strconv.Itoa(int(endpointID)) - if item, ok := service.tunnelDetailsMap.Get(key); ok { - tunnelDetails := item.(*portainer.TunnelDetails) - return tunnelDetails + if tunnel, ok := service.tunnelDetailsMap[key]; ok { + return tunnel } - jobs := make([]portainer.EdgeJob, 0) - return &portainer.TunnelDetails{ - Status: portainer.EdgeAgentIdle, - Port: 0, - Jobs: jobs, - Credentials: "", + tunnel := &portainer.TunnelDetails{ + Status: portainer.EdgeAgentIdle, } + + service.tunnelDetailsMap[key] = tunnel + + return tunnel +} + +// GetTunnelDetails returns information about the tunnel associated to an environment(endpoint). +func (service *Service) GetTunnelDetails(endpointID portainer.EndpointID) portainer.TunnelDetails { + service.mu.Lock() + defer service.mu.Unlock() + + return *service.getTunnelDetails(endpointID) } // GetActiveTunnel retrieves an active tunnel which allows communicating with edge agent -func (service *Service) GetActiveTunnel(endpoint *portainer.Endpoint) (*portainer.TunnelDetails, error) { +func (service *Service) GetActiveTunnel(endpoint *portainer.Endpoint) (portainer.TunnelDetails, error) { tunnel := service.GetTunnelDetails(endpoint.ID) if tunnel.Status == portainer.EdgeAgentActive { @@ -68,13 +74,13 @@ func (service *Service) GetActiveTunnel(endpoint *portainer.Endpoint) (*portaine if tunnel.Status == portainer.EdgeAgentIdle || tunnel.Status == portainer.EdgeAgentManagementRequired { err := service.SetTunnelStatusToRequired(endpoint.ID) if err != nil { - return nil, fmt.Errorf("failed opening tunnel to endpoint: %w", err) + return portainer.TunnelDetails{}, fmt.Errorf("failed opening tunnel to endpoint: %w", err) } if endpoint.EdgeCheckinInterval == 0 { settings, err := service.dataStore.Settings().Settings() if err != nil { - return nil, fmt.Errorf("failed fetching settings from db: %w", err) + return portainer.TunnelDetails{}, fmt.Errorf("failed fetching settings from db: %w", err) } endpoint.EdgeCheckinInterval = settings.EdgeAgentCheckinInterval @@ -83,29 +89,23 @@ func (service *Service) GetActiveTunnel(endpoint *portainer.Endpoint) (*portaine time.Sleep(2 * time.Duration(endpoint.EdgeCheckinInterval) * time.Second) } - tunnel = service.GetTunnelDetails(endpoint.ID) - - return tunnel, nil + return service.GetTunnelDetails(endpoint.ID), nil } // SetTunnelStatusToActive update the status of the tunnel associated to the specified environment(endpoint). // It sets the status to ACTIVE. func (service *Service) SetTunnelStatusToActive(endpointID portainer.EndpointID) { - tunnel := service.GetTunnelDetails(endpointID) + service.mu.Lock() + tunnel := service.getTunnelDetails(endpointID) tunnel.Status = portainer.EdgeAgentActive tunnel.Credentials = "" tunnel.LastActivity = time.Now() - - key := strconv.Itoa(int(endpointID)) - service.tunnelDetailsMap.Set(key, tunnel) + service.mu.Unlock() } -// SetTunnelStatusToIdle update the status of the tunnel associated to the specified environment(endpoint). -// It sets the status to IDLE. -// It removes any existing credentials associated to the tunnel. -func (service *Service) SetTunnelStatusToIdle(endpointID portainer.EndpointID) { - tunnel := service.GetTunnelDetails(endpointID) - +// NOTE: it needs to be called with the lock acquired +func (service *Service) setTunnelStatusToIdle(endpointID portainer.EndpointID) { + tunnel := service.getTunnelDetails(endpointID) tunnel.Status = portainer.EdgeAgentIdle tunnel.Port = 0 tunnel.LastActivity = time.Now() @@ -116,19 +116,28 @@ func (service *Service) SetTunnelStatusToIdle(endpointID portainer.EndpointID) { service.chiselServer.DeleteUser(strings.Split(credentials, ":")[0]) } - key := strconv.Itoa(int(endpointID)) - service.tunnelDetailsMap.Set(key, tunnel) - service.ProxyManager.DeleteEndpointProxy(endpointID) } +// SetTunnelStatusToIdle update the status of the tunnel associated to the specified environment(endpoint). +// It sets the status to IDLE. +// It removes any existing credentials associated to the tunnel. +func (service *Service) SetTunnelStatusToIdle(endpointID portainer.EndpointID) { + service.mu.Lock() + service.setTunnelStatusToIdle(endpointID) + service.mu.Unlock() +} + // SetTunnelStatusToRequired update the status of the tunnel associated to the specified environment(endpoint). // It sets the status to REQUIRED. // If no port is currently associated to the tunnel, it will associate a random unused port to the tunnel // and generate temporary credentials that can be used to establish a reverse tunnel on that port. // Credentials are encrypted using the Edge ID associated to the environment(endpoint). func (service *Service) SetTunnelStatusToRequired(endpointID portainer.EndpointID) error { - tunnel := service.GetTunnelDetails(endpointID) + tunnel := service.getTunnelDetails(endpointID) + + service.mu.Lock() + defer service.mu.Unlock() if tunnel.Port == 0 { endpoint, err := service.dataStore.Endpoint().Endpoint(endpointID) @@ -152,9 +161,6 @@ func (service *Service) SetTunnelStatusToRequired(endpointID portainer.EndpointI return err } tunnel.Credentials = credentials - - key := strconv.Itoa(int(endpointID)) - service.tunnelDetailsMap.Set(key, tunnel) } return nil diff --git a/api/internal/testhelpers/reverse_tunnel_service.go b/api/internal/testhelpers/reverse_tunnel_service.go deleted file mode 100644 index 0dbc19d19..000000000 --- a/api/internal/testhelpers/reverse_tunnel_service.go +++ /dev/null @@ -1,23 +0,0 @@ -package testhelpers - -import portainer "github.com/portainer/portainer/api" - -type ReverseTunnelService struct{} - -func (r ReverseTunnelService) StartTunnelServer(addr, port string, snapshotService portainer.SnapshotService) error { - return nil -} -func (r ReverseTunnelService) GenerateEdgeKey(url, host string, endpointIdentifier int) string { - return "nil" -} -func (r ReverseTunnelService) SetTunnelStatusToActive(endpointID portainer.EndpointID) {} -func (r ReverseTunnelService) SetTunnelStatusToRequired(endpointID portainer.EndpointID) error { - return nil -} -func (r ReverseTunnelService) SetTunnelStatusToIdle(endpointID portainer.EndpointID) {} -func (r ReverseTunnelService) GetTunnelDetails(endpointID portainer.EndpointID) *portainer.TunnelDetails { - return nil -} -func (r ReverseTunnelService) AddEdgeJob(endpointID portainer.EndpointID, edgeJob *portainer.EdgeJob) { -} -func (r ReverseTunnelService) RemoveEdgeJob(edgeJobID portainer.EdgeJobID) {} diff --git a/api/portainer.go b/api/portainer.go index d0fb5b74d..773cf56fd 100644 --- a/api/portainer.go +++ b/api/portainer.go @@ -1305,8 +1305,8 @@ type ( SetTunnelStatusToRequired(endpointID EndpointID) error SetTunnelStatusToIdle(endpointID EndpointID) KeepTunnelAlive(endpointID EndpointID, ctx context.Context, maxKeepAlive time.Duration) - GetTunnelDetails(endpointID EndpointID) *TunnelDetails - GetActiveTunnel(endpoint *Endpoint) (*TunnelDetails, error) + GetTunnelDetails(endpointID EndpointID) TunnelDetails + GetActiveTunnel(endpoint *Endpoint) (TunnelDetails, error) AddEdgeJob(endpointID EndpointID, edgeJob *EdgeJob) RemoveEdgeJob(edgeJobID EdgeJobID) }