1
0
Fork 0
mirror of https://github.com/portainer/portainer.git synced 2025-07-18 21:09:40 +02:00

chore(code): reduce the code duplication EE-7278 (#11969)

This commit is contained in:
andres-portainer 2024-06-26 18:14:22 -03:00 committed by GitHub
parent 39bdfa4512
commit 9ee092aa5e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
85 changed files with 520 additions and 618 deletions

View file

@ -10,7 +10,7 @@ import (
"time" "time"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/internal/url" "github.com/portainer/portainer/api/url"
) )
// GetAgentVersionAndPlatform returns the agent version and platform // GetAgentVersionAndPlatform returns the agent version and platform

View file

@ -3,7 +3,6 @@ package apikey
import ( import (
"testing" "testing"
"github.com/portainer/portainer/api/internal/securecookie"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -34,7 +33,7 @@ func Test_generateRandomKey(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got := securecookie.GenerateRandomKey(tt.wantLenth) got := GenerateRandomKey(tt.wantLenth)
is.Equal(tt.wantLenth, len(got)) is.Equal(tt.wantLenth, len(got))
}) })
} }
@ -42,7 +41,7 @@ func Test_generateRandomKey(t *testing.T) {
t.Run("Generated keys are unique", func(t *testing.T) { t.Run("Generated keys are unique", func(t *testing.T) {
keys := make(map[string]bool) keys := make(map[string]bool)
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
key := securecookie.GenerateRandomKey(8) key := GenerateRandomKey(8)
_, ok := keys[string(key)] _, ok := keys[string(key)]
is.False(ok) is.False(ok)
keys[string(key)] = true keys[string(key)] = true

View file

@ -1,69 +1,79 @@
package apikey package apikey
import ( import (
lru "github.com/hashicorp/golang-lru"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
lru "github.com/hashicorp/golang-lru"
) )
const defaultAPIKeyCacheSize = 1024 const DefaultAPIKeyCacheSize = 1024
// entry is a tuple containing the user and API key associated to an API key digest // entry is a tuple containing the user and API key associated to an API key digest
type entry struct { type entry[T any] struct {
user portainer.User user T
apiKey portainer.APIKey apiKey portainer.APIKey
} }
// apiKeyCache is a concurrency-safe, in-memory cache which primarily exists for to reduce database roundtrips. type UserCompareFn[T any] func(T, portainer.UserID) bool
// ApiKeyCache is a concurrency-safe, in-memory cache which primarily exists for to reduce database roundtrips.
// We store the api-key digest (keys) and the associated user and key-data (values) in the cache. // We store the api-key digest (keys) and the associated user and key-data (values) in the cache.
// This is required because HTTP requests will contain only the api-key digest in the x-api-key request header; // This is required because HTTP requests will contain only the api-key digest in the x-api-key request header;
// digest value must be mapped to a portainer user (and respective key data) for validation. // digest value must be mapped to a portainer user (and respective key data) for validation.
// This cache is used to avoid multiple database queries to retrieve these user/key associated to the digest. // This cache is used to avoid multiple database queries to retrieve these user/key associated to the digest.
type apiKeyCache struct { type ApiKeyCache[T any] struct {
// cache type [string]entry cache (key: string(digest), value: user/key entry) // cache type [string]entry cache (key: string(digest), value: user/key entry)
// note: []byte keys are not supported by golang-lru Cache // note: []byte keys are not supported by golang-lru Cache
cache *lru.Cache cache *lru.Cache
userCmpFn UserCompareFn[T]
} }
// NewAPIKeyCache creates a new cache for API keys // NewAPIKeyCache creates a new cache for API keys
func NewAPIKeyCache(cacheSize int) *apiKeyCache { func NewAPIKeyCache[T any](cacheSize int, userCompareFn UserCompareFn[T]) *ApiKeyCache[T] {
cache, _ := lru.New(cacheSize) cache, _ := lru.New(cacheSize)
return &apiKeyCache{cache: cache}
return &ApiKeyCache[T]{cache: cache, userCmpFn: userCompareFn}
} }
// Get returns the user/key associated to an api-key's digest // Get returns the user/key associated to an api-key's digest
// This is required because HTTP requests will contain the digest of the API key in header, // This is required because HTTP requests will contain the digest of the API key in header,
// the digest value must be mapped to a portainer user. // the digest value must be mapped to a portainer user.
func (c *apiKeyCache) Get(digest string) (portainer.User, portainer.APIKey, bool) { func (c *ApiKeyCache[T]) Get(digest string) (T, portainer.APIKey, bool) {
val, ok := c.cache.Get(digest) val, ok := c.cache.Get(digest)
if !ok { if !ok {
return portainer.User{}, portainer.APIKey{}, false var t T
return t, portainer.APIKey{}, false
} }
tuple := val.(entry)
tuple := val.(entry[T])
return tuple.user, tuple.apiKey, true return tuple.user, tuple.apiKey, true
} }
// Set persists a user/key entry to the cache // Set persists a user/key entry to the cache
func (c *apiKeyCache) Set(digest string, user portainer.User, apiKey portainer.APIKey) { func (c *ApiKeyCache[T]) Set(digest string, user T, apiKey portainer.APIKey) {
c.cache.Add(digest, entry{ c.cache.Add(digest, entry[T]{
user: user, user: user,
apiKey: apiKey, apiKey: apiKey,
}) })
} }
// Delete evicts a digest's user/key entry key from the cache // Delete evicts a digest's user/key entry key from the cache
func (c *apiKeyCache) Delete(digest string) { func (c *ApiKeyCache[T]) Delete(digest string) {
c.cache.Remove(digest) c.cache.Remove(digest)
} }
// InvalidateUserKeyCache loops through all the api-keys associated to a user and removes them from the cache // InvalidateUserKeyCache loops through all the api-keys associated to a user and removes them from the cache
func (c *apiKeyCache) InvalidateUserKeyCache(userId portainer.UserID) bool { func (c *ApiKeyCache[T]) InvalidateUserKeyCache(userId portainer.UserID) bool {
present := false present := false
for _, k := range c.cache.Keys() { for _, k := range c.cache.Keys() {
user, _, _ := c.Get(k.(string)) user, _, _ := c.Get(k.(string))
if user.ID == userId { if c.userCmpFn(user, userId) {
present = c.cache.Remove(k) present = c.cache.Remove(k)
} }
} }
return present return present
} }

View file

@ -10,11 +10,11 @@ import (
func Test_apiKeyCacheGet(t *testing.T) { func Test_apiKeyCacheGet(t *testing.T) {
is := assert.New(t) is := assert.New(t)
keyCache := NewAPIKeyCache(10) keyCache := NewAPIKeyCache(10, compareUser)
// pre-populate cache // pre-populate cache
keyCache.cache.Add(string("foo"), entry{user: portainer.User{}, apiKey: portainer.APIKey{}}) keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{}, apiKey: portainer.APIKey{}})
keyCache.cache.Add(string(""), entry{user: portainer.User{}, apiKey: portainer.APIKey{}}) keyCache.cache.Add(string(""), entry[portainer.User]{user: portainer.User{}, apiKey: portainer.APIKey{}})
tests := []struct { tests := []struct {
digest string digest string
@ -45,7 +45,7 @@ func Test_apiKeyCacheGet(t *testing.T) {
func Test_apiKeyCacheSet(t *testing.T) { func Test_apiKeyCacheSet(t *testing.T) {
is := assert.New(t) is := assert.New(t)
keyCache := NewAPIKeyCache(10) keyCache := NewAPIKeyCache(10, compareUser)
// pre-populate cache // pre-populate cache
keyCache.Set("bar", portainer.User{ID: 2}, portainer.APIKey{}) keyCache.Set("bar", portainer.User{ID: 2}, portainer.APIKey{})
@ -57,23 +57,23 @@ func Test_apiKeyCacheSet(t *testing.T) {
val, ok := keyCache.cache.Get(string("bar")) val, ok := keyCache.cache.Get(string("bar"))
is.True(ok) is.True(ok)
tuple := val.(entry) tuple := val.(entry[portainer.User])
is.Equal(portainer.User{ID: 2}, tuple.user) is.Equal(portainer.User{ID: 2}, tuple.user)
val, ok = keyCache.cache.Get(string("foo")) val, ok = keyCache.cache.Get(string("foo"))
is.True(ok) is.True(ok)
tuple = val.(entry) tuple = val.(entry[portainer.User])
is.Equal(portainer.User{ID: 3}, tuple.user) is.Equal(portainer.User{ID: 3}, tuple.user)
} }
func Test_apiKeyCacheDelete(t *testing.T) { func Test_apiKeyCacheDelete(t *testing.T) {
is := assert.New(t) is := assert.New(t)
keyCache := NewAPIKeyCache(10) keyCache := NewAPIKeyCache(10, compareUser)
t.Run("Delete an existing entry", func(t *testing.T) { t.Run("Delete an existing entry", func(t *testing.T) {
keyCache.cache.Add(string("foo"), entry{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}}) keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
keyCache.Delete("foo") keyCache.Delete("foo")
_, ok := keyCache.cache.Get(string("foo")) _, ok := keyCache.cache.Get(string("foo"))
@ -128,7 +128,7 @@ func Test_apiKeyCacheLRU(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
keyCache := NewAPIKeyCache(test.cacheLen) keyCache := NewAPIKeyCache(test.cacheLen, compareUser)
for _, key := range test.key { for _, key := range test.key {
keyCache.Set(key, portainer.User{ID: 1}, portainer.APIKey{}) keyCache.Set(key, portainer.User{ID: 1}, portainer.APIKey{})
@ -150,10 +150,10 @@ func Test_apiKeyCacheLRU(t *testing.T) {
func Test_apiKeyCacheInvalidateUserKeyCache(t *testing.T) { func Test_apiKeyCacheInvalidateUserKeyCache(t *testing.T) {
is := assert.New(t) is := assert.New(t)
keyCache := NewAPIKeyCache(10) keyCache := NewAPIKeyCache(10, compareUser)
t.Run("Removes users keys from cache", func(t *testing.T) { t.Run("Removes users keys from cache", func(t *testing.T) {
keyCache.cache.Add(string("foo"), entry{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}}) keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
ok := keyCache.InvalidateUserKeyCache(1) ok := keyCache.InvalidateUserKeyCache(1)
is.True(ok) is.True(ok)
@ -163,8 +163,8 @@ func Test_apiKeyCacheInvalidateUserKeyCache(t *testing.T) {
}) })
t.Run("Does not affect other keys", func(t *testing.T) { t.Run("Does not affect other keys", func(t *testing.T) {
keyCache.cache.Add(string("foo"), entry{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}}) keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
keyCache.cache.Add(string("bar"), entry{user: portainer.User{ID: 2}, apiKey: portainer.APIKey{}}) keyCache.cache.Add(string("bar"), entry[portainer.User]{user: portainer.User{ID: 2}, apiKey: portainer.APIKey{}})
ok := keyCache.InvalidateUserKeyCache(1) ok := keyCache.InvalidateUserKeyCache(1)
is.True(ok) is.True(ok)

View file

@ -1,14 +1,15 @@
package apikey package apikey
import ( import (
"crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io"
"time" "time"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/securecookie"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -20,30 +21,45 @@ var ErrInvalidAPIKey = errors.New("Invalid API key")
type apiKeyService struct { type apiKeyService struct {
apiKeyRepository dataservices.APIKeyRepository apiKeyRepository dataservices.APIKeyRepository
userRepository dataservices.UserService userRepository dataservices.UserService
cache *apiKeyCache cache *ApiKeyCache[portainer.User]
}
// GenerateRandomKey generates a random key of specified length
// source: https://github.com/gorilla/securecookie/blob/master/securecookie.go#L515
func GenerateRandomKey(length int) []byte {
k := make([]byte, length)
if _, err := io.ReadFull(rand.Reader, k); err != nil {
return nil
}
return k
}
func compareUser(u portainer.User, id portainer.UserID) bool {
return u.ID == id
} }
func NewAPIKeyService(apiKeyRepository dataservices.APIKeyRepository, userRepository dataservices.UserService) *apiKeyService { func NewAPIKeyService(apiKeyRepository dataservices.APIKeyRepository, userRepository dataservices.UserService) *apiKeyService {
return &apiKeyService{ return &apiKeyService{
apiKeyRepository: apiKeyRepository, apiKeyRepository: apiKeyRepository,
userRepository: userRepository, userRepository: userRepository,
cache: NewAPIKeyCache(defaultAPIKeyCacheSize), cache: NewAPIKeyCache(DefaultAPIKeyCacheSize, compareUser),
} }
} }
// HashRaw computes a hash digest of provided raw API key. // HashRaw computes a hash digest of provided raw API key.
func (a *apiKeyService) HashRaw(rawKey string) string { func (a *apiKeyService) HashRaw(rawKey string) string {
hashDigest := sha256.Sum256([]byte(rawKey)) hashDigest := sha256.Sum256([]byte(rawKey))
return base64.StdEncoding.EncodeToString(hashDigest[:]) return base64.StdEncoding.EncodeToString(hashDigest[:])
} }
// GenerateApiKey generates a raw API key for a user (for one-time display). // GenerateApiKey generates a raw API key for a user (for one-time display).
// The generated API key is stored in the cache and database. // The generated API key is stored in the cache and database.
func (a *apiKeyService) GenerateApiKey(user portainer.User, description string) (string, *portainer.APIKey, error) { func (a *apiKeyService) GenerateApiKey(user portainer.User, description string) (string, *portainer.APIKey, error) {
randKey := securecookie.GenerateRandomKey(32) randKey := GenerateRandomKey(32)
encodedRawAPIKey := base64.StdEncoding.EncodeToString(randKey) encodedRawAPIKey := base64.StdEncoding.EncodeToString(randKey)
prefixedAPIKey := portainerAPIKeyPrefix + encodedRawAPIKey prefixedAPIKey := portainerAPIKeyPrefix + encodedRawAPIKey
hashDigest := a.HashRaw(prefixedAPIKey) hashDigest := a.HashRaw(prefixedAPIKey)
apiKey := &portainer.APIKey{ apiKey := &portainer.APIKey{
@ -54,8 +70,7 @@ func (a *apiKeyService) GenerateApiKey(user portainer.User, description string)
Digest: hashDigest, Digest: hashDigest,
} }
err := a.apiKeyRepository.Create(apiKey) if err := a.apiKeyRepository.Create(apiKey); err != nil {
if err != nil {
return "", nil, errors.Wrap(err, "Unable to create API key") return "", nil, errors.Wrap(err, "Unable to create API key")
} }
@ -78,7 +93,6 @@ func (a *apiKeyService) GetAPIKeys(userID portainer.UserID) ([]portainer.APIKey,
// GetDigestUserAndKey returns the user and api-key associated to a specified hash digest. // GetDigestUserAndKey returns the user and api-key associated to a specified hash digest.
// A cache lookup is performed first; if the user/api-key is not found in the cache, respective database lookups are performed. // A cache lookup is performed first; if the user/api-key is not found in the cache, respective database lookups are performed.
func (a *apiKeyService) GetDigestUserAndKey(digest string) (portainer.User, portainer.APIKey, error) { func (a *apiKeyService) GetDigestUserAndKey(digest string) (portainer.User, portainer.APIKey, error) {
// get api key from cache if possible
cachedUser, cachedKey, ok := a.cache.Get(digest) cachedUser, cachedKey, ok := a.cache.Get(digest)
if ok { if ok {
return cachedUser, cachedKey, nil return cachedUser, cachedKey, nil
@ -106,20 +120,21 @@ func (a *apiKeyService) UpdateAPIKey(apiKey *portainer.APIKey) error {
if err != nil { if err != nil {
return errors.Wrap(err, "Unable to retrieve API key") return errors.Wrap(err, "Unable to retrieve API key")
} }
a.cache.Set(apiKey.Digest, user, *apiKey) a.cache.Set(apiKey.Digest, user, *apiKey)
return a.apiKeyRepository.Update(apiKey.ID, apiKey) return a.apiKeyRepository.Update(apiKey.ID, apiKey)
} }
// DeleteAPIKey deletes an API key and removes the digest/api-key entry from the cache. // DeleteAPIKey deletes an API key and removes the digest/api-key entry from the cache.
func (a *apiKeyService) DeleteAPIKey(apiKeyID portainer.APIKeyID) error { func (a *apiKeyService) DeleteAPIKey(apiKeyID portainer.APIKeyID) error {
// get api-key digest to remove from cache
apiKey, err := a.apiKeyRepository.Read(apiKeyID) apiKey, err := a.apiKeyRepository.Read(apiKeyID)
if err != nil { if err != nil {
return errors.Wrap(err, fmt.Sprintf("Unable to retrieve API key: %d", apiKeyID)) return errors.Wrap(err, fmt.Sprintf("Unable to retrieve API key: %d", apiKeyID))
} }
// delete the user/api-key from cache
a.cache.Delete(apiKey.Digest) a.cache.Delete(apiKey.Digest)
return a.apiKeyRepository.Delete(apiKeyID) return a.apiKeyRepository.Delete(apiKeyID)
} }

View file

@ -17,17 +17,14 @@ import (
type Service struct{} type Service struct{}
var ( var (
errInvalidEndpointProtocol = errors.New("Invalid environment protocol: Portainer only supports unix://, npipe:// or tcp://") ErrInvalidEndpointProtocol = errors.New("Invalid environment protocol: Portainer only supports unix://, npipe:// or tcp://")
errSocketOrNamedPipeNotFound = errors.New("Unable to locate Unix socket or named pipe") ErrSocketOrNamedPipeNotFound = errors.New("Unable to locate Unix socket or named pipe")
errInvalidSnapshotInterval = errors.New("Invalid snapshot interval") ErrInvalidSnapshotInterval = errors.New("Invalid snapshot interval")
errAdminPassExcludeAdminPassFile = errors.New("Cannot use --admin-password with --admin-password-file") ErrAdminPassExcludeAdminPassFile = errors.New("Cannot use --admin-password with --admin-password-file")
) )
// ParseFlags parse the CLI flags and return a portainer.Flags struct func CLIFlags() *portainer.CLIFlags {
func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) { return &portainer.CLIFlags{
kingpin.Version(version)
flags := &portainer.CLIFlags{
Addr: kingpin.Flag("bind", "Address and port to serve Portainer").Default(defaultBindAddress).Short('p').String(), Addr: kingpin.Flag("bind", "Address and port to serve Portainer").Default(defaultBindAddress).Short('p').String(),
AddrHTTPS: kingpin.Flag("bind-https", "Address and port to serve Portainer via https").Default(defaultHTTPSBindAddress).String(), AddrHTTPS: kingpin.Flag("bind-https", "Address and port to serve Portainer via https").Default(defaultHTTPSBindAddress).String(),
TunnelAddr: kingpin.Flag("tunnel-addr", "Address to serve the tunnel server").Default(defaultTunnelServerAddress).String(), TunnelAddr: kingpin.Flag("tunnel-addr", "Address to serve the tunnel server").Default(defaultTunnelServerAddress).String(),
@ -63,6 +60,13 @@ func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
LogLevel: kingpin.Flag("log-level", "Set the minimum logging level to show").Default("INFO").Enum("DEBUG", "INFO", "WARN", "ERROR"), LogLevel: kingpin.Flag("log-level", "Set the minimum logging level to show").Default("INFO").Enum("DEBUG", "INFO", "WARN", "ERROR"),
LogMode: kingpin.Flag("log-mode", "Set the logging output mode").Default("PRETTY").Enum("NOCOLOR", "PRETTY", "JSON"), LogMode: kingpin.Flag("log-mode", "Set the logging output mode").Default("PRETTY").Enum("NOCOLOR", "PRETTY", "JSON"),
} }
}
// ParseFlags parse the CLI flags and return a portainer.Flags struct
func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
kingpin.Version(version)
flags := CLIFlags()
kingpin.Parse() kingpin.Parse()
@ -82,18 +86,16 @@ func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
func (*Service) ValidateFlags(flags *portainer.CLIFlags) error { func (*Service) ValidateFlags(flags *portainer.CLIFlags) error {
displayDeprecationWarnings(flags) displayDeprecationWarnings(flags)
err := validateEndpointURL(*flags.EndpointURL) if err := validateEndpointURL(*flags.EndpointURL); err != nil {
if err != nil {
return err return err
} }
err = validateSnapshotInterval(*flags.SnapshotInterval) if err := validateSnapshotInterval(*flags.SnapshotInterval); err != nil {
if err != nil {
return err return err
} }
if *flags.AdminPassword != "" && *flags.AdminPasswordFile != "" { if *flags.AdminPassword != "" && *flags.AdminPasswordFile != "" {
return errAdminPassExcludeAdminPassFile return ErrAdminPassExcludeAdminPassFile
} }
return nil return nil
@ -115,15 +117,16 @@ func validateEndpointURL(endpointURL string) error {
} }
if !strings.HasPrefix(endpointURL, "unix://") && !strings.HasPrefix(endpointURL, "tcp://") && !strings.HasPrefix(endpointURL, "npipe://") { if !strings.HasPrefix(endpointURL, "unix://") && !strings.HasPrefix(endpointURL, "tcp://") && !strings.HasPrefix(endpointURL, "npipe://") {
return errInvalidEndpointProtocol return ErrInvalidEndpointProtocol
} }
if strings.HasPrefix(endpointURL, "unix://") || strings.HasPrefix(endpointURL, "npipe://") { if strings.HasPrefix(endpointURL, "unix://") || strings.HasPrefix(endpointURL, "npipe://") {
socketPath := strings.TrimPrefix(endpointURL, "unix://") socketPath := strings.TrimPrefix(endpointURL, "unix://")
socketPath = strings.TrimPrefix(socketPath, "npipe://") socketPath = strings.TrimPrefix(socketPath, "npipe://")
if _, err := os.Stat(socketPath); err != nil { if _, err := os.Stat(socketPath); err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
return errSocketOrNamedPipeNotFound return ErrSocketOrNamedPipeNotFound
} }
return err return err
@ -138,9 +141,8 @@ func validateSnapshotInterval(snapshotInterval string) error {
return nil return nil
} }
_, err := time.ParseDuration(snapshotInterval) if _, err := time.ParseDuration(snapshotInterval); err != nil {
if err != nil { return ErrInvalidSnapshotInterval
return errInvalidSnapshotInterval
} }
return nil return nil

View file

@ -56,14 +56,14 @@ import (
) )
func initCLI() *portainer.CLIFlags { func initCLI() *portainer.CLIFlags {
var cliService portainer.CLIService = &cli.Service{} cliService := &cli.Service{}
flags, err := cliService.ParseFlags(portainer.APIVersion) flags, err := cliService.ParseFlags(portainer.APIVersion)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("failed parsing flags") log.Fatal().Err(err).Msg("failed parsing flags")
} }
err = cliService.ValidateFlags(flags) if err := cliService.ValidateFlags(flags); err != nil {
if err != nil {
log.Fatal().Err(err).Msg("failed validating flags") log.Fatal().Err(err).Msg("failed validating flags")
} }
@ -94,14 +94,14 @@ func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService port
} }
store := datastore.NewStore(*flags.Data, fileService, connection) store := datastore.NewStore(*flags.Data, fileService, connection)
isNew, err := store.Open() isNew, err := store.Open()
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("failed opening store") log.Fatal().Err(err).Msg("failed opening store")
} }
if *flags.Rollback { if *flags.Rollback {
err := store.Rollback(false) if err := store.Rollback(false); err != nil {
if err != nil {
log.Fatal().Err(err).Msg("failed rolling back") log.Fatal().Err(err).Msg("failed rolling back")
} }
@ -110,8 +110,7 @@ func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService port
} }
// Init sets some defaults - it's basically a migration // Init sets some defaults - it's basically a migration
err = store.Init() if err := store.Init(); err != nil {
if err != nil {
log.Fatal().Err(err).Msg("failed initializing data store") log.Fatal().Err(err).Msg("failed initializing data store")
} }
@ -133,25 +132,23 @@ func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService port
} }
store.VersionService.UpdateVersion(&v) store.VersionService.UpdateVersion(&v)
err = updateSettingsFromFlags(store, flags) if err := updateSettingsFromFlags(store, flags); err != nil {
if err != nil {
log.Fatal().Err(err).Msg("failed updating settings from flags") log.Fatal().Err(err).Msg("failed updating settings from flags")
} }
} else { } else {
err = store.MigrateData() if err := store.MigrateData(); err != nil {
if err != nil {
log.Fatal().Err(err).Msg("failed migration") log.Fatal().Err(err).Msg("failed migration")
} }
} }
err = updateSettingsFromFlags(store, flags) if err := updateSettingsFromFlags(store, flags); err != nil {
if err != nil {
log.Fatal().Err(err).Msg("failed updating settings from flags") log.Fatal().Err(err).Msg("failed updating settings from flags")
} }
// this is for the db restore functionality - needs more tests. // this is for the db restore functionality - needs more tests.
go func() { go func() {
<-shutdownCtx.Done() <-shutdownCtx.Done()
defer connection.Close() defer connection.Close()
}() }()
@ -205,36 +202,16 @@ func initJWTService(userSessionTimeout string, dataStore dataservices.DataStore)
userSessionTimeout = portainer.DefaultUserSessionTimeout userSessionTimeout = portainer.DefaultUserSessionTimeout
} }
jwtService, err := jwt.NewService(userSessionTimeout, dataStore) return jwt.NewService(userSessionTimeout, dataStore)
if err != nil {
return nil, err
}
return jwtService, nil
} }
func initDigitalSignatureService() portainer.DigitalSignatureService { func initDigitalSignatureService() portainer.DigitalSignatureService {
return crypto.NewECDSAService(os.Getenv("AGENT_SECRET")) return crypto.NewECDSAService(os.Getenv("AGENT_SECRET"))
} }
func initCryptoService() portainer.CryptoService {
return &crypto.Service{}
}
func initLDAPService() portainer.LDAPService {
return &ldap.Service{}
}
func initOAuthService() portainer.OAuthService {
return oauth.NewService()
}
func initGitService(ctx context.Context) portainer.GitService {
return git.NewService(ctx)
}
func initSSLService(addr, certPath, keyPath string, fileService portainer.FileService, dataStore dataservices.DataStore, shutdownTrigger context.CancelFunc) (*ssl.Service, error) { func initSSLService(addr, certPath, keyPath string, fileService portainer.FileService, dataStore dataservices.DataStore, shutdownTrigger context.CancelFunc) (*ssl.Service, error) {
slices := strings.Split(addr, ":") slices := strings.Split(addr, ":")
host := slices[0] host := slices[0]
if host == "" { if host == "" {
host = "0.0.0.0" host = "0.0.0.0"
@ -242,22 +219,13 @@ func initSSLService(addr, certPath, keyPath string, fileService portainer.FileSe
sslService := ssl.NewService(fileService, dataStore, shutdownTrigger) sslService := ssl.NewService(fileService, dataStore, shutdownTrigger)
err := sslService.Init(host, certPath, keyPath) if err := sslService.Init(host, certPath, keyPath); err != nil {
if err != nil {
return nil, err return nil, err
} }
return sslService, nil return sslService, nil
} }
func initDockerClientFactory(signatureService portainer.DigitalSignatureService, reverseTunnelService portainer.ReverseTunnelService) *dockerclient.ClientFactory {
return dockerclient.NewClientFactory(signatureService, reverseTunnelService)
}
func initKubernetesClientFactory(signatureService portainer.DigitalSignatureService, reverseTunnelService portainer.ReverseTunnelService, dataStore dataservices.DataStore, instanceID, addrHTTPS, userSessionTimeout string) (*kubecli.ClientFactory, error) {
return kubecli.NewClientFactory(signatureService, reverseTunnelService, dataStore, instanceID, addrHTTPS, userSessionTimeout)
}
func initSnapshotService( func initSnapshotService(
snapshotIntervalFromFlag string, snapshotIntervalFromFlag string,
dataStore dataservices.DataStore, dataStore dataservices.DataStore,
@ -310,14 +278,12 @@ func updateSettingsFromFlags(dataStore dataservices.DataStore, flags *portainer.
settings.BlackListedLabels = *flags.Labels settings.BlackListedLabels = *flags.Labels
} }
settings.AgentSecret = ""
if agentKey, ok := os.LookupEnv("AGENT_SECRET"); ok { if agentKey, ok := os.LookupEnv("AGENT_SECRET"); ok {
settings.AgentSecret = agentKey settings.AgentSecret = agentKey
} else {
settings.AgentSecret = ""
} }
err = dataStore.Settings().UpdateSettings(settings) if err := dataStore.Settings().UpdateSettings(settings); err != nil {
if err != nil {
return err return err
} }
@ -340,6 +306,7 @@ func loadAndParseKeyPair(fileService portainer.FileService, signatureService por
if err != nil { if err != nil {
return err return err
} }
return signatureService.ParseKeyPair(private, public) return signatureService.ParseKeyPair(private, public)
} }
@ -348,7 +315,9 @@ func generateAndStoreKeyPair(fileService portainer.FileService, signatureService
if err != nil { if err != nil {
return err return err
} }
privateHeader, publicHeader := signatureService.PEMHeaders() privateHeader, publicHeader := signatureService.PEMHeaders()
return fileService.StoreKeyPair(private, public, privateHeader, publicHeader) return fileService.StoreKeyPair(private, public, privateHeader, publicHeader)
} }
@ -361,6 +330,7 @@ func initKeyPair(fileService portainer.FileService, signatureService portainer.D
if existingKeyPair { if existingKeyPair {
return loadAndParseKeyPair(fileService, signatureService) return loadAndParseKeyPair(fileService, signatureService)
} }
return generateAndStoreKeyPair(fileService, signatureService) return generateAndStoreKeyPair(fileService, signatureService)
} }
@ -378,6 +348,7 @@ func loadEncryptionSecretKey(keyfilename string) []byte {
// return a 32 byte hash of the secret (required for AES) // return a 32 byte hash of the secret (required for AES)
hash := sha256.Sum256(content) hash := sha256.Sum256(content)
return hash[:] return hash[:]
} }
@ -422,17 +393,17 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
log.Fatal().Err(err).Msg("failed initializing JWT service") log.Fatal().Err(err).Msg("failed initializing JWT service")
} }
ldapService := initLDAPService() ldapService := &ldap.Service{}
oauthService := initOAuthService() oauthService := oauth.NewService()
gitService := initGitService(shutdownCtx) gitService := git.NewService(shutdownCtx)
openAMTService := openamt.NewService() openAMTService := openamt.NewService()
cryptoService := initCryptoService() cryptoService := &crypto.Service{}
digitalSignatureService := initDigitalSignatureService() signatureService := initDigitalSignatureService()
edgeStacksService := edgestacks.NewService(dataStore) edgeStacksService := edgestacks.NewService(dataStore)
@ -446,15 +417,18 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
log.Fatal().Err(err).Msg("failed to get SSL settings") log.Fatal().Err(err).Msg("failed to get SSL settings")
} }
err = initKeyPair(fileService, digitalSignatureService) if err := initKeyPair(fileService, signatureService); err != nil {
if err != nil {
log.Fatal().Err(err).Msg("failed initializing key pair") log.Fatal().Err(err).Msg("failed initializing key pair")
} }
reverseTunnelService := chisel.NewService(dataStore, shutdownCtx, fileService) reverseTunnelService := chisel.NewService(dataStore, shutdownCtx, fileService)
dockerClientFactory := initDockerClientFactory(digitalSignatureService, reverseTunnelService) dockerClientFactory := dockerclient.NewClientFactory(signatureService, reverseTunnelService)
kubernetesClientFactory, err := initKubernetesClientFactory(digitalSignatureService, reverseTunnelService, dataStore, instanceID, *flags.AddrHTTPS, settings.UserSessionTimeout)
kubernetesClientFactory, err := kubecli.NewClientFactory(signatureService, reverseTunnelService, dataStore, instanceID, *flags.AddrHTTPS, settings.UserSessionTimeout)
if err != nil {
log.Fatal().Err(err).Msg("failed initializing Kubernetes Client Factory service")
}
authorizationService := authorization.NewService(dataStore) authorizationService := authorization.NewService(dataStore)
authorizationService.K8sClientFactory = kubernetesClientFactory authorizationService.K8sClientFactory = kubernetesClientFactory
@ -476,12 +450,12 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
composeStackManager := initComposeStackManager(composeDeployer, proxyManager) composeStackManager := initComposeStackManager(composeDeployer, proxyManager)
swarmStackManager, err := initSwarmStackManager(*flags.Assets, dockerConfigPath, digitalSignatureService, fileService, reverseTunnelService, dataStore) swarmStackManager, err := initSwarmStackManager(*flags.Assets, dockerConfigPath, signatureService, fileService, reverseTunnelService, dataStore)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("failed initializing swarm stack manager") log.Fatal().Err(err).Msg("failed initializing swarm stack manager")
} }
kubernetesDeployer := initKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, digitalSignatureService, proxyManager, *flags.Assets) kubernetesDeployer := initKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, signatureService, proxyManager, *flags.Assets)
pendingActionsService := pendingactions.NewService(dataStore, kubernetesClientFactory) pendingActionsService := pendingactions.NewService(dataStore, kubernetesClientFactory)
pendingActionsService.RegisterHandler(actions.CleanNAPWithOverridePolicies, handlers.NewHandlerCleanNAPWithOverridePolicies(authorizationService, dataStore)) pendingActionsService.RegisterHandler(actions.CleanNAPWithOverridePolicies, handlers.NewHandlerCleanNAPWithOverridePolicies(authorizationService, dataStore))
@ -492,17 +466,17 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("failed initializing snapshot service") log.Fatal().Err(err).Msg("failed initializing snapshot service")
} }
snapshotService.Start() snapshotService.Start()
proxyManager.NewProxyFactory(dataStore, digitalSignatureService, reverseTunnelService, dockerClientFactory, kubernetesClientFactory, kubernetesTokenCacheManager, gitService, snapshotService) proxyManager.NewProxyFactory(dataStore, signatureService, reverseTunnelService, dockerClientFactory, kubernetesClientFactory, kubernetesTokenCacheManager, gitService, snapshotService)
helmPackageManager, err := initHelmPackageManager(*flags.Assets) helmPackageManager, err := initHelmPackageManager(*flags.Assets)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("failed initializing helm package manager") log.Fatal().Err(err).Msg("failed initializing helm package manager")
} }
err = edge.LoadEdgeJobs(dataStore, reverseTunnelService) if err := edge.LoadEdgeJobs(dataStore, reverseTunnelService); err != nil {
if err != nil {
log.Fatal().Err(err).Msg("failed loading edge jobs from database") log.Fatal().Err(err).Msg("failed loading edge jobs from database")
} }
@ -514,6 +488,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
go endpointutils.InitEndpoint(shutdownCtx, adminCreationDone, flags, dataStore, snapshotService) go endpointutils.InitEndpoint(shutdownCtx, adminCreationDone, flags, dataStore, snapshotService)
adminPasswordHash := "" adminPasswordHash := ""
if *flags.AdminPasswordFile != "" { if *flags.AdminPasswordFile != "" {
content, err := fileService.GetFileContent(*flags.AdminPasswordFile, "") content, err := fileService.GetFileContent(*flags.AdminPasswordFile, "")
if err != nil { if err != nil {
@ -536,14 +511,14 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
if len(users) == 0 { if len(users) == 0 {
log.Info().Msg("created admin user with the given password.") log.Info().Msg("created admin user with the given password.")
user := &portainer.User{ user := &portainer.User{
Username: "admin", Username: "admin",
Role: portainer.AdministratorRole, Role: portainer.AdministratorRole,
Password: adminPasswordHash, Password: adminPasswordHash,
} }
err := dataStore.User().Create(user) if err := dataStore.User().Create(user); err != nil {
if err != nil {
log.Fatal().Err(err).Msg("failed creating admin user") log.Fatal().Err(err).Msg("failed creating admin user")
} }
@ -554,8 +529,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
} }
} }
err = reverseTunnelService.StartTunnelServer(*flags.TunnelAddr, *flags.TunnelPort, snapshotService) if err := reverseTunnelService.StartTunnelServer(*flags.TunnelAddr, *flags.TunnelPort, snapshotService); err != nil {
if err != nil {
log.Fatal().Err(err).Msg("failed starting tunnel server") log.Fatal().Err(err).Msg("failed starting tunnel server")
} }
@ -613,7 +587,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
ProxyManager: proxyManager, ProxyManager: proxyManager,
KubernetesTokenCacheManager: kubernetesTokenCacheManager, KubernetesTokenCacheManager: kubernetesTokenCacheManager,
KubeClusterAccessService: kubeClusterAccessService, KubeClusterAccessService: kubeClusterAccessService,
SignatureService: digitalSignatureService, SignatureService: signatureService,
SnapshotService: snapshotService, SnapshotService: snapshotService,
SSLService: sslService, SSLService: sslService,
DockerClientFactory: dockerClientFactory, DockerClientFactory: dockerClientFactory,
@ -639,6 +613,7 @@ func main() {
for { for {
server := buildServer(flags) server := buildServer(flags)
log.Info(). log.Info().
Str("version", portainer.APIVersion). Str("version", portainer.APIVersion).
Str("build_number", build.BuildNumber). Str("build_number", build.BuildNumber).

View file

@ -203,6 +203,7 @@ func (connection *DbConnection) ExportRaw(filename string) error {
func (connection *DbConnection) ConvertToKey(v int) []byte { func (connection *DbConnection) ConvertToKey(v int) []byte {
b := make([]byte, 8) b := make([]byte, 8)
binary.BigEndian.PutUint64(b, uint64(v)) binary.BigEndian.PutUint64(b, uint64(v))
return b return b
} }

View file

@ -46,8 +46,8 @@ func (connection *DbConnection) UnmarshalObject(data []byte, object interface{})
return errors.Wrap(err, "Failed decrypting object") return errors.Wrap(err, "Failed decrypting object")
} }
} }
e := json.Unmarshal(data, object)
if e != nil { if e := json.Unmarshal(data, object); e != nil {
// Special case for the VERSION bucket. Here we're not using json // Special case for the VERSION bucket. Here we're not using json
// So we need to return it as a string // So we need to return it as a string
s, ok := object.(*string) s, ok := object.(*string)
@ -57,6 +57,7 @@ func (connection *DbConnection) UnmarshalObject(data []byte, object interface{})
*s = string(data) *s = string(data)
} }
return err return err
} }
@ -71,7 +72,7 @@ func encrypt(plaintext []byte, passphrase []byte) (encrypted []byte, err error)
} }
nonce := make([]byte, gcm.NonceSize()) nonce := make([]byte, gcm.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil { if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return encrypted, err return encrypted, err
} }

View file

@ -78,6 +78,7 @@ func (tx *DbTransaction) GetNextIdentifier(bucketName string) int {
id, err := bucket.NextSequence() id, err := bucket.NextSequence()
if err != nil { if err != nil {
log.Error().Err(err).Str("bucket", bucketName).Msg("failed to get the next identifier") log.Error().Err(err).Str("bucket", bucketName).Msg("failed to get the next identifier")
return 0 return 0
} }

View file

@ -111,5 +111,6 @@ func (store *Store) finishMigrateLegacyVersion(versionToWrite *models.Version) e
store.connection.DeleteObject(bucketName, []byte(legacyDBVersionKey)) store.connection.DeleteObject(bucketName, []byte(legacyDBVersionKey))
store.connection.DeleteObject(bucketName, []byte(legacyEditionKey)) store.connection.DeleteObject(bucketName, []byte(legacyEditionKey))
store.connection.DeleteObject(bucketName, []byte(legacyInstanceKey)) store.connection.DeleteObject(bucketName, []byte(legacyInstanceKey))
return err return err
} }

View file

@ -39,20 +39,19 @@ func (m *Migrator) Migrate() error {
latestMigrations := m.LatestMigrations() latestMigrations := m.LatestMigrations()
if latestMigrations.Version.Equal(schemaVersion) && if latestMigrations.Version.Equal(schemaVersion) &&
version.MigratorCount != len(latestMigrations.MigrationFuncs) { version.MigratorCount != len(latestMigrations.MigrationFuncs) {
err := runMigrations(latestMigrations.MigrationFuncs) if err := runMigrations(latestMigrations.MigrationFuncs); err != nil {
if err != nil {
return err return err
} }
newMigratorCount = len(latestMigrations.MigrationFuncs) newMigratorCount = len(latestMigrations.MigrationFuncs)
} }
} else { } else {
// regular path when major/minor/patch versions differ // regular path when major/minor/patch versions differ
for _, migration := range m.migrations { for _, migration := range m.migrations {
if schemaVersion.LessThan(migration.Version) { if schemaVersion.LessThan(migration.Version) {
log.Info().Msgf("migrating data to %s", migration.Version.String()) log.Info().Msgf("migrating data to %s", migration.Version.String())
err := runMigrations(migration.MigrationFuncs)
if err != nil { if err := runMigrations(migration.MigrationFuncs); err != nil {
return err return err
} }
} }
@ -63,16 +62,14 @@ func (m *Migrator) Migrate() error {
} }
} }
err = m.Always() if err := m.Always(); err != nil {
if err != nil {
return migrationError(err, "Always migrations returned error") return migrationError(err, "Always migrations returned error")
} }
version.SchemaVersion = portainer.APIVersion version.SchemaVersion = portainer.APIVersion
version.MigratorCount = newMigratorCount version.MigratorCount = newMigratorCount
err = m.versionService.UpdateVersion(version) if err := m.versionService.UpdateVersion(version); err != nil {
if err != nil {
return migrationError(err, "StoreDBVersion") return migrationError(err, "StoreDBVersion")
} }
@ -99,6 +96,7 @@ func (m *Migrator) NeedsMigration() bool {
// In this particular instance we should log a fatal error // In this particular instance we should log a fatal error
if m.CurrentDBEdition() != portainer.PortainerCE { if m.CurrentDBEdition() != portainer.PortainerCE {
log.Fatal().Msg("the Portainer database is set for Portainer Business Edition, please follow the instructions in our documentation to downgrade it: https://documentation.portainer.io/v2.0-be/downgrade/be-to-ce/") log.Fatal().Msg("the Portainer database is set for Portainer Business Edition, please follow the instructions in our documentation to downgrade it: https://documentation.portainer.io/v2.0-be/downgrade/be-to-ce/")
return false return false
} }

View file

@ -7,6 +7,7 @@ import (
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/chisel/crypto" "github.com/portainer/portainer/api/chisel/crypto"
"github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/dataservices"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -37,9 +38,11 @@ func (m *Migrator) convertSeedToPrivateKeyForDB100() error {
log.Info().Msg("ServerInfo object not found") log.Info().Msg("ServerInfo object not found")
return nil return nil
} }
log.Error(). log.Error().
Err(err). Err(err).
Msg("Failed to read ServerInfo from DB") Msg("Failed to read ServerInfo from DB")
return err return err
} }
@ -49,14 +52,15 @@ func (m *Migrator) convertSeedToPrivateKeyForDB100() error {
log.Error(). log.Error().
Err(err). Err(err).
Msg("Failed to read ServerInfo from DB") Msg("Failed to read ServerInfo from DB")
return err return err
} }
err = m.fileService.StoreChiselPrivateKey(key) if err := m.fileService.StoreChiselPrivateKey(key); err != nil {
if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
Msg("Failed to save Chisel private key to disk") Msg("Failed to save Chisel private key to disk")
return err return err
} }
} else { } else {
@ -64,14 +68,14 @@ func (m *Migrator) convertSeedToPrivateKeyForDB100() error {
} }
serverInfo.PrivateKeySeed = "" serverInfo.PrivateKeySeed = ""
err = m.TunnelServerService.UpdateInfo(serverInfo) if err := m.TunnelServerService.UpdateInfo(serverInfo); err != nil {
if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
Msg("Failed to clean private key seed in DB") Msg("Failed to clean private key seed in DB")
} else { } else {
log.Info().Msg("Success to migrate private key seed to private key file") log.Info().Msg("Success to migrate private key seed to private key file")
} }
return err return err
} }
@ -84,9 +88,8 @@ func (m *Migrator) updateEdgeStackStatusForDB100() error {
} }
for _, edgeStack := range edgeStacks { for _, edgeStack := range edgeStacks {
for environmentID, environmentStatus := range edgeStack.Status { for environmentID, environmentStatus := range edgeStack.Status {
// skip if status is already updated // Skip if status is already updated
if len(environmentStatus.Status) > 0 { if len(environmentStatus.Status) > 0 {
continue continue
} }
@ -146,8 +149,7 @@ func (m *Migrator) updateEdgeStackStatusForDB100() error {
edgeStack.Status[environmentID] = environmentStatus edgeStack.Status[environmentID] = environmentStatus
} }
err = m.edgeStackService.UpdateEdgeStack(edgeStack.ID, &edgeStack) if err := m.edgeStackService.UpdateEdgeStack(edgeStack.ID, &edgeStack); err != nil {
if err != nil {
return err return err
} }
} }

View file

@ -32,8 +32,8 @@ func (m *Migrator) updateStacksToDB24() error {
for idx := range stacks { for idx := range stacks {
stack := &stacks[idx] stack := &stacks[idx]
stack.Status = portainer.StackStatusActive stack.Status = portainer.StackStatusActive
err := m.stackService.Update(stack.ID, stack)
if err != nil { if err := m.stackService.Update(stack.ID, stack); err != nil {
return err return err
} }
} }

View file

@ -583,7 +583,6 @@
"AuthenticationMethod": 1, "AuthenticationMethod": 1,
"BlackListedLabels": [], "BlackListedLabels": [],
"Edge": { "Edge": {
"AsyncMode": false,
"CommandInterval": 0, "CommandInterval": 0,
"PingInterval": 0, "PingInterval": 0,
"SnapshotInterval": 0 "SnapshotInterval": 0

View file

@ -52,27 +52,24 @@ func NewTestStore(t testing.TB, init, secure bool) (bool, *Store, func(), error)
} }
if init { if init {
err = store.Init() if err := store.Init(); err != nil {
if err != nil {
return newStore, nil, nil, err return newStore, nil, nil, err
} }
} }
if newStore { if newStore {
// from MigrateData // From MigrateData
v := models.Version{ v := models.Version{
SchemaVersion: portainer.APIVersion, SchemaVersion: portainer.APIVersion,
Edition: int(portainer.PortainerCE), Edition: int(portainer.PortainerCE),
} }
err = store.VersionService.UpdateVersion(&v) if err := store.VersionService.UpdateVersion(&v); err != nil {
if err != nil {
return newStore, nil, nil, err return newStore, nil, nil, err
} }
} }
teardown := func() { teardown := func() {
err := store.Close() if err := store.Close(); err != nil {
if err != nil {
log.Fatal().Err(err).Msg("") log.Fatal().Err(err).Msg("")
} }
} }

View file

@ -36,7 +36,6 @@ func (c *ContainerService) Recreate(ctx context.Context, endpoint *portainer.End
if err != nil { if err != nil {
return nil, errors.Wrap(err, "create client error") return nil, errors.Wrap(err, "create client error")
} }
defer cli.Close() defer cli.Close()
log.Debug().Str("container_id", containerId).Msg("starting to fetch container information") log.Debug().Str("container_id", containerId).Msg("starting to fetch container information")

View file

@ -5,10 +5,10 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"github.com/portainer/portainer/api/http/security"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
gorillacsrf "github.com/gorilla/csrf" gorillacsrf "github.com/gorilla/csrf"
"github.com/portainer/portainer/api/http/security"
"github.com/urfave/negroni" "github.com/urfave/negroni"
) )
@ -16,8 +16,7 @@ func WithProtect(handler http.Handler) (http.Handler, error) {
handler = withSendCSRFToken(handler) handler = withSendCSRFToken(handler)
token := make([]byte, 32) token := make([]byte, 32)
_, err := rand.Read(token) if _, err := rand.Read(token); err != nil {
if err != nil {
return nil, fmt.Errorf("failed to generate CSRF token: %w", err) return nil, fmt.Errorf("failed to generate CSRF token: %w", err)
} }
@ -32,7 +31,6 @@ func WithProtect(handler http.Handler) (http.Handler, error) {
func withSendCSRFToken(handler http.Handler) http.Handler { func withSendCSRFToken(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sw := negroni.NewResponseWriter(w) sw := negroni.NewResponseWriter(w)
sw.Before(func(sw negroni.ResponseWriter) { sw.Before(func(sw negroni.ResponseWriter) {
@ -44,16 +42,15 @@ func withSendCSRFToken(handler http.Handler) http.Handler {
}) })
handler.ServeHTTP(sw, r) handler.ServeHTTP(sw, r)
}) })
} }
func withSkipCSRF(handler http.Handler) http.Handler { func withSkipCSRF(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
skip, err := security.ShouldSkipCSRFCheck(r) skip, err := security.ShouldSkipCSRFCheck(r)
if err != nil { if err != nil {
httperror.WriteError(w, http.StatusForbidden, err.Error(), err) httperror.WriteError(w, http.StatusForbidden, err.Error(), err)
return return
} }

View file

@ -56,8 +56,7 @@ func (payload *authenticatePayload) Validate(r *http.Request) error {
// @router /auth [post] // @router /auth [post]
func (handler *Handler) authenticate(rw http.ResponseWriter, r *http.Request) *httperror.HandlerError { func (handler *Handler) authenticate(rw http.ResponseWriter, r *http.Request) *httperror.HandlerError {
var payload authenticatePayload var payload authenticatePayload
err := request.DecodeAndValidateJSONPayload(r, &payload) if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil {
if err != nil {
return httperror.BadRequest("Invalid request payload", err) return httperror.BadRequest("Invalid request payload", err)
} }
@ -104,8 +103,7 @@ func isUserInitialAdmin(user *portainer.User) bool {
} }
func (handler *Handler) authenticateInternal(w http.ResponseWriter, user *portainer.User, password string) *httperror.HandlerError { func (handler *Handler) authenticateInternal(w http.ResponseWriter, user *portainer.User, password string) *httperror.HandlerError {
err := handler.CryptoService.CompareHashAndData(user.Password, password) if err := handler.CryptoService.CompareHashAndData(user.Password, password); err != nil {
if err != nil {
return httperror.NewError(http.StatusUnprocessableEntity, "Invalid credentials", httperrors.ErrUnauthorized) return httperror.NewError(http.StatusUnprocessableEntity, "Invalid credentials", httperrors.ErrUnauthorized)
} }
@ -115,8 +113,7 @@ func (handler *Handler) authenticateInternal(w http.ResponseWriter, user *portai
} }
func (handler *Handler) authenticateLDAP(w http.ResponseWriter, user *portainer.User, username, password string, ldapSettings *portainer.LDAPSettings) *httperror.HandlerError { func (handler *Handler) authenticateLDAP(w http.ResponseWriter, user *portainer.User, username, password string, ldapSettings *portainer.LDAPSettings) *httperror.HandlerError {
err := handler.LDAPService.AuthenticateUser(username, password, ldapSettings) if err := handler.LDAPService.AuthenticateUser(username, password, ldapSettings); err != nil {
if err != nil {
if errors.Is(err, httperrors.ErrUnauthorized) { if errors.Is(err, httperrors.ErrUnauthorized) {
return httperror.NewError(http.StatusUnprocessableEntity, "Invalid credentials", httperrors.ErrUnauthorized) return httperror.NewError(http.StatusUnprocessableEntity, "Invalid credentials", httperrors.ErrUnauthorized)
} }
@ -131,14 +128,12 @@ func (handler *Handler) authenticateLDAP(w http.ResponseWriter, user *portainer.
PortainerAuthorizations: authorization.DefaultPortainerAuthorizations(), PortainerAuthorizations: authorization.DefaultPortainerAuthorizations(),
} }
err = handler.DataStore.User().Create(user) if err := handler.DataStore.User().Create(user); err != nil {
if err != nil {
return httperror.InternalServerError("Unable to persist user inside the database", err) return httperror.InternalServerError("Unable to persist user inside the database", err)
} }
} }
err = handler.syncUserTeamsWithLDAPGroups(user, ldapSettings) if err := handler.syncUserTeamsWithLDAPGroups(user, ldapSettings); err != nil {
if err != nil {
log.Warn().Err(err).Msg("unable to automatically sync user teams with ldap") log.Warn().Err(err).Msg("unable to automatically sync user teams with ldap")
} }
@ -186,7 +181,6 @@ func (handler *Handler) syncUserTeamsWithLDAPGroups(user *portainer.User, settin
for _, team := range teams { for _, team := range teams {
if teamExists(team.Name, userGroups) { if teamExists(team.Name, userGroups) {
if teamMembershipExists(team.ID, userMemberships) { if teamMembershipExists(team.ID, userMemberships) {
continue continue
} }
@ -197,8 +191,7 @@ func (handler *Handler) syncUserTeamsWithLDAPGroups(user *portainer.User, settin
Role: portainer.TeamMember, Role: portainer.TeamMember,
} }
err := handler.DataStore.TeamMembership().Create(membership) if err := handler.DataStore.TeamMembership().Create(membership); err != nil {
if err != nil {
return err return err
} }
} }

View file

@ -41,5 +41,6 @@ func NewHandler(bouncer security.BouncerService, rateLimiter *security.RateLimit
rateLimiter.LimitAccess(bouncer.PublicAccess(httperror.LoggerHandler(h.authenticate)))).Methods(http.MethodPost) rateLimiter.LimitAccess(bouncer.PublicAccess(httperror.LoggerHandler(h.authenticate)))).Methods(http.MethodPost)
h.Handle("/auth/logout", h.Handle("/auth/logout",
bouncer.PublicAccess(httperror.LoggerHandler(h.logout))).Methods(http.MethodPost) bouncer.PublicAccess(httperror.LoggerHandler(h.logout))).Methods(http.MethodPost)
return h return h
} }

View file

@ -4,7 +4,7 @@ import (
"net/http" "net/http"
"github.com/portainer/portainer/api/http/security" "github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/logoutcontext" "github.com/portainer/portainer/api/logoutcontext"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/response" "github.com/portainer/portainer/pkg/libhttp/response"
) )

View file

@ -4,14 +4,15 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"github.com/pkg/errors"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/http/security" "github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/authorization" "github.com/portainer/portainer/api/internal/authorization"
"github.com/portainer/portainer/api/internal/slices" "github.com/portainer/portainer/api/slicesx"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response" "github.com/portainer/portainer/pkg/libhttp/response"
"github.com/pkg/errors"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -70,7 +71,7 @@ func (handler *Handler) customTemplateList(w http.ResponseWriter, r *http.Reques
customTemplates = filterByType(customTemplates, templateTypes) customTemplates = filterByType(customTemplates, templateTypes)
if edge != nil { if edge != nil {
customTemplates = slices.Filter(customTemplates, func(customTemplate portainer.CustomTemplate) bool { customTemplates = slicesx.Filter(customTemplates, func(customTemplate portainer.CustomTemplate) bool {
return customTemplate.EdgeTemplate == *edge return customTemplate.EdgeTemplate == *edge
}) })
} }

View file

@ -6,7 +6,7 @@ import (
"github.com/portainer/portainer/api/docker/client" "github.com/portainer/portainer/api/docker/client"
"github.com/portainer/portainer/api/http/handler/docker/utils" "github.com/portainer/portainer/api/http/handler/docker/utils"
"github.com/portainer/portainer/api/internal/set" "github.com/portainer/portainer/api/set"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response" "github.com/portainer/portainer/pkg/libhttp/response"

View file

@ -7,7 +7,7 @@ import (
"github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/http/security" "github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/authorization" "github.com/portainer/portainer/api/internal/authorization"
"github.com/portainer/portainer/api/internal/slices" "github.com/portainer/portainer/api/slicesx"
) )
// filterByResourceControl filters a list of items based on the user's role and the resource control associated to the item. // filterByResourceControl filters a list of items based on the user's role and the resource control associated to the item.
@ -16,7 +16,7 @@ func FilterByResourceControl[T any](tx dataservices.DataStoreTx, items []T, rcTy
return items, nil return items, nil
} }
userTeamIDs := slices.Map(securityContext.UserMemberships, func(membership portainer.TeamMembership) portainer.TeamID { userTeamIDs := slicesx.Map(securityContext.UserMemberships, func(membership portainer.TeamMembership) portainer.TeamID {
return membership.TeamID return membership.TeamID
}) })
@ -32,5 +32,6 @@ func FilterByResourceControl[T any](tx dataservices.DataStoreTx, items []T, rcTy
} }
} }
return filteredItems, nil return filteredItems, nil
} }

