From 9ee092aa5ef6770223c693901f7b7f06ecefe503 Mon Sep 17 00:00:00 2001 From: andres-portainer <91705312+andres-portainer@users.noreply.github.com> Date: Wed, 26 Jun 2024 18:14:22 -0300 Subject: [PATCH] chore(code): reduce the code duplication EE-7278 (#11969) --- api/agent/version.go | 2 +- api/apikey/apikey_test.go | 5 +- api/apikey/cache.go | 44 ++++--- api/apikey/cache_test.go | 26 ++-- api/apikey/service.go | 35 ++++-- api/cli/cli.go | 40 ++++--- api/cmd/portainer/main.go | 111 +++++++----------- api/{internal => }/concurrent/concurrent.go | 0 api/database/boltdb/db.go | 1 + api/database/boltdb/json.go | 7 +- api/database/boltdb/tx.go | 1 + api/datastore/migrate_legacyversion.go | 1 + api/datastore/migrator/migrate_ce.go | 16 ++- .../migrator/migrate_dbversion100.go | 18 +-- api/datastore/migrator/migrate_dbversion23.go | 4 +- .../test_data/output_24_to_latest.json | 1 - api/datastore/teststore.go | 11 +- api/docker/container.go | 1 - api/http/csrf/csrf.go | 9 +- api/http/handler/auth/authenticate.go | 19 +-- api/http/handler/auth/handler.go | 1 + api/http/handler/auth/logout.go | 2 +- .../customtemplates/customtemplate_list.go | 7 +- api/http/handler/docker/images/images_list.go | 2 +- .../handler/docker/utils/filter_by_uac.go | 5 +- .../handler/edgegroups/edgegroup_create.go | 33 +++--- .../handler/edgegroups/edgegroup_update.go | 4 +- api/http/handler/edgestacks/edgestack_test.go | 18 +-- .../handler/edgestacks/edgestack_update.go | 2 +- .../edgestacks/edgestack_update_test.go | 17 ++- .../endpointedge_status_inspect_test.go | 38 +++--- .../endpointgroups/endpointgroup_update.go | 2 +- .../endpoints/endpoint_agent_versions.go | 2 +- api/http/handler/endpoints/endpoint_create.go | 82 ++++++------- .../handler/endpoints/endpoint_delete_test.go | 10 +- .../handler/endpoints/endpoint_list_test.go | 4 +- .../endpoints/endpoint_registries_list.go | 12 +- api/http/handler/endpoints/filter.go | 6 +- api/http/handler/endpoints/filter_test.go | 4 +- api/http/handler/endpoints/sort_test.go | 5 +- .../endpoints/update_edge_relations.go | 2 +- .../endpoints/utils_update_edge_groups.go | 2 +- .../utils_update_edge_groups_test.go | 7 +- .../handler/endpoints/utils_update_tags.go | 2 +- api/http/handler/helm/helm_delete_test.go | 6 +- api/http/handler/registries/handler.go | 4 +- api/http/handler/settings/settings_public.go | 9 +- .../handler/settings/settings_public_test.go | 6 + api/http/handler/stacks/stack_delete.go | 30 ++--- api/http/handler/tags/tag_delete_test.go | 14 +-- api/http/handler/users/user_create.go | 8 +- api/http/handler/users/user_list_test.go | 22 ++-- api/http/handler/users/user_update_test.go | 4 +- api/http/handler/websocket/proxy.go | 4 +- api/http/proxy/factory/agent.go | 2 +- .../proxy/factory/azure/access_control.go | 22 ++-- api/http/proxy/factory/docker.go | 2 +- .../proxy/factory/docker/access_control.go | 14 ++- api/http/proxy/factory/docker/configs.go | 11 +- api/http/proxy/factory/docker/tasks.go | 5 +- api/http/proxy/factory/docker/volumes.go | 26 ++-- api/internal/edge/edgegroup.go | 2 +- api/internal/edge/edgestacks/service.go | 21 ++-- api/internal/endpointutils/endpoint_test.go | 1 + api/internal/nodes/nodes.go | 9 +- api/internal/securecookie/securecookie.go | 16 --- api/internal/slices/slices.go | 23 ---- api/internal/unique/unique.go | 41 ------- api/jwt/jwt.go | 13 +- api/kubernetes/cli/dashboard.go | 2 +- api/kubernetes/cli/service.go | 58 +++------ api/kubernetes/{contants.go => constants.go} | 0 .../logoutcontext/logout_context.go | 0 api/{internal => }/logoutcontext/service.go | 0 .../logoutcontext/service_factory.go | 0 api/oauth/oauth_test.go | 15 +-- api/portainer.go | 78 ++++++------ api/{internal => }/set/set.go | 4 +- api/slicesx/slices.go | 43 +++++++ .../slices => slicesx}/slices_test.go | 2 +- api/{internal => }/tag/tag.go | 0 api/{internal => }/tag/tag_match.go | 0 api/{internal => }/tag/tag_match_test.go | 0 api/{internal => }/tag/tag_test.go | 0 api/{internal => }/url/url.go | 0 85 files changed, 520 insertions(+), 618 deletions(-) rename api/{internal => }/concurrent/concurrent.go (100%) delete mode 100644 api/internal/securecookie/securecookie.go delete mode 100644 api/internal/slices/slices.go delete mode 100644 api/internal/unique/unique.go rename api/kubernetes/{contants.go => constants.go} (100%) rename api/{internal => }/logoutcontext/logout_context.go (100%) rename api/{internal => }/logoutcontext/service.go (100%) rename api/{internal => }/logoutcontext/service_factory.go (100%) rename api/{internal => }/set/set.go (95%) create mode 100644 api/slicesx/slices.go rename api/{internal/slices => slicesx}/slices_test.go (99%) rename api/{internal => }/tag/tag.go (100%) rename api/{internal => }/tag/tag_match.go (100%) rename api/{internal => }/tag/tag_match_test.go (100%) rename api/{internal => }/tag/tag_test.go (100%) rename api/{internal => }/url/url.go (100%) diff --git a/api/agent/version.go b/api/agent/version.go index 03d3cf4f6..a9480850a 100644 --- a/api/agent/version.go +++ b/api/agent/version.go @@ -10,7 +10,7 @@ import ( "time" portainer "github.com/portainer/portainer/api" - "github.com/portainer/portainer/api/internal/url" + "github.com/portainer/portainer/api/url" ) // GetAgentVersionAndPlatform returns the agent version and platform diff --git a/api/apikey/apikey_test.go b/api/apikey/apikey_test.go index b41329f72..ace83d62d 100644 --- a/api/apikey/apikey_test.go +++ b/api/apikey/apikey_test.go @@ -3,7 +3,6 @@ package apikey import ( "testing" - "github.com/portainer/portainer/api/internal/securecookie" "github.com/stretchr/testify/assert" ) @@ -34,7 +33,7 @@ func Test_generateRandomKey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := securecookie.GenerateRandomKey(tt.wantLenth) + got := GenerateRandomKey(tt.wantLenth) is.Equal(tt.wantLenth, len(got)) }) } @@ -42,7 +41,7 @@ func Test_generateRandomKey(t *testing.T) { t.Run("Generated keys are unique", func(t *testing.T) { keys := make(map[string]bool) for i := 0; i < 100; i++ { - key := securecookie.GenerateRandomKey(8) + key := GenerateRandomKey(8) _, ok := keys[string(key)] is.False(ok) keys[string(key)] = true diff --git a/api/apikey/cache.go b/api/apikey/cache.go index e36a05a17..0c24da8a9 100644 --- a/api/apikey/cache.go +++ b/api/apikey/cache.go @@ -1,69 +1,79 @@ package apikey import ( - lru "github.com/hashicorp/golang-lru" portainer "github.com/portainer/portainer/api" + + lru "github.com/hashicorp/golang-lru" ) -const defaultAPIKeyCacheSize = 1024 +const DefaultAPIKeyCacheSize = 1024 // entry is a tuple containing the user and API key associated to an API key digest -type entry struct { - user portainer.User +type entry[T any] struct { + user T apiKey portainer.APIKey } -// apiKeyCache is a concurrency-safe, in-memory cache which primarily exists for to reduce database roundtrips. +type UserCompareFn[T any] func(T, portainer.UserID) bool + +// ApiKeyCache is a concurrency-safe, in-memory cache which primarily exists for to reduce database roundtrips. // We store the api-key digest (keys) and the associated user and key-data (values) in the cache. // This is required because HTTP requests will contain only the api-key digest in the x-api-key request header; // digest value must be mapped to a portainer user (and respective key data) for validation. // This cache is used to avoid multiple database queries to retrieve these user/key associated to the digest. -type apiKeyCache struct { +type ApiKeyCache[T any] struct { // cache type [string]entry cache (key: string(digest), value: user/key entry) // note: []byte keys are not supported by golang-lru Cache - cache *lru.Cache + cache *lru.Cache + userCmpFn UserCompareFn[T] } // NewAPIKeyCache creates a new cache for API keys -func NewAPIKeyCache(cacheSize int) *apiKeyCache { +func NewAPIKeyCache[T any](cacheSize int, userCompareFn UserCompareFn[T]) *ApiKeyCache[T] { cache, _ := lru.New(cacheSize) - return &apiKeyCache{cache: cache} + + return &ApiKeyCache[T]{cache: cache, userCmpFn: userCompareFn} } // Get returns the user/key associated to an api-key's digest // This is required because HTTP requests will contain the digest of the API key in header, // the digest value must be mapped to a portainer user. -func (c *apiKeyCache) Get(digest string) (portainer.User, portainer.APIKey, bool) { +func (c *ApiKeyCache[T]) Get(digest string) (T, portainer.APIKey, bool) { val, ok := c.cache.Get(digest) if !ok { - return portainer.User{}, portainer.APIKey{}, false + var t T + + return t, portainer.APIKey{}, false } - tuple := val.(entry) + + tuple := val.(entry[T]) return tuple.user, tuple.apiKey, true } // Set persists a user/key entry to the cache -func (c *apiKeyCache) Set(digest string, user portainer.User, apiKey portainer.APIKey) { - c.cache.Add(digest, entry{ +func (c *ApiKeyCache[T]) Set(digest string, user T, apiKey portainer.APIKey) { + c.cache.Add(digest, entry[T]{ user: user, apiKey: apiKey, }) } // Delete evicts a digest's user/key entry key from the cache -func (c *apiKeyCache) Delete(digest string) { +func (c *ApiKeyCache[T]) Delete(digest string) { c.cache.Remove(digest) } // InvalidateUserKeyCache loops through all the api-keys associated to a user and removes them from the cache -func (c *apiKeyCache) InvalidateUserKeyCache(userId portainer.UserID) bool { +func (c *ApiKeyCache[T]) InvalidateUserKeyCache(userId portainer.UserID) bool { present := false + for _, k := range c.cache.Keys() { user, _, _ := c.Get(k.(string)) - if user.ID == userId { + if c.userCmpFn(user, userId) { present = c.cache.Remove(k) } } + return present } diff --git a/api/apikey/cache_test.go b/api/apikey/cache_test.go index 0821bec35..040423a63 100644 --- a/api/apikey/cache_test.go +++ b/api/apikey/cache_test.go @@ -10,11 +10,11 @@ import ( func Test_apiKeyCacheGet(t *testing.T) { is := assert.New(t) - keyCache := NewAPIKeyCache(10) + keyCache := NewAPIKeyCache(10, compareUser) // pre-populate cache - keyCache.cache.Add(string("foo"), entry{user: portainer.User{}, apiKey: portainer.APIKey{}}) - keyCache.cache.Add(string(""), entry{user: portainer.User{}, apiKey: portainer.APIKey{}}) + keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{}, apiKey: portainer.APIKey{}}) + keyCache.cache.Add(string(""), entry[portainer.User]{user: portainer.User{}, apiKey: portainer.APIKey{}}) tests := []struct { digest string @@ -45,7 +45,7 @@ func Test_apiKeyCacheGet(t *testing.T) { func Test_apiKeyCacheSet(t *testing.T) { is := assert.New(t) - keyCache := NewAPIKeyCache(10) + keyCache := NewAPIKeyCache(10, compareUser) // pre-populate cache keyCache.Set("bar", portainer.User{ID: 2}, portainer.APIKey{}) @@ -57,23 +57,23 @@ func Test_apiKeyCacheSet(t *testing.T) { val, ok := keyCache.cache.Get(string("bar")) is.True(ok) - tuple := val.(entry) + tuple := val.(entry[portainer.User]) is.Equal(portainer.User{ID: 2}, tuple.user) val, ok = keyCache.cache.Get(string("foo")) is.True(ok) - tuple = val.(entry) + tuple = val.(entry[portainer.User]) is.Equal(portainer.User{ID: 3}, tuple.user) } func Test_apiKeyCacheDelete(t *testing.T) { is := assert.New(t) - keyCache := NewAPIKeyCache(10) + keyCache := NewAPIKeyCache(10, compareUser) t.Run("Delete an existing entry", func(t *testing.T) { - keyCache.cache.Add(string("foo"), entry{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}}) + keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}}) keyCache.Delete("foo") _, ok := keyCache.cache.Get(string("foo")) @@ -128,7 +128,7 @@ func Test_apiKeyCacheLRU(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - keyCache := NewAPIKeyCache(test.cacheLen) + keyCache := NewAPIKeyCache(test.cacheLen, compareUser) for _, key := range test.key { keyCache.Set(key, portainer.User{ID: 1}, portainer.APIKey{}) @@ -150,10 +150,10 @@ func Test_apiKeyCacheLRU(t *testing.T) { func Test_apiKeyCacheInvalidateUserKeyCache(t *testing.T) { is := assert.New(t) - keyCache := NewAPIKeyCache(10) + keyCache := NewAPIKeyCache(10, compareUser) t.Run("Removes users keys from cache", func(t *testing.T) { - keyCache.cache.Add(string("foo"), entry{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}}) + keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}}) ok := keyCache.InvalidateUserKeyCache(1) is.True(ok) @@ -163,8 +163,8 @@ func Test_apiKeyCacheInvalidateUserKeyCache(t *testing.T) { }) t.Run("Does not affect other keys", func(t *testing.T) { - keyCache.cache.Add(string("foo"), entry{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}}) - keyCache.cache.Add(string("bar"), entry{user: portainer.User{ID: 2}, apiKey: portainer.APIKey{}}) + keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}}) + keyCache.cache.Add(string("bar"), entry[portainer.User]{user: portainer.User{ID: 2}, apiKey: portainer.APIKey{}}) ok := keyCache.InvalidateUserKeyCache(1) is.True(ok) diff --git a/api/apikey/service.go b/api/apikey/service.go index fc3c7f739..88b4e86f3 100644 --- a/api/apikey/service.go +++ b/api/apikey/service.go @@ -1,14 +1,15 @@ package apikey import ( + "crypto/rand" "crypto/sha256" "encoding/base64" "fmt" + "io" "time" portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/dataservices" - "github.com/portainer/portainer/api/internal/securecookie" "github.com/pkg/errors" ) @@ -20,30 +21,45 @@ var ErrInvalidAPIKey = errors.New("Invalid API key") type apiKeyService struct { apiKeyRepository dataservices.APIKeyRepository userRepository dataservices.UserService - cache *apiKeyCache + cache *ApiKeyCache[portainer.User] +} + +// GenerateRandomKey generates a random key of specified length +// source: https://github.com/gorilla/securecookie/blob/master/securecookie.go#L515 +func GenerateRandomKey(length int) []byte { + k := make([]byte, length) + if _, err := io.ReadFull(rand.Reader, k); err != nil { + return nil + } + + return k +} + +func compareUser(u portainer.User, id portainer.UserID) bool { + return u.ID == id } func NewAPIKeyService(apiKeyRepository dataservices.APIKeyRepository, userRepository dataservices.UserService) *apiKeyService { return &apiKeyService{ apiKeyRepository: apiKeyRepository, userRepository: userRepository, - cache: NewAPIKeyCache(defaultAPIKeyCacheSize), + cache: NewAPIKeyCache(DefaultAPIKeyCacheSize, compareUser), } } // HashRaw computes a hash digest of provided raw API key. func (a *apiKeyService) HashRaw(rawKey string) string { hashDigest := sha256.Sum256([]byte(rawKey)) + return base64.StdEncoding.EncodeToString(hashDigest[:]) } // GenerateApiKey generates a raw API key for a user (for one-time display). // The generated API key is stored in the cache and database. func (a *apiKeyService) GenerateApiKey(user portainer.User, description string) (string, *portainer.APIKey, error) { - randKey := securecookie.GenerateRandomKey(32) + randKey := GenerateRandomKey(32) encodedRawAPIKey := base64.StdEncoding.EncodeToString(randKey) prefixedAPIKey := portainerAPIKeyPrefix + encodedRawAPIKey - hashDigest := a.HashRaw(prefixedAPIKey) apiKey := &portainer.APIKey{ @@ -54,8 +70,7 @@ func (a *apiKeyService) GenerateApiKey(user portainer.User, description string) Digest: hashDigest, } - err := a.apiKeyRepository.Create(apiKey) - if err != nil { + if err := a.apiKeyRepository.Create(apiKey); err != nil { return "", nil, errors.Wrap(err, "Unable to create API key") } @@ -78,7 +93,6 @@ func (a *apiKeyService) GetAPIKeys(userID portainer.UserID) ([]portainer.APIKey, // GetDigestUserAndKey returns the user and api-key associated to a specified hash digest. // A cache lookup is performed first; if the user/api-key is not found in the cache, respective database lookups are performed. func (a *apiKeyService) GetDigestUserAndKey(digest string) (portainer.User, portainer.APIKey, error) { - // get api key from cache if possible cachedUser, cachedKey, ok := a.cache.Get(digest) if ok { return cachedUser, cachedKey, nil @@ -106,20 +120,21 @@ func (a *apiKeyService) UpdateAPIKey(apiKey *portainer.APIKey) error { if err != nil { return errors.Wrap(err, "Unable to retrieve API key") } + a.cache.Set(apiKey.Digest, user, *apiKey) + return a.apiKeyRepository.Update(apiKey.ID, apiKey) } // DeleteAPIKey deletes an API key and removes the digest/api-key entry from the cache. func (a *apiKeyService) DeleteAPIKey(apiKeyID portainer.APIKeyID) error { - // get api-key digest to remove from cache apiKey, err := a.apiKeyRepository.Read(apiKeyID) if err != nil { return errors.Wrap(err, fmt.Sprintf("Unable to retrieve API key: %d", apiKeyID)) } - // delete the user/api-key from cache a.cache.Delete(apiKey.Digest) + return a.apiKeyRepository.Delete(apiKeyID) } diff --git a/api/cli/cli.go b/api/cli/cli.go index d1b15975d..5bd9d07f9 100644 --- a/api/cli/cli.go +++ b/api/cli/cli.go @@ -17,17 +17,14 @@ import ( type Service struct{} var ( - errInvalidEndpointProtocol = errors.New("Invalid environment protocol: Portainer only supports unix://, npipe:// or tcp://") - errSocketOrNamedPipeNotFound = errors.New("Unable to locate Unix socket or named pipe") - errInvalidSnapshotInterval = errors.New("Invalid snapshot interval") - errAdminPassExcludeAdminPassFile = errors.New("Cannot use --admin-password with --admin-password-file") + ErrInvalidEndpointProtocol = errors.New("Invalid environment protocol: Portainer only supports unix://, npipe:// or tcp://") + ErrSocketOrNamedPipeNotFound = errors.New("Unable to locate Unix socket or named pipe") + ErrInvalidSnapshotInterval = errors.New("Invalid snapshot interval") + ErrAdminPassExcludeAdminPassFile = errors.New("Cannot use --admin-password with --admin-password-file") ) -// ParseFlags parse the CLI flags and return a portainer.Flags struct -func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) { - kingpin.Version(version) - - flags := &portainer.CLIFlags{ +func CLIFlags() *portainer.CLIFlags { + return &portainer.CLIFlags{ Addr: kingpin.Flag("bind", "Address and port to serve Portainer").Default(defaultBindAddress).Short('p').String(), AddrHTTPS: kingpin.Flag("bind-https", "Address and port to serve Portainer via https").Default(defaultHTTPSBindAddress).String(), TunnelAddr: kingpin.Flag("tunnel-addr", "Address to serve the tunnel server").Default(defaultTunnelServerAddress).String(), @@ -63,6 +60,13 @@ func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) { LogLevel: kingpin.Flag("log-level", "Set the minimum logging level to show").Default("INFO").Enum("DEBUG", "INFO", "WARN", "ERROR"), LogMode: kingpin.Flag("log-mode", "Set the logging output mode").Default("PRETTY").Enum("NOCOLOR", "PRETTY", "JSON"), } +} + +// ParseFlags parse the CLI flags and return a portainer.Flags struct +func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) { + kingpin.Version(version) + + flags := CLIFlags() kingpin.Parse() @@ -82,18 +86,16 @@ func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) { func (*Service) ValidateFlags(flags *portainer.CLIFlags) error { displayDeprecationWarnings(flags) - err := validateEndpointURL(*flags.EndpointURL) - if err != nil { + if err := validateEndpointURL(*flags.EndpointURL); err != nil { return err } - err = validateSnapshotInterval(*flags.SnapshotInterval) - if err != nil { + if err := validateSnapshotInterval(*flags.SnapshotInterval); err != nil { return err } if *flags.AdminPassword != "" && *flags.AdminPasswordFile != "" { - return errAdminPassExcludeAdminPassFile + return ErrAdminPassExcludeAdminPassFile } return nil @@ -115,15 +117,16 @@ func validateEndpointURL(endpointURL string) error { } if !strings.HasPrefix(endpointURL, "unix://") && !strings.HasPrefix(endpointURL, "tcp://") && !strings.HasPrefix(endpointURL, "npipe://") { - return errInvalidEndpointProtocol + return ErrInvalidEndpointProtocol } if strings.HasPrefix(endpointURL, "unix://") || strings.HasPrefix(endpointURL, "npipe://") { socketPath := strings.TrimPrefix(endpointURL, "unix://") socketPath = strings.TrimPrefix(socketPath, "npipe://") + if _, err := os.Stat(socketPath); err != nil { if os.IsNotExist(err) { - return errSocketOrNamedPipeNotFound + return ErrSocketOrNamedPipeNotFound } return err @@ -138,9 +141,8 @@ func validateSnapshotInterval(snapshotInterval string) error { return nil } - _, err := time.ParseDuration(snapshotInterval) - if err != nil { - return errInvalidSnapshotInterval + if _, err := time.ParseDuration(snapshotInterval); err != nil { + return ErrInvalidSnapshotInterval } return nil diff --git a/api/cmd/portainer/main.go b/api/cmd/portainer/main.go index 16bec1a03..ef7f1b224 100644 --- a/api/cmd/portainer/main.go +++ b/api/cmd/portainer/main.go @@ -56,14 +56,14 @@ import ( ) func initCLI() *portainer.CLIFlags { - var cliService portainer.CLIService = &cli.Service{} + cliService := &cli.Service{} + flags, err := cliService.ParseFlags(portainer.APIVersion) if err != nil { log.Fatal().Err(err).Msg("failed parsing flags") } - err = cliService.ValidateFlags(flags) - if err != nil { + if err := cliService.ValidateFlags(flags); err != nil { log.Fatal().Err(err).Msg("failed validating flags") } @@ -94,14 +94,14 @@ func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService port } store := datastore.NewStore(*flags.Data, fileService, connection) + isNew, err := store.Open() if err != nil { log.Fatal().Err(err).Msg("failed opening store") } if *flags.Rollback { - err := store.Rollback(false) - if err != nil { + if err := store.Rollback(false); err != nil { log.Fatal().Err(err).Msg("failed rolling back") } @@ -110,8 +110,7 @@ func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService port } // Init sets some defaults - it's basically a migration - err = store.Init() - if err != nil { + if err := store.Init(); err != nil { log.Fatal().Err(err).Msg("failed initializing data store") } @@ -133,25 +132,23 @@ func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService port } store.VersionService.UpdateVersion(&v) - err = updateSettingsFromFlags(store, flags) - if err != nil { + if err := updateSettingsFromFlags(store, flags); err != nil { log.Fatal().Err(err).Msg("failed updating settings from flags") } } else { - err = store.MigrateData() - if err != nil { + if err := store.MigrateData(); err != nil { log.Fatal().Err(err).Msg("failed migration") } } - err = updateSettingsFromFlags(store, flags) - if err != nil { + if err := updateSettingsFromFlags(store, flags); err != nil { log.Fatal().Err(err).Msg("failed updating settings from flags") } // this is for the db restore functionality - needs more tests. go func() { <-shutdownCtx.Done() + defer connection.Close() }() @@ -205,36 +202,16 @@ func initJWTService(userSessionTimeout string, dataStore dataservices.DataStore) userSessionTimeout = portainer.DefaultUserSessionTimeout } - jwtService, err := jwt.NewService(userSessionTimeout, dataStore) - if err != nil { - return nil, err - } - - return jwtService, nil + return jwt.NewService(userSessionTimeout, dataStore) } func initDigitalSignatureService() portainer.DigitalSignatureService { return crypto.NewECDSAService(os.Getenv("AGENT_SECRET")) } -func initCryptoService() portainer.CryptoService { - return &crypto.Service{} -} - -func initLDAPService() portainer.LDAPService { - return &ldap.Service{} -} - -func initOAuthService() portainer.OAuthService { - return oauth.NewService() -} - -func initGitService(ctx context.Context) portainer.GitService { - return git.NewService(ctx) -} - func initSSLService(addr, certPath, keyPath string, fileService portainer.FileService, dataStore dataservices.DataStore, shutdownTrigger context.CancelFunc) (*ssl.Service, error) { slices := strings.Split(addr, ":") + host := slices[0] if host == "" { host = "0.0.0.0" @@ -242,22 +219,13 @@ func initSSLService(addr, certPath, keyPath string, fileService portainer.FileSe sslService := ssl.NewService(fileService, dataStore, shutdownTrigger) - err := sslService.Init(host, certPath, keyPath) - if err != nil { + if err := sslService.Init(host, certPath, keyPath); err != nil { return nil, err } return sslService, nil } -func initDockerClientFactory(signatureService portainer.DigitalSignatureService, reverseTunnelService portainer.ReverseTunnelService) *dockerclient.ClientFactory { - return dockerclient.NewClientFactory(signatureService, reverseTunnelService) -} - -func initKubernetesClientFactory(signatureService portainer.DigitalSignatureService, reverseTunnelService portainer.ReverseTunnelService, dataStore dataservices.DataStore, instanceID, addrHTTPS, userSessionTimeout string) (*kubecli.ClientFactory, error) { - return kubecli.NewClientFactory(signatureService, reverseTunnelService, dataStore, instanceID, addrHTTPS, userSessionTimeout) -} - func initSnapshotService( snapshotIntervalFromFlag string, dataStore dataservices.DataStore, @@ -310,14 +278,12 @@ func updateSettingsFromFlags(dataStore dataservices.DataStore, flags *portainer. settings.BlackListedLabels = *flags.Labels } + settings.AgentSecret = "" if agentKey, ok := os.LookupEnv("AGENT_SECRET"); ok { settings.AgentSecret = agentKey - } else { - settings.AgentSecret = "" } - err = dataStore.Settings().UpdateSettings(settings) - if err != nil { + if err := dataStore.Settings().UpdateSettings(settings); err != nil { return err } @@ -340,6 +306,7 @@ func loadAndParseKeyPair(fileService portainer.FileService, signatureService por if err != nil { return err } + return signatureService.ParseKeyPair(private, public) } @@ -348,7 +315,9 @@ func generateAndStoreKeyPair(fileService portainer.FileService, signatureService if err != nil { return err } + privateHeader, publicHeader := signatureService.PEMHeaders() + return fileService.StoreKeyPair(private, public, privateHeader, publicHeader) } @@ -361,6 +330,7 @@ func initKeyPair(fileService portainer.FileService, signatureService portainer.D if existingKeyPair { return loadAndParseKeyPair(fileService, signatureService) } + return generateAndStoreKeyPair(fileService, signatureService) } @@ -378,6 +348,7 @@ func loadEncryptionSecretKey(keyfilename string) []byte { // return a 32 byte hash of the secret (required for AES) hash := sha256.Sum256(content) + return hash[:] } @@ -422,17 +393,17 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server { log.Fatal().Err(err).Msg("failed initializing JWT service") } - ldapService := initLDAPService() + ldapService := &ldap.Service{} - oauthService := initOAuthService() + oauthService := oauth.NewService() - gitService := initGitService(shutdownCtx) + gitService := git.NewService(shutdownCtx) openAMTService := openamt.NewService() - cryptoService := initCryptoService() + cryptoService := &crypto.Service{} - digitalSignatureService := initDigitalSignatureService() + signatureService := initDigitalSignatureService() edgeStacksService := edgestacks.NewService(dataStore) @@ -446,15 +417,18 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server { log.Fatal().Err(err).Msg("failed to get SSL settings") } - err = initKeyPair(fileService, digitalSignatureService) - if err != nil { + if err := initKeyPair(fileService, signatureService); err != nil { log.Fatal().Err(err).Msg("failed initializing key pair") } reverseTunnelService := chisel.NewService(dataStore, shutdownCtx, fileService) - dockerClientFactory := initDockerClientFactory(digitalSignatureService, reverseTunnelService) - kubernetesClientFactory, err := initKubernetesClientFactory(digitalSignatureService, reverseTunnelService, dataStore, instanceID, *flags.AddrHTTPS, settings.UserSessionTimeout) + dockerClientFactory := dockerclient.NewClientFactory(signatureService, reverseTunnelService) + + kubernetesClientFactory, err := kubecli.NewClientFactory(signatureService, reverseTunnelService, dataStore, instanceID, *flags.AddrHTTPS, settings.UserSessionTimeout) + if err != nil { + log.Fatal().Err(err).Msg("failed initializing Kubernetes Client Factory service") + } authorizationService := authorization.NewService(dataStore) authorizationService.K8sClientFactory = kubernetesClientFactory @@ -476,12 +450,12 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server { composeStackManager := initComposeStackManager(composeDeployer, proxyManager) - swarmStackManager, err := initSwarmStackManager(*flags.Assets, dockerConfigPath, digitalSignatureService, fileService, reverseTunnelService, dataStore) + swarmStackManager, err := initSwarmStackManager(*flags.Assets, dockerConfigPath, signatureService, fileService, reverseTunnelService, dataStore) if err != nil { log.Fatal().Err(err).Msg("failed initializing swarm stack manager") } - kubernetesDeployer := initKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, digitalSignatureService, proxyManager, *flags.Assets) + kubernetesDeployer := initKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, signatureService, proxyManager, *flags.Assets) pendingActionsService := pendingactions.NewService(dataStore, kubernetesClientFactory) pendingActionsService.RegisterHandler(actions.CleanNAPWithOverridePolicies, handlers.NewHandlerCleanNAPWithOverridePolicies(authorizationService, dataStore)) @@ -492,17 +466,17 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server { if err != nil { log.Fatal().Err(err).Msg("failed initializing snapshot service") } + snapshotService.Start() - proxyManager.NewProxyFactory(dataStore, digitalSignatureService, reverseTunnelService, dockerClientFactory, kubernetesClientFactory, kubernetesTokenCacheManager, gitService, snapshotService) + proxyManager.NewProxyFactory(dataStore, signatureService, reverseTunnelService, dockerClientFactory, kubernetesClientFactory, kubernetesTokenCacheManager, gitService, snapshotService) helmPackageManager, err := initHelmPackageManager(*flags.Assets) if err != nil { log.Fatal().Err(err).Msg("failed initializing helm package manager") } - err = edge.LoadEdgeJobs(dataStore, reverseTunnelService) - if err != nil { + if err := edge.LoadEdgeJobs(dataStore, reverseTunnelService); err != nil { log.Fatal().Err(err).Msg("failed loading edge jobs from database") } @@ -514,6 +488,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server { go endpointutils.InitEndpoint(shutdownCtx, adminCreationDone, flags, dataStore, snapshotService) adminPasswordHash := "" + if *flags.AdminPasswordFile != "" { content, err := fileService.GetFileContent(*flags.AdminPasswordFile, "") if err != nil { @@ -536,14 +511,14 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server { if len(users) == 0 { log.Info().Msg("created admin user with the given password.") + user := &portainer.User{ Username: "admin", Role: portainer.AdministratorRole, Password: adminPasswordHash, } - err := dataStore.User().Create(user) - if err != nil { + if err := dataStore.User().Create(user); err != nil { log.Fatal().Err(err).Msg("failed creating admin user") } @@ -554,8 +529,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server { } } - err = reverseTunnelService.StartTunnelServer(*flags.TunnelAddr, *flags.TunnelPort, snapshotService) - if err != nil { + if err := reverseTunnelService.StartTunnelServer(*flags.TunnelAddr, *flags.TunnelPort, snapshotService); err != nil { log.Fatal().Err(err).Msg("failed starting tunnel server") } @@ -613,7 +587,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server { ProxyManager: proxyManager, KubernetesTokenCacheManager: kubernetesTokenCacheManager, KubeClusterAccessService: kubeClusterAccessService, - SignatureService: digitalSignatureService, + SignatureService: signatureService, SnapshotService: snapshotService, SSLService: sslService, DockerClientFactory: dockerClientFactory, @@ -639,6 +613,7 @@ func main() { for { server := buildServer(flags) + log.Info(). Str("version", portainer.APIVersion). Str("build_number", build.BuildNumber). diff --git a/api/internal/concurrent/concurrent.go b/api/concurrent/concurrent.go similarity index 100% rename from api/internal/concurrent/concurrent.go rename to api/concurrent/concurrent.go diff --git a/api/database/boltdb/db.go b/api/database/boltdb/db.go index a0ef4df63..7d9d9786d 100644 --- a/api/database/boltdb/db.go +++ b/api/database/boltdb/db.go @@ -203,6 +203,7 @@ func (connection *DbConnection) ExportRaw(filename string) error { func (connection *DbConnection) ConvertToKey(v int) []byte { b := make([]byte, 8) binary.BigEndian.PutUint64(b, uint64(v)) + return b } diff --git a/api/database/boltdb/json.go b/api/database/boltdb/json.go index 47db59ad7..8970082e7 100644 --- a/api/database/boltdb/json.go +++ b/api/database/boltdb/json.go @@ -46,8 +46,8 @@ func (connection *DbConnection) UnmarshalObject(data []byte, object interface{}) return errors.Wrap(err, "Failed decrypting object") } } - e := json.Unmarshal(data, object) - if e != nil { + + if e := json.Unmarshal(data, object); e != nil { // Special case for the VERSION bucket. Here we're not using json // So we need to return it as a string s, ok := object.(*string) @@ -57,6 +57,7 @@ func (connection *DbConnection) UnmarshalObject(data []byte, object interface{}) *s = string(data) } + return err } @@ -71,7 +72,7 @@ func encrypt(plaintext []byte, passphrase []byte) (encrypted []byte, err error) } nonce := make([]byte, gcm.NonceSize()) - if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { return encrypted, err } diff --git a/api/database/boltdb/tx.go b/api/database/boltdb/tx.go index be10359bc..9aede519a 100644 --- a/api/database/boltdb/tx.go +++ b/api/database/boltdb/tx.go @@ -78,6 +78,7 @@ func (tx *DbTransaction) GetNextIdentifier(bucketName string) int { id, err := bucket.NextSequence() if err != nil { log.Error().Err(err).Str("bucket", bucketName).Msg("failed to get the next identifier") + return 0 } diff --git a/api/datastore/migrate_legacyversion.go b/api/datastore/migrate_legacyversion.go index e85b5d276..6a42f2c31 100644 --- a/api/datastore/migrate_legacyversion.go +++ b/api/datastore/migrate_legacyversion.go @@ -111,5 +111,6 @@ func (store *Store) finishMigrateLegacyVersion(versionToWrite *models.Version) e store.connection.DeleteObject(bucketName, []byte(legacyDBVersionKey)) store.connection.DeleteObject(bucketName, []byte(legacyEditionKey)) store.connection.DeleteObject(bucketName, []byte(legacyInstanceKey)) + return err } diff --git a/api/datastore/migrator/migrate_ce.go b/api/datastore/migrator/migrate_ce.go index 10b7c1425..176957caf 100644 --- a/api/datastore/migrator/migrate_ce.go +++ b/api/datastore/migrator/migrate_ce.go @@ -39,20 +39,19 @@ func (m *Migrator) Migrate() error { latestMigrations := m.LatestMigrations() if latestMigrations.Version.Equal(schemaVersion) && version.MigratorCount != len(latestMigrations.MigrationFuncs) { - err := runMigrations(latestMigrations.MigrationFuncs) - if err != nil { + if err := runMigrations(latestMigrations.MigrationFuncs); err != nil { return err } + newMigratorCount = len(latestMigrations.MigrationFuncs) } } else { // regular path when major/minor/patch versions differ for _, migration := range m.migrations { if schemaVersion.LessThan(migration.Version) { - log.Info().Msgf("migrating data to %s", migration.Version.String()) - err := runMigrations(migration.MigrationFuncs) - if err != nil { + + if err := runMigrations(migration.MigrationFuncs); err != nil { return err } } @@ -63,16 +62,14 @@ func (m *Migrator) Migrate() error { } } - err = m.Always() - if err != nil { + if err := m.Always(); err != nil { return migrationError(err, "Always migrations returned error") } version.SchemaVersion = portainer.APIVersion version.MigratorCount = newMigratorCount - err = m.versionService.UpdateVersion(version) - if err != nil { + if err := m.versionService.UpdateVersion(version); err != nil { return migrationError(err, "StoreDBVersion") } @@ -99,6 +96,7 @@ func (m *Migrator) NeedsMigration() bool { // In this particular instance we should log a fatal error if m.CurrentDBEdition() != portainer.PortainerCE { log.Fatal().Msg("the Portainer database is set for Portainer Business Edition, please follow the instructions in our documentation to downgrade it: https://documentation.portainer.io/v2.0-be/downgrade/be-to-ce/") + return false } diff --git a/api/datastore/migrator/migrate_dbversion100.go b/api/datastore/migrator/migrate_dbversion100.go index 14896fd48..458c10c95 100644 --- a/api/datastore/migrator/migrate_dbversion100.go +++ b/api/datastore/migrator/migrate_dbversion100.go @@ -7,6 +7,7 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/chisel/crypto" "github.com/portainer/portainer/api/dataservices" + "github.com/rs/zerolog/log" ) @@ -37,9 +38,11 @@ func (m *Migrator) convertSeedToPrivateKeyForDB100() error { log.Info().Msg("ServerInfo object not found") return nil } + log.Error(). Err(err). Msg("Failed to read ServerInfo from DB") + return err } @@ -49,14 +52,15 @@ func (m *Migrator) convertSeedToPrivateKeyForDB100() error { log.Error(). Err(err). Msg("Failed to read ServerInfo from DB") + return err } - err = m.fileService.StoreChiselPrivateKey(key) - if err != nil { + if err := m.fileService.StoreChiselPrivateKey(key); err != nil { log.Error(). Err(err). Msg("Failed to save Chisel private key to disk") + return err } } else { @@ -64,14 +68,14 @@ func (m *Migrator) convertSeedToPrivateKeyForDB100() error { } serverInfo.PrivateKeySeed = "" - err = m.TunnelServerService.UpdateInfo(serverInfo) - if err != nil { + if err := m.TunnelServerService.UpdateInfo(serverInfo); err != nil { log.Error(). Err(err). Msg("Failed to clean private key seed in DB") } else { log.Info().Msg("Success to migrate private key seed to private key file") } + return err } @@ -84,9 +88,8 @@ func (m *Migrator) updateEdgeStackStatusForDB100() error { } for _, edgeStack := range edgeStacks { - for environmentID, environmentStatus := range edgeStack.Status { - // skip if status is already updated + // Skip if status is already updated if len(environmentStatus.Status) > 0 { continue } @@ -146,8 +149,7 @@ func (m *Migrator) updateEdgeStackStatusForDB100() error { edgeStack.Status[environmentID] = environmentStatus } - err = m.edgeStackService.UpdateEdgeStack(edgeStack.ID, &edgeStack) - if err != nil { + if err := m.edgeStackService.UpdateEdgeStack(edgeStack.ID, &edgeStack); err != nil { return err } } diff --git a/api/datastore/migrator/migrate_dbversion23.go b/api/datastore/migrator/migrate_dbversion23.go index 0ccb08de0..f5010582d 100644 --- a/api/datastore/migrator/migrate_dbversion23.go +++ b/api/datastore/migrator/migrate_dbversion23.go @@ -32,8 +32,8 @@ func (m *Migrator) updateStacksToDB24() error { for idx := range stacks { stack := &stacks[idx] stack.Status = portainer.StackStatusActive - err := m.stackService.Update(stack.ID, stack) - if err != nil { + + if err := m.stackService.Update(stack.ID, stack); err != nil { return err } } diff --git a/api/datastore/test_data/output_24_to_latest.json b/api/datastore/test_data/output_24_to_latest.json index 424921170..ca6694872 100644 --- a/api/datastore/test_data/output_24_to_latest.json +++ b/api/datastore/test_data/output_24_to_latest.json @@ -583,7 +583,6 @@ "AuthenticationMethod": 1, "BlackListedLabels": [], "Edge": { - "AsyncMode": false, "CommandInterval": 0, "PingInterval": 0, "SnapshotInterval": 0 diff --git a/api/datastore/teststore.go b/api/datastore/teststore.go index 5df7dccc3..d581f25b3 100644 --- a/api/datastore/teststore.go +++ b/api/datastore/teststore.go @@ -52,27 +52,24 @@ func NewTestStore(t testing.TB, init, secure bool) (bool, *Store, func(), error) } if init { - err = store.Init() - if err != nil { + if err := store.Init(); err != nil { return newStore, nil, nil, err } } if newStore { - // from MigrateData + // From MigrateData v := models.Version{ SchemaVersion: portainer.APIVersion, Edition: int(portainer.PortainerCE), } - err = store.VersionService.UpdateVersion(&v) - if err != nil { + if err := store.VersionService.UpdateVersion(&v); err != nil { return newStore, nil, nil, err } } teardown := func() { - err := store.Close() - if err != nil { + if err := store.Close(); err != nil { log.Fatal().Err(err).Msg("") } } diff --git a/api/docker/container.go b/api/docker/container.go index f5167afbf..242f1cce9 100644 --- a/api/docker/container.go +++ b/api/docker/container.go @@ -36,7 +36,6 @@ func (c *ContainerService) Recreate(ctx context.Context, endpoint *portainer.End if err != nil { return nil, errors.Wrap(err, "create client error") } - defer cli.Close() log.Debug().Str("container_id", containerId).Msg("starting to fetch container information") diff --git a/api/http/csrf/csrf.go b/api/http/csrf/csrf.go index e2e641cef..5a230ceac 100644 --- a/api/http/csrf/csrf.go +++ b/api/http/csrf/csrf.go @@ -5,10 +5,10 @@ import ( "fmt" "net/http" + "github.com/portainer/portainer/api/http/security" httperror "github.com/portainer/portainer/pkg/libhttp/error" gorillacsrf "github.com/gorilla/csrf" - "github.com/portainer/portainer/api/http/security" "github.com/urfave/negroni" ) @@ -16,8 +16,7 @@ func WithProtect(handler http.Handler) (http.Handler, error) { handler = withSendCSRFToken(handler) token := make([]byte, 32) - _, err := rand.Read(token) - if err != nil { + if _, err := rand.Read(token); err != nil { return nil, fmt.Errorf("failed to generate CSRF token: %w", err) } @@ -32,7 +31,6 @@ func WithProtect(handler http.Handler) (http.Handler, error) { func withSendCSRFToken(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - sw := negroni.NewResponseWriter(w) sw.Before(func(sw negroni.ResponseWriter) { @@ -44,16 +42,15 @@ func withSendCSRFToken(handler http.Handler) http.Handler { }) handler.ServeHTTP(sw, r) - }) } func withSkipCSRF(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - skip, err := security.ShouldSkipCSRFCheck(r) if err != nil { httperror.WriteError(w, http.StatusForbidden, err.Error(), err) + return } diff --git a/api/http/handler/auth/authenticate.go b/api/http/handler/auth/authenticate.go index 4547d7795..b342f7b71 100644 --- a/api/http/handler/auth/authenticate.go +++ b/api/http/handler/auth/authenticate.go @@ -56,8 +56,7 @@ func (payload *authenticatePayload) Validate(r *http.Request) error { // @router /auth [post] func (handler *Handler) authenticate(rw http.ResponseWriter, r *http.Request) *httperror.HandlerError { var payload authenticatePayload - err := request.DecodeAndValidateJSONPayload(r, &payload) - if err != nil { + if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil { return httperror.BadRequest("Invalid request payload", err) } @@ -104,8 +103,7 @@ func isUserInitialAdmin(user *portainer.User) bool { } func (handler *Handler) authenticateInternal(w http.ResponseWriter, user *portainer.User, password string) *httperror.HandlerError { - err := handler.CryptoService.CompareHashAndData(user.Password, password) - if err != nil { + if err := handler.CryptoService.CompareHashAndData(user.Password, password); err != nil { return httperror.NewError(http.StatusUnprocessableEntity, "Invalid credentials", httperrors.ErrUnauthorized) } @@ -115,8 +113,7 @@ func (handler *Handler) authenticateInternal(w http.ResponseWriter, user *portai } func (handler *Handler) authenticateLDAP(w http.ResponseWriter, user *portainer.User, username, password string, ldapSettings *portainer.LDAPSettings) *httperror.HandlerError { - err := handler.LDAPService.AuthenticateUser(username, password, ldapSettings) - if err != nil { + if err := handler.LDAPService.AuthenticateUser(username, password, ldapSettings); err != nil { if errors.Is(err, httperrors.ErrUnauthorized) { return httperror.NewError(http.StatusUnprocessableEntity, "Invalid credentials", httperrors.ErrUnauthorized) } @@ -131,14 +128,12 @@ func (handler *Handler) authenticateLDAP(w http.ResponseWriter, user *portainer. PortainerAuthorizations: authorization.DefaultPortainerAuthorizations(), } - err = handler.DataStore.User().Create(user) - if err != nil { + if err := handler.DataStore.User().Create(user); err != nil { return httperror.InternalServerError("Unable to persist user inside the database", err) } } - err = handler.syncUserTeamsWithLDAPGroups(user, ldapSettings) - if err != nil { + if err := handler.syncUserTeamsWithLDAPGroups(user, ldapSettings); err != nil { log.Warn().Err(err).Msg("unable to automatically sync user teams with ldap") } @@ -186,7 +181,6 @@ func (handler *Handler) syncUserTeamsWithLDAPGroups(user *portainer.User, settin for _, team := range teams { if teamExists(team.Name, userGroups) { - if teamMembershipExists(team.ID, userMemberships) { continue } @@ -197,8 +191,7 @@ func (handler *Handler) syncUserTeamsWithLDAPGroups(user *portainer.User, settin Role: portainer.TeamMember, } - err := handler.DataStore.TeamMembership().Create(membership) - if err != nil { + if err := handler.DataStore.TeamMembership().Create(membership); err != nil { return err } } diff --git a/api/http/handler/auth/handler.go b/api/http/handler/auth/handler.go index 78a1aa2fb..3b7210fbf 100644 --- a/api/http/handler/auth/handler.go +++ b/api/http/handler/auth/handler.go @@ -41,5 +41,6 @@ func NewHandler(bouncer security.BouncerService, rateLimiter *security.RateLimit rateLimiter.LimitAccess(bouncer.PublicAccess(httperror.LoggerHandler(h.authenticate)))).Methods(http.MethodPost) h.Handle("/auth/logout", bouncer.PublicAccess(httperror.LoggerHandler(h.logout))).Methods(http.MethodPost) + return h } diff --git a/api/http/handler/auth/logout.go b/api/http/handler/auth/logout.go index f551cbcf1..d4117b283 100644 --- a/api/http/handler/auth/logout.go +++ b/api/http/handler/auth/logout.go @@ -4,7 +4,7 @@ import ( "net/http" "github.com/portainer/portainer/api/http/security" - "github.com/portainer/portainer/api/internal/logoutcontext" + "github.com/portainer/portainer/api/logoutcontext" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/response" ) diff --git a/api/http/handler/customtemplates/customtemplate_list.go b/api/http/handler/customtemplates/customtemplate_list.go index 7ed8fb6a7..581b219ae 100644 --- a/api/http/handler/customtemplates/customtemplate_list.go +++ b/api/http/handler/customtemplates/customtemplate_list.go @@ -4,14 +4,15 @@ import ( "net/http" "strconv" - "github.com/pkg/errors" portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/http/security" "github.com/portainer/portainer/api/internal/authorization" - "github.com/portainer/portainer/api/internal/slices" + "github.com/portainer/portainer/api/slicesx" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/response" + + "github.com/pkg/errors" "github.com/rs/zerolog/log" ) @@ -70,7 +71,7 @@ func (handler *Handler) customTemplateList(w http.ResponseWriter, r *http.Reques customTemplates = filterByType(customTemplates, templateTypes) if edge != nil { - customTemplates = slices.Filter(customTemplates, func(customTemplate portainer.CustomTemplate) bool { + customTemplates = slicesx.Filter(customTemplates, func(customTemplate portainer.CustomTemplate) bool { return customTemplate.EdgeTemplate == *edge }) } diff --git a/api/http/handler/docker/images/images_list.go b/api/http/handler/docker/images/images_list.go index b9995e5c2..23f7d5398 100644 --- a/api/http/handler/docker/images/images_list.go +++ b/api/http/handler/docker/images/images_list.go @@ -6,7 +6,7 @@ import ( "github.com/portainer/portainer/api/docker/client" "github.com/portainer/portainer/api/http/handler/docker/utils" - "github.com/portainer/portainer/api/internal/set" + "github.com/portainer/portainer/api/set" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/response" diff --git a/api/http/handler/docker/utils/filter_by_uac.go b/api/http/handler/docker/utils/filter_by_uac.go index bb8b41529..a01eec70a 100644 --- a/api/http/handler/docker/utils/filter_by_uac.go +++ b/api/http/handler/docker/utils/filter_by_uac.go @@ -7,7 +7,7 @@ import ( "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/http/security" "github.com/portainer/portainer/api/internal/authorization" - "github.com/portainer/portainer/api/internal/slices" + "github.com/portainer/portainer/api/slicesx" ) // filterByResourceControl filters a list of items based on the user's role and the resource control associated to the item. @@ -16,7 +16,7 @@ func FilterByResourceControl[T any](tx dataservices.DataStoreTx, items []T, rcTy return items, nil } - userTeamIDs := slices.Map(securityContext.UserMemberships, func(membership portainer.TeamMembership) portainer.TeamID { + userTeamIDs := slicesx.Map(securityContext.UserMemberships, func(membership portainer.TeamMembership) portainer.TeamID { return membership.TeamID }) @@ -32,5 +32,6 @@ func FilterByResourceControl[T any](tx dataservices.DataStoreTx, items []T, rcTy } } + return filteredItems, nil } diff --git a/api/http/handler/edgegroups/edgegroup_create.go b/api/http/handler/edgegroups/edgegroup_create.go index 28443883b..df12fc4c0 100644 --- a/api/http/handler/edgegroups/edgegroup_create.go +++ b/api/http/handler/edgegroups/edgegroup_create.go @@ -36,23 +36,25 @@ func (payload *edgeGroupCreatePayload) Validate(r *http.Request) error { func calculateEndpointsOrTags(tx dataservices.DataStoreTx, edgeGroup *portainer.EdgeGroup, endpoints []portainer.EndpointID, tagIDs []portainer.TagID) error { if edgeGroup.Dynamic { edgeGroup.TagIDs = tagIDs - } else { - endpointIDs := []portainer.EndpointID{} - for _, endpointID := range endpoints { - endpoint, err := tx.Endpoint().Endpoint(endpointID) - if err != nil { - return httperror.InternalServerError("Unable to retrieve environment from the database", err) - } + return nil + } - if endpointutils.IsEdgeEndpoint(endpoint) { - endpointIDs = append(endpointIDs, endpoint.ID) - } + endpointIDs := []portainer.EndpointID{} + + for _, endpointID := range endpoints { + endpoint, err := tx.Endpoint().Endpoint(endpointID) + if err != nil { + return httperror.InternalServerError("Unable to retrieve environment from the database", err) } - edgeGroup.Endpoints = endpointIDs + if endpointutils.IsEdgeEndpoint(endpoint) { + endpointIDs = append(endpointIDs, endpoint.ID) + } } + edgeGroup.Endpoints = endpointIDs + return nil } @@ -71,13 +73,13 @@ func calculateEndpointsOrTags(tx dataservices.DataStoreTx, edgeGroup *portainer. // @router /edge_groups [post] func (handler *Handler) edgeGroupCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError { var payload edgeGroupCreatePayload - err := request.DecodeAndValidateJSONPayload(r, &payload) - if err != nil { + if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil { return httperror.BadRequest("Invalid request payload", err) } var edgeGroup *portainer.EdgeGroup - err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { + + err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { edgeGroups, err := tx.EdgeGroup().ReadAll() if err != nil { return httperror.InternalServerError("Unable to retrieve Edge groups from the database", err) @@ -101,8 +103,7 @@ func (handler *Handler) edgeGroupCreate(w http.ResponseWriter, r *http.Request) return err } - err = tx.EdgeGroup().Create(edgeGroup) - if err != nil { + if err := tx.EdgeGroup().Create(edgeGroup); err != nil { return httperror.InternalServerError("Unable to persist the Edge group inside the database", err) } diff --git a/api/http/handler/edgegroups/edgegroup_update.go b/api/http/handler/edgegroups/edgegroup_update.go index 38f975739..f4aaa5b92 100644 --- a/api/http/handler/edgegroups/edgegroup_update.go +++ b/api/http/handler/edgegroups/edgegroup_update.go @@ -9,7 +9,7 @@ import ( "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/internal/edge" "github.com/portainer/portainer/api/internal/endpointutils" - "github.com/portainer/portainer/api/internal/unique" + "github.com/portainer/portainer/api/slicesx" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" @@ -113,7 +113,7 @@ func (handler *Handler) edgeGroupUpdate(w http.ResponseWriter, r *http.Request) } newRelatedEndpoints := edge.EdgeGroupRelatedEndpoints(edgeGroup, endpoints, endpointGroups) - endpointsToUpdate := unique.Unique(append(newRelatedEndpoints, oldRelatedEndpoints...)) + endpointsToUpdate := slicesx.Unique(append(newRelatedEndpoints, oldRelatedEndpoints...)) edgeJobs, err := tx.EdgeJob().ReadAll() if err != nil { diff --git a/api/http/handler/edgestacks/edgestack_test.go b/api/http/handler/edgestacks/edgestack_test.go index a2c619a31..1db87a920 100644 --- a/api/http/handler/edgestacks/edgestack_test.go +++ b/api/http/handler/edgestacks/edgestack_test.go @@ -31,8 +31,7 @@ func setupHandler(t *testing.T) (*Handler, string) { } user := &portainer.User{ID: 2, Username: "admin", Role: portainer.AdministratorRole} - err = store.User().Create(user) - if err != nil { + if err := store.User().Create(user); err != nil { t.Fatal(err) } @@ -66,8 +65,7 @@ func setupHandler(t *testing.T) (*Handler, string) { } settings.EnableEdgeComputeFeatures = true - err = handler.DataStore.Settings().UpdateSettings(settings) - if err != nil { + if err := handler.DataStore.Settings().UpdateSettings(settings); err != nil { t.Fatal(err) } @@ -88,8 +86,7 @@ func createEndpointWithId(t *testing.T, store dataservices.DataStore, endpointID LastCheckInDate: time.Now().Unix(), } - err := store.Endpoint().Create(&endpoint) - if err != nil { + if err := store.Endpoint().Create(&endpoint); err != nil { t.Fatal(err) } @@ -112,8 +109,7 @@ func createEdgeStack(t *testing.T, store dataservices.DataStore, endpointID port PartialMatch: false, } - err := store.EdgeGroup().Create(&edgeGroup) - if err != nil { + if err := store.EdgeGroup().Create(&edgeGroup); err != nil { t.Fatal(err) } @@ -138,13 +134,11 @@ func createEdgeStack(t *testing.T, store dataservices.DataStore, endpointID port }, } - err = store.EdgeStack().Create(edgeStack.ID, &edgeStack) - if err != nil { + if err := store.EdgeStack().Create(edgeStack.ID, &edgeStack); err != nil { t.Fatal(err) } - err = store.EndpointRelation().Create(&endpointRelation) - if err != nil { + if err := store.EndpointRelation().Create(&endpointRelation); err != nil { t.Fatal(err) } diff --git a/api/http/handler/edgestacks/edgestack_update.go b/api/http/handler/edgestacks/edgestack_update.go index d9563bf16..5d0a52232 100644 --- a/api/http/handler/edgestacks/edgestack_update.go +++ b/api/http/handler/edgestacks/edgestack_update.go @@ -6,7 +6,7 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/internal/edge" - "github.com/portainer/portainer/api/internal/set" + "github.com/portainer/portainer/api/set" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/response" diff --git a/api/http/handler/edgestacks/edgestack_update_test.go b/api/http/handler/edgestacks/edgestack_update_test.go index 973cec42d..7e4a9b23c 100644 --- a/api/http/handler/edgestacks/edgestack_update_test.go +++ b/api/http/handler/edgestacks/edgestack_update_test.go @@ -9,6 +9,7 @@ import ( "testing" portainer "github.com/portainer/portainer/api" + "github.com/stretchr/testify/require" "github.com/segmentio/encoding/json" ) @@ -24,8 +25,7 @@ func TestUpdateAndInspect(t *testing.T) { endpointID := portainer.EndpointID(6) newEndpoint := createEndpointWithId(t, handler.DataStore, endpointID) - err := handler.DataStore.Endpoint().Create(&newEndpoint) - if err != nil { + if err := handler.DataStore.Endpoint().Create(&newEndpoint); err != nil { t.Fatal(err) } @@ -36,8 +36,7 @@ func TestUpdateAndInspect(t *testing.T) { }, } - err = handler.DataStore.EndpointRelation().Create(&endpointRelation) - if err != nil { + if err := handler.DataStore.EndpointRelation().Create(&endpointRelation); err != nil { t.Fatal(err) } @@ -50,8 +49,7 @@ func TestUpdateAndInspect(t *testing.T) { PartialMatch: false, } - err = handler.DataStore.EdgeGroup().Create(&newEdgeGroup) - if err != nil { + if err := handler.DataStore.EdgeGroup().Create(&newEdgeGroup); err != nil { t.Fatal(err) } @@ -96,8 +94,7 @@ func TestUpdateAndInspect(t *testing.T) { } updatedStack := portainer.EdgeStack{} - err = json.NewDecoder(rec.Body).Decode(&updatedStack) - if err != nil { + if err := json.NewDecoder(rec.Body).Decode(&updatedStack); err != nil { t.Fatal("error decoding response:", err) } @@ -120,7 +117,6 @@ func TestUpdateWithInvalidEdgeGroups(t *testing.T) { endpoint := createEndpoint(t, handler.DataStore) edgeStack := createEdgeStack(t, handler.DataStore, endpoint.ID) - //newEndpoint := createEndpoint(t, handler.DataStore) newEdgeGroup := portainer.EdgeGroup{ ID: 2, Name: "EdgeGroup 2", @@ -130,7 +126,8 @@ func TestUpdateWithInvalidEdgeGroups(t *testing.T) { PartialMatch: false, } - handler.DataStore.EdgeGroup().Create(&newEdgeGroup) + err := handler.DataStore.EdgeGroup().Create(&newEdgeGroup) + require.NoError(t, err) cases := []struct { Name string diff --git a/api/http/handler/endpointedge/endpointedge_status_inspect_test.go b/api/http/handler/endpointedge/endpointedge_status_inspect_test.go index 41b8eb2de..b173446e7 100644 --- a/api/http/handler/endpointedge/endpointedge_status_inspect_test.go +++ b/api/http/handler/endpointedge/endpointedge_status_inspect_test.go @@ -18,6 +18,7 @@ import ( "github.com/segmentio/encoding/json" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type endpointTestCase struct { @@ -99,8 +100,7 @@ func mustSetupHandler(t *testing.T) *Handler { } settings.TrustOnFirstConnect = true - err = store.Settings().UpdateSettings(settings) - if err != nil { + if err = store.Settings().UpdateSettings(settings); err != nil { t.Fatalf("could not update settings: %s", err) } @@ -122,8 +122,7 @@ func createEndpoint(handler *Handler, endpoint portainer.Endpoint, endpointRelat return nil } - err = handler.DataStore.Endpoint().Create(&endpoint) - if err != nil { + if err := handler.DataStore.Endpoint().Create(&endpoint); err != nil { return err } @@ -134,14 +133,13 @@ func TestMissingEdgeIdentifier(t *testing.T) { handler := mustSetupHandler(t) endpointID := portainer.EndpointID(45) - err := createEndpoint(handler, portainer.Endpoint{ + if err := createEndpoint(handler, portainer.Endpoint{ ID: endpointID, Name: "endpoint-id-45", Type: portainer.EdgeAgentOnDockerEnvironment, URL: "https://portainer.io:9443", EdgeID: "edge-id", - }, portainer.EndpointRelation{EndpointID: endpointID}) - if err != nil { + }, portainer.EndpointRelation{EndpointID: endpointID}); err != nil { t.Fatal(err) } @@ -201,8 +199,7 @@ func TestLastCheckInDateIncreases(t *testing.T) { EndpointID: endpoint.ID, } - err := createEndpoint(handler, endpoint, endpointRelation) - if err != nil { + if err := createEndpoint(handler, endpoint, endpointRelation); err != nil { t.Fatal(err) } @@ -212,6 +209,7 @@ func TestLastCheckInDateIncreases(t *testing.T) { if err != nil { t.Fatal("request error:", err) } + req.Header.Set(portainer.PortainerAgentEdgeIDHeader, "edge-id") req.Header.Set(portainer.HTTPResponseAgentPlatform, "1") @@ -246,8 +244,7 @@ func TestEmptyEdgeIdWithAgentPlatformHeader(t *testing.T) { EndpointID: endpoint.ID, } - err := createEndpoint(handler, endpoint, endpointRelation) - if err != nil { + if err := createEndpoint(handler, endpoint, endpointRelation); err != nil { t.Fatal(err) } @@ -255,6 +252,7 @@ func TestEmptyEdgeIdWithAgentPlatformHeader(t *testing.T) { if err != nil { t.Fatal("request error:", err) } + req.Header.Set(portainer.PortainerAgentEdgeIDHeader, edgeId) req.Header.Set(portainer.HTTPResponseAgentPlatform, "1") @@ -308,10 +306,11 @@ func TestEdgeStackStatus(t *testing.T) { edgeStack.ID: true, }, } - handler.DataStore.EdgeStack().Create(edgeStack.ID, &edgeStack) - err := createEndpoint(handler, endpoint, endpointRelation) - if err != nil { + err := handler.DataStore.EdgeStack().Create(edgeStack.ID, &edgeStack) + require.NoError(t, err) + + if err := createEndpoint(handler, endpoint, endpointRelation); err != nil { t.Fatal(err) } @@ -319,6 +318,7 @@ func TestEdgeStackStatus(t *testing.T) { if err != nil { t.Fatal("request error:", err) } + req.Header.Set(portainer.PortainerAgentEdgeIDHeader, "edge-id") req.Header.Set(portainer.HTTPResponseAgentPlatform, "1") @@ -330,8 +330,7 @@ func TestEdgeStackStatus(t *testing.T) { } var data endpointEdgeStatusInspectResponse - err = json.NewDecoder(rec.Body).Decode(&data) - if err != nil { + if err := json.NewDecoder(rec.Body).Decode(&data); err != nil { t.Fatal("error decoding response:", err) } @@ -357,8 +356,7 @@ func TestEdgeJobsResponse(t *testing.T) { EndpointID: endpoint.ID, } - err := createEndpoint(handler, endpoint, endpointRelation) - if err != nil { + if err := createEndpoint(handler, endpoint, endpointRelation); err != nil { t.Fatal(err) } @@ -384,6 +382,7 @@ func TestEdgeJobsResponse(t *testing.T) { if err != nil { t.Fatal("request error:", err) } + req.Header.Set(portainer.PortainerAgentEdgeIDHeader, "edge-id") req.Header.Set(portainer.HTTPResponseAgentPlatform, "1") @@ -395,8 +394,7 @@ func TestEdgeJobsResponse(t *testing.T) { } var data endpointEdgeStatusInspectResponse - err = json.NewDecoder(rec.Body).Decode(&data) - if err != nil { + if err := json.NewDecoder(rec.Body).Decode(&data); err != nil { t.Fatal("error decoding response:", err) } diff --git a/api/http/handler/endpointgroups/endpointgroup_update.go b/api/http/handler/endpointgroups/endpointgroup_update.go index 2a4e2576d..19c68658d 100644 --- a/api/http/handler/endpointgroups/endpointgroup_update.go +++ b/api/http/handler/endpointgroups/endpointgroup_update.go @@ -8,8 +8,8 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/internal/endpointutils" - "github.com/portainer/portainer/api/internal/tag" "github.com/portainer/portainer/api/pendingactions/handlers" + "github.com/portainer/portainer/api/tag" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/response" diff --git a/api/http/handler/endpoints/endpoint_agent_versions.go b/api/http/handler/endpoints/endpoint_agent_versions.go index 00e83c640..ced7d0b06 100644 --- a/api/http/handler/endpoints/endpoint_agent_versions.go +++ b/api/http/handler/endpoints/endpoint_agent_versions.go @@ -4,7 +4,7 @@ import ( "net/http" "github.com/portainer/portainer/api/http/security" - "github.com/portainer/portainer/api/internal/set" + "github.com/portainer/portainer/api/set" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/response" ) diff --git a/api/http/handler/endpoints/endpoint_create.go b/api/http/handler/endpoints/endpoint_create.go index a4c0798ca..c1ef2f393 100644 --- a/api/http/handler/endpoints/endpoint_create.go +++ b/api/http/handler/endpoints/endpoint_create.go @@ -73,8 +73,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error { payload.GroupID = groupID var tagIDs []portainer.TagID - err = request.RetrieveMultiPartFormJSONValue(r, "TagIds", &tagIDs, true) - if err != nil { + if err := request.RetrieveMultiPartFormJSONValue(r, "TagIds", &tagIDs, true); err != nil { return errors.New("invalid TagIds parameter") } payload.TagIDs = tagIDs @@ -96,6 +95,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error { if err != nil { return errors.New("invalid CA certificate file. Ensure that the file is uploaded correctly") } + payload.TLSCACertFile = caCert } @@ -110,6 +110,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error { if err != nil { return errors.New("invalid key file. Ensure that the file is uploaded correctly") } + payload.TLSKeyFile = key } } @@ -120,6 +121,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error { if err != nil { return errors.New("invalid Azure application ID") } + payload.AzureApplicationID = azureApplicationID azureTenantID, err := request.RetrieveMultiPartFormValue(r, "AzureTenantID", false) @@ -139,6 +141,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error { if err != nil || strings.EqualFold("", strings.Trim(endpointURL, " ")) { return errors.New("URL cannot be empty") } + payload.URL = endpointURL publicURL, _ := request.RetrieveMultiPartFormValue(r, "PublicURL", true) @@ -156,10 +159,10 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error { } gpus := make([]portainer.Pair, 0) - err = request.RetrieveMultiPartFormJSONValue(r, "Gpus", &gpus, true) - if err != nil { + if err := request.RetrieveMultiPartFormJSONValue(r, "Gpus", &gpus, true); err != nil { return errors.New("invalid Gpus parameter") } + payload.Gpus = gpus edgeCheckinInterval, _ := request.RetrieveNumericMultiPartFormValue(r, "EdgeCheckinInterval", true) @@ -206,8 +209,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error { // @router /endpoints [post] func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError { payload := &endpointCreatePayload{} - err := payload.Validate(r) - if err != nil { + if err := payload.Validate(r); err != nil { return httperror.BadRequest("Invalid request payload", err) } @@ -268,8 +270,7 @@ func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) * ) } - err = handler.DataStore.EndpointRelation().Create(relationObject) - if err != nil { + if err := handler.DataStore.EndpointRelation().Create(relationObject); err != nil { return httperror.InternalServerError("Unable to persist the relation object inside the database", err) } @@ -278,6 +279,7 @@ func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) * func (handler *Handler) createEndpoint(tx dataservices.DataStoreTx, payload *endpointCreatePayload) (*portainer.Endpoint, *httperror.HandlerError) { var err error + switch payload.EndpointCreationType { case azureEnvironment: return handler.createAzureEndpoint(tx, payload) @@ -329,8 +331,7 @@ func (handler *Handler) createAzureEndpoint(tx dataservices.DataStoreTx, payload } httpClient := client.NewHTTPClient() - _, err := httpClient.ExecuteAzureAuthenticationRequest(&credentials) - if err != nil { + if _, err := httpClient.ExecuteAzureAuthenticationRequest(&credentials); err != nil { return nil, httperror.InternalServerError("Unable to authenticate against Azure", err) } @@ -352,8 +353,7 @@ func (handler *Handler) createAzureEndpoint(tx dataservices.DataStoreTx, payload Kubernetes: portainer.KubernetesDefault(), } - err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint) - if err != nil { + if err := handler.saveEndpointAndUpdateAuthorizations(tx, endpoint); err != nil { return nil, httperror.InternalServerError("An error occurred while trying to create the environment", err) } @@ -405,8 +405,7 @@ func (handler *Handler) createEdgeAgentEndpoint(tx dataservices.DataStoreTx, pay endpoint.EdgeID = edgeID.String() } - err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint) - if err != nil { + if err := handler.saveEndpointAndUpdateAuthorizations(tx, endpoint); err != nil { return nil, httperror.InternalServerError("An error occurred while trying to create the environment", err) } @@ -443,8 +442,7 @@ func (handler *Handler) createUnsecuredEndpoint(tx dataservices.DataStoreTx, pay Kubernetes: portainer.KubernetesDefault(), } - err := handler.snapshotAndPersistEndpoint(tx, endpoint) - if err != nil { + if err := handler.snapshotAndPersistEndpoint(tx, endpoint); err != nil { return nil, err } @@ -478,8 +476,7 @@ func (handler *Handler) createKubernetesEndpoint(tx dataservices.DataStoreTx, pa Kubernetes: portainer.KubernetesDefault(), } - err := handler.snapshotAndPersistEndpoint(tx, endpoint) - if err != nil { + if err := handler.snapshotAndPersistEndpoint(tx, endpoint); err != nil { return nil, err } @@ -510,13 +507,11 @@ func (handler *Handler) createTLSSecuredEndpoint(tx dataservices.DataStoreTx, pa endpoint.Agent.Version = agentVersion - err := handler.storeTLSFiles(endpoint, payload) - if err != nil { + if err := handler.storeTLSFiles(endpoint, payload); err != nil { return nil, err } - err = handler.snapshotAndPersistEndpoint(tx, endpoint) - if err != nil { + if err := handler.snapshotAndPersistEndpoint(tx, endpoint); err != nil { return nil, err } @@ -524,17 +519,16 @@ func (handler *Handler) createTLSSecuredEndpoint(tx dataservices.DataStoreTx, pa } func (handler *Handler) snapshotAndPersistEndpoint(tx dataservices.DataStoreTx, endpoint *portainer.Endpoint) *httperror.HandlerError { - err := handler.SnapshotService.SnapshotEndpoint(endpoint) - if err != nil { + if err := handler.SnapshotService.SnapshotEndpoint(endpoint); err != nil { if (endpoint.Type == portainer.AgentOnDockerEnvironment && strings.Contains(err.Error(), "Invalid request signature")) || (endpoint.Type == portainer.AgentOnKubernetesEnvironment && strings.Contains(err.Error(), "unknown")) { err = errors.New("agent already paired with another Portainer instance") } + return httperror.InternalServerError("Unable to initiate communications with environment", err) } - err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint) - if err != nil { + if err := handler.saveEndpointAndUpdateAuthorizations(tx, endpoint); err != nil { return httperror.InternalServerError("An error occurred while trying to create the environment", err) } @@ -555,16 +549,14 @@ func (handler *Handler) saveEndpointAndUpdateAuthorizations(tx dataservices.Data AllowStackManagementForRegularUsers: true, } - err := tx.Endpoint().Create(endpoint) - if err != nil { + if err := tx.Endpoint().Create(endpoint); err != nil { return err } for _, tagID := range endpoint.TagIDs { - err = tx.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { + if err := tx.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { tag.Endpoints[endpoint.ID] = true - }) - if err != nil { + }); err != nil { return err } } @@ -580,22 +572,26 @@ func (handler *Handler) storeTLSFiles(endpoint *portainer.Endpoint, payload *end if err != nil { return httperror.InternalServerError("Unable to persist TLS CA certificate file on disk", err) } + endpoint.TLSConfig.TLSCACertPath = caCertPath } - if !payload.TLSSkipClientVerify { - certPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileCert, payload.TLSCertFile) - if err != nil { - return httperror.InternalServerError("Unable to persist TLS certificate file on disk", err) - } - endpoint.TLSConfig.TLSCertPath = certPath - - keyPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileKey, payload.TLSKeyFile) - if err != nil { - return httperror.InternalServerError("Unable to persist TLS key file on disk", err) - } - endpoint.TLSConfig.TLSKeyPath = keyPath + if payload.TLSSkipClientVerify { + return nil } + certPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileCert, payload.TLSCertFile) + if err != nil { + return httperror.InternalServerError("Unable to persist TLS certificate file on disk", err) + } + + endpoint.TLSConfig.TLSCertPath = certPath + + keyPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileKey, payload.TLSKeyFile) + if err != nil { + return httperror.InternalServerError("Unable to persist TLS key file on disk", err) + } + endpoint.TLSConfig.TLSKeyPath = keyPath + return nil } diff --git a/api/http/handler/endpoints/endpoint_delete_test.go b/api/http/handler/endpoints/endpoint_delete_test.go index e0117a9f3..4b9034a31 100644 --- a/api/http/handler/endpoints/endpoint_delete_test.go +++ b/api/http/handler/endpoints/endpoint_delete_test.go @@ -30,24 +30,22 @@ func TestEndpointDeleteEdgeGroupsConcurrently(t *testing.T) { for i := 0; i < endpointsCount; i++ { endpointID := portainer.EndpointID(i) + 1 - err := store.Endpoint().Create(&portainer.Endpoint{ + if err := store.Endpoint().Create(&portainer.Endpoint{ ID: endpointID, Name: "env-" + strconv.Itoa(int(endpointID)), Type: portainer.EdgeAgentOnDockerEnvironment, - }) - if err != nil { + }); err != nil { t.Fatal("could not create endpoint:", err) } endpointIDs = append(endpointIDs, endpointID) } - err := store.EdgeGroup().Create(&portainer.EdgeGroup{ + if err := store.EdgeGroup().Create(&portainer.EdgeGroup{ ID: 1, Name: "edgegroup-1", Endpoints: endpointIDs, - }) - if err != nil { + }); err != nil { t.Fatal("could not create edge group:", err) } diff --git a/api/http/handler/endpoints/endpoint_list_test.go b/api/http/handler/endpoints/endpoint_list_test.go index 9ddedccef..94a058875 100644 --- a/api/http/handler/endpoints/endpoint_list_test.go +++ b/api/http/handler/endpoints/endpoint_list_test.go @@ -102,7 +102,6 @@ func Test_EndpointList_AgentVersion(t *testing.T) { } func Test_endpointList_edgeFilter(t *testing.T) { - trustedEdgeAsync := portainer.Endpoint{ID: 1, UserTrusted: true, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: true}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment} untrustedEdgeAsync := portainer.Endpoint{ID: 2, UserTrusted: false, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: true}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment} regularUntrustedEdgeStandard := portainer.Endpoint{ID: 3, UserTrusted: false, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: false}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment} @@ -227,8 +226,7 @@ func doEndpointListRequest(req *http.Request, h *Handler, is *assert.Assertions) } resp := []portainer.Endpoint{} - err = json.Unmarshal(body, &resp) - if err != nil { + if err := json.Unmarshal(body, &resp); err != nil { return nil, err } diff --git a/api/http/handler/endpoints/endpoint_registries_list.go b/api/http/handler/endpoints/endpoint_registries_list.go index aad91ecea..806f3e25c 100644 --- a/api/http/handler/endpoints/endpoint_registries_list.go +++ b/api/http/handler/endpoints/endpoint_registries_list.go @@ -34,12 +34,10 @@ func (handler *Handler) endpointRegistriesList(w http.ResponseWriter, r *http.Re } var registries []portainer.Registry - err = handler.DataStore.ViewTx(func(tx dataservices.DataStoreTx) error { + if err := handler.DataStore.ViewTx(func(tx dataservices.DataStoreTx) error { registries, err = handler.listRegistries(tx, r, portainer.EndpointID(endpointID)) return err - }) - - if err != nil { + }); err != nil { var httpErr *httperror.HandlerError if errors.As(err, &httpErr) { return httpErr @@ -104,11 +102,9 @@ func (handler *Handler) filterKubernetesEndpointRegistries(r *http.Request, regi } if namespaceParam != "" { - authorized, err := handler.isNamespaceAuthorized(endpoint, namespaceParam, user.ID, memberships, isAdmin) - if err != nil { + if authorized, err := handler.isNamespaceAuthorized(endpoint, namespaceParam, user.ID, memberships, isAdmin); err != nil { return nil, httperror.NotFound("Unable to check for namespace authorization", err) - } - if !authorized { + } else if !authorized { return nil, httperror.Forbidden("User is not authorized to use namespace", errors.New("user is not authorized to use namespace")) } diff --git a/api/http/handler/endpoints/filter.go b/api/http/handler/endpoints/filter.go index 950780fc1..ee2029aa2 100644 --- a/api/http/handler/endpoints/filter.go +++ b/api/http/handler/endpoints/filter.go @@ -13,7 +13,7 @@ import ( "github.com/portainer/portainer/api/http/handler/edgegroups" "github.com/portainer/portainer/api/internal/edge" "github.com/portainer/portainer/api/internal/endpointutils" - "github.com/portainer/portainer/api/internal/unique" + "github.com/portainer/portainer/api/slicesx" "github.com/portainer/portainer/pkg/libhttp/request" "github.com/pkg/errors" @@ -254,6 +254,7 @@ func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId port if err != nil { return nil, errors.WithMessage(err, "Unable to retrieve edge group from the database") } + if edgeGroup.Dynamic { endpointIDs, err := edgegroups.GetEndpointsByTags(datastore, edgeGroup.TagIDs, edgeGroup.PartialMatch) if err != nil { @@ -261,6 +262,7 @@ func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId port } edgeGroup.Endpoints = endpointIDs } + envIds = append(envIds, edgeGroup.Endpoints...) } @@ -275,7 +277,7 @@ func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId port envIds = envIds[:n] } - uniqueIds := unique.Unique(envIds) + uniqueIds := slicesx.Unique(envIds) filteredEndpoints := filteredEndpointsByIds(endpoints, uniqueIds) return filteredEndpoints, nil diff --git a/api/http/handler/endpoints/filter_test.go b/api/http/handler/endpoints/filter_test.go index f05f00cf5..fda61acee 100644 --- a/api/http/handler/endpoints/filter_test.go +++ b/api/http/handler/endpoints/filter_test.go @@ -5,8 +5,8 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/datastore" - "github.com/portainer/portainer/api/internal/slices" "github.com/portainer/portainer/api/internal/testhelpers" + "github.com/portainer/portainer/api/slicesx" "github.com/stretchr/testify/assert" ) @@ -129,7 +129,7 @@ func Test_Filter_edgeFilter(t *testing.T) { func Test_Filter_excludeIDs(t *testing.T) { ids := []portainer.EndpointID{1, 2, 3, 4, 5, 6, 7, 8, 9} - environments := slices.Map(ids, func(id portainer.EndpointID) portainer.Endpoint { + environments := slicesx.Map(ids, func(id portainer.EndpointID) portainer.Endpoint { return portainer.Endpoint{ID: id, GroupID: 1, Type: portainer.DockerEnvironment} }) diff --git a/api/http/handler/endpoints/sort_test.go b/api/http/handler/endpoints/sort_test.go index aad83f751..2eb748e33 100644 --- a/api/http/handler/endpoints/sort_test.go +++ b/api/http/handler/endpoints/sort_test.go @@ -4,7 +4,8 @@ import ( "testing" portainer "github.com/portainer/portainer/api" - "github.com/portainer/portainer/api/internal/slices" + "github.com/portainer/portainer/api/slicesx" + "github.com/stretchr/testify/assert" ) @@ -162,7 +163,7 @@ func TestSortEndpointsByField(t *testing.T) { } func getEndpointIDs(environments []portainer.Endpoint) []portainer.EndpointID { - return slices.Map(environments, func(environment portainer.Endpoint) portainer.EndpointID { + return slicesx.Map(environments, func(environment portainer.Endpoint) portainer.EndpointID { return environment.ID }) } diff --git a/api/http/handler/endpoints/update_edge_relations.go b/api/http/handler/endpoints/update_edge_relations.go index 23f6302d4..6e1aa1861 100644 --- a/api/http/handler/endpoints/update_edge_relations.go +++ b/api/http/handler/endpoints/update_edge_relations.go @@ -6,7 +6,7 @@ import ( "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/internal/edge" "github.com/portainer/portainer/api/internal/endpointutils" - "github.com/portainer/portainer/api/internal/set" + "github.com/portainer/portainer/api/set" ) // updateEdgeRelations updates the edge stacks associated to an edge endpoint diff --git a/api/http/handler/endpoints/utils_update_edge_groups.go b/api/http/handler/endpoints/utils_update_edge_groups.go index 42e1dba76..bd9c413d7 100644 --- a/api/http/handler/endpoints/utils_update_edge_groups.go +++ b/api/http/handler/endpoints/utils_update_edge_groups.go @@ -6,7 +6,7 @@ import ( "github.com/pkg/errors" portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/dataservices" - "github.com/portainer/portainer/api/internal/set" + "github.com/portainer/portainer/api/set" ) func updateEnvironmentEdgeGroups(tx dataservices.DataStoreTx, newEdgeGroups []portainer.EdgeGroupID, environmentID portainer.EndpointID) (bool, error) { diff --git a/api/http/handler/endpoints/utils_update_edge_groups_test.go b/api/http/handler/endpoints/utils_update_edge_groups_test.go index 8f34816a3..e89d501fb 100644 --- a/api/http/handler/endpoints/utils_update_edge_groups_test.go +++ b/api/http/handler/endpoints/utils_update_edge_groups_test.go @@ -10,7 +10,6 @@ import ( ) func Test_updateEdgeGroups(t *testing.T) { - createGroups := func(store *datastore.Store, names []string) ([]portainer.EdgeGroup, error) { groups := make([]portainer.EdgeGroup, len(names)) for index, name := range names { @@ -21,8 +20,7 @@ func Test_updateEdgeGroups(t *testing.T) { Endpoints: make([]portainer.EndpointID, 0), } - err := store.EdgeGroup().Create(group) - if err != nil { + if err := store.EdgeGroup().Create(group); err != nil { return nil, err } @@ -42,6 +40,7 @@ func Test_updateEdgeGroups(t *testing.T) { return } } + is.Fail("expected endpoint to be in group") } } @@ -52,6 +51,7 @@ func Test_updateEdgeGroups(t *testing.T) { for j, tag := range groups { if tag.Name == tagName { result[i] = groups[j] + break } } @@ -88,6 +88,7 @@ func Test_updateEdgeGroups(t *testing.T) { } expectedGroups := groupsByName(groups, testCase.groupsToApply) + expectedIDs := make([]portainer.EdgeGroupID, len(expectedGroups)) for i, tag := range expectedGroups { expectedIDs[i] = tag.ID diff --git a/api/http/handler/endpoints/utils_update_tags.go b/api/http/handler/endpoints/utils_update_tags.go index a792f32b9..b2d5b5d02 100644 --- a/api/http/handler/endpoints/utils_update_tags.go +++ b/api/http/handler/endpoints/utils_update_tags.go @@ -4,7 +4,7 @@ import ( "github.com/pkg/errors" portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/dataservices" - "github.com/portainer/portainer/api/internal/set" + "github.com/portainer/portainer/api/set" ) // updateEnvironmentTags updates the tags associated to an environment diff --git a/api/http/handler/helm/helm_delete_test.go b/api/http/handler/helm/helm_delete_test.go index cb2aca2d3..bfe4f66ae 100644 --- a/api/http/handler/helm/helm_delete_test.go +++ b/api/http/handler/helm/helm_delete_test.go @@ -10,14 +10,14 @@ import ( "github.com/portainer/portainer/api/datastore" "github.com/portainer/portainer/api/exec/exectest" "github.com/portainer/portainer/api/http/security" + "github.com/portainer/portainer/api/internal/testhelpers" + helper "github.com/portainer/portainer/api/internal/testhelpers" "github.com/portainer/portainer/api/jwt" "github.com/portainer/portainer/api/kubernetes" "github.com/portainer/portainer/pkg/libhelm/binary/test" "github.com/portainer/portainer/pkg/libhelm/options" - "github.com/stretchr/testify/assert" - "github.com/portainer/portainer/api/internal/testhelpers" - helper "github.com/portainer/portainer/api/internal/testhelpers" + "github.com/stretchr/testify/assert" ) func Test_helmDelete(t *testing.T) { diff --git a/api/http/handler/registries/handler.go b/api/http/handler/registries/handler.go index 089853afe..67274b0a9 100644 --- a/api/http/handler/registries/handler.go +++ b/api/http/handler/registries/handler.go @@ -97,13 +97,13 @@ func (handler *Handler) userHasRegistryAccess(r *http.Request) (hasAccess bool, if err != nil { return false, false, err } + endpoint, err := handler.DataStore.Endpoint().Endpoint(portainer.EndpointID(endpointID)) if err != nil { return false, false, err } - err = handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint) - if err != nil { + if err := handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint); err != nil { return false, false, err } diff --git a/api/http/handler/settings/settings_public.go b/api/http/handler/settings/settings_public.go index fd45748b2..61f27c025 100644 --- a/api/http/handler/settings/settings_public.go +++ b/api/http/handler/settings/settings_public.go @@ -71,6 +71,7 @@ func (handler *Handler) settingsPublic(w http.ResponseWriter, r *http.Request) * } publicSettings := generatePublicSettings(settings) + return response.JSON(w, publicSettings) } @@ -96,7 +97,7 @@ func generatePublicSettings(appSettings *portainer.Settings) *publicSettingsResp publicSettings.IsDockerDesktopExtension = appSettings.IsDockerDesktopExtension - //if OAuth authentication is on, compose the related fields from application settings + // If OAuth authentication is on, compose the related fields from application settings if publicSettings.AuthenticationMethod == portainer.AuthenticationOAuth { publicSettings.OAuthLogoutURI = appSettings.OAuthSettings.LogoutURI publicSettings.OAuthLoginURI = fmt.Sprintf("%s?response_type=code&client_id=%s&redirect_uri=%s&scope=%s", @@ -104,16 +105,18 @@ func generatePublicSettings(appSettings *portainer.Settings) *publicSettingsResp appSettings.OAuthSettings.ClientID, appSettings.OAuthSettings.RedirectURI, appSettings.OAuthSettings.Scopes) - //control prompt=login param according to the SSO setting + + // Control prompt=login param according to the SSO setting if !appSettings.OAuthSettings.SSO { publicSettings.OAuthLoginURI += "&prompt=login" } } - //if LDAP authentication is on, compose the related fields from application settings + // If LDAP authentication is on, compose the related fields from application settings if publicSettings.AuthenticationMethod == portainer.AuthenticationLDAP && appSettings.LDAPSettings.GroupSearchSettings != nil { if len(appSettings.LDAPSettings.GroupSearchSettings) > 0 { publicSettings.TeamSync = len(appSettings.LDAPSettings.GroupSearchSettings[0].GroupBaseDN) > 0 } } + return publicSettings } diff --git a/api/http/handler/settings/settings_public_test.go b/api/http/handler/settings/settings_public_test.go index 65d331747..166aebb07 100644 --- a/api/http/handler/settings/settings_public_test.go +++ b/api/http/handler/settings/settings_public_test.go @@ -40,14 +40,17 @@ func setup() { func TestGeneratePublicSettingsWithSSO(t *testing.T) { setup() + mockAppSettings.OAuthSettings.SSO = true publicSettings := generatePublicSettings(mockAppSettings) if publicSettings.AuthenticationMethod != portainer.AuthenticationOAuth { t.Errorf("wrong AuthenticationMethod, want: %d, got: %d", portainer.AuthenticationOAuth, publicSettings.AuthenticationMethod) } + if publicSettings.OAuthLoginURI != dummyOAuthLoginURI { t.Errorf("wrong OAuthLoginURI when SSO is switched on, want: %s, got: %s", dummyOAuthLoginURI, publicSettings.OAuthLoginURI) } + if publicSettings.OAuthLogoutURI != dummyOAuthLogoutURI { t.Errorf("wrong OAuthLogoutURI, want: %s, got: %s", dummyOAuthLogoutURI, publicSettings.OAuthLogoutURI) } @@ -55,15 +58,18 @@ func TestGeneratePublicSettingsWithSSO(t *testing.T) { func TestGeneratePublicSettingsWithoutSSO(t *testing.T) { setup() + mockAppSettings.OAuthSettings.SSO = false publicSettings := generatePublicSettings(mockAppSettings) if publicSettings.AuthenticationMethod != portainer.AuthenticationOAuth { t.Errorf("wrong AuthenticationMethod, want: %d, got: %d", portainer.AuthenticationOAuth, publicSettings.AuthenticationMethod) } + expectedOAuthLoginURI := dummyOAuthLoginURI + "&prompt=login" if publicSettings.OAuthLoginURI != expectedOAuthLoginURI { t.Errorf("wrong OAuthLoginURI when SSO is switched off, want: %s, got: %s", expectedOAuthLoginURI, publicSettings.OAuthLoginURI) } + if publicSettings.OAuthLogoutURI != dummyOAuthLogoutURI { t.Errorf("wrong OAuthLogoutURI, want: %s, got: %s", dummyOAuthLogoutURI, publicSettings.OAuthLogoutURI) } diff --git a/api/http/handler/stacks/stack_delete.go b/api/http/handler/stacks/stack_delete.go index b28237ff5..7af27e213 100644 --- a/api/http/handler/stacks/stack_delete.go +++ b/api/http/handler/stacks/stack_delete.go @@ -89,8 +89,7 @@ func (handler *Handler) stackDelete(w http.ResponseWriter, r *http.Request) *htt } if !isOrphaned { - err = handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint) - if err != nil { + if err := handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint); err != nil { return httperror.Forbidden("Permission denied to access endpoint", err) } @@ -119,25 +118,21 @@ func (handler *Handler) stackDelete(w http.ResponseWriter, r *http.Request) *htt deployments.StopAutoupdate(stack.ID, stack.AutoUpdate.JobID, handler.Scheduler) } - err = handler.deleteStack(securityContext.UserID, stack, endpoint) - if err != nil { + if err := handler.deleteStack(securityContext.UserID, stack, endpoint); err != nil { return httperror.InternalServerError(err.Error(), err) } - err = handler.DataStore.Stack().Delete(portainer.StackID(id)) - if err != nil { + if err := handler.DataStore.Stack().Delete(portainer.StackID(id)); err != nil { return httperror.InternalServerError("Unable to remove the stack from the database", err) } if resourceControl != nil { - err = handler.DataStore.ResourceControl().Delete(resourceControl.ID) - if err != nil { + if err := handler.DataStore.ResourceControl().Delete(resourceControl.ID); err != nil { return httperror.InternalServerError("Unable to remove the associated resource control from the database", err) } } - err = handler.FileService.RemoveDirectory(stack.ProjectPath) - if err != nil { + if err := handler.FileService.RemoveDirectory(stack.ProjectPath); err != nil { log.Warn().Err(err).Msg("Unable to remove stack files from disk") } @@ -169,8 +164,7 @@ func (handler *Handler) deleteExternalStack(r *http.Request, w http.ResponseWrit return httperror.InternalServerError("Unable to find the endpoint associated to the stack inside the database", err) } - err = handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint) - if err != nil { + if err := handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint); err != nil { return httperror.Forbidden("Permission denied to access endpoint", err) } @@ -179,8 +173,7 @@ func (handler *Handler) deleteExternalStack(r *http.Request, w http.ResponseWrit Type: portainer.DockerSwarmStack, } - err = handler.deleteStack(securityContext.UserID, stack, endpoint) - if err != nil { + if err := handler.deleteStack(securityContext.UserID, stack, endpoint); err != nil { return httperror.InternalServerError("Unable to delete stack", err) } @@ -255,6 +248,7 @@ func (handler *Handler) deleteStack(userID portainer.UserID, stack *portainer.St } } } + return errors.WithMessagef(err, "failed to remove kubernetes resources: %q", out) } @@ -369,18 +363,18 @@ func (handler *Handler) stackDeleteKubernetesByName(w http.ResponseWriter, r *ht if err != nil { log.Err(err).Msgf("Unable to delete Kubernetes stack `%d`", stack.ID) errors = append(errors, err) + continue } - err = handler.DataStore.Stack().Delete(stack.ID) - if err != nil { + if err := handler.DataStore.Stack().Delete(stack.ID); err != nil { errors = append(errors, err) log.Err(err).Msgf("Unable to remove the stack `%d` from the database", stack.ID) + continue } - err = handler.FileService.RemoveDirectory(stack.ProjectPath) - if err != nil { + if err := handler.FileService.RemoveDirectory(stack.ProjectPath); err != nil { errors = append(errors, err) log.Warn().Err(err).Msg("Unable to remove stack files from disk") } diff --git a/api/http/handler/tags/tag_delete_test.go b/api/http/handler/tags/tag_delete_test.go index 9f5e80e4c..3c55ac509 100644 --- a/api/http/handler/tags/tag_delete_test.go +++ b/api/http/handler/tags/tag_delete_test.go @@ -18,8 +18,7 @@ func TestTagDeleteEdgeGroupsConcurrently(t *testing.T) { _, store := datastore.MustNewTestStore(t, true, false) user := &portainer.User{ID: 2, Username: "admin", Role: portainer.AdministratorRole} - err := store.User().Create(user) - if err != nil { + if err := store.User().Create(user); err != nil { t.Fatal("could not create admin user:", err) } @@ -33,29 +32,28 @@ func TestTagDeleteEdgeGroupsConcurrently(t *testing.T) { for i := 0; i < tagsCount; i++ { tagID := portainer.TagID(i) + 1 - err = store.Tag().Create(&portainer.Tag{ + if err := store.Tag().Create(&portainer.Tag{ ID: tagID, Name: "tag-" + strconv.Itoa(int(tagID)), - }) - if err != nil { + }); err != nil { t.Fatal("could not create tag:", err) } tagIDs = append(tagIDs, tagID) } - err = store.EdgeGroup().Create(&portainer.EdgeGroup{ + if err := store.EdgeGroup().Create(&portainer.EdgeGroup{ ID: 1, Name: "edgegroup-1", TagIDs: tagIDs, - }) - if err != nil { + }); err != nil { t.Fatal("could not create edge group:", err) } // Remove the tags concurrently var wg sync.WaitGroup + wg.Add(len(tagIDs)) for _, tagID := range tagIDs { diff --git a/api/http/handler/users/user_create.go b/api/http/handler/users/user_create.go index d932b8c43..13da0f94c 100644 --- a/api/http/handler/users/user_create.go +++ b/api/http/handler/users/user_create.go @@ -27,6 +27,7 @@ func (payload *userCreatePayload) Validate(r *http.Request) error { if payload.Role != 1 && payload.Role != 2 { return errors.New("Invalid role value. Value must be one of: 1 (administrator) or 2 (regular user)") } + return nil } @@ -49,8 +50,7 @@ func (payload *userCreatePayload) Validate(r *http.Request) error { // @router /users [post] func (handler *Handler) userCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError { var payload userCreatePayload - err := request.DecodeAndValidateJSONPayload(r, &payload) - if err != nil { + if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil { return httperror.BadRequest("Invalid request payload", err) } @@ -89,11 +89,11 @@ func (handler *Handler) userCreate(w http.ResponseWriter, r *http.Request) *http } } - err = handler.DataStore.User().Create(user) - if err != nil { + if err := handler.DataStore.User().Create(user); err != nil { return httperror.InternalServerError("Unable to persist user inside the database", err) } hideFields(user) + return response.JSON(w, user) } diff --git a/api/http/handler/users/user_list_test.go b/api/http/handler/users/user_list_test.go index 2fde780bb..33bd5f04b 100644 --- a/api/http/handler/users/user_list_test.go +++ b/api/http/handler/users/user_list_test.go @@ -26,12 +26,12 @@ func Test_userList(t *testing.T) { _, store := datastore.MustNewTestStore(t, true, true) - // create admin and standard user(s) + // Create admin and standard user(s) adminUser := &portainer.User{ID: 1, Username: "admin", Role: portainer.AdministratorRole} err := store.User().Create(adminUser) is.NoError(err, "error creating admin user") - // setup services + // Setup services jwtService, err := jwt.NewService("1h", store) is.NoError(err, "Error initiating jwt service") apiKeyService := apikey.NewAPIKeyService(store.APIKeyRepository(), store.User()) @@ -42,7 +42,7 @@ func Test_userList(t *testing.T) { h := NewHandler(requestBouncer, rateLimiter, apiKeyService, passwordChecker) h.DataStore = store - // generate admin user tokens + // Generate admin user tokens adminJWT, _, _ := jwtService.GenerateToken(&portainer.TokenData{ID: adminUser.ID, Username: adminUser.Username, Role: adminUser.Role}) // Case 1: the user is given the endpoint access directly @@ -54,12 +54,12 @@ func Test_userList(t *testing.T) { err = store.User().Create(userWithoutEndpointAccess) is.NoError(err, "error creating user") - // create environment group + // Create environment group endpointGroup := &portainer.EndpointGroup{ID: 1, Name: "default-endpoint-group"} err = store.EndpointGroup().Create(endpointGroup) is.NoError(err, "error creating endpoint group") - // create endpoint and user access policies + // Create endpoint and user access policies userAccessPolicies := make(portainer.UserAccessPolicies, 0) userAccessPolicies[userWithEndpointAccess.ID] = portainer.AccessPolicy{RoleID: portainer.RoleID(userWithEndpointAccess.Role)} @@ -129,7 +129,7 @@ func Test_userList(t *testing.T) { err = store.User().Create(userUnderGroup) is.NoError(err, "error creating user") - // create environment group including a user + // Create environment group including a user userAccessPoliciesUnderGroup := make(portainer.UserAccessPolicies, 0) userAccessPoliciesUnderGroup[userUnderGroup.ID] = portainer.AccessPolicy{RoleID: portainer.RoleID(userUnderGroup.Role)} @@ -137,7 +137,7 @@ func Test_userList(t *testing.T) { err = store.EndpointGroup().Create(endpointGroupWithUser) is.NoError(err, "error creating endpoint group") - // create endpoint + // Create endpoint endpointUnderGroupWithUser := &portainer.Endpoint{ID: 2, GroupID: endpointGroupWithUser.ID} err = store.Endpoint().Create(endpointUnderGroupWithUser) is.NoError(err, "error creating endpoint") @@ -182,7 +182,7 @@ func Test_userList(t *testing.T) { err = store.TeamMembership().Create(teamMembership) is.NoError(err, "error creating team membership") - // create environment group including a team + // Create environment group including a team teamAccessPoliciesUnderGroup := make(portainer.TeamAccessPolicies, 0) teamAccessPoliciesUnderGroup[teamUnderGroup.ID] = portainer.AccessPolicy{RoleID: portainer.RoleID(userUnderTeam.Role)} @@ -190,7 +190,7 @@ func Test_userList(t *testing.T) { err = store.EndpointGroup().Create(endpointGroupWithTeam) is.NoError(err, "error creating endpoint group") - // create endpoint + // Create endpoint endpointUnderGroupWithTeam := &portainer.Endpoint{ID: 3, GroupID: endpointGroupWithTeam.ID} err = store.Endpoint().Create(endpointUnderGroupWithTeam) is.NoError(err, "error creating endpoint") @@ -233,12 +233,12 @@ func Test_userList(t *testing.T) { err = store.TeamMembership().Create(teamMembershipWithEndpointAccess) is.NoError(err, "error creating team membership") - // create environment group + // Create environment group endpointGroupWithoutTeam := &portainer.EndpointGroup{ID: 4, Name: "endpoint-group-without-team"} err = store.EndpointGroup().Create(endpointGroupWithoutTeam) is.NoError(err, "error creating endpoint group") - // create endpoint and team access policies + // Create endpoint and team access policies teamAccessPolicies := make(portainer.TeamAccessPolicies, 0) teamAccessPolicies[teamWithEndpointAccess.ID] = portainer.AccessPolicy{RoleID: portainer.RoleID(userUnderTeamWithEndpointAccess.Role)} diff --git a/api/http/handler/users/user_update_test.go b/api/http/handler/users/user_update_test.go index b0728a89c..17eb231c6 100644 --- a/api/http/handler/users/user_update_test.go +++ b/api/http/handler/users/user_update_test.go @@ -19,12 +19,12 @@ func Test_updateUserRemovesAccessTokens(t *testing.T) { _, store := datastore.MustNewTestStore(t, true, true) - // create standard user + // Create standard user user := &portainer.User{ID: 2, Username: "standard", Role: portainer.StandardUserRole} err := store.User().Create(user) is.NoError(err, "error creating user") - // setup services + // Setup services jwtService, err := jwt.NewService("1h", store) is.NoError(err, "Error initiating jwt service") apiKeyService := apikey.NewAPIKeyService(store.APIKeyRepository(), store.User()) diff --git a/api/http/handler/websocket/proxy.go b/api/http/handler/websocket/proxy.go index 3d6858ff3..be1b94b06 100644 --- a/api/http/handler/websocket/proxy.go +++ b/api/http/handler/websocket/proxy.go @@ -8,12 +8,12 @@ import ( "net/url" portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/crypto" "github.com/portainer/portainer/api/http/security" - "github.com/portainer/portainer/api/internal/logoutcontext" + "github.com/portainer/portainer/api/logoutcontext" "github.com/gorilla/websocket" "github.com/koding/websocketproxy" - "github.com/portainer/portainer/api/crypto" "github.com/rs/zerolog/log" ) diff --git a/api/http/proxy/factory/agent.go b/api/http/proxy/factory/agent.go index ded36b90a..bd08efd6c 100644 --- a/api/http/proxy/factory/agent.go +++ b/api/http/proxy/factory/agent.go @@ -9,7 +9,7 @@ import ( "github.com/portainer/portainer/api/crypto" "github.com/portainer/portainer/api/http/proxy/factory/agent" "github.com/portainer/portainer/api/internal/endpointutils" - "github.com/portainer/portainer/api/internal/url" + "github.com/portainer/portainer/api/url" "github.com/pkg/errors" "github.com/rs/zerolog/log" diff --git a/api/http/proxy/factory/azure/access_control.go b/api/http/proxy/factory/azure/access_control.go index bf6f237f8..b9c625a34 100644 --- a/api/http/proxy/factory/azure/access_control.go +++ b/api/http/proxy/factory/azure/access_control.go @@ -54,6 +54,7 @@ func decorateObject(object map[string]interface{}, resourceControl *portainer.Re portainerMetadata := object["Portainer"].(map[string]interface{}) portainerMetadata["ResourceControl"] = resourceControl + return object } @@ -64,8 +65,7 @@ func (transport *Transport) createPrivateResourceControl( resourceControl := authorization.NewPrivateResourceControl(resourceIdentifier, resourceType, userID) - err := transport.dataStore.ResourceControl().Create(resourceControl) - if err != nil { + if err := transport.dataStore.ResourceControl().Create(resourceControl); err != nil { log.Error(). Str("resource", resourceIdentifier). Err(err). @@ -84,6 +84,7 @@ func (transport *Transport) userCanDeleteContainerGroup(request *http.Request, c resourceIdentifier := request.URL.Path resourceControl := transport.findResourceControl(resourceIdentifier, context) + return authorization.UserCanAccessResource(context.userID, context.userTeamIDs, resourceControl) } @@ -136,20 +137,19 @@ func (transport *Transport) filterContainerGroups(containerGroups []interface{}, func (transport *Transport) removeResourceControl(containerGroup map[string]interface{}, context *azureRequestContext) error { containerGroupID, ok := containerGroup["id"].(string) - if ok { - resourceControl := transport.findResourceControl(containerGroupID, context) - if resourceControl != nil { - err := transport.dataStore.ResourceControl().Delete(resourceControl.ID) - return err - } - } else { + if !ok { log.Debug().Msg("missing ID in container group") + + return nil + } + + if resourceControl := transport.findResourceControl(containerGroupID, context); resourceControl != nil { + return transport.dataStore.ResourceControl().Delete(resourceControl.ID) } return nil } func (transport *Transport) findResourceControl(containerGroupId string, context *azureRequestContext) *portainer.ResourceControl { - resourceControl := authorization.GetResourceControlByResourceIDAndType(containerGroupId, portainer.ContainerGroupResourceControl, context.resourceControls) - return resourceControl + return authorization.GetResourceControlByResourceIDAndType(containerGroupId, portainer.ContainerGroupResourceControl, context.resourceControls) } diff --git a/api/http/proxy/factory/docker.go b/api/http/proxy/factory/docker.go index 0e040cac5..e558bebb9 100644 --- a/api/http/proxy/factory/docker.go +++ b/api/http/proxy/factory/docker.go @@ -8,7 +8,7 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/crypto" "github.com/portainer/portainer/api/http/proxy/factory/docker" - "github.com/portainer/portainer/api/internal/url" + "github.com/portainer/portainer/api/url" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/rs/zerolog/log" diff --git a/api/http/proxy/factory/docker/access_control.go b/api/http/proxy/factory/docker/access_control.go index 4df8f7412..6e13a6f45 100644 --- a/api/http/proxy/factory/docker/access_control.go +++ b/api/http/proxy/factory/docker/access_control.go @@ -105,8 +105,7 @@ func (transport *Transport) newResourceControlFromPortainerLabels(labelsObject m resourceControl := authorization.NewRestrictedResourceControl(resourceID, resourceType, userIDs, teamIDs) - err := transport.dataStore.ResourceControl().Create(resourceControl) - if err != nil { + if err := transport.dataStore.ResourceControl().Create(resourceControl); err != nil { return nil, err } @@ -119,8 +118,7 @@ func (transport *Transport) newResourceControlFromPortainerLabels(labelsObject m func (transport *Transport) createPrivateResourceControl(resourceIdentifier string, resourceType portainer.ResourceControlType, userID portainer.UserID) (*portainer.ResourceControl, error) { resourceControl := authorization.NewPrivateResourceControl(resourceIdentifier, resourceType, userID) - err := transport.dataStore.ResourceControl().Create(resourceControl) - if err != nil { + if err := transport.dataStore.ResourceControl().Create(resourceControl); err != nil { log.Error(). Str("resource", resourceIdentifier). Err(err). @@ -170,6 +168,7 @@ func (transport *Transport) applyAccessControlOnResource(parameters *resourceOpe systemResourceControl := findSystemNetworkResourceControl(responseObject) if systemResourceControl != nil { responseObject = decorateObject(responseObject, systemResourceControl) + return utils.RewriteResponse(response, responseObject, http.StatusOK) } } @@ -188,6 +187,7 @@ func (transport *Transport) applyAccessControlOnResource(parameters *resourceOpe if executor.operationContext.isAdmin || (resourceControl != nil && authorization.UserCanAccessResource(executor.operationContext.userID, executor.operationContext.userTeamIDs, resourceControl)) { responseObject = decorateObject(responseObject, resourceControl) + return utils.RewriteResponse(response, responseObject, http.StatusOK) } @@ -221,6 +221,7 @@ func (transport *Transport) decorateResourceList(parameters *resourceOperationPa if systemResourceControl != nil { resourceObject = decorateObject(resourceObject, systemResourceControl) decoratedResourceData = append(decoratedResourceData, resourceObject) + continue } } @@ -264,6 +265,7 @@ func (transport *Transport) filterResourceList(parameters *resourceOperationPara if systemResourceControl != nil { resourceObject = decorateObject(resourceObject, systemResourceControl) filteredResourceData = append(filteredResourceData, resourceObject) + continue } } @@ -277,6 +279,7 @@ func (transport *Transport) filterResourceList(parameters *resourceOperationPara if context.isAdmin { filteredResourceData = append(filteredResourceData, resourceObject) } + continue } @@ -334,11 +337,13 @@ func (transport *Transport) findResourceControl(resourceIdentifier string, resou func getStackResourceIDFromLabels(resourceLabelsObject map[string]string, endpointID portainer.EndpointID) string { if resourceLabelsObject[resourceLabelForDockerSwarmStackName] != "" { stackName := resourceLabelsObject[resourceLabelForDockerSwarmStackName] + return stackutils.ResourceControlID(endpointID, stackName) } if resourceLabelsObject[resourceLabelForDockerComposeStackName] != "" { stackName := resourceLabelsObject[resourceLabelForDockerComposeStackName] + return stackutils.ResourceControlID(endpointID, stackName) } @@ -352,5 +357,6 @@ func decorateObject(object map[string]interface{}, resourceControl *portainer.Re portainerMetadata := object["Portainer"].(map[string]interface{}) portainerMetadata["ResourceControl"] = resourceControl + return object } diff --git a/api/http/proxy/factory/docker/configs.go b/api/http/proxy/factory/docker/configs.go index 4820b74c6..4b9a2fe3a 100644 --- a/api/http/proxy/factory/docker/configs.go +++ b/api/http/proxy/factory/docker/configs.go @@ -11,9 +11,7 @@ import ( "github.com/portainer/portainer/api/internal/authorization" ) -const ( - configObjectIdentifier = "ID" -) +const configObjectIdentifier = "ID" func getInheritedResourceControlFromConfigLabels(dockerClient *client.Client, endpointID portainer.EndpointID, configID string, resourceControls []portainer.ResourceControl) (*portainer.ResourceControl, error) { config, _, err := dockerClient.ConfigInspectWithRaw(context.Background(), configID) @@ -78,10 +76,9 @@ func (transport *Transport) configInspectOperation(response *http.Response, exec // https://docs.docker.com/engine/api/v1.37/#operation/ConfigList // https://docs.docker.com/engine/api/v1.37/#operation/ConfigInspect func selectorConfigLabels(responseObject map[string]interface{}) map[string]interface{} { - secretSpec := utils.GetJSONObject(responseObject, "Spec") - if secretSpec != nil { - secretLabelsObject := utils.GetJSONObject(secretSpec, "Labels") - return secretLabelsObject + if secretSpec := utils.GetJSONObject(responseObject, "Spec"); secretSpec != nil { + return utils.GetJSONObject(secretSpec, "Labels") } + return nil } diff --git a/api/http/proxy/factory/docker/tasks.go b/api/http/proxy/factory/docker/tasks.go index f91c1a81c..d67774a69 100644 --- a/api/http/proxy/factory/docker/tasks.go +++ b/api/http/proxy/factory/docker/tasks.go @@ -7,9 +7,7 @@ import ( "github.com/portainer/portainer/api/http/proxy/factory/utils" ) -const ( - taskServiceObjectIdentifier = "ServiceID" -) +const taskServiceObjectIdentifier = "ServiceID" // taskListOperation extracts the response as a JSON array, loop through the tasks array // and filter the containers based on resource controls before rewriting the response. @@ -46,5 +44,6 @@ func selectorTaskLabels(responseObject map[string]interface{}) map[string]interf return utils.GetJSONObject(containerSpecObject, "Labels") } } + return nil } diff --git a/api/http/proxy/factory/docker/volumes.go b/api/http/proxy/factory/docker/volumes.go index 5b60ea63a..a77abf2c0 100644 --- a/api/http/proxy/factory/docker/volumes.go +++ b/api/http/proxy/factory/docker/volumes.go @@ -7,19 +7,17 @@ import ( "net/http" "path" - "github.com/docker/docker/client" - "github.com/rs/zerolog/log" - portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/http/proxy/factory/utils" "github.com/portainer/portainer/api/http/security" "github.com/portainer/portainer/api/internal/authorization" "github.com/portainer/portainer/api/internal/snapshot" + + "github.com/docker/docker/client" + "github.com/rs/zerolog/log" ) -const ( - volumeObjectIdentifier = "ResourceID" -) +const volumeObjectIdentifier = "ResourceID" func getInheritedResourceControlFromVolumeLabels(dockerClient *client.Client, endpointID portainer.EndpointID, volumeID string, resourceControls []portainer.ResourceControl) (*portainer.ResourceControl, error) { volume, err := dockerClient.VolumeInspect(context.Background(), volumeID) @@ -57,14 +55,13 @@ func (transport *Transport) volumeListOperation(response *http.Response, executo Msg("snapshot is not filled into the endpoint.") } } + for _, volumeObject := range volumeData { volume := volumeObject.(map[string]interface{}) - err = transport.decorateVolumeResponseWithResourceID(volume) - if err != nil { + if err := transport.decorateVolumeResponseWithResourceID(volume); err != nil { return fmt.Errorf("failed decorating volume response: %w", err) } - } resourceOperationParameters := &resourceOperationParameters{ @@ -77,6 +74,7 @@ func (transport *Transport) volumeListOperation(response *http.Response, executo if err != nil { return err } + // Overwrite the original volume list responseObject["Volumes"] = volumeData } @@ -94,8 +92,7 @@ func (transport *Transport) volumeInspectOperation(response *http.Response, exec return err } - err = transport.decorateVolumeResponseWithResourceID(responseObject) - if err != nil { + if err := transport.decorateVolumeResponseWithResourceID(responseObject); err != nil { return fmt.Errorf("failed decorating volume response: %w", err) } @@ -148,8 +145,7 @@ func (transport *Transport) decorateVolumeResourceCreationOperation(request *htt } defer cli.Close() - _, err = cli.VolumeInspect(context.Background(), volumeID) - if err == nil { + if _, err = cli.VolumeInspect(context.Background(), volumeID); err == nil { return &http.Response{ StatusCode: http.StatusConflict, }, errors.New("a volume with the same name already exists") @@ -164,6 +160,7 @@ func (transport *Transport) decorateVolumeResourceCreationOperation(request *htt if response.StatusCode == http.StatusCreated { err = transport.decorateVolumeCreationResponse(response, resourceType, tokenData.ID) } + return response, err } @@ -195,7 +192,6 @@ func (transport *Transport) decorateVolumeCreationResponse(response *http.Respon } func (transport *Transport) restrictedVolumeOperation(requestPath string, request *http.Request) (*http.Response, error) { - if request.Method == http.MethodGet { return transport.rewriteOperation(request, transport.volumeInspectOperation) } @@ -210,6 +206,7 @@ func (transport *Transport) restrictedVolumeOperation(requestPath string, reques if request.Method == http.MethodDelete { return transport.executeGenericResourceDeletionOperation(request, resourceID, volumeName, portainer.VolumeResourceControl) } + return transport.restrictedResourceOperation(request, resourceID, volumeName, portainer.VolumeResourceControl, false) } @@ -218,6 +215,7 @@ func (transport *Transport) getVolumeResourceID(volumeName string) (string, erro if err != nil { return "", fmt.Errorf("failed fetching docker id: %w", err) } + return fmt.Sprintf("%s_%s", volumeName, dockerID), nil } diff --git a/api/internal/edge/edgegroup.go b/api/internal/edge/edgegroup.go index edbe6eaa9..519e4fd7a 100644 --- a/api/internal/edge/edgegroup.go +++ b/api/internal/edge/edgegroup.go @@ -4,7 +4,7 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/internal/endpointutils" - "github.com/portainer/portainer/api/internal/tag" + "github.com/portainer/portainer/api/tag" ) // EdgeGroupRelatedEndpoints returns a list of environments(endpoints) related to this Edge group diff --git a/api/internal/edge/edgestacks/service.go b/api/internal/edge/edgestacks/service.go index a9e8d1a6b..a000fdd79 100644 --- a/api/internal/edge/edgestacks/service.go +++ b/api/internal/edge/edgestacks/service.go @@ -37,12 +37,12 @@ func (service *Service) BuildEdgeStack( registries []portainer.RegistryID, useManifestNamespaces bool, ) (*portainer.EdgeStack, error) { - err := validateUniqueName(tx.EdgeStack().EdgeStacks, name) - if err != nil { + if err := validateUniqueName(tx.EdgeStack().EdgeStacks, name); err != nil { return nil, err } stackID := tx.EdgeStack().GetNextIdentifier() + return &portainer.EdgeStack{ ID: portainer.EdgeStackID(stackID), Name: name, @@ -77,7 +77,6 @@ func (service *Service) PersistEdgeStack( storeManifest edgetypes.StoreManifestFunc) (*portainer.EdgeStack, error) { relationConfig, err := edge.FetchEndpointRelationsConfig(tx) - if err != nil { return nil, fmt.Errorf("unable to find environment relations in database: %w", err) } @@ -87,6 +86,7 @@ func (service *Service) PersistEdgeStack( if errors.Is(err, edge.ErrEdgeGroupNotFound) { return nil, httperrors.NewInvalidPayloadError(err.Error()) } + return nil, fmt.Errorf("unable to persist environment relation in database: %w", err) } @@ -101,13 +101,11 @@ func (service *Service) PersistEdgeStack( stack.EntryPoint = composePath stack.NumDeployments = len(relatedEndpointIds) - err = service.updateEndpointRelations(tx, stack.ID, relatedEndpointIds) - if err != nil { + if err := service.updateEndpointRelations(tx, stack.ID, relatedEndpointIds); err != nil { return nil, fmt.Errorf("unable to update endpoint relations: %w", err) } - err = tx.EdgeStack().Create(stack.ID, stack) - if err != nil { + if err := tx.EdgeStack().Create(stack.ID, stack); err != nil { return nil, err } @@ -126,8 +124,7 @@ func (service *Service) updateEndpointRelations(tx dataservices.DataStoreTx, edg relation.EdgeStacks[edgeStackID] = true - err = endpointRelationService.UpdateEndpointRelation(endpointID, relation) - if err != nil { + if err := endpointRelationService.UpdateEndpointRelation(endpointID, relation); err != nil { return fmt.Errorf("unable to persist endpoint relation in database: %w", err) } } @@ -155,14 +152,12 @@ func (service *Service) DeleteEdgeStack(tx dataservices.DataStoreTx, edgeStackID delete(relation.EdgeStacks, edgeStackID) - err = tx.EndpointRelation().UpdateEndpointRelation(endpointID, relation) - if err != nil { + if err := tx.EndpointRelation().UpdateEndpointRelation(endpointID, relation); err != nil { return errors.WithMessage(err, "Unable to persist environment relation in database") } } - err = tx.EdgeStack().DeleteEdgeStack(edgeStackID) - if err != nil { + if err := tx.EdgeStack().DeleteEdgeStack(edgeStackID); err != nil { return errors.WithMessage(err, "Unable to remove the edge stack from the database") } diff --git a/api/internal/endpointutils/endpoint_test.go b/api/internal/endpointutils/endpoint_test.go index 208d79dfd..93e5db3f4 100644 --- a/api/internal/endpointutils/endpoint_test.go +++ b/api/internal/endpointutils/endpoint_test.go @@ -4,6 +4,7 @@ import ( "testing" portainer "github.com/portainer/portainer/api" + "github.com/stretchr/testify/assert" ) diff --git a/api/internal/nodes/nodes.go b/api/internal/nodes/nodes.go index a95895f98..9e8168cc0 100644 --- a/api/internal/nodes/nodes.go +++ b/api/internal/nodes/nodes.go @@ -8,6 +8,7 @@ import ( // NodesCount returns the total node number of all environments func NodesCount(endpoints []portainer.Endpoint) int { nodes := 0 + for _, env := range endpoints { if !endpointutils.IsEdgeEndpoint(&env) || env.UserTrusted { nodes += countNodes(&env) @@ -28,11 +29,3 @@ func countNodes(endpoint *portainer.Endpoint) int { return 1 } - -func max(a, b int) int { - if a > b { - return a - } - - return b -} diff --git a/api/internal/securecookie/securecookie.go b/api/internal/securecookie/securecookie.go deleted file mode 100644 index 263171504..000000000 --- a/api/internal/securecookie/securecookie.go +++ /dev/null @@ -1,16 +0,0 @@ -package securecookie - -import ( - "crypto/rand" - "io" -) - -// GenerateRandomKey generates a random key of specified length -// source: https://github.com/gorilla/securecookie/blob/master/securecookie.go#L515 -func GenerateRandomKey(length int) []byte { - k := make([]byte, length) - if _, err := io.ReadFull(rand.Reader, k); err != nil { - return nil - } - return k -} diff --git a/api/internal/slices/slices.go b/api/internal/slices/slices.go deleted file mode 100644 index d6022afca..000000000 --- a/api/internal/slices/slices.go +++ /dev/null @@ -1,23 +0,0 @@ -package slices - -// Map applies the given function to each element of the slice and returns a new slice with the results -func Map[T, U any](s []T, f func(T) U) []U { - result := make([]U, len(s)) - for i, v := range s { - result[i] = f(v) - } - return result -} - -// Filter returns a new slice containing only the elements of the slice for which the given predicate returns true -func Filter[T any](s []T, predicate func(T) bool) []T { - n := 0 - for _, v := range s { - if predicate(v) { - s[n] = v - n++ - } - } - - return s[:n] -} diff --git a/api/internal/unique/unique.go b/api/internal/unique/unique.go deleted file mode 100644 index 54c23af01..000000000 --- a/api/internal/unique/unique.go +++ /dev/null @@ -1,41 +0,0 @@ -package unique - -func Unique[T comparable](items []T) []T { - return UniqueBy(items, func(item T) T { - return item - }) -} - -func UniqueBy[ItemType any, ComparableType comparable](items []ItemType, accessorFunc func(ItemType) ComparableType) []ItemType { - includedItems := make(map[ComparableType]bool) - result := []ItemType{} - - for _, item := range items { - if _, isIncluded := includedItems[accessorFunc(item)]; !isIncluded { - includedItems[accessorFunc(item)] = true - result = append(result, item) - } - } - - return result -} - -/** - -type someType struct { - id int - fn func() -} - -func Test() { - ids := []int{1, 2, 3, 3} - _ = UniqueBy(ids, func(id int) int { return id }) - _ = Unique(ids) // shorthand for UniqueBy Identity/self - - as := []someType{{id: 1}, {id: 2}, {id: 3}, {id: 3}} - _ = UniqueBy(as, func(item someType) int { return item.id }) // no error - _ = UniqueBy(as, func(item someType) someType { return item }) // compile error - someType is not comparable - _ = Unique(as) // compile error - shorthand fails for the same reason -} - -*/ diff --git a/api/jwt/jwt.go b/api/jwt/jwt.go index 1ac607a40..85adb5cca 100644 --- a/api/jwt/jwt.go +++ b/api/jwt/jwt.go @@ -6,10 +6,10 @@ import ( "time" portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/apikey" "github.com/portainer/portainer/api/dataservices" "github.com/golang-jwt/jwt/v4" - "github.com/portainer/portainer/api/internal/securecookie" "github.com/rs/zerolog/log" ) @@ -51,7 +51,7 @@ func NewService(userSessionDuration string, dataStore dataservices.DataStore) (* return nil, err } - secret := securecookie.GenerateRandomKey(32) + secret := apikey.GenerateRandomKey(32) if secret == nil { return nil, errSecretGeneration } @@ -69,6 +69,7 @@ func NewService(userSessionDuration string, dataStore dataservices.DataStore) (* userSessionTimeout, dataStore, } + return service, nil } @@ -80,16 +81,18 @@ func getOrCreateKubeSecret(dataStore dataservices.DataStore) ([]byte, error) { kubeSecret := settings.OAuthSettings.KubeSecretKey if kubeSecret == nil { - kubeSecret = securecookie.GenerateRandomKey(32) + kubeSecret = apikey.GenerateRandomKey(32) if kubeSecret == nil { return nil, errSecretGeneration } + settings.OAuthSettings.KubeSecretKey = kubeSecret - err = dataStore.Settings().UpdateSettings(settings) - if err != nil { + + if err := dataStore.Settings().UpdateSettings(settings); err != nil { return nil, err } } + return kubeSecret, nil } diff --git a/api/kubernetes/cli/dashboard.go b/api/kubernetes/cli/dashboard.go index c96910365..32cf71c9d 100644 --- a/api/kubernetes/cli/dashboard.go +++ b/api/kubernetes/cli/dashboard.go @@ -3,8 +3,8 @@ package cli import ( "context" + "github.com/portainer/portainer/api/concurrent" models "github.com/portainer/portainer/api/http/models/kubernetes" - "github.com/portainer/portainer/api/internal/concurrent" "k8s.io/apimachinery/pkg/api/errors" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" diff --git a/api/kubernetes/cli/service.go b/api/kubernetes/cli/service.go index c75c631b0..f7c94e94a 100644 --- a/api/kubernetes/cli/service.go +++ b/api/kubernetes/cli/service.go @@ -4,6 +4,7 @@ import ( "context" models "github.com/portainer/portainer/api/http/models/kubernetes" + v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" labels "k8s.io/apimachinery/pkg/labels" @@ -67,9 +68,7 @@ func (kcl *KubeClient) GetServices(namespace string, lookupApplications bool) ([ return result, nil } -// CreateService creates a new service in a given namespace in a k8s endpoint. -func (kcl *KubeClient) CreateService(namespace string, info models.K8sServiceInfo) error { - ServiceClient := kcl.cli.CoreV1().Services(namespace) +func (kcl *KubeClient) fillService(info models.K8sServiceInfo) v1.Service { var service v1.Service service.Name = info.Name @@ -93,16 +92,21 @@ func (kcl *KubeClient) CreateService(namespace string, info models.K8sServiceInf // Set ingresses. for _, i := range info.IngressStatus { - var ing v1.LoadBalancerIngress - ing.IP = i.IP - ing.Hostname = i.Host service.Status.LoadBalancer.Ingress = append( service.Status.LoadBalancer.Ingress, - ing, + v1.LoadBalancerIngress{IP: i.IP, Hostname: i.Host}, ) } - _, err := ServiceClient.Create(context.Background(), &service, metav1.CreateOptions{}) + return service +} + +// CreateService creates a new service in a given namespace in a k8s endpoint. +func (kcl *KubeClient) CreateService(namespace string, info models.K8sServiceInfo) error { + serviceClient := kcl.cli.CoreV1().Services(namespace) + service := kcl.fillService(info) + + _, err := serviceClient.Create(context.Background(), &service, metav1.CreateOptions{}) return err } @@ -120,45 +124,16 @@ func (kcl *KubeClient) DeleteServices(reqs models.K8sServiceDeleteRequests) erro ) } } + return err } // UpdateService updates service in a given namespace in a k8s endpoint. func (kcl *KubeClient) UpdateService(namespace string, info models.K8sServiceInfo) error { - ServiceClient := kcl.cli.CoreV1().Services(namespace) - var service v1.Service + serviceClient := kcl.cli.CoreV1().Services(namespace) + service := kcl.fillService(info) - service.Name = info.Name - service.Spec.Type = v1.ServiceType(info.Type) - service.Namespace = info.Namespace - service.Annotations = info.Annotations - service.Labels = info.Labels - service.Spec.AllocateLoadBalancerNodePorts = info.AllocateLoadBalancerNodePorts - service.Spec.Selector = info.Selector - - // Set ports. - for _, p := range info.Ports { - var port v1.ServicePort - port.Name = p.Name - port.NodePort = int32(p.NodePort) - port.Port = int32(p.Port) - port.Protocol = v1.Protocol(p.Protocol) - port.TargetPort = intstr.FromString(p.TargetPort) - service.Spec.Ports = append(service.Spec.Ports, port) - } - - // Set ingresses. - for _, i := range info.IngressStatus { - var ing v1.LoadBalancerIngress - ing.IP = i.IP - ing.Hostname = i.Host - service.Status.LoadBalancer.Ingress = append( - service.Status.LoadBalancer.Ingress, - ing, - ) - } - - _, err := ServiceClient.Update(context.Background(), &service, metav1.UpdateOptions{}) + _, err := serviceClient.Update(context.Background(), &service, metav1.UpdateOptions{}) return err } @@ -210,5 +185,4 @@ func makeApplication(meta metav1.Object) []models.K8sApplication { Name: ownerReference.Name, }, } - } diff --git a/api/kubernetes/contants.go b/api/kubernetes/constants.go similarity index 100% rename from api/kubernetes/contants.go rename to api/kubernetes/constants.go diff --git a/api/internal/logoutcontext/logout_context.go b/api/logoutcontext/logout_context.go similarity index 100% rename from api/internal/logoutcontext/logout_context.go rename to api/logoutcontext/logout_context.go diff --git a/api/internal/logoutcontext/service.go b/api/logoutcontext/service.go similarity index 100% rename from api/internal/logoutcontext/service.go rename to api/logoutcontext/service.go diff --git a/api/internal/logoutcontext/service_factory.go b/api/logoutcontext/service_factory.go similarity index 100% rename from api/internal/logoutcontext/service_factory.go rename to api/logoutcontext/service_factory.go diff --git a/api/oauth/oauth_test.go b/api/oauth/oauth_test.go index 0fb10e587..92782b96b 100644 --- a/api/oauth/oauth_test.go +++ b/api/oauth/oauth_test.go @@ -16,8 +16,7 @@ func Test_getOAuthToken(t *testing.T) { t.Run("getOAuthToken fails upon invalid code", func(t *testing.T) { code := "" - _, err := getOAuthToken(code, config) - if err == nil { + if _, err := getOAuthToken(code, config); err == nil { t.Errorf("getOAuthToken should fail upon providing invalid code; code=%v", code) } }) @@ -91,22 +90,19 @@ func Test_getResource(t *testing.T) { defer srv.Close() t.Run("should fail upon missing Authorization Bearer header", func(t *testing.T) { - _, err := getResource("", config) - if err == nil { + if _, err := getResource("", config); err == nil { t.Errorf("getResource should fail if access token is not provided in auth bearer header") } }) t.Run("should fail upon providing incorrect Authorization Bearer header", func(t *testing.T) { - _, err := getResource("incorrect-token", config) - if err == nil { + if _, err := getResource("incorrect-token", config); err == nil { t.Errorf("getResource should fail if incorrect access token provided in auth bearer header") } }) t.Run("should succeed upon providing correct Authorization Bearer header", func(t *testing.T) { - _, err := getResource(oauthtest.AccessToken, config) - if err != nil { + if _, err := getResource(oauthtest.AccessToken, config); err != nil { t.Errorf("getResource should succeed if correct access token provided in auth bearer header") } }) @@ -120,8 +116,7 @@ func Test_Authenticate(t *testing.T) { srv, config := oauthtest.RunOAuthServer(code, &portainer.OAuthSettings{}) defer srv.Close() - _, err := authService.Authenticate(code, config) - if err == nil { + if _, err := authService.Authenticate(code, config); err == nil { t.Error("Authenticate should fail to extract username from resource if incorrect UserIdentifier provided") } }) diff --git a/api/portainer.go b/api/portainer.go index e0694a6c2..0f5aba27a 100644 --- a/api/portainer.go +++ b/api/portainer.go @@ -12,6 +12,7 @@ import ( gittypes "github.com/portainer/portainer/api/git/types" models "github.com/portainer/portainer/api/http/models/kubernetes" "github.com/portainer/portainer/pkg/featureflags" + "golang.org/x/oauth2" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/version" @@ -322,14 +323,14 @@ type ( Name string `json:"Name"` Status map[EndpointID]EdgeStackStatus `json:"Status"` // StatusArray map[EndpointID][]EdgeStackStatus `json:"StatusArray"` - CreationDate int64 `json:"CreationDate"` - EdgeGroups []EdgeGroupID `json:"EdgeGroups"` - ProjectPath string `json:"ProjectPath"` - EntryPoint string `json:"EntryPoint"` - Version int `json:"Version"` - NumDeployments int `json:"NumDeployments"` - ManifestPath string - DeploymentType EdgeStackDeploymentType + CreationDate int64 `json:"CreationDate"` + EdgeGroups []EdgeGroupID `json:"EdgeGroups"` + ProjectPath string `json:"ProjectPath"` + EntryPoint string `json:"EntryPoint"` + Version int `json:"Version"` + NumDeployments int `json:"NumDeployments"` + ManifestPath string `json:"ManifestPath"` + DeploymentType EdgeStackDeploymentType `json:"DeploymentType"` // Uses the manifest's namespaces instead of the default one UseManifestNamespaces bool @@ -554,23 +555,22 @@ type ( // Extension represents a deprecated Portainer extension Extension struct { - // Extension Identifier - ID ExtensionID `json:"Id" example:"1"` - Enabled bool `json:"Enabled"` - Name string `json:"Name,omitempty"` - ShortDescription string `json:"ShortDescription,omitempty"` - Description string `json:"Description,omitempty"` - DescriptionURL string `json:"DescriptionURL,omitempty"` - Price string `json:"Price,omitempty"` - PriceDescription string `json:"PriceDescription,omitempty"` - Deal bool `json:"Deal,omitempty"` - Available bool `json:"Available,omitempty"` - License LicenseInformation `json:"License,omitempty"` - Version string `json:"Version"` - UpdateAvailable bool `json:"UpdateAvailable"` - ShopURL string `json:"ShopURL,omitempty"` - Images []string `json:"Images,omitempty"` - Logo string `json:"Logo,omitempty"` + ID ExtensionID `json:"Id" example:"1"` + Enabled bool `json:"Enabled"` + Name string `json:"Name,omitempty"` + ShortDescription string `json:"ShortDescription,omitempty"` + Description string `json:"Description,omitempty"` + DescriptionURL string `json:"DescriptionURL,omitempty"` + Price string `json:"Price,omitempty"` + PriceDescription string `json:"PriceDescription,omitempty"` + Deal bool `json:"Deal,omitempty"` + Available bool `json:"Available,omitempty"` + License ExtensionLicenseInformation `json:"License,omitempty"` + Version string `json:"Version"` + UpdateAvailable bool `json:"UpdateAvailable"` + ShopURL string `json:"ShopURL,omitempty"` + Images []string `json:"Images,omitempty"` + Logo string `json:"Logo,omitempty"` } // ExtensionID represents a extension identifier @@ -737,8 +737,8 @@ type ( Groups []string } - // LicenseInformation represents information about an extension license - LicenseInformation struct { + // ExtensionLicenseInformation represents information about an extension license + ExtensionLicenseInformation struct { LicenseKey string `json:"LicenseKey,omitempty"` Company string `json:"Company,omitempty"` Expiration string `json:"Expiration,omitempty"` @@ -939,6 +939,18 @@ type ( HideStacksFunctionality bool `json:"hideStacksFunctionality" example:"false"` } + Edge struct { + // The command list interval for edge agent - used in edge async mode (in seconds) + CommandInterval int `json:"CommandInterval" example:"5"` + // The ping interval for edge agent - used in edge async mode (in seconds) + PingInterval int `json:"PingInterval" example:"5"` + // The snapshot interval for edge agent - used in edge async mode (in seconds) + SnapshotInterval int `json:"SnapshotInterval" example:"5"` + + // Deprecated 2.18 + AsyncMode bool `json:"AsyncMode,omitempty" example:"false"` + } + // Settings represents the application settings Settings struct { // URL to a logo that will be displayed on the login page as well as on top of the sidebar. Will use default Portainer logo when value is empty string @@ -984,17 +996,7 @@ type ( // EdgePortainerURL is the URL that is exposed to edge agents EdgePortainerURL string `json:"EdgePortainerUrl"` - Edge struct { - // The command list interval for edge agent - used in edge async mode (in seconds) - CommandInterval int `json:"CommandInterval" example:"5"` - // The ping interval for edge agent - used in edge async mode (in seconds) - PingInterval int `json:"PingInterval" example:"5"` - // The snapshot interval for edge agent - used in edge async mode (in seconds) - SnapshotInterval int `json:"SnapshotInterval" example:"5"` - - // Deprecated 2.18 - AsyncMode bool - } + Edge Edge `json:"Edge"` // Deprecated fields DisplayDonationHeader bool `json:"DisplayDonationHeader,omitempty"` diff --git a/api/internal/set/set.go b/api/set/set.go similarity index 95% rename from api/internal/set/set.go rename to api/set/set.go index 4ef31a0d7..13768428d 100644 --- a/api/internal/set/set.go +++ b/api/set/set.go @@ -58,8 +58,8 @@ func (s Set[T]) Copy() Set[T] { } // Difference returns a new set containing the keys that are in the first set but not in the second set. -func (set Set[T]) Difference(second Set[T]) Set[T] { - difference := set.Copy() +func (s Set[T]) Difference(second Set[T]) Set[T] { + difference := s.Copy() for key := range second { difference.Remove(key) diff --git a/api/slicesx/slices.go b/api/slicesx/slices.go new file mode 100644 index 000000000..b7e0aa0ef --- /dev/null +++ b/api/slicesx/slices.go @@ -0,0 +1,43 @@ +package slicesx + +// Map applies the given function to each element of the slice and returns a new slice with the results +func Map[T, U any](s []T, f func(T) U) []U { + result := make([]U, len(s)) + for i, v := range s { + result[i] = f(v) + } + return result +} + +// Filter returns a new slice containing only the elements of the slice for which the given predicate returns true +func Filter[T any](s []T, predicate func(T) bool) []T { + n := 0 + for _, v := range s { + if predicate(v) { + s[n] = v + n++ + } + } + + return s[:n] +} + +func Unique[T comparable](items []T) []T { + return UniqueBy(items, func(item T) T { + return item + }) +} + +func UniqueBy[ItemType any, ComparableType comparable](items []ItemType, accessorFunc func(ItemType) ComparableType) []ItemType { + includedItems := make(map[ComparableType]bool) + result := []ItemType{} + + for _, item := range items { + if _, isIncluded := includedItems[accessorFunc(item)]; !isIncluded { + includedItems[accessorFunc(item)] = true + result = append(result, item) + } + } + + return result +} diff --git a/api/internal/slices/slices_test.go b/api/slicesx/slices_test.go similarity index 99% rename from api/internal/slices/slices_test.go rename to api/slicesx/slices_test.go index 887a5be45..d75f9b559 100644 --- a/api/internal/slices/slices_test.go +++ b/api/slicesx/slices_test.go @@ -1,4 +1,4 @@ -package slices +package slicesx import ( "strconv" diff --git a/api/internal/tag/tag.go b/api/tag/tag.go similarity index 100% rename from api/internal/tag/tag.go rename to api/tag/tag.go diff --git a/api/internal/tag/tag_match.go b/api/tag/tag_match.go similarity index 100% rename from api/internal/tag/tag_match.go rename to api/tag/tag_match.go diff --git a/api/internal/tag/tag_match_test.go b/api/tag/tag_match_test.go similarity index 100% rename from api/internal/tag/tag_match_test.go rename to api/tag/tag_match_test.go diff --git a/api/internal/tag/tag_test.go b/api/tag/tag_test.go similarity index 100% rename from api/internal/tag/tag_test.go rename to api/tag/tag_test.go diff --git a/api/internal/url/url.go b/api/url/url.go similarity index 100% rename from api/internal/url/url.go rename to api/url/url.go