View file

@ -36,23 +36,25 @@ func (payload *edgeGroupCreatePayload) Validate(r *http.Request) error {
func calculateEndpointsOrTags(tx dataservices.DataStoreTx, edgeGroup *portainer.EdgeGroup, endpoints []portainer.EndpointID, tagIDs []portainer.TagID) error { func calculateEndpointsOrTags(tx dataservices.DataStoreTx, edgeGroup *portainer.EdgeGroup, endpoints []portainer.EndpointID, tagIDs []portainer.TagID) error {
if edgeGroup.Dynamic { if edgeGroup.Dynamic {
edgeGroup.TagIDs = tagIDs edgeGroup.TagIDs = tagIDs
} else {
endpointIDs := []portainer.EndpointID{}
for _, endpointID := range endpoints { return nil
endpoint, err := tx.Endpoint().Endpoint(endpointID) }
if err != nil {
return httperror.InternalServerError("Unable to retrieve environment from the database", err)
}
if endpointutils.IsEdgeEndpoint(endpoint) { endpointIDs := []portainer.EndpointID{}
endpointIDs = append(endpointIDs, endpoint.ID)
} for _, endpointID := range endpoints {
endpoint, err := tx.Endpoint().Endpoint(endpointID)
if err != nil {
return httperror.InternalServerError("Unable to retrieve environment from the database", err)
} }
edgeGroup.Endpoints = endpointIDs if endpointutils.IsEdgeEndpoint(endpoint) {
endpointIDs = append(endpointIDs, endpoint.ID)
}
} }
edgeGroup.Endpoints = endpointIDs
return nil return nil
} }
@ -71,13 +73,13 @@ func calculateEndpointsOrTags(tx dataservices.DataStoreTx, edgeGroup *portainer.
// @router /edge_groups [post] // @router /edge_groups [post]
func (handler *Handler) edgeGroupCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError { func (handler *Handler) edgeGroupCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
var payload edgeGroupCreatePayload var payload edgeGroupCreatePayload
err := request.DecodeAndValidateJSONPayload(r, &payload) if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil {
if err != nil {
return httperror.BadRequest("Invalid request payload", err) return httperror.BadRequest("Invalid request payload", err)
} }
var edgeGroup *portainer.EdgeGroup var edgeGroup *portainer.EdgeGroup
err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
edgeGroups, err := tx.EdgeGroup().ReadAll() edgeGroups, err := tx.EdgeGroup().ReadAll()
if err != nil { if err != nil {
return httperror.InternalServerError("Unable to retrieve Edge groups from the database", err) return httperror.InternalServerError("Unable to retrieve Edge groups from the database", err)
@ -101,8 +103,7 @@ func (handler *Handler) edgeGroupCreate(w http.ResponseWriter, r *http.Request)
return err return err
} }
err = tx.EdgeGroup().Create(edgeGroup) if err := tx.EdgeGroup().Create(edgeGroup); err != nil {
if err != nil {
return httperror.InternalServerError("Unable to persist the Edge group inside the database", err) return httperror.InternalServerError("Unable to persist the Edge group inside the database", err)
} }

View file

@ -9,7 +9,7 @@ import (
"github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/edge" "github.com/portainer/portainer/api/internal/edge"
"github.com/portainer/portainer/api/internal/endpointutils" "github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/api/internal/unique" "github.com/portainer/portainer/api/slicesx"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
@ -113,7 +113,7 @@ func (handler *Handler) edgeGroupUpdate(w http.ResponseWriter, r *http.Request)
} }
newRelatedEndpoints := edge.EdgeGroupRelatedEndpoints(edgeGroup, endpoints, endpointGroups) newRelatedEndpoints := edge.EdgeGroupRelatedEndpoints(edgeGroup, endpoints, endpointGroups)
endpointsToUpdate := unique.Unique(append(newRelatedEndpoints, oldRelatedEndpoints...)) endpointsToUpdate := slicesx.Unique(append(newRelatedEndpoints, oldRelatedEndpoints...))
edgeJobs, err := tx.EdgeJob().ReadAll() edgeJobs, err := tx.EdgeJob().ReadAll()
if err != nil { if err != nil {

View file

@ -31,8 +31,7 @@ func setupHandler(t *testing.T) (*Handler, string) {
} }
user := &portainer.User{ID: 2, Username: "admin", Role: portainer.AdministratorRole} user := &portainer.User{ID: 2, Username: "admin", Role: portainer.AdministratorRole}
err = store.User().Create(user) if err := store.User().Create(user); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -66,8 +65,7 @@ func setupHandler(t *testing.T) (*Handler, string) {
} }
settings.EnableEdgeComputeFeatures = true settings.EnableEdgeComputeFeatures = true
err = handler.DataStore.Settings().UpdateSettings(settings) if err := handler.DataStore.Settings().UpdateSettings(settings); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -88,8 +86,7 @@ func createEndpointWithId(t *testing.T, store dataservices.DataStore, endpointID
LastCheckInDate: time.Now().Unix(), LastCheckInDate: time.Now().Unix(),
} }
err := store.Endpoint().Create(&endpoint) if err := store.Endpoint().Create(&endpoint); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -112,8 +109,7 @@ func createEdgeStack(t *testing.T, store dataservices.DataStore, endpointID port
PartialMatch: false, PartialMatch: false,
} }
err := store.EdgeGroup().Create(&edgeGroup) if err := store.EdgeGroup().Create(&edgeGroup); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -138,13 +134,11 @@ func createEdgeStack(t *testing.T, store dataservices.DataStore, endpointID port
}, },
} }
err = store.EdgeStack().Create(edgeStack.ID, &edgeStack) if err := store.EdgeStack().Create(edgeStack.ID, &edgeStack); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = store.EndpointRelation().Create(&endpointRelation) if err := store.EndpointRelation().Create(&endpointRelation); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -6,7 +6,7 @@ import (
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/edge" "github.com/portainer/portainer/api/internal/edge"
"github.com/portainer/portainer/api/internal/set" "github.com/portainer/portainer/api/set"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response" "github.com/portainer/portainer/pkg/libhttp/response"

View file

@ -9,6 +9,7 @@ import (
"testing" "testing"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/stretchr/testify/require"
"github.com/segmentio/encoding/json" "github.com/segmentio/encoding/json"
) )
@ -24,8 +25,7 @@ func TestUpdateAndInspect(t *testing.T) {
endpointID := portainer.EndpointID(6) endpointID := portainer.EndpointID(6)
newEndpoint := createEndpointWithId(t, handler.DataStore, endpointID) newEndpoint := createEndpointWithId(t, handler.DataStore, endpointID)
err := handler.DataStore.Endpoint().Create(&newEndpoint) if err := handler.DataStore.Endpoint().Create(&newEndpoint); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -36,8 +36,7 @@ func TestUpdateAndInspect(t *testing.T) {
}, },
} }
err = handler.DataStore.EndpointRelation().Create(&endpointRelation) if err := handler.DataStore.EndpointRelation().Create(&endpointRelation); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -50,8 +49,7 @@ func TestUpdateAndInspect(t *testing.T) {
PartialMatch: false, PartialMatch: false,
} }
err = handler.DataStore.EdgeGroup().Create(&newEdgeGroup) if err := handler.DataStore.EdgeGroup().Create(&newEdgeGroup); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -96,8 +94,7 @@ func TestUpdateAndInspect(t *testing.T) {
} }
updatedStack := portainer.EdgeStack{} updatedStack := portainer.EdgeStack{}
err = json.NewDecoder(rec.Body).Decode(&updatedStack) if err := json.NewDecoder(rec.Body).Decode(&updatedStack); err != nil {
if err != nil {
t.Fatal("error decoding response:", err) t.Fatal("error decoding response:", err)
} }
@ -120,7 +117,6 @@ func TestUpdateWithInvalidEdgeGroups(t *testing.T) {
endpoint := createEndpoint(t, handler.DataStore) endpoint := createEndpoint(t, handler.DataStore)
edgeStack := createEdgeStack(t, handler.DataStore, endpoint.ID) edgeStack := createEdgeStack(t, handler.DataStore, endpoint.ID)
//newEndpoint := createEndpoint(t, handler.DataStore)
newEdgeGroup := portainer.EdgeGroup{ newEdgeGroup := portainer.EdgeGroup{
ID: 2, ID: 2,
Name: "EdgeGroup 2", Name: "EdgeGroup 2",
@ -130,7 +126,8 @@ func TestUpdateWithInvalidEdgeGroups(t *testing.T) {
PartialMatch: false, PartialMatch: false,
} }
handler.DataStore.EdgeGroup().Create(&newEdgeGroup) err := handler.DataStore.EdgeGroup().Create(&newEdgeGroup)
require.NoError(t, err)
cases := []struct { cases := []struct {
Name string Name string

View file

@ -18,6 +18,7 @@ import (
"github.com/segmentio/encoding/json" "github.com/segmentio/encoding/json"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
type endpointTestCase struct { type endpointTestCase struct {
@ -99,8 +100,7 @@ func mustSetupHandler(t *testing.T) *Handler {
} }
settings.TrustOnFirstConnect = true settings.TrustOnFirstConnect = true
err = store.Settings().UpdateSettings(settings) if err = store.Settings().UpdateSettings(settings); err != nil {
if err != nil {
t.Fatalf("could not update settings: %s", err) t.Fatalf("could not update settings: %s", err)
} }
@ -122,8 +122,7 @@ func createEndpoint(handler *Handler, endpoint portainer.Endpoint, endpointRelat
return nil return nil
} }
err = handler.DataStore.Endpoint().Create(&endpoint) if err := handler.DataStore.Endpoint().Create(&endpoint); err != nil {
if err != nil {
return err return err
} }
@ -134,14 +133,13 @@ func TestMissingEdgeIdentifier(t *testing.T) {
handler := mustSetupHandler(t) handler := mustSetupHandler(t)
endpointID := portainer.EndpointID(45) endpointID := portainer.EndpointID(45)
err := createEndpoint(handler, portainer.Endpoint{ if err := createEndpoint(handler, portainer.Endpoint{
ID: endpointID, ID: endpointID,
Name: "endpoint-id-45", Name: "endpoint-id-45",
Type: portainer.EdgeAgentOnDockerEnvironment, Type: portainer.EdgeAgentOnDockerEnvironment,
URL: "https://portainer.io:9443", URL: "https://portainer.io:9443",
EdgeID: "edge-id", EdgeID: "edge-id",
}, portainer.EndpointRelation{EndpointID: endpointID}) }, portainer.EndpointRelation{EndpointID: endpointID}); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -201,8 +199,7 @@ func TestLastCheckInDateIncreases(t *testing.T) {
EndpointID: endpoint.ID, EndpointID: endpoint.ID,
} }
err := createEndpoint(handler, endpoint, endpointRelation) if err := createEndpoint(handler, endpoint, endpointRelation); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -212,6 +209,7 @@ func TestLastCheckInDateIncreases(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("request error:", err) t.Fatal("request error:", err)
} }
req.Header.Set(portainer.PortainerAgentEdgeIDHeader, "edge-id") req.Header.Set(portainer.PortainerAgentEdgeIDHeader, "edge-id")
req.Header.Set(portainer.HTTPResponseAgentPlatform, "1") req.Header.Set(portainer.HTTPResponseAgentPlatform, "1")
@ -246,8 +244,7 @@ func TestEmptyEdgeIdWithAgentPlatformHeader(t *testing.T) {
EndpointID: endpoint.ID, EndpointID: endpoint.ID,
} }
err := createEndpoint(handler, endpoint, endpointRelation) if err := createEndpoint(handler, endpoint, endpointRelation); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -255,6 +252,7 @@ func TestEmptyEdgeIdWithAgentPlatformHeader(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("request error:", err) t.Fatal("request error:", err)
} }
req.Header.Set(portainer.PortainerAgentEdgeIDHeader, edgeId) req.Header.Set(portainer.PortainerAgentEdgeIDHeader, edgeId)
req.Header.Set(portainer.HTTPResponseAgentPlatform, "1") req.Header.Set(portainer.HTTPResponseAgentPlatform, "1")
@ -308,10 +306,11 @@ func TestEdgeStackStatus(t *testing.T) {
edgeStack.ID: true, edgeStack.ID: true,
}, },
} }
handler.DataStore.EdgeStack().Create(edgeStack.ID, &edgeStack)
err := createEndpoint(handler, endpoint, endpointRelation) err := handler.DataStore.EdgeStack().Create(edgeStack.ID, &edgeStack)
if err != nil { require.NoError(t, err)
if err := createEndpoint(handler, endpoint, endpointRelation); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -319,6 +318,7 @@ func TestEdgeStackStatus(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("request error:", err) t.Fatal("request error:", err)
} }
req.Header.Set(portainer.PortainerAgentEdgeIDHeader, "edge-id") req.Header.Set(portainer.PortainerAgentEdgeIDHeader, "edge-id")
req.Header.Set(portainer.HTTPResponseAgentPlatform, "1") req.Header.Set(portainer.HTTPResponseAgentPlatform, "1")
@ -330,8 +330,7 @@ func TestEdgeStackStatus(t *testing.T) {
} }
var data endpointEdgeStatusInspectResponse var data endpointEdgeStatusInspectResponse
err = json.NewDecoder(rec.Body).Decode(&data) if err := json.NewDecoder(rec.Body).Decode(&data); err != nil {
if err != nil {
t.Fatal("error decoding response:", err) t.Fatal("error decoding response:", err)
} }
@ -357,8 +356,7 @@ func TestEdgeJobsResponse(t *testing.T) {
EndpointID: endpoint.ID, EndpointID: endpoint.ID,
} }
err := createEndpoint(handler, endpoint, endpointRelation) if err := createEndpoint(handler, endpoint, endpointRelation); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -384,6 +382,7 @@ func TestEdgeJobsResponse(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("request error:", err) t.Fatal("request error:", err)
} }
req.Header.Set(portainer.PortainerAgentEdgeIDHeader, "edge-id") req.Header.Set(portainer.PortainerAgentEdgeIDHeader, "edge-id")
req.Header.Set(portainer.HTTPResponseAgentPlatform, "1") req.Header.Set(portainer.HTTPResponseAgentPlatform, "1")
@ -395,8 +394,7 @@ func TestEdgeJobsResponse(t *testing.T) {
} }
var data endpointEdgeStatusInspectResponse var data endpointEdgeStatusInspectResponse
err = json.NewDecoder(rec.Body).Decode(&data) if err := json.NewDecoder(rec.Body).Decode(&data); err != nil {
if err != nil {
t.Fatal("error decoding response:", err) t.Fatal("error decoding response:", err)
} }

View file

@ -8,8 +8,8 @@ import (
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/endpointutils" "github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/api/internal/tag"
"github.com/portainer/portainer/api/pendingactions/handlers" "github.com/portainer/portainer/api/pendingactions/handlers"
"github.com/portainer/portainer/api/tag"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response" "github.com/portainer/portainer/pkg/libhttp/response"

View file

@ -4,7 +4,7 @@ import (
"net/http" "net/http"
"github.com/portainer/portainer/api/http/security" "github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/set" "github.com/portainer/portainer/api/set"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/response" "github.com/portainer/portainer/pkg/libhttp/response"
) )

View file

@ -73,8 +73,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
payload.GroupID = groupID payload.GroupID = groupID
var tagIDs []portainer.TagID var tagIDs []portainer.TagID
err = request.RetrieveMultiPartFormJSONValue(r, "TagIds", &tagIDs, true) if err := request.RetrieveMultiPartFormJSONValue(r, "TagIds", &tagIDs, true); err != nil {
if err != nil {
return errors.New("invalid TagIds parameter") return errors.New("invalid TagIds parameter")
} }
payload.TagIDs = tagIDs payload.TagIDs = tagIDs
@ -96,6 +95,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
if err != nil { if err != nil {
return errors.New("invalid CA certificate file. Ensure that the file is uploaded correctly") return errors.New("invalid CA certificate file. Ensure that the file is uploaded correctly")
} }
payload.TLSCACertFile = caCert payload.TLSCACertFile = caCert
} }
@ -110,6 +110,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
if err != nil { if err != nil {
return errors.New("invalid key file. Ensure that the file is uploaded correctly") return errors.New("invalid key file. Ensure that the file is uploaded correctly")
} }
payload.TLSKeyFile = key payload.TLSKeyFile = key
} }
} }
@ -120,6 +121,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
if err != nil { if err != nil {
return errors.New("invalid Azure application ID") return errors.New("invalid Azure application ID")
} }
payload.AzureApplicationID = azureApplicationID payload.AzureApplicationID = azureApplicationID
azureTenantID, err := request.RetrieveMultiPartFormValue(r, "AzureTenantID", false) azureTenantID, err := request.RetrieveMultiPartFormValue(r, "AzureTenantID", false)
@ -139,6 +141,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
if err != nil || strings.EqualFold("", strings.Trim(endpointURL, " ")) { if err != nil || strings.EqualFold("", strings.Trim(endpointURL, " ")) {
return errors.New("URL cannot be empty") return errors.New("URL cannot be empty")
} }
payload.URL = endpointURL payload.URL = endpointURL
publicURL, _ := request.RetrieveMultiPartFormValue(r, "PublicURL", true) publicURL, _ := request.RetrieveMultiPartFormValue(r, "PublicURL", true)
@ -156,10 +159,10 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
} }
gpus := make([]portainer.Pair, 0) gpus := make([]portainer.Pair, 0)
err = request.RetrieveMultiPartFormJSONValue(r, "Gpus", &gpus, true) if err := request.RetrieveMultiPartFormJSONValue(r, "Gpus", &gpus, true); err != nil {
if err != nil {
return errors.New("invalid Gpus parameter") return errors.New("invalid Gpus parameter")
} }
payload.Gpus = gpus payload.Gpus = gpus
edgeCheckinInterval, _ := request.RetrieveNumericMultiPartFormValue(r, "EdgeCheckinInterval", true) edgeCheckinInterval, _ := request.RetrieveNumericMultiPartFormValue(r, "EdgeCheckinInterval", true)
@ -206,8 +209,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
// @router /endpoints [post] // @router /endpoints [post]
func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError { func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
payload := &endpointCreatePayload{} payload := &endpointCreatePayload{}
err := payload.Validate(r) if err := payload.Validate(r); err != nil {
if err != nil {
return httperror.BadRequest("Invalid request payload", err) return httperror.BadRequest("Invalid request payload", err)
} }
@ -268,8 +270,7 @@ func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) *
) )
} }
err = handler.DataStore.EndpointRelation().Create(relationObject) if err := handler.DataStore.EndpointRelation().Create(relationObject); err != nil {
if err != nil {
return httperror.InternalServerError("Unable to persist the relation object inside the database", err) return httperror.InternalServerError("Unable to persist the relation object inside the database", err)
} }
@ -278,6 +279,7 @@ func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) *
func (handler *Handler) createEndpoint(tx dataservices.DataStoreTx, payload *endpointCreatePayload) (*portainer.Endpoint, *httperror.HandlerError) { func (handler *Handler) createEndpoint(tx dataservices.DataStoreTx, payload *endpointCreatePayload) (*portainer.Endpoint, *httperror.HandlerError) {
var err error var err error
switch payload.EndpointCreationType { switch payload.EndpointCreationType {
case azureEnvironment: case azureEnvironment:
return handler.createAzureEndpoint(tx, payload) return handler.createAzureEndpoint(tx, payload)
@ -329,8 +331,7 @@ func (handler *Handler) createAzureEndpoint(tx dataservices.DataStoreTx, payload
} }
httpClient := client.NewHTTPClient() httpClient := client.NewHTTPClient()
_, err := httpClient.ExecuteAzureAuthenticationRequest(&credentials) if _, err := httpClient.ExecuteAzureAuthenticationRequest(&credentials); err != nil {
if err != nil {
return nil, httperror.InternalServerError("Unable to authenticate against Azure", err) return nil, httperror.InternalServerError("Unable to authenticate against Azure", err)
} }
@ -352,8 +353,7 @@ func (handler *Handler) createAzureEndpoint(tx dataservices.DataStoreTx, payload
Kubernetes: portainer.KubernetesDefault(), Kubernetes: portainer.KubernetesDefault(),
} }
err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint) if err := handler.saveEndpointAndUpdateAuthorizations(tx, endpoint); err != nil {
if err != nil {
return nil, httperror.InternalServerError("An error occurred while trying to create the environment", err) return nil, httperror.InternalServerError("An error occurred while trying to create the environment", err)
} }
@ -405,8 +405,7 @@ func (handler *Handler) createEdgeAgentEndpoint(tx dataservices.DataStoreTx, pay
endpoint.EdgeID = edgeID.String() endpoint.EdgeID = edgeID.String()
} }
err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint) if err := handler.saveEndpointAndUpdateAuthorizations(tx, endpoint); err != nil {
if err != nil {
return nil, httperror.InternalServerError("An error occurred while trying to create the environment", err) return nil, httperror.InternalServerError("An error occurred while trying to create the environment", err)
} }
@ -443,8 +442,7 @@ func (handler *Handler) createUnsecuredEndpoint(tx dataservices.DataStoreTx, pay
Kubernetes: portainer.KubernetesDefault(), Kubernetes: portainer.KubernetesDefault(),
} }
err := handler.snapshotAndPersistEndpoint(tx, endpoint) if err := handler.snapshotAndPersistEndpoint(tx, endpoint); err != nil {
if err != nil {
return nil, err return nil, err
} }
@ -478,8 +476,7 @@ func (handler *Handler) createKubernetesEndpoint(tx dataservices.DataStoreTx, pa
Kubernetes: portainer.KubernetesDefault(), Kubernetes: portainer.KubernetesDefault(),
} }
err := handler.snapshotAndPersistEndpoint(tx, endpoint) if err := handler.snapshotAndPersistEndpoint(tx, endpoint); err != nil {
if err != nil {
return nil, err return nil, err
} }
@ -510,13 +507,11 @@ func (handler *Handler) createTLSSecuredEndpoint(tx dataservices.DataStoreTx, pa
endpoint.Agent.Version = agentVersion endpoint.Agent.Version = agentVersion
err := handler.storeTLSFiles(endpoint, payload) if err := handler.storeTLSFiles(endpoint, payload); err != nil {
if err != nil {
return nil, err return nil, err
} }
err = handler.snapshotAndPersistEndpoint(tx, endpoint) if err := handler.snapshotAndPersistEndpoint(tx, endpoint); err != nil {
if err != nil {
return nil, err return nil, err
} }
@ -524,17 +519,16 @@ func (handler *Handler) createTLSSecuredEndpoint(tx dataservices.DataStoreTx, pa
} }
func (handler *Handler) snapshotAndPersistEndpoint(tx dataservices.DataStoreTx, endpoint *portainer.Endpoint) *httperror.HandlerError { func (handler *Handler) snapshotAndPersistEndpoint(tx dataservices.DataStoreTx, endpoint *portainer.Endpoint) *httperror.HandlerError {
err := handler.SnapshotService.SnapshotEndpoint(endpoint) if err := handler.SnapshotService.SnapshotEndpoint(endpoint); err != nil {
if err != nil {
if (endpoint.Type == portainer.AgentOnDockerEnvironment && strings.Contains(err.Error(), "Invalid request signature")) || if (endpoint.Type == portainer.AgentOnDockerEnvironment && strings.Contains(err.Error(), "Invalid request signature")) ||
(endpoint.Type == portainer.AgentOnKubernetesEnvironment && strings.Contains(err.Error(), "unknown")) { (endpoint.Type == portainer.AgentOnKubernetesEnvironment && strings.Contains(err.Error(), "unknown")) {
err = errors.New("agent already paired with another Portainer instance") err = errors.New("agent already paired with another Portainer instance")
} }
return httperror.InternalServerError("Unable to initiate communications with environment", err) return httperror.InternalServerError("Unable to initiate communications with environment", err)
} }
err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint) if err := handler.saveEndpointAndUpdateAuthorizations(tx, endpoint); err != nil {
if err != nil {
return httperror.InternalServerError("An error occurred while trying to create the environment", err) return httperror.InternalServerError("An error occurred while trying to create the environment", err)
} }
@ -555,16 +549,14 @@ func (handler *Handler) saveEndpointAndUpdateAuthorizations(tx dataservices.Data
AllowStackManagementForRegularUsers: true, AllowStackManagementForRegularUsers: true,
} }
err := tx.Endpoint().Create(endpoint) if err := tx.Endpoint().Create(endpoint); err != nil {
if err != nil {
return err return err
} }
for _, tagID := range endpoint.TagIDs { for _, tagID := range endpoint.TagIDs {
err = tx.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { if err := tx.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) {
tag.Endpoints[endpoint.ID] = true tag.Endpoints[endpoint.ID] = true
}) }); err != nil {
if err != nil {
return err return err
} }
} }
@ -580,22 +572,26 @@ func (handler *Handler) storeTLSFiles(endpoint *portainer.Endpoint, payload *end
if err != nil { if err != nil {
return httperror.InternalServerError("Unable to persist TLS CA certificate file on disk", err) return httperror.InternalServerError("Unable to persist TLS CA certificate file on disk", err)
} }
endpoint.TLSConfig.TLSCACertPath = caCertPath endpoint.TLSConfig.TLSCACertPath = caCertPath
} }
if !payload.TLSSkipClientVerify { if payload.TLSSkipClientVerify {
certPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileCert, payload.TLSCertFile) return nil
if err != nil {
return httperror.InternalServerError("Unable to persist TLS certificate file on disk", err)
}
endpoint.TLSConfig.TLSCertPath = certPath
keyPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileKey, payload.TLSKeyFile)
if err != nil {
return httperror.InternalServerError("Unable to persist TLS key file on disk", err)
}
endpoint.TLSConfig.TLSKeyPath = keyPath
} }
certPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileCert, payload.TLSCertFile)
if err != nil {
return httperror.InternalServerError("Unable to persist TLS certificate file on disk", err)
}
endpoint.TLSConfig.TLSCertPath = certPath
keyPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileKey, payload.TLSKeyFile)
if err != nil {
return httperror.InternalServerError("Unable to persist TLS key file on disk", err)
}
endpoint.TLSConfig.TLSKeyPath = keyPath
return nil return nil
} }

View file

@ -30,24 +30,22 @@ func TestEndpointDeleteEdgeGroupsConcurrently(t *testing.T) {
for i := 0; i < endpointsCount; i++ { for i := 0; i < endpointsCount; i++ {
endpointID := portainer.EndpointID(i) + 1 endpointID := portainer.EndpointID(i) + 1
err := store.Endpoint().Create(&portainer.Endpoint{ if err := store.Endpoint().Create(&portainer.Endpoint{
ID: endpointID, ID: endpointID,
Name: "env-" + strconv.Itoa(int(endpointID)), Name: "env-" + strconv.Itoa(int(endpointID)),
Type: portainer.EdgeAgentOnDockerEnvironment, Type: portainer.EdgeAgentOnDockerEnvironment,
}) }); err != nil {
if err != nil {
t.Fatal("could not create endpoint:", err) t.Fatal("could not create endpoint:", err)
} }
endpointIDs = append(endpointIDs, endpointID) endpointIDs = append(endpointIDs, endpointID)
} }
err := store.EdgeGroup().Create(&portainer.EdgeGroup{ if err := store.EdgeGroup().Create(&portainer.EdgeGroup{
ID: 1, ID: 1,
Name: "edgegroup-1", Name: "edgegroup-1",
Endpoints: endpointIDs, Endpoints: endpointIDs,
}) }); err != nil {
if err != nil {
t.Fatal("could not create edge group:", err) t.Fatal("could not create edge group:", err)
} }

View file

@ -102,7 +102,6 @@ func Test_EndpointList_AgentVersion(t *testing.T) {
} }
func Test_endpointList_edgeFilter(t *testing.T) { func Test_endpointList_edgeFilter(t *testing.T) {
trustedEdgeAsync := portainer.Endpoint{ID: 1, UserTrusted: true, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: true}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment} trustedEdgeAsync := portainer.Endpoint{ID: 1, UserTrusted: true, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: true}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment}
untrustedEdgeAsync := portainer.Endpoint{ID: 2, UserTrusted: false, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: true}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment} untrustedEdgeAsync := portainer.Endpoint{ID: 2, UserTrusted: false, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: true}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment}
regularUntrustedEdgeStandard := portainer.Endpoint{ID: 3, UserTrusted: false, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: false}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment} regularUntrustedEdgeStandard := portainer.Endpoint{ID: 3, UserTrusted: false, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: false}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment}
@ -227,8 +226,7 @@ func doEndpointListRequest(req *http.Request, h *Handler, is *assert.Assertions)
} }
resp := []portainer.Endpoint{} resp := []portainer.Endpoint{}
err = json.Unmarshal(body, &resp) if err := json.Unmarshal(body, &resp); err != nil {
if err != nil {
return nil, err return nil, err
} }

View file

@ -34,12 +34,10 @@ func (handler *Handler) endpointRegistriesList(w http.ResponseWriter, r *http.Re
} }
var registries []portainer.Registry var registries []portainer.Registry
err = handler.DataStore.ViewTx(func(tx dataservices.DataStoreTx) error { if err := handler.DataStore.ViewTx(func(tx dataservices.DataStoreTx) error {
registries, err = handler.listRegistries(tx, r, portainer.EndpointID(endpointID)) registries, err = handler.listRegistries(tx, r, portainer.EndpointID(endpointID))
return err return err
}) }); err != nil {
if err != nil {
var httpErr *httperror.HandlerError var httpErr *httperror.HandlerError
if errors.As(err, &httpErr) { if errors.As(err, &httpErr) {
return httpErr return httpErr
@ -104,11 +102,9 @@ func (handler *Handler) filterKubernetesEndpointRegistries(r *http.Request, regi
} }
if namespaceParam != "" { if namespaceParam != "" {
authorized, err := handler.isNamespaceAuthorized(endpoint, namespaceParam, user.ID, memberships, isAdmin) if authorized, err := handler.isNamespaceAuthorized(endpoint, namespaceParam, user.ID, memberships, isAdmin); err != nil {
if err != nil {
return nil, httperror.NotFound("Unable to check for namespace authorization", err) return nil, httperror.NotFound("Unable to check for namespace authorization", err)
} } else if !authorized {
if !authorized {
return nil, httperror.Forbidden("User is not authorized to use namespace", errors.New("user is not authorized to use namespace")) return nil, httperror.Forbidden("User is not authorized to use namespace", errors.New("user is not authorized to use namespace"))
} }

View file

@ -13,7 +13,7 @@ import (
"github.com/portainer/portainer/api/http/handler/edgegroups" "github.com/portainer/portainer/api/http/handler/edgegroups"
"github.com/portainer/portainer/api/internal/edge" "github.com/portainer/portainer/api/internal/edge"
"github.com/portainer/portainer/api/internal/endpointutils" "github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/api/internal/unique" "github.com/portainer/portainer/api/slicesx"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -254,6 +254,7 @@ func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId port
if err != nil { if err != nil {
return nil, errors.WithMessage(err, "Unable to retrieve edge group from the database") return nil, errors.WithMessage(err, "Unable to retrieve edge group from the database")
} }
if edgeGroup.Dynamic { if edgeGroup.Dynamic {
endpointIDs, err := edgegroups.GetEndpointsByTags(datastore, edgeGroup.TagIDs, edgeGroup.PartialMatch) endpointIDs, err := edgegroups.GetEndpointsByTags(datastore, edgeGroup.TagIDs, edgeGroup.PartialMatch)
if err != nil { if err != nil {
@ -261,6 +262,7 @@ func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId port
} }
edgeGroup.Endpoints = endpointIDs edgeGroup.Endpoints = endpointIDs
} }
envIds = append(envIds, edgeGroup.Endpoints...) envIds = append(envIds, edgeGroup.Endpoints...)
} }
@ -275,7 +277,7 @@ func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId port
envIds = envIds[:n] envIds = envIds[:n]
} }
uniqueIds := unique.Unique(envIds) uniqueIds := slicesx.Unique(envIds)
filteredEndpoints := filteredEndpointsByIds(endpoints, uniqueIds) filteredEndpoints := filteredEndpointsByIds(endpoints, uniqueIds)
return filteredEndpoints, nil return filteredEndpoints, nil

View file

@ -5,8 +5,8 @@ import (
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/datastore" "github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/internal/slices"
"github.com/portainer/portainer/api/internal/testhelpers" "github.com/portainer/portainer/api/internal/testhelpers"
"github.com/portainer/portainer/api/slicesx"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -129,7 +129,7 @@ func Test_Filter_edgeFilter(t *testing.T) {
func Test_Filter_excludeIDs(t *testing.T) { func Test_Filter_excludeIDs(t *testing.T) {
ids := []portainer.EndpointID{1, 2, 3, 4, 5, 6, 7, 8, 9} ids := []portainer.EndpointID{1, 2, 3, 4, 5, 6, 7, 8, 9}
environments := slices.Map(ids, func(id portainer.EndpointID) portainer.Endpoint { environments := slicesx.Map(ids, func(id portainer.EndpointID) portainer.Endpoint {
return portainer.Endpoint{ID: id, GroupID: 1, Type: portainer.DockerEnvironment} return portainer.Endpoint{ID: id, GroupID: 1, Type: portainer.DockerEnvironment}
}) })

View file

@ -4,7 +4,8 @@ import (
"testing" "testing"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/internal/slices" "github.com/portainer/portainer/api/slicesx"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -162,7 +163,7 @@ func TestSortEndpointsByField(t *testing.T) {
} }
func getEndpointIDs(environments []portainer.Endpoint) []portainer.EndpointID { func getEndpointIDs(environments []portainer.Endpoint) []portainer.EndpointID {
return slices.Map(environments, func(environment portainer.Endpoint) portainer.EndpointID { return slicesx.Map(environments, func(environment portainer.Endpoint) portainer.EndpointID {
return environment.ID return environment.ID
}) })
} }

View file

@ -6,7 +6,7 @@ import (
"github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/edge" "github.com/portainer/portainer/api/internal/edge"
"github.com/portainer/portainer/api/internal/endpointutils" "github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/api/internal/set" "github.com/portainer/portainer/api/set"
) )
// updateEdgeRelations updates the edge stacks associated to an edge endpoint // updateEdgeRelations updates the edge stacks associated to an edge endpoint

View file

@ -6,7 +6,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/set" "github.com/portainer/portainer/api/set"
) )
func updateEnvironmentEdgeGroups(tx dataservices.DataStoreTx, newEdgeGroups []portainer.EdgeGroupID, environmentID portainer.EndpointID) (bool, error) { func updateEnvironmentEdgeGroups(tx dataservices.DataStoreTx, newEdgeGroups []portainer.EdgeGroupID, environmentID portainer.EndpointID) (bool, error) {

View file

@ -10,7 +10,6 @@ import (
) )
func Test_updateEdgeGroups(t *testing.T) { func Test_updateEdgeGroups(t *testing.T) {
createGroups := func(store *datastore.Store, names []string) ([]portainer.EdgeGroup, error) { createGroups := func(store *datastore.Store, names []string) ([]portainer.EdgeGroup, error) {
groups := make([]portainer.EdgeGroup, len(names)) groups := make([]portainer.EdgeGroup, len(names))
for index, name := range names { for index, name := range names {
@ -21,8 +20,7 @@ func Test_updateEdgeGroups(t *testing.T) {
Endpoints: make([]portainer.EndpointID, 0), Endpoints: make([]portainer.EndpointID, 0),
} }
err := store.EdgeGroup().Create(group) if err := store.EdgeGroup().Create(group); err != nil {
if err != nil {
return nil, err return nil, err
} }
@ -42,6 +40,7 @@ func Test_updateEdgeGroups(t *testing.T) {
return return
} }
} }
is.Fail("expected endpoint to be in group") is.Fail("expected endpoint to be in group")
} }
} }
@ -52,6 +51,7 @@ func Test_updateEdgeGroups(t *testing.T) {
for j, tag := range groups { for j, tag := range groups {
if tag.Name == tagName { if tag.Name == tagName {
result[i] = groups[j] result[i] = groups[j]
break break
} }
} }
@ -88,6 +88,7 @@ func Test_updateEdgeGroups(t *testing.T) {
} }
expectedGroups := groupsByName(groups, testCase.groupsToApply) expectedGroups := groupsByName(groups, testCase.groupsToApply)
expectedIDs := make([]portainer.EdgeGroupID, len(expectedGroups)) expectedIDs := make([]portainer.EdgeGroupID, len(expectedGroups))
for i, tag := range expectedGroups { for i, tag := range expectedGroups {
expectedIDs[i] = tag.ID expectedIDs[i] = tag.ID

View file

@ -4,7 +4,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/set" "github.com/portainer/portainer/api/set"
) )
// updateEnvironmentTags updates the tags associated to an environment // updateEnvironmentTags updates the tags associated to an environment

View file

@ -10,14 +10,14 @@ import (
"github.com/portainer/portainer/api/datastore" "github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/exec/exectest" "github.com/portainer/portainer/api/exec/exectest"
"github.com/portainer/portainer/api/http/security" "github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/testhelpers"
helper "github.com/portainer/portainer/api/internal/testhelpers"
"github.com/portainer/portainer/api/jwt" "github.com/portainer/portainer/api/jwt"
"github.com/portainer/portainer/api/kubernetes" "github.com/portainer/portainer/api/kubernetes"
"github.com/portainer/portainer/pkg/libhelm/binary/test" "github.com/portainer/portainer/pkg/libhelm/binary/test"
"github.com/portainer/portainer/pkg/libhelm/options" "github.com/portainer/portainer/pkg/libhelm/options"
"github.com/stretchr/testify/assert"
"github.com/portainer/portainer/api/internal/testhelpers" "github.com/stretchr/testify/assert"
helper "github.com/portainer/portainer/api/internal/testhelpers"
) )
func Test_helmDelete(t *testing.T) { func Test_helmDelete(t *testing.T) {

View file

@ -97,13 +97,13 @@ func (handler *Handler) userHasRegistryAccess(r *http.Request) (hasAccess bool,
if err != nil { if err != nil {
return false, false, err return false, false, err
} }
endpoint, err := handler.DataStore.Endpoint().Endpoint(portainer.EndpointID(endpointID)) endpoint, err := handler.DataStore.Endpoint().Endpoint(portainer.EndpointID(endpointID))
if err != nil { if err != nil {
return false, false, err return false, false, err
} }
err = handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint) if err := handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint); err != nil {
if err != nil {
return false, false, err return false, false, err
} }

View file

@ -71,6 +71,7 @@ func (handler *Handler) settingsPublic(w http.ResponseWriter, r *http.Request) *
} }
publicSettings := generatePublicSettings(settings) publicSettings := generatePublicSettings(settings)
return response.JSON(w, publicSettings) return response.JSON(w, publicSettings)
} }
@ -96,7 +97,7 @@ func generatePublicSettings(appSettings *portainer.Settings) *publicSettingsResp
publicSettings.IsDockerDesktopExtension = appSettings.IsDockerDesktopExtension publicSettings.IsDockerDesktopExtension = appSettings.IsDockerDesktopExtension
//if OAuth authentication is on, compose the related fields from application settings // If OAuth authentication is on, compose the related fields from application settings
if publicSettings.AuthenticationMethod == portainer.AuthenticationOAuth { if publicSettings.AuthenticationMethod == portainer.AuthenticationOAuth {
publicSettings.OAuthLogoutURI = appSettings.OAuthSettings.LogoutURI publicSettings.OAuthLogoutURI = appSettings.OAuthSettings.LogoutURI
publicSettings.OAuthLoginURI = fmt.Sprintf("%s?response_type=code&client_id=%s&redirect_uri=%s&scope=%s", publicSettings.OAuthLoginURI = fmt.Sprintf("%s?response_type=code&client_id=%s&redirect_uri=%s&scope=%s",
@ -104,16 +105,18 @@ func generatePublicSettings(appSettings *portainer.Settings) *publicSettingsResp
appSettings.OAuthSettings.ClientID, appSettings.OAuthSettings.ClientID,
appSettings.OAuthSettings.RedirectURI, appSettings.OAuthSettings.RedirectURI,
appSettings.OAuthSettings.Scopes) appSettings.OAuthSettings.Scopes)
//control prompt=login param according to the SSO setting
// Control prompt=login param according to the SSO setting
if !appSettings.OAuthSettings.SSO { if !appSettings.OAuthSettings.SSO {
publicSettings.OAuthLoginURI += "&prompt=login" publicSettings.OAuthLoginURI += "&prompt=login"
} }
} }
//if LDAP authentication is on, compose the related fields from application settings // If LDAP authentication is on, compose the related fields from application settings
if publicSettings.AuthenticationMethod == portainer.AuthenticationLDAP && appSettings.LDAPSettings.GroupSearchSettings != nil { if publicSettings.AuthenticationMethod == portainer.AuthenticationLDAP && appSettings.LDAPSettings.GroupSearchSettings != nil {
if len(appSettings.LDAPSettings.GroupSearchSettings) > 0 { if len(appSettings.LDAPSettings.GroupSearchSettings) > 0 {
publicSettings.TeamSync = len(appSettings.LDAPSettings.GroupSearchSettings[0].GroupBaseDN) > 0 publicSettings.TeamSync = len(appSettings.LDAPSettings.GroupSearchSettings[0].GroupBaseDN) > 0
} }
} }
return publicSettings return publicSettings
} }

View file

@ -40,14 +40,17 @@ func setup() {
func TestGeneratePublicSettingsWithSSO(t *testing.T) { func TestGeneratePublicSettingsWithSSO(t *testing.T) {
setup() setup()
mockAppSettings.OAuthSettings.SSO = true mockAppSettings.OAuthSettings.SSO = true
publicSettings := generatePublicSettings(mockAppSettings) publicSettings := generatePublicSettings(mockAppSettings)
if publicSettings.AuthenticationMethod != portainer.AuthenticationOAuth { if publicSettings.AuthenticationMethod != portainer.AuthenticationOAuth {
t.Errorf("wrong AuthenticationMethod, want: %d, got: %d", portainer.AuthenticationOAuth, publicSettings.AuthenticationMethod) t.Errorf("wrong AuthenticationMethod, want: %d, got: %d", portainer.AuthenticationOAuth, publicSettings.AuthenticationMethod)
} }
if publicSettings.OAuthLoginURI != dummyOAuthLoginURI { if publicSettings.OAuthLoginURI != dummyOAuthLoginURI {
t.Errorf("wrong OAuthLoginURI when SSO is switched on, want: %s, got: %s", dummyOAuthLoginURI, publicSettings.OAuthLoginURI) t.Errorf("wrong OAuthLoginURI when SSO is switched on, want: %s, got: %s", dummyOAuthLoginURI, publicSettings.OAuthLoginURI)
} }
if publicSettings.OAuthLogoutURI != dummyOAuthLogoutURI { if publicSettings.OAuthLogoutURI != dummyOAuthLogoutURI {
t.Errorf("wrong OAuthLogoutURI, want: %s, got: %s", dummyOAuthLogoutURI, publicSettings.OAuthLogoutURI) t.Errorf("wrong OAuthLogoutURI, want: %s, got: %s", dummyOAuthLogoutURI, publicSettings.OAuthLogoutURI)
} }
@ -55,15 +58,18 @@ func TestGeneratePublicSettingsWithSSO(t *testing.T) {
func TestGeneratePublicSettingsWithoutSSO(t *testing.T) { func TestGeneratePublicSettingsWithoutSSO(t *testing.T) {
setup() setup()
mockAppSettings.OAuthSettings.SSO = false mockAppSettings.OAuthSettings.SSO = false
publicSettings := generatePublicSettings(mockAppSettings) publicSettings := generatePublicSettings(mockAppSettings)
if publicSettings.AuthenticationMethod != portainer.AuthenticationOAuth { if publicSettings.AuthenticationMethod != portainer.AuthenticationOAuth {
t.Errorf("wrong AuthenticationMethod, want: %d, got: %d", portainer.AuthenticationOAuth, publicSettings.AuthenticationMethod) t.Errorf("wrong AuthenticationMethod, want: %d, got: %d", portainer.AuthenticationOAuth, publicSettings.AuthenticationMethod)
} }
expectedOAuthLoginURI := dummyOAuthLoginURI + "&prompt=login" expectedOAuthLoginURI := dummyOAuthLoginURI + "&prompt=login"
if publicSettings.OAuthLoginURI != expectedOAuthLoginURI { if publicSettings.OAuthLoginURI != expectedOAuthLoginURI {
t.Errorf("wrong OAuthLoginURI when SSO is switched off, want: %s, got: %s", expectedOAuthLoginURI, publicSettings.OAuthLoginURI) t.Errorf("wrong OAuthLoginURI when SSO is switched off, want: %s, got: %s", expectedOAuthLoginURI, publicSettings.OAuthLoginURI)
} }
if publicSettings.OAuthLogoutURI != dummyOAuthLogoutURI { if publicSettings.OAuthLogoutURI != dummyOAuthLogoutURI {
t.Errorf("wrong OAuthLogoutURI, want: %s, got: %s", dummyOAuthLogoutURI, publicSettings.OAuthLogoutURI) t.Errorf("wrong OAuthLogoutURI, want: %s, got: %s", dummyOAuthLogoutURI, publicSettings.OAuthLogoutURI)
} }

View file

@ -89,8 +89,7 @@ func (handler *Handler) stackDelete(w http.ResponseWriter, r *http.Request) *htt
} }
if !isOrphaned { if !isOrphaned {
err = handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint) if err := handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint); err != nil {
if err != nil {
return httperror.Forbidden("Permission denied to access endpoint", err) return httperror.Forbidden("Permission denied to access endpoint", err)
} }
@ -119,25 +118,21 @@ func (handler *Handler) stackDelete(w http.ResponseWriter, r *http.Request) *htt
deployments.StopAutoupdate(stack.ID, stack.AutoUpdate.JobID, handler.Scheduler) deployments.StopAutoupdate(stack.ID, stack.AutoUpdate.JobID, handler.Scheduler)
} }
err = handler.deleteStack(securityContext.UserID, stack, endpoint) if err := handler.deleteStack(securityContext.UserID, stack, endpoint); err != nil {
if err != nil {
return httperror.InternalServerError(err.Error(), err) return httperror.InternalServerError(err.Error(), err)
} }
err = handler.DataStore.Stack().Delete(portainer.StackID(id)) if err := handler.DataStore.Stack().Delete(portainer.StackID(id)); err != nil {
if err != nil {
return httperror.InternalServerError("Unable to remove the stack from the database", err) return httperror.InternalServerError("Unable to remove the stack from the database", err)
} }
if resourceControl != nil { if resourceControl != nil {
err = handler.DataStore.ResourceControl().Delete(resourceControl.ID) if err := handler.DataStore.ResourceControl().Delete(resourceControl.ID); err != nil {
if err != nil {
return httperror.InternalServerError("Unable to remove the associated resource control from the database", err) return httperror.InternalServerError("Unable to remove the associated resource control from the database", err)
} }
} }
err = handler.FileService.RemoveDirectory(stack.ProjectPath) if err := handler.FileService.RemoveDirectory(stack.ProjectPath); err != nil {
if err != nil {
log.Warn().Err(err).Msg("Unable to remove stack files from disk") log.Warn().Err(err).Msg("Unable to remove stack files from disk")
} }
@ -169,8 +164,7 @@ func (handler *Handler) deleteExternalStack(r *http.Request, w http.ResponseWrit
return httperror.InternalServerError("Unable to find the endpoint associated to the stack inside the database", err) return httperror.InternalServerError("Unable to find the endpoint associated to the stack inside the database", err)
} }
err = handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint) if err := handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint); err != nil {
if err != nil {
return httperror.Forbidden("Permission denied to access endpoint", err) return httperror.Forbidden("Permission denied to access endpoint", err)
} }
@ -179,8 +173,7 @@ func (handler *Handler) deleteExternalStack(r *http.Request, w http.ResponseWrit
Type: portainer.DockerSwarmStack, Type: portainer.DockerSwarmStack,
} }
err = handler.deleteStack(securityContext.UserID, stack, endpoint) if err := handler.deleteStack(securityContext.UserID, stack, endpoint); err != nil {
if err != nil {
return httperror.InternalServerError("Unable to delete stack", err) return httperror.InternalServerError("Unable to delete stack", err)
} }
@ -255,6 +248,7 @@ func (handler *Handler) deleteStack(userID portainer.UserID, stack *portainer.St
} }
} }
} }
return errors.WithMessagef(err, "failed to remove kubernetes resources: %q", out) return errors.WithMessagef(err, "failed to remove kubernetes resources: %q", out)
} }
@ -369,18 +363,18 @@ func (handler *Handler) stackDeleteKubernetesByName(w http.ResponseWriter, r *ht
if err != nil { if err != nil {
log.Err(err).Msgf("Unable to delete Kubernetes stack `%d`", stack.ID) log.Err(err).Msgf("Unable to delete Kubernetes stack `%d`", stack.ID)
errors = append(errors, err) errors = append(errors, err)
continue continue
} }
err = handler.DataStore.Stack().Delete(stack.ID) if err := handler.DataStore.Stack().Delete(stack.ID); err != nil {
if err != nil {
errors = append(errors, err) errors = append(errors, err)
log.Err(err).Msgf("Unable to remove the stack `%d` from the database", stack.ID) log.Err(err).Msgf("Unable to remove the stack `%d` from the database", stack.ID)
continue continue
} }
err = handler.FileService.RemoveDirectory(stack.ProjectPath) if err := handler.FileService.RemoveDirectory(stack.ProjectPath); err != nil {
if err != nil {
errors = append(errors, err) errors = append(errors, err)
log.Warn().Err(err).Msg("Unable to remove stack files from disk") log.Warn().Err(err).Msg("Unable to remove stack files from disk")
} }

View file

@ -18,8 +18,7 @@ func TestTagDeleteEdgeGroupsConcurrently(t *testing.T) {
_, store := datastore.MustNewTestStore(t, true, false) _, store := datastore.MustNewTestStore(t, true, false)
user := &portainer.User{ID: 2, Username: "admin", Role: portainer.AdministratorRole} user := &portainer.User{ID: 2, Username: "admin", Role: portainer.AdministratorRole}
err := store.User().Create(user) if err := store.User().Create(user); err != nil {
if err != nil {
t.Fatal("could not create admin user:", err) t.Fatal("could not create admin user:", err)
} }
@ -33,29 +32,28 @@ func TestTagDeleteEdgeGroupsConcurrently(t *testing.T) {
for i := 0; i < tagsCount; i++ { for i := 0; i < tagsCount; i++ {
tagID := portainer.TagID(i) + 1 tagID := portainer.TagID(i) + 1
err = store.Tag().Create(&portainer.Tag{ if err := store.Tag().Create(&portainer.Tag{
ID: tagID, ID: tagID,
Name: "tag-" + strconv.Itoa(int(tagID)), Name: "tag-" + strconv.Itoa(int(tagID)),
}) }); err != nil {
if err != nil {
t.Fatal("could not create tag:", err) t.Fatal("could not create tag:", err)
} }
tagIDs = append(tagIDs, tagID) tagIDs = append(tagIDs, tagID)
} }
err = store.EdgeGroup().Create(&portainer.EdgeGroup{ if err := store.EdgeGroup().Create(&portainer.EdgeGroup{
ID: 1, ID: 1,
Name: "edgegroup-1", Name: "edgegroup-1",
TagIDs: tagIDs, TagIDs: tagIDs,
}) }); err != nil {
if err != nil {
t.Fatal("could not create edge group:", err) t.Fatal("could not create edge group:", err)
} }
// Remove the tags concurrently // Remove the tags concurrently
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(len(tagIDs)) wg.Add(len(tagIDs))
for _, tagID := range tagIDs { for _, tagID := range tagIDs {

View file

@ -27,6 +27,7 @@ func (payload *userCreatePayload) Validate(r *http.Request) error {
if payload.Role != 1 && payload.Role != 2 { if payload.Role != 1 && payload.Role != 2 {
return errors.New("Invalid role value. Value must be one of: 1 (administrator) or 2 (regular user)") return errors.New("Invalid role value. Value must be one of: 1 (administrator) or 2 (regular user)")
} }
return nil return nil
} }
@ -49,8 +50,7 @@ func (payload *userCreatePayload) Validate(r *http.Request) error {
// @router /users [post] // @router /users [post]
func (handler *Handler) userCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError { func (handler *Handler) userCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
var payload userCreatePayload var payload userCreatePayload
err := request.DecodeAndValidateJSONPayload(r, &payload) if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil {
if err != nil {
return httperror.BadRequest("Invalid request payload", err) return httperror.BadRequest("Invalid request payload", err)
} }
@ -89,11 +89,11 @@ func (handler *Handler) userCreate(w http.ResponseWriter, r *http.Request) *http
} }
} }
err = handler.DataStore.User().Create(user) if err := handler.DataStore.User().Create(user); err != nil {
if err != nil {
return httperror.InternalServerError("Unable to persist user inside the database", err) return httperror.InternalServerError("Unable to persist user inside the database", err)
} }
hideFields(user) hideFields(user)
return response.JSON(w, user) return response.JSON(w, user)
} }

View file

@ -26,12 +26,12 @@ func Test_userList(t *testing.T) {
_, store := datastore.MustNewTestStore(t, true, true) _, store := datastore.MustNewTestStore(t, true, true)
// create admin and standard user(s) // Create admin and standard user(s)
adminUser := &portainer.User{ID: 1, Username: "admin", Role: portainer.AdministratorRole} adminUser := &portainer.User{ID: 1, Username: "admin", Role: portainer.AdministratorRole}
err := store.User().Create(adminUser) err := store.User().Create(adminUser)
is.NoError(err, "error creating admin user") is.NoError(err, "error creating admin user")
// setup services // Setup services
jwtService, err := jwt.NewService("1h", store) jwtService, err := jwt.NewService("1h", store)
is.NoError(err, "Error initiating jwt service") is.NoError(err, "Error initiating jwt service")
apiKeyService := apikey.NewAPIKeyService(store.APIKeyRepository(), store.User()) apiKeyService := apikey.NewAPIKeyService(store.APIKeyRepository(), store.User())
@ -42,7 +42,7 @@ func Test_userList(t *testing.T) {
h := NewHandler(requestBouncer, rateLimiter, apiKeyService, passwordChecker) h := NewHandler(requestBouncer, rateLimiter, apiKeyService, passwordChecker)
h.DataStore = store h.DataStore = store
// generate admin user tokens // Generate admin user tokens
adminJWT, _, _ := jwtService.GenerateToken(&portainer.TokenData{ID: adminUser.ID, Username: adminUser.Username, Role: adminUser.Role}) adminJWT, _, _ := jwtService.GenerateToken(&portainer.TokenData{ID: adminUser.ID, Username: adminUser.Username, Role: adminUser.Role})
// Case 1: the user is given the endpoint access directly // Case 1: the user is given the endpoint access directly
@ -54,12 +54,12 @@ func Test_userList(t *testing.T) {
err = store.User().Create(userWithoutEndpointAccess) err = store.User().Create(userWithoutEndpointAccess)
is.NoError(err, "error creating user") is.NoError(err, "error creating user")
// create environment group // Create environment group
endpointGroup := &portainer.EndpointGroup{ID: 1, Name: "default-endpoint-group"} endpointGroup := &portainer.EndpointGroup{ID: 1, Name: "default-endpoint-group"}
err = store.EndpointGroup().Create(endpointGroup) err = store.EndpointGroup().Create(endpointGroup)
is.NoError(err, "error creating endpoint group") is.NoError(err, "error creating endpoint group")
// create endpoint and user access policies // Create endpoint and user access policies
userAccessPolicies := make(portainer.UserAccessPolicies, 0) userAccessPolicies := make(portainer.UserAccessPolicies, 0)
userAccessPolicies[userWithEndpointAccess.ID] = portainer.AccessPolicy{RoleID: portainer.RoleID(userWithEndpointAccess.Role)} userAccessPolicies[userWithEndpointAccess.ID] = portainer.AccessPolicy{RoleID: portainer.RoleID(userWithEndpointAccess.Role)}
@ -129,7 +129,7 @@ func Test_userList(t *testing.T) {
err = store.User().Create(userUnderGroup) err = store.User().Create(userUnderGroup)
is.NoError(err, "error creating user") is.NoError(err, "error creating user")
// create environment group including a user // Create environment group including a user
userAccessPoliciesUnderGroup := make(portainer.UserAccessPolicies, 0) userAccessPoliciesUnderGroup := make(portainer.UserAccessPolicies, 0)
userAccessPoliciesUnderGroup[userUnderGroup.ID] = portainer.AccessPolicy{RoleID: portainer.RoleID(userUnderGroup.Role)} userAccessPoliciesUnderGroup[userUnderGroup.ID] = portainer.AccessPolicy{RoleID: portainer.RoleID(userUnderGroup.Role)}
@ -137,7 +137,7 @@ func Test_userList(t *testing.T) {
err = store.EndpointGroup().Create(endpointGroupWithUser) err = store.EndpointGroup().Create(endpointGroupWithUser)
is.NoError(err, "error creating endpoint group") is.NoError(err, "error creating endpoint group")
// create endpoint // Create endpoint
endpointUnderGroupWithUser := &portainer.Endpoint{ID: 2, GroupID: endpointGroupWithUser.ID} endpointUnderGroupWithUser := &portainer.Endpoint{ID: 2, GroupID: endpointGroupWithUser.ID}
err = store.Endpoint().Create(endpointUnderGroupWithUser) err = store.Endpoint().Create(endpointUnderGroupWithUser)
is.NoError(err, "error creating endpoint") is.NoError(err, "error creating endpoint")
@ -182,7 +182,7 @@ func Test_userList(t *testing.T) {
err = store.TeamMembership().Create(teamMembership) err = store.TeamMembership().Create(teamMembership)
is.NoError(err, "error creating team membership") is.NoError(err, "error creating team membership")
// create environment group including a team // Create environment group including a team
teamAccessPoliciesUnderGroup := make(portainer.TeamAccessPolicies, 0) teamAccessPoliciesUnderGroup := make(portainer.TeamAccessPolicies, 0)
teamAccessPoliciesUnderGroup[teamUnderGroup.ID] = portainer.AccessPolicy{RoleID: portainer.RoleID(userUnderTeam.Role)} teamAccessPoliciesUnderGroup[teamUnderGroup.ID] = portainer.AccessPolicy{RoleID: portainer.RoleID(userUnderTeam.Role)}
@ -190,7 +190,7 @@ func Test_userList(t *testing.T) {
err = store.EndpointGroup().Create(endpointGroupWithTeam) err = store.EndpointGroup().Create(endpointGroupWithTeam)
is.NoError(err, "error creating endpoint group") is.NoError(err, "error creating endpoint group")
// create endpoint // Create endpoint
endpointUnderGroupWithTeam := &portainer.Endpoint{ID: 3, GroupID: endpointGroupWithTeam.ID} endpointUnderGroupWithTeam := &portainer.Endpoint{ID: 3, GroupID: endpointGroupWithTeam.ID}
err = store.Endpoint().Create(endpointUnderGroupWithTeam) err = store.Endpoint().Create(endpointUnderGroupWithTeam)
is.NoError(err, "error creating endpoint") is.NoError(err, "error creating endpoint")
@ -233,12 +233,12 @@ func Test_userList(t *testing.T) {
err = store.TeamMembership().Create(teamMembershipWithEndpointAccess) err = store.TeamMembership().Create(teamMembershipWithEndpointAccess)
is.NoError(err, "error creating team membership") is.NoError(err, "error creating team membership")
// create environment group // Create environment group
endpointGroupWithoutTeam := &portainer.EndpointGroup{ID: 4, Name: "endpoint-group-without-team"} endpointGroupWithoutTeam := &portainer.EndpointGroup{ID: 4, Name: "endpoint-group-without-team"}
err = store.EndpointGroup().Create(endpointGroupWithoutTeam) err = store.EndpointGroup().Create(endpointGroupWithoutTeam)
is.NoError(err, "error creating endpoint group") is.NoError(err, "error creating endpoint group")
// create endpoint and team access policies // Create endpoint and team access policies
teamAccessPolicies := make(portainer.TeamAccessPolicies, 0) teamAccessPolicies := make(portainer.TeamAccessPolicies, 0)
teamAccessPolicies[teamWithEndpointAccess.ID] = portainer.AccessPolicy{RoleID: portainer.RoleID(userUnderTeamWithEndpointAccess.Role)} teamAccessPolicies[teamWithEndpointAccess.ID] = portainer.AccessPolicy{RoleID: portainer.RoleID(userUnderTeamWithEndpointAccess.Role)}

View file

@ -19,12 +19,12 @@ func Test_updateUserRemovesAccessTokens(t *testing.T) {
_, store := datastore.MustNewTestStore(t, true, true) _, store := datastore.MustNewTestStore(t, true, true)
// create standard user // Create standard user
user := &portainer.User{ID: 2, Username: "standard", Role: portainer.StandardUserRole} user := &portainer.User{ID: 2, Username: "standard", Role: portainer.StandardUserRole}
err := store.User().Create(user) err := store.User().Create(user)
is.NoError(err, "error creating user") is.NoError(err, "error creating user")
// setup services // Setup services
jwtService, err := jwt.NewService("1h", store) jwtService, err := jwt.NewService("1h", store)
is.NoError(err, "Error initiating jwt service") is.NoError(err, "Error initiating jwt service")
apiKeyService := apikey.NewAPIKeyService(store.APIKeyRepository(), store.User()) apiKeyService := apikey.NewAPIKeyService(store.APIKeyRepository(), store.User())

View file

@ -8,12 +8,12 @@ import (
"net/url" "net/url"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/crypto"
"github.com/portainer/portainer/api/http/security" "github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/logoutcontext" "github.com/portainer/portainer/api/logoutcontext"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/koding/websocketproxy" "github.com/koding/websocketproxy"
"github.com/portainer/portainer/api/crypto"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )

View file

@ -9,7 +9,7 @@ import (
"github.com/portainer/portainer/api/crypto" "github.com/portainer/portainer/api/crypto"
"github.com/portainer/portainer/api/http/proxy/factory/agent" "github.com/portainer/portainer/api/http/proxy/factory/agent"
"github.com/portainer/portainer/api/internal/endpointutils" "github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/api/internal/url" "github.com/portainer/portainer/api/url"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"

View file

@ -54,6 +54,7 @@ func decorateObject(object map[string]interface{}, resourceControl *portainer.Re
portainerMetadata := object["Portainer"].(map[string]interface{}) portainerMetadata := object["Portainer"].(map[string]interface{})
portainerMetadata["ResourceControl"] = resourceControl portainerMetadata["ResourceControl"] = resourceControl
return object return object
} }
@ -64,8 +65,7 @@ func (transport *Transport) createPrivateResourceControl(
resourceControl := authorization.NewPrivateResourceControl(resourceIdentifier, resourceType, userID) resourceControl := authorization.NewPrivateResourceControl(resourceIdentifier, resourceType, userID)
err := transport.dataStore.ResourceControl().Create(resourceControl) if err := transport.dataStore.ResourceControl().Create(resourceControl); err != nil {
if err != nil {
log.Error(). log.Error().
Str("resource", resourceIdentifier). Str("resource", resourceIdentifier).
Err(err). Err(err).
@ -84,6 +84,7 @@ func (transport *Transport) userCanDeleteContainerGroup(request *http.Request, c
resourceIdentifier := request.URL.Path resourceIdentifier := request.URL.Path
resourceControl := transport.findResourceControl(resourceIdentifier, context) resourceControl := transport.findResourceControl(resourceIdentifier, context)
return authorization.UserCanAccessResource(context.userID, context.userTeamIDs, resourceControl) return authorization.UserCanAccessResource(context.userID, context.userTeamIDs, resourceControl)
} }
@ -136,20 +137,19 @@ func (transport *Transport) filterContainerGroups(containerGroups []interface{},
func (transport *Transport) removeResourceControl(containerGroup map[string]interface{}, context *azureRequestContext) error { func (transport *Transport) removeResourceControl(containerGroup map[string]interface{}, context *azureRequestContext) error {
containerGroupID, ok := containerGroup["id"].(string) containerGroupID, ok := containerGroup["id"].(string)
if ok { if !ok {
resourceControl := transport.findResourceControl(containerGroupID, context)
if resourceControl != nil {
err := transport.dataStore.ResourceControl().Delete(resourceControl.ID)
return err
}
} else {
log.Debug().Msg("missing ID in container group") log.Debug().Msg("missing ID in container group")
return nil
}
if resourceControl := transport.findResourceControl(containerGroupID, context); resourceControl != nil {
return transport.dataStore.ResourceControl().Delete(resourceControl.ID)
} }
return nil return nil
} }
func (transport *Transport) findResourceControl(containerGroupId string, context *azureRequestContext) *portainer.ResourceControl { func (transport *Transport) findResourceControl(containerGroupId string, context *azureRequestContext) *portainer.ResourceControl {
resourceControl := authorization.GetResourceControlByResourceIDAndType(containerGroupId, portainer.ContainerGroupResourceControl, context.resourceControls) return authorization.GetResourceControlByResourceIDAndType(containerGroupId, portainer.ContainerGroupResourceControl, context.resourceControls)
return resourceControl
} }

View file

@ -8,7 +8,7 @@ import (
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/crypto" "github.com/portainer/portainer/api/crypto"
"github.com/portainer/portainer/api/http/proxy/factory/docker" "github.com/portainer/portainer/api/http/proxy/factory/docker"
"github.com/portainer/portainer/api/internal/url" "github.com/portainer/portainer/api/url"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"

View file

@ -105,8 +105,7 @@ func (transport *Transport) newResourceControlFromPortainerLabels(labelsObject m
resourceControl := authorization.NewRestrictedResourceControl(resourceID, resourceType, userIDs, teamIDs) resourceControl := authorization.NewRestrictedResourceControl(resourceID, resourceType, userIDs, teamIDs)
err := transport.dataStore.ResourceControl().Create(resourceControl) if err := transport.dataStore.ResourceControl().Create(resourceControl); err != nil {
if err != nil {
return nil, err return nil, err
} }
@ -119,8 +118,7 @@ func (transport *Transport) newResourceControlFromPortainerLabels(labelsObject m
func (transport *Transport) createPrivateResourceControl(resourceIdentifier string, resourceType portainer.ResourceControlType, userID portainer.UserID) (*portainer.ResourceControl, error) { func (transport *Transport) createPrivateResourceControl(resourceIdentifier string, resourceType portainer.ResourceControlType, userID portainer.UserID) (*portainer.ResourceControl, error) {
resourceControl := authorization.NewPrivateResourceControl(resourceIdentifier, resourceType, userID) resourceControl := authorization.NewPrivateResourceControl(resourceIdentifier, resourceType, userID)
err := transport.dataStore.ResourceControl().Create(resourceControl) if err := transport.dataStore.ResourceControl().Create(resourceControl); err != nil {
if err != nil {
log.Error(). log.Error().
Str("resource", resourceIdentifier). Str("resource", resourceIdentifier).
Err(err). Err(err).
@ -170,6 +168,7 @@ func (transport *Transport) applyAccessControlOnResource(parameters *resourceOpe
systemResourceControl := findSystemNetworkResourceControl(responseObject) systemResourceControl := findSystemNetworkResourceControl(responseObject)
if systemResourceControl != nil { if systemResourceControl != nil {
responseObject = decorateObject(responseObject, systemResourceControl) responseObject = decorateObject(responseObject, systemResourceControl)
return utils.RewriteResponse(response, responseObject, http.StatusOK) return utils.RewriteResponse(response, responseObject, http.StatusOK)
} }
} }
@ -188,6 +187,7 @@ func (transport *Transport) applyAccessControlOnResource(parameters *resourceOpe
if executor.operationContext.isAdmin || (resourceControl != nil && authorization.UserCanAccessResource(executor.operationContext.userID, executor.operationContext.userTeamIDs, resourceControl)) { if executor.operationContext.isAdmin || (resourceControl != nil && authorization.UserCanAccessResource(executor.operationContext.userID, executor.operationContext.userTeamIDs, resourceControl)) {
responseObject = decorateObject(responseObject, resourceControl) responseObject = decorateObject(responseObject, resourceControl)
return utils.RewriteResponse(response, responseObject, http.StatusOK) return utils.RewriteResponse(response, responseObject, http.StatusOK)
} }
@ -221,6 +221,7 @@ func (transport *Transport) decorateResourceList(parameters *resourceOperationPa
if systemResourceControl != nil { if systemResourceControl != nil {
resourceObject = decorateObject(resourceObject, systemResourceControl) resourceObject = decorateObject(resourceObject, systemResourceControl)
decoratedResourceData = append(decoratedResourceData, resourceObject) decoratedResourceData = append(decoratedResourceData, resourceObject)
continue continue
} }
} }
@ -264,6 +265,7 @@ func (transport *Transport) filterResourceList(parameters *resourceOperationPara
if systemResourceControl != nil { if systemResourceControl != nil {
resourceObject = decorateObject(resourceObject, systemResourceControl) resourceObject = decorateObject(resourceObject, systemResourceControl)
filteredResourceData = append(filteredResourceData, resourceObject) filteredResourceData = append(filteredResourceData, resourceObject)
continue continue
} }
} }
@ -277,6 +279,7 @@ func (transport *Transport) filterResourceList(parameters *resourceOperationPara
if context.isAdmin { if context.isAdmin {
filteredResourceData = append(filteredResourceData, resourceObject) filteredResourceData = append(filteredResourceData, resourceObject)
} }
continue continue
} }
@ -334,11 +337,13 @@ func (transport *Transport) findResourceControl(resourceIdentifier string, resou
func getStackResourceIDFromLabels(resourceLabelsObject map[string]string, endpointID portainer.EndpointID) string { func getStackResourceIDFromLabels(resourceLabelsObject map[string]string, endpointID portainer.EndpointID) string {
if resourceLabelsObject[resourceLabelForDockerSwarmStackName] != "" { if resourceLabelsObject[resourceLabelForDockerSwarmStackName] != "" {
stackName := resourceLabelsObject[resourceLabelForDockerSwarmStackName] stackName := resourceLabelsObject[resourceLabelForDockerSwarmStackName]
return stackutils.ResourceControlID(endpointID, stackName) return stackutils.ResourceControlID(endpointID, stackName)
} }
if resourceLabelsObject[resourceLabelForDockerComposeStackName] != "" { if resourceLabelsObject[resourceLabelForDockerComposeStackName] != "" {
stackName := resourceLabelsObject[resourceLabelForDockerComposeStackName] stackName := resourceLabelsObject[resourceLabelForDockerComposeStackName]
return stackutils.ResourceControlID(endpointID, stackName) return stackutils.ResourceControlID(endpointID, stackName)
} }
@ -352,5 +357,6 @@ func decorateObject(object map[string]interface{}, resourceControl *portainer.Re
portainerMetadata := object["Portainer"].(map[string]interface{}) portainerMetadata := object["Portainer"].(map[string]interface{})
portainerMetadata["ResourceControl"] = resourceControl portainerMetadata["ResourceControl"] = resourceControl
return object return object
} }

View file

@ -11,9 +11,7 @@ import (
"github.com/portainer/portainer/api/internal/authorization" "github.com/portainer/portainer/api/internal/authorization"
) )
const ( const configObjectIdentifier = "ID"
configObjectIdentifier = "ID"
)
func getInheritedResourceControlFromConfigLabels(dockerClient *client.Client, endpointID portainer.EndpointID, configID string, resourceControls []portainer.ResourceControl) (*portainer.ResourceControl, error) { func getInheritedResourceControlFromConfigLabels(dockerClient *client.Client, endpointID portainer.EndpointID, configID string, resourceControls []portainer.ResourceControl) (*portainer.ResourceControl, error) {
config, _, err := dockerClient.ConfigInspectWithRaw(context.Background(), configID) config, _, err := dockerClient.ConfigInspectWithRaw(context.Background(), configID)
@ -78,10 +76,9 @@ func (transport *Transport) configInspectOperation(response *http.Response, exec
// https://docs.docker.com/engine/api/v1.37/#operation/ConfigList // https://docs.docker.com/engine/api/v1.37/#operation/ConfigList
// https://docs.docker.com/engine/api/v1.37/#operation/ConfigInspect // https://docs.docker.com/engine/api/v1.37/#operation/ConfigInspect
func selectorConfigLabels(responseObject map[string]interface{}) map[string]interface{} { func selectorConfigLabels(responseObject map[string]interface{}) map[string]interface{} {
secretSpec := utils.GetJSONObject(responseObject, "Spec") if secretSpec := utils.GetJSONObject(responseObject, "Spec"); secretSpec != nil {
if secretSpec != nil { return utils.GetJSONObject(secretSpec, "Labels")
secretLabelsObject := utils.GetJSONObject(secretSpec, "Labels")
return secretLabelsObject
} }
return nil return nil
} }

View file

@ -7,9 +7,7 @@ import (
"github.com/portainer/portainer/api/http/proxy/factory/utils" "github.com/portainer/portainer/api/http/proxy/factory/utils"
) )
const ( const taskServiceObjectIdentifier = "ServiceID"
taskServiceObjectIdentifier = "ServiceID"
)
// taskListOperation extracts the response as a JSON array, loop through the tasks array // taskListOperation extracts the response as a JSON array, loop through the tasks array
// and filter the containers based on resource controls before rewriting the response. // and filter the containers based on resource controls before rewriting the response.
@ -46,5 +44,6 @@ func selectorTaskLabels(responseObject map[string]interface{}) map[string]interf
return utils.GetJSONObject(containerSpecObject, "Labels") return utils.GetJSONObject(containerSpecObject, "Labels")
} }
} }
return nil return nil
} }

View file

@ -7,19 +7,17 @@ import (
"net/http" "net/http"
"path" "path"
"github.com/docker/docker/client"
"github.com/rs/zerolog/log"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/http/proxy/factory/utils" "github.com/portainer/portainer/api/http/proxy/factory/utils"
"github.com/portainer/portainer/api/http/security" "github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/authorization" "github.com/portainer/portainer/api/internal/authorization"
"github.com/portainer/portainer/api/internal/snapshot" "github.com/portainer/portainer/api/internal/snapshot"
"github.com/docker/docker/client"
"github.com/rs/zerolog/log"
) )
const ( const volumeObjectIdentifier = "ResourceID"
volumeObjectIdentifier = "ResourceID"
)
func getInheritedResourceControlFromVolumeLabels(dockerClient *client.Client, endpointID portainer.EndpointID, volumeID string, resourceControls []portainer.ResourceControl) (*portainer.ResourceControl, error) { func getInheritedResourceControlFromVolumeLabels(dockerClient *client.Client, endpointID portainer.EndpointID, volumeID string, resourceControls []portainer.ResourceControl) (*portainer.ResourceControl, error) {
volume, err := dockerClient.VolumeInspect(context.Background(), volumeID) volume, err := dockerClient.VolumeInspect(context.Background(), volumeID)
@ -57,14 +55,13 @@ func (transport *Transport) volumeListOperation(response *http.Response, executo
Msg("snapshot is not filled into the endpoint.") Msg("snapshot is not filled into the endpoint.")
} }
} }
for _, volumeObject := range volumeData { for _, volumeObject := range volumeData {
volume := volumeObject.(map[string]interface{}) volume := volumeObject.(map[string]interface{})
err = transport.decorateVolumeResponseWithResourceID(volume) if err := transport.decorateVolumeResponseWithResourceID(volume); err != nil {
if err != nil {
return fmt.Errorf("failed decorating volume response: %w", err) return fmt.Errorf("failed decorating volume response: %w", err)
} }
} }
resourceOperationParameters := &resourceOperationParameters{ resourceOperationParameters := &resourceOperationParameters{
@ -77,6 +74,7 @@ func (transport *Transport) volumeListOperation(response *http.Response, executo
if err != nil { if err != nil {
return err return err
} }
// Overwrite the original volume list // Overwrite the original volume list
responseObject["Volumes"] = volumeData responseObject["Volumes"] = volumeData
} }
@ -94,8 +92,7 @@ func (transport *Transport) volumeInspectOperation(response *http.Response, exec
return err return err
} }
err = transport.decorateVolumeResponseWithResourceID(responseObject) if err := transport.decorateVolumeResponseWithResourceID(responseObject); err != nil {
if err != nil {
return fmt.Errorf("failed decorating volume response: %w", err) return fmt.Errorf("failed decorating volume response: %w", err)
} }
@ -148,8 +145,7 @@ func (transport *Transport) decorateVolumeResourceCreationOperation(request *htt
} }
defer cli.Close() defer cli.Close()
_, err = cli.VolumeInspect(context.Background(), volumeID) if _, err = cli.VolumeInspect(context.Background(), volumeID); err == nil {
if err == nil {
return &http.Response{ return &http.Response{
StatusCode: http.StatusConflict, StatusCode: http.StatusConflict,
}, errors.New("a volume with the same name already exists") }, errors.New("a volume with the same name already exists")
@ -164,6 +160,7 @@ func (transport *Transport) decorateVolumeResourceCreationOperation(request *htt
if response.StatusCode == http.StatusCreated { if response.StatusCode == http.StatusCreated {
err = transport.decorateVolumeCreationResponse(response, resourceType, tokenData.ID) err = transport.decorateVolumeCreationResponse(response, resourceType, tokenData.ID)
} }
return response, err return response, err
} }
@ -195,7 +192,6 @@ func (transport *Transport) decorateVolumeCreationResponse(response *http.Respon
} }
func (transport *Transport) restrictedVolumeOperation(requestPath string, request *http.Request) (*http.Response, error) { func (transport *Transport) restrictedVolumeOperation(requestPath string, request *http.Request) (*http.Response, error) {
if request.Method == http.MethodGet { if request.Method == http.MethodGet {
return transport.rewriteOperation(request, transport.volumeInspectOperation) return transport.rewriteOperation(request, transport.volumeInspectOperation)
} }
@ -210,6 +206,7 @@ func (transport *Transport) restrictedVolumeOperation(requestPath string, reques
if request.Method == http.MethodDelete { if request.Method == http.MethodDelete {
return transport.executeGenericResourceDeletionOperation(request, resourceID, volumeName, portainer.VolumeResourceControl) return transport.executeGenericResourceDeletionOperation(request, resourceID, volumeName, portainer.VolumeResourceControl)
} }
return transport.restrictedResourceOperation(request, resourceID, volumeName, portainer.VolumeResourceControl, false) return transport.restrictedResourceOperation(request, resourceID, volumeName, portainer.VolumeResourceControl, false)
} }
@ -218,6 +215,7 @@ func (transport *Transport) getVolumeResourceID(volumeName string) (string, erro
if err != nil { if err != nil {
return "", fmt.Errorf("failed fetching docker id: %w", err) return "", fmt.Errorf("failed fetching docker id: %w", err)
} }
return fmt.Sprintf("%s_%s", volumeName, dockerID), nil return fmt.Sprintf("%s_%s", volumeName, dockerID), nil
} }

View file

@ -4,7 +4,7 @@ import (
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/endpointutils" "github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/api/internal/tag" "github.com/portainer/portainer/api/tag"
) )
// EdgeGroupRelatedEndpoints returns a list of environments(endpoints) related to this Edge group // EdgeGroupRelatedEndpoints returns a list of environments(endpoints) related to this Edge group

View file

@ -37,12 +37,12 @@ func (service *Service) BuildEdgeStack(
registries []portainer.RegistryID, registries []portainer.RegistryID,
useManifestNamespaces bool, useManifestNamespaces bool,
) (*portainer.EdgeStack, error) { ) (*portainer.EdgeStack, error) {
err := validateUniqueName(tx.EdgeStack().EdgeStacks, name) if err := validateUniqueName(tx.EdgeStack().EdgeStacks, name); err != nil {
if err != nil {
return nil, err return nil, err
} }
stackID := tx.EdgeStack().GetNextIdentifier() stackID := tx.EdgeStack().GetNextIdentifier()
return &portainer.EdgeStack{ return &portainer.EdgeStack{
ID: portainer.EdgeStackID(stackID), ID: portainer.EdgeStackID(stackID),
Name: name, Name: name,
@ -77,7 +77,6 @@ func (service *Service) PersistEdgeStack(
storeManifest edgetypes.StoreManifestFunc) (*portainer.EdgeStack, error) { storeManifest edgetypes.StoreManifestFunc) (*portainer.EdgeStack, error) {
relationConfig, err := edge.FetchEndpointRelationsConfig(tx) relationConfig, err := edge.FetchEndpointRelationsConfig(tx)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to find environment relations in database: %w", err) return nil, fmt.Errorf("unable to find environment relations in database: %w", err)
} }
@ -87,6 +86,7 @@ func (service *Service) PersistEdgeStack(
if errors.Is(err, edge.ErrEdgeGroupNotFound) { if errors.Is(err, edge.ErrEdgeGroupNotFound) {
return nil, httperrors.NewInvalidPayloadError(err.Error()) return nil, httperrors.NewInvalidPayloadError(err.Error())
} }
return nil, fmt.Errorf("unable to persist environment relation in database: %w", err) return nil, fmt.Errorf("unable to persist environment relation in database: %w", err)
} }
@ -101,13 +101,11 @@ func (service *Service) PersistEdgeStack(
stack.EntryPoint = composePath stack.EntryPoint = composePath
stack.NumDeployments = len(relatedEndpointIds) stack.NumDeployments = len(relatedEndpointIds)
err = service.updateEndpointRelations(tx, stack.ID, relatedEndpointIds) if err := service.updateEndpointRelations(tx, stack.ID, relatedEndpointIds); err != nil {
if err != nil {
return nil, fmt.Errorf("unable to update endpoint relations: %w", err) return nil, fmt.Errorf("unable to update endpoint relations: %w", err)
} }
err = tx.EdgeStack().Create(stack.ID, stack) if err := tx.EdgeStack().Create(stack.ID, stack); err != nil {
if err != nil {
return nil, err return nil, err
} }
@ -126,8 +124,7 @@ func (service *Service) updateEndpointRelations(tx dataservices.DataStoreTx, edg
relation.EdgeStacks[edgeStackID] = true relation.EdgeStacks[edgeStackID] = true
err = endpointRelationService.UpdateEndpointRelation(endpointID, relation) if err := endpointRelationService.UpdateEndpointRelation(endpointID, relation); err != nil {
if err != nil {
return fmt.Errorf("unable to persist endpoint relation in database: %w", err) return fmt.Errorf("unable to persist endpoint relation in database: %w", err)
} }
} }
@ -155,14 +152,12 @@ func (service *Service) DeleteEdgeStack(tx dataservices.DataStoreTx, edgeStackID
delete(relation.EdgeStacks, edgeStackID) delete(relation.EdgeStacks, edgeStackID)
err = tx.EndpointRelation().UpdateEndpointRelation(endpointID, relation) if err := tx.EndpointRelation().UpdateEndpointRelation(endpointID, relation); err != nil {
if err != nil {
return errors.WithMessage(err, "Unable to persist environment relation in database") return errors.WithMessage(err, "Unable to persist environment relation in database")
} }
} }
err = tx.EdgeStack().DeleteEdgeStack(edgeStackID) if err := tx.EdgeStack().DeleteEdgeStack(edgeStackID); err != nil {
if err != nil {
return errors.WithMessage(err, "Unable to remove the edge stack from the database") return errors.WithMessage(err, "Unable to remove the edge stack from the database")
} }

View file

@ -4,6 +4,7 @@ import (
"testing" "testing"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )

View file

@ -8,6 +8,7 @@ import (
// NodesCount returns the total node number of all environments // NodesCount returns the total node number of all environments
func NodesCount(endpoints []portainer.Endpoint) int { func NodesCount(endpoints []portainer.Endpoint) int {
nodes := 0 nodes := 0
for _, env := range endpoints { for _, env := range endpoints {
if !endpointutils.IsEdgeEndpoint(&env) || env.UserTrusted { if !endpointutils.IsEdgeEndpoint(&env) || env.UserTrusted {
nodes += countNodes(&env) nodes += countNodes(&env)
@ -28,11 +29,3 @@ func countNodes(endpoint *portainer.Endpoint) int {
return 1 return 1
} }
func max(a, b int) int {
if a > b {
return a
}
return b
}

View file

@ -1,16 +0,0 @@
package securecookie
import (
"crypto/rand"
"io"
)
// GenerateRandomKey generates a random key of specified length
// source: https://github.com/gorilla/securecookie/blob/master/securecookie.go#L515
func GenerateRandomKey(length int) []byte {
k := make([]byte, length)
if _, err := io.ReadFull(rand.Reader, k); err != nil {
return nil
}
return k
}

View file

@ -1,23 +0,0 @@
package slices
// Map applies the given function to each element of the slice and returns a new slice with the results
func Map[T, U any](s []T, f func(T) U) []U {
result := make([]U, len(s))
for i, v := range s {
result[i] = f(v)
}
return result
}
// Filter returns a new slice containing only the elements of the slice for which the given predicate returns true
func Filter[T any](s []T, predicate func(T) bool) []T {
n := 0
for _, v := range s {
if predicate(v) {
s[n] = v
n++
}
}
return s[:n]
}

View file

@ -1,41 +0,0 @@
package unique
func Unique[T comparable](items []T) []T {
return UniqueBy(items, func(item T) T {
return item
})
}
func UniqueBy[ItemType any, ComparableType comparable](items []ItemType, accessorFunc func(ItemType) ComparableType) []ItemType {
includedItems := make(map[ComparableType]bool)
result := []ItemType{}
for _, item := range items {
if _, isIncluded := includedItems[accessorFunc(item)]; !isIncluded {
includedItems[accessorFunc(item)] = true
result = append(result, item)
}
}
return result
}
/**
type someType struct {
id int
fn func()
}
func Test() {
ids := []int{1, 2, 3, 3}
_ = UniqueBy(ids, func(id int) int { return id })
_ = Unique(ids) // shorthand for UniqueBy Identity/self
as := []someType{{id: 1}, {id: 2}, {id: 3}, {id: 3}}
_ = UniqueBy(as, func(item someType) int { return item.id }) // no error
_ = UniqueBy(as, func(item someType) someType { return item }) // compile error - someType is not comparable
_ = Unique(as) // compile error - shorthand fails for the same reason
}
*/

View file

@ -6,10 +6,10 @@ import (
"time" "time"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/apikey"
"github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/dataservices"
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
"github.com/portainer/portainer/api/internal/securecookie"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -51,7 +51,7 @@ func NewService(userSessionDuration string, dataStore dataservices.DataStore) (*
return nil, err return nil, err
} }
secret := securecookie.GenerateRandomKey(32) secret := apikey.GenerateRandomKey(32)
if secret == nil { if secret == nil {
return nil, errSecretGeneration return nil, errSecretGeneration
} }
@ -69,6 +69,7 @@ func NewService(userSessionDuration string, dataStore dataservices.DataStore) (*
userSessionTimeout, userSessionTimeout,
dataStore, dataStore,
} }
return service, nil return service, nil
} }
@ -80,16 +81,18 @@ func getOrCreateKubeSecret(dataStore dataservices.DataStore) ([]byte, error) {
kubeSecret := settings.OAuthSettings.KubeSecretKey kubeSecret := settings.OAuthSettings.KubeSecretKey
if kubeSecret == nil { if kubeSecret == nil {
kubeSecret = securecookie.GenerateRandomKey(32) kubeSecret = apikey.GenerateRandomKey(32)
if kubeSecret == nil { if kubeSecret == nil {
return nil, errSecretGeneration return nil, errSecretGeneration
} }
settings.OAuthSettings.KubeSecretKey = kubeSecret settings.OAuthSettings.KubeSecretKey = kubeSecret
err = dataStore.Settings().UpdateSettings(settings)
if err != nil { if err := dataStore.Settings().UpdateSettings(settings); err != nil {
return nil, err return nil, err
} }
} }
return kubeSecret, nil return kubeSecret, nil
} }

View file

@ -3,8 +3,8 @@ package cli
import ( import (
"context" "context"
"github.com/portainer/portainer/api/concurrent"
models "github.com/portainer/portainer/api/http/models/kubernetes" models "github.com/portainer/portainer/api/http/models/kubernetes"
"github.com/portainer/portainer/api/internal/concurrent"
"k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/errors"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1" v1 "k8s.io/apimachinery/pkg/apis/meta/v1"

View file

@ -4,6 +4,7 @@ import (
"context" "context"
models "github.com/portainer/portainer/api/http/models/kubernetes" models "github.com/portainer/portainer/api/http/models/kubernetes"
v1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
labels "k8s.io/apimachinery/pkg/labels" labels "k8s.io/apimachinery/pkg/labels"
@ -67,9 +68,7 @@ func (kcl *KubeClient) GetServices(namespace string, lookupApplications bool) ([
return result, nil return result, nil
} }
// CreateService creates a new service in a given namespace in a k8s endpoint. func (kcl *KubeClient) fillService(info models.K8sServiceInfo) v1.Service {
func (kcl *KubeClient) CreateService(namespace string, info models.K8sServiceInfo) error {
ServiceClient := kcl.cli.CoreV1().Services(namespace)
var service v1.Service var service v1.Service
service.Name = info.Name service.Name = info.Name
@ -93,16 +92,21 @@ func (kcl *KubeClient) CreateService(namespace string, info models.K8sServiceInf
// Set ingresses. // Set ingresses.
for _, i := range info.IngressStatus { for _, i := range info.IngressStatus {
var ing v1.LoadBalancerIngress
ing.IP = i.IP
ing.Hostname = i.Host
service.Status.LoadBalancer.Ingress = append( service.Status.LoadBalancer.Ingress = append(
service.Status.LoadBalancer.Ingress, service.Status.LoadBalancer.Ingress,
ing, v1.LoadBalancerIngress{IP: i.IP, Hostname: i.Host},
) )
} }
_, err := ServiceClient.Create(context.Background(), &service, metav1.CreateOptions{}) return service
}
// CreateService creates a new service in a given namespace in a k8s endpoint.
func (kcl *KubeClient) CreateService(namespace string, info models.K8sServiceInfo) error {
serviceClient := kcl.cli.CoreV1().Services(namespace)
service := kcl.fillService(info)
_, err := serviceClient.Create(context.Background(), &service, metav1.CreateOptions{})
return err return err
} }
@ -120,45 +124,16 @@ func (kcl *KubeClient) DeleteServices(reqs models.K8sServiceDeleteRequests) erro
) )
} }
} }
return err return err
} }
// UpdateService updates service in a given namespace in a k8s endpoint. // UpdateService updates service in a given namespace in a k8s endpoint.
func (kcl *KubeClient) UpdateService(namespace string, info models.K8sServiceInfo) error { func (kcl *KubeClient) UpdateService(namespace string, info models.K8sServiceInfo) error {
ServiceClient := kcl.cli.CoreV1().Services(namespace) serviceClient := kcl.cli.CoreV1().Services(namespace)
var service v1.Service service := kcl.fillService(info)
service.Name = info.Name _, err := serviceClient.Update(context.Background(), &service, metav1.UpdateOptions{})
service.Spec.Type = v1.ServiceType(info.Type)
service.Namespace = info.Namespace
service.Annotations = info.Annotations
service.Labels = info.Labels
service.Spec.AllocateLoadBalancerNodePorts = info.AllocateLoadBalancerNodePorts
service.Spec.Selector = info.Selector
// Set ports.
for _, p := range info.Ports {
var port v1.ServicePort
port.Name = p.Name
port.NodePort = int32(p.NodePort)
port.Port = int32(p.Port)
port.Protocol = v1.Protocol(p.Protocol)
port.TargetPort = intstr.FromString(p.TargetPort)
service.Spec.Ports = append(service.Spec.Ports, port)
}
// Set ingresses.
for _, i := range info.IngressStatus {
var ing v1.LoadBalancerIngress
ing.IP = i.IP
ing.Hostname = i.Host
service.Status.LoadBalancer.Ingress = append(
service.Status.LoadBalancer.Ingress,
ing,
)
}
_, err := ServiceClient.Update(context.Background(), &service, metav1.UpdateOptions{})
return err return err
} }
@ -210,5 +185,4 @@ func makeApplication(meta metav1.Object) []models.K8sApplication {
Name: ownerReference.Name, Name: ownerReference.Name,
}, },
} }
} }

View file

@ -16,8 +16,7 @@ func Test_getOAuthToken(t *testing.T) {
t.Run("getOAuthToken fails upon invalid code", func(t *testing.T) { t.Run("getOAuthToken fails upon invalid code", func(t *testing.T) {
code := "" code := ""
_, err := getOAuthToken(code, config) if _, err := getOAuthToken(code, config); err == nil {
if err == nil {
t.Errorf("getOAuthToken should fail upon providing invalid code; code=%v", code) t.Errorf("getOAuthToken should fail upon providing invalid code; code=%v", code)
} }
}) })
@ -91,22 +90,19 @@ func Test_getResource(t *testing.T) {
defer srv.Close() defer srv.Close()
t.Run("should fail upon missing Authorization Bearer header", func(t *testing.T) { t.Run("should fail upon missing Authorization Bearer header", func(t *testing.T) {
_, err := getResource("", config) if _, err := getResource("", config); err == nil {
if err == nil {
t.Errorf("getResource should fail if access token is not provided in auth bearer header") t.Errorf("getResource should fail if access token is not provided in auth bearer header")
} }
}) })
t.Run("should fail upon providing incorrect Authorization Bearer header", func(t *testing.T) { t.Run("should fail upon providing incorrect Authorization Bearer header", func(t *testing.T) {
_, err := getResource("incorrect-token", config) if _, err := getResource("incorrect-token", config); err == nil {
if err == nil {
t.Errorf("getResource should fail if incorrect access token provided in auth bearer header") t.Errorf("getResource should fail if incorrect access token provided in auth bearer header")
} }
}) })
t.Run("should succeed upon providing correct Authorization Bearer header", func(t *testing.T) { t.Run("should succeed upon providing correct Authorization Bearer header", func(t *testing.T) {
_, err := getResource(oauthtest.AccessToken, config) if _, err := getResource(oauthtest.AccessToken, config); err != nil {
if err != nil {
t.Errorf("getResource should succeed if correct access token provided in auth bearer header") t.Errorf("getResource should succeed if correct access token provided in auth bearer header")
} }
}) })
@ -120,8 +116,7 @@ func Test_Authenticate(t *testing.T) {
srv, config := oauthtest.RunOAuthServer(code, &portainer.OAuthSettings{}) srv, config := oauthtest.RunOAuthServer(code, &portainer.OAuthSettings{})
defer srv.Close() defer srv.Close()
_, err := authService.Authenticate(code, config) if _, err := authService.Authenticate(code, config); err == nil {
if err == nil {
t.Error("Authenticate should fail to extract username from resource if incorrect UserIdentifier provided") t.Error("Authenticate should fail to extract username from resource if incorrect UserIdentifier provided")
} }
}) })

View file

@ -12,6 +12,7 @@ import (
gittypes "github.com/portainer/portainer/api/git/types" gittypes "github.com/portainer/portainer/api/git/types"
models "github.com/portainer/portainer/api/http/models/kubernetes" models "github.com/portainer/portainer/api/http/models/kubernetes"
"github.com/portainer/portainer/pkg/featureflags" "github.com/portainer/portainer/pkg/featureflags"
"golang.org/x/oauth2" "golang.org/x/oauth2"
v1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/version" "k8s.io/apimachinery/pkg/version"
@ -322,14 +323,14 @@ type (
Name string `json:"Name"` Name string `json:"Name"`
Status map[EndpointID]EdgeStackStatus `json:"Status"` Status map[EndpointID]EdgeStackStatus `json:"Status"`
// StatusArray map[EndpointID][]EdgeStackStatus `json:"StatusArray"` // StatusArray map[EndpointID][]EdgeStackStatus `json:"StatusArray"`
CreationDate int64 `json:"CreationDate"` CreationDate int64 `json:"CreationDate"`
EdgeGroups []EdgeGroupID `json:"EdgeGroups"` EdgeGroups []EdgeGroupID `json:"EdgeGroups"`
ProjectPath string `json:"ProjectPath"` ProjectPath string `json:"ProjectPath"`
EntryPoint string `json:"EntryPoint"` EntryPoint string `json:"EntryPoint"`
Version int `json:"Version"` Version int `json:"Version"`
NumDeployments int `json:"NumDeployments"` NumDeployments int `json:"NumDeployments"`
ManifestPath string ManifestPath string `json:"ManifestPath"`
DeploymentType EdgeStackDeploymentType DeploymentType EdgeStackDeploymentType `json:"DeploymentType"`
// Uses the manifest's namespaces instead of the default one // Uses the manifest's namespaces instead of the default one
UseManifestNamespaces bool UseManifestNamespaces bool
@ -554,23 +555,22 @@ type (
// Extension represents a deprecated Portainer extension // Extension represents a deprecated Portainer extension
Extension struct { Extension struct {
// Extension Identifier ID ExtensionID `json:"Id" example:"1"`
ID ExtensionID `json:"Id" example:"1"` Enabled bool `json:"Enabled"`
Enabled bool `json:"Enabled"` Name string `json:"Name,omitempty"`
Name string `json:"Name,omitempty"` ShortDescription string `json:"ShortDescription,omitempty"`
ShortDescription string `json:"ShortDescription,omitempty"` Description string `json:"Description,omitempty"`
Description string `json:"Description,omitempty"` DescriptionURL string `json:"DescriptionURL,omitempty"`
DescriptionURL string `json:"DescriptionURL,omitempty"` Price string `json:"Price,omitempty"`
Price string `json:"Price,omitempty"` PriceDescription string `json:"PriceDescription,omitempty"`
PriceDescription string `json:"PriceDescription,omitempty"` Deal bool `json:"Deal,omitempty"`
Deal bool `json:"Deal,omitempty"` Available bool `json:"Available,omitempty"`
Available bool `json:"Available,omitempty"` License ExtensionLicenseInformation `json:"License,omitempty"`
License LicenseInformation `json:"License,omitempty"` Version string `json:"Version"`
Version string `json:"Version"` UpdateAvailable bool `json:"UpdateAvailable"`
UpdateAvailable bool `json:"UpdateAvailable"` ShopURL string `json:"ShopURL,omitempty"`
ShopURL string `json:"ShopURL,omitempty"` Images []string `json:"Images,omitempty"`
Images []string `json:"Images,omitempty"` Logo string `json:"Logo,omitempty"`
Logo string `json:"Logo,omitempty"`
} }
// ExtensionID represents a extension identifier // ExtensionID represents a extension identifier
@ -737,8 +737,8 @@ type (
Groups []string Groups []string
} }
// LicenseInformation represents information about an extension license // ExtensionLicenseInformation represents information about an extension license
LicenseInformation struct { ExtensionLicenseInformation struct {
LicenseKey string `json:"LicenseKey,omitempty"` LicenseKey string `json:"LicenseKey,omitempty"`
Company string `json:"Company,omitempty"` Company string `json:"Company,omitempty"`
Expiration string `json:"Expiration,omitempty"` Expiration string `json:"Expiration,omitempty"`
@ -939,6 +939,18 @@ type (
HideStacksFunctionality bool `json:"hideStacksFunctionality" example:"false"` HideStacksFunctionality bool `json:"hideStacksFunctionality" example:"false"`
} }
Edge struct {
// The command list interval for edge agent - used in edge async mode (in seconds)
CommandInterval int `json:"CommandInterval" example:"5"`
// The ping interval for edge agent - used in edge async mode (in seconds)
PingInterval int `json:"PingInterval" example:"5"`
// The snapshot interval for edge agent - used in edge async mode (in seconds)
SnapshotInterval int `json:"SnapshotInterval" example:"5"`
// Deprecated 2.18
AsyncMode bool `json:"AsyncMode,omitempty" example:"false"`
}
// Settings represents the application settings // Settings represents the application settings
Settings struct { Settings struct {
// URL to a logo that will be displayed on the login page as well as on top of the sidebar. Will use default Portainer logo when value is empty string // URL to a logo that will be displayed on the login page as well as on top of the sidebar. Will use default Portainer logo when value is empty string
@ -984,17 +996,7 @@ type (
// EdgePortainerURL is the URL that is exposed to edge agents // EdgePortainerURL is the URL that is exposed to edge agents
EdgePortainerURL string `json:"EdgePortainerUrl"` EdgePortainerURL string `json:"EdgePortainerUrl"`
Edge struct { Edge Edge `json:"Edge"`
// The command list interval for edge agent - used in edge async mode (in seconds)
CommandInterval int `json:"CommandInterval" example:"5"`
// The ping interval for edge agent - used in edge async mode (in seconds)
PingInterval int `json:"PingInterval" example:"5"`
// The snapshot interval for edge agent - used in edge async mode (in seconds)
SnapshotInterval int `json:"SnapshotInterval" example:"5"`
// Deprecated 2.18
AsyncMode bool
}
// Deprecated fields // Deprecated fields
DisplayDonationHeader bool `json:"DisplayDonationHeader,omitempty"` DisplayDonationHeader bool `json:"DisplayDonationHeader,omitempty"`

View file

@ -58,8 +58,8 @@ func (s Set[T]) Copy() Set[T] {
} }
// Difference returns a new set containing the keys that are in the first set but not in the second set. // Difference returns a new set containing the keys that are in the first set but not in the second set.
func (set Set[T]) Difference(second Set[T]) Set[T] { func (s Set[T]) Difference(second Set[T]) Set[T] {
difference := set.Copy() difference := s.Copy()
for key := range second { for key := range second {
difference.Remove(key) difference.Remove(key)

43
api/slicesx/slices.go Normal file
View file

@ -0,0 +1,43 @@
package slicesx
// Map applies the given function to each element of the slice and returns a new slice with the results
func Map[T, U any](s []T, f func(T) U) []U {
result := make([]U, len(s))
for i, v := range s {
result[i] = f(v)
}
return result
}
// Filter returns a new slice containing only the elements of the slice for which the given predicate returns true
func Filter[T any](s []T, predicate func(T) bool) []T {
n := 0
for _, v := range s {
if predicate(v) {
s[n] = v
n++
}
}
return s[:n]
}
func Unique[T comparable](items []T) []T {
return UniqueBy(items, func(item T) T {
return item
})
}
func UniqueBy[ItemType any, ComparableType comparable](items []ItemType, accessorFunc func(ItemType) ComparableType) []ItemType {
includedItems := make(map[ComparableType]bool)
result := []ItemType{}
for _, item := range items {
if _, isIncluded := includedItems[accessorFunc(item)]; !isIncluded {
includedItems[accessorFunc(item)] = true
result = append(result, item)
}
}
return result
}

View file

@ -1,4 +1,4 @@
package slices package slicesx
import ( import (
"strconv" "strconv"