mirror of
https://github.com/portainer/portainer.git
synced 2025-07-24 15:59:41 +02:00
chore(code): reduce the code duplication EE-7278 (#11969)
This commit is contained in:
parent
39bdfa4512
commit
9ee092aa5e
85 changed files with 520 additions and 618 deletions
|
@ -4,7 +4,7 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/set"
|
||||
"github.com/portainer/portainer/api/set"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
)
|
||||
|
|
|
@ -73,8 +73,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
|
|||
payload.GroupID = groupID
|
||||
|
||||
var tagIDs []portainer.TagID
|
||||
err = request.RetrieveMultiPartFormJSONValue(r, "TagIds", &tagIDs, true)
|
||||
if err != nil {
|
||||
if err := request.RetrieveMultiPartFormJSONValue(r, "TagIds", &tagIDs, true); err != nil {
|
||||
return errors.New("invalid TagIds parameter")
|
||||
}
|
||||
payload.TagIDs = tagIDs
|
||||
|
@ -96,6 +95,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
|
|||
if err != nil {
|
||||
return errors.New("invalid CA certificate file. Ensure that the file is uploaded correctly")
|
||||
}
|
||||
|
||||
payload.TLSCACertFile = caCert
|
||||
}
|
||||
|
||||
|
@ -110,6 +110,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
|
|||
if err != nil {
|
||||
return errors.New("invalid key file. Ensure that the file is uploaded correctly")
|
||||
}
|
||||
|
||||
payload.TLSKeyFile = key
|
||||
}
|
||||
}
|
||||
|
@ -120,6 +121,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
|
|||
if err != nil {
|
||||
return errors.New("invalid Azure application ID")
|
||||
}
|
||||
|
||||
payload.AzureApplicationID = azureApplicationID
|
||||
|
||||
azureTenantID, err := request.RetrieveMultiPartFormValue(r, "AzureTenantID", false)
|
||||
|
@ -139,6 +141,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
|
|||
if err != nil || strings.EqualFold("", strings.Trim(endpointURL, " ")) {
|
||||
return errors.New("URL cannot be empty")
|
||||
}
|
||||
|
||||
payload.URL = endpointURL
|
||||
|
||||
publicURL, _ := request.RetrieveMultiPartFormValue(r, "PublicURL", true)
|
||||
|
@ -156,10 +159,10 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
|
|||
}
|
||||
|
||||
gpus := make([]portainer.Pair, 0)
|
||||
err = request.RetrieveMultiPartFormJSONValue(r, "Gpus", &gpus, true)
|
||||
if err != nil {
|
||||
if err := request.RetrieveMultiPartFormJSONValue(r, "Gpus", &gpus, true); err != nil {
|
||||
return errors.New("invalid Gpus parameter")
|
||||
}
|
||||
|
||||
payload.Gpus = gpus
|
||||
|
||||
edgeCheckinInterval, _ := request.RetrieveNumericMultiPartFormValue(r, "EdgeCheckinInterval", true)
|
||||
|
@ -206,8 +209,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
|
|||
// @router /endpoints [post]
|
||||
func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
|
||||
payload := &endpointCreatePayload{}
|
||||
err := payload.Validate(r)
|
||||
if err != nil {
|
||||
if err := payload.Validate(r); err != nil {
|
||||
return httperror.BadRequest("Invalid request payload", err)
|
||||
}
|
||||
|
||||
|
@ -268,8 +270,7 @@ func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) *
|
|||
)
|
||||
}
|
||||
|
||||
err = handler.DataStore.EndpointRelation().Create(relationObject)
|
||||
if err != nil {
|
||||
if err := handler.DataStore.EndpointRelation().Create(relationObject); err != nil {
|
||||
return httperror.InternalServerError("Unable to persist the relation object inside the database", err)
|
||||
}
|
||||
|
||||
|
@ -278,6 +279,7 @@ func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) *
|
|||
|
||||
func (handler *Handler) createEndpoint(tx dataservices.DataStoreTx, payload *endpointCreatePayload) (*portainer.Endpoint, *httperror.HandlerError) {
|
||||
var err error
|
||||
|
||||
switch payload.EndpointCreationType {
|
||||
case azureEnvironment:
|
||||
return handler.createAzureEndpoint(tx, payload)
|
||||
|
@ -329,8 +331,7 @@ func (handler *Handler) createAzureEndpoint(tx dataservices.DataStoreTx, payload
|
|||
}
|
||||
|
||||
httpClient := client.NewHTTPClient()
|
||||
_, err := httpClient.ExecuteAzureAuthenticationRequest(&credentials)
|
||||
if err != nil {
|
||||
if _, err := httpClient.ExecuteAzureAuthenticationRequest(&credentials); err != nil {
|
||||
return nil, httperror.InternalServerError("Unable to authenticate against Azure", err)
|
||||
}
|
||||
|
||||
|
@ -352,8 +353,7 @@ func (handler *Handler) createAzureEndpoint(tx dataservices.DataStoreTx, payload
|
|||
Kubernetes: portainer.KubernetesDefault(),
|
||||
}
|
||||
|
||||
err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint)
|
||||
if err != nil {
|
||||
if err := handler.saveEndpointAndUpdateAuthorizations(tx, endpoint); err != nil {
|
||||
return nil, httperror.InternalServerError("An error occurred while trying to create the environment", err)
|
||||
}
|
||||
|
||||
|
@ -405,8 +405,7 @@ func (handler *Handler) createEdgeAgentEndpoint(tx dataservices.DataStoreTx, pay
|
|||
endpoint.EdgeID = edgeID.String()
|
||||
}
|
||||
|
||||
err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint)
|
||||
if err != nil {
|
||||
if err := handler.saveEndpointAndUpdateAuthorizations(tx, endpoint); err != nil {
|
||||
return nil, httperror.InternalServerError("An error occurred while trying to create the environment", err)
|
||||
}
|
||||
|
||||
|
@ -443,8 +442,7 @@ func (handler *Handler) createUnsecuredEndpoint(tx dataservices.DataStoreTx, pay
|
|||
Kubernetes: portainer.KubernetesDefault(),
|
||||
}
|
||||
|
||||
err := handler.snapshotAndPersistEndpoint(tx, endpoint)
|
||||
if err != nil {
|
||||
if err := handler.snapshotAndPersistEndpoint(tx, endpoint); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -478,8 +476,7 @@ func (handler *Handler) createKubernetesEndpoint(tx dataservices.DataStoreTx, pa
|
|||
Kubernetes: portainer.KubernetesDefault(),
|
||||
}
|
||||
|
||||
err := handler.snapshotAndPersistEndpoint(tx, endpoint)
|
||||
if err != nil {
|
||||
if err := handler.snapshotAndPersistEndpoint(tx, endpoint); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -510,13 +507,11 @@ func (handler *Handler) createTLSSecuredEndpoint(tx dataservices.DataStoreTx, pa
|
|||
|
||||
endpoint.Agent.Version = agentVersion
|
||||
|
||||
err := handler.storeTLSFiles(endpoint, payload)
|
||||
if err != nil {
|
||||
if err := handler.storeTLSFiles(endpoint, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = handler.snapshotAndPersistEndpoint(tx, endpoint)
|
||||
if err != nil {
|
||||
if err := handler.snapshotAndPersistEndpoint(tx, endpoint); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -524,17 +519,16 @@ func (handler *Handler) createTLSSecuredEndpoint(tx dataservices.DataStoreTx, pa
|
|||
}
|
||||
|
||||
func (handler *Handler) snapshotAndPersistEndpoint(tx dataservices.DataStoreTx, endpoint *portainer.Endpoint) *httperror.HandlerError {
|
||||
err := handler.SnapshotService.SnapshotEndpoint(endpoint)
|
||||
if err != nil {
|
||||
if err := handler.SnapshotService.SnapshotEndpoint(endpoint); err != nil {
|
||||
if (endpoint.Type == portainer.AgentOnDockerEnvironment && strings.Contains(err.Error(), "Invalid request signature")) ||
|
||||
(endpoint.Type == portainer.AgentOnKubernetesEnvironment && strings.Contains(err.Error(), "unknown")) {
|
||||
err = errors.New("agent already paired with another Portainer instance")
|
||||
}
|
||||
|
||||
return httperror.InternalServerError("Unable to initiate communications with environment", err)
|
||||
}
|
||||
|
||||
err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint)
|
||||
if err != nil {
|
||||
if err := handler.saveEndpointAndUpdateAuthorizations(tx, endpoint); err != nil {
|
||||
return httperror.InternalServerError("An error occurred while trying to create the environment", err)
|
||||
}
|
||||
|
||||
|
@ -555,16 +549,14 @@ func (handler *Handler) saveEndpointAndUpdateAuthorizations(tx dataservices.Data
|
|||
AllowStackManagementForRegularUsers: true,
|
||||
}
|
||||
|
||||
err := tx.Endpoint().Create(endpoint)
|
||||
if err != nil {
|
||||
if err := tx.Endpoint().Create(endpoint); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, tagID := range endpoint.TagIDs {
|
||||
err = tx.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) {
|
||||
if err := tx.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) {
|
||||
tag.Endpoints[endpoint.ID] = true
|
||||
})
|
||||
if err != nil {
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -580,22 +572,26 @@ func (handler *Handler) storeTLSFiles(endpoint *portainer.Endpoint, payload *end
|
|||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to persist TLS CA certificate file on disk", err)
|
||||
}
|
||||
|
||||
endpoint.TLSConfig.TLSCACertPath = caCertPath
|
||||
}
|
||||
|
||||
if !payload.TLSSkipClientVerify {
|
||||
certPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileCert, payload.TLSCertFile)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to persist TLS certificate file on disk", err)
|
||||
}
|
||||
endpoint.TLSConfig.TLSCertPath = certPath
|
||||
|
||||
keyPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileKey, payload.TLSKeyFile)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to persist TLS key file on disk", err)
|
||||
}
|
||||
endpoint.TLSConfig.TLSKeyPath = keyPath
|
||||
if payload.TLSSkipClientVerify {
|
||||
return nil
|
||||
}
|
||||
|
||||
certPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileCert, payload.TLSCertFile)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to persist TLS certificate file on disk", err)
|
||||
}
|
||||
|
||||
endpoint.TLSConfig.TLSCertPath = certPath
|
||||
|
||||
keyPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileKey, payload.TLSKeyFile)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to persist TLS key file on disk", err)
|
||||
}
|
||||
endpoint.TLSConfig.TLSKeyPath = keyPath
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -30,24 +30,22 @@ func TestEndpointDeleteEdgeGroupsConcurrently(t *testing.T) {
|
|||
for i := 0; i < endpointsCount; i++ {
|
||||
endpointID := portainer.EndpointID(i) + 1
|
||||
|
||||
err := store.Endpoint().Create(&portainer.Endpoint{
|
||||
if err := store.Endpoint().Create(&portainer.Endpoint{
|
||||
ID: endpointID,
|
||||
Name: "env-" + strconv.Itoa(int(endpointID)),
|
||||
Type: portainer.EdgeAgentOnDockerEnvironment,
|
||||
})
|
||||
if err != nil {
|
||||
}); err != nil {
|
||||
t.Fatal("could not create endpoint:", err)
|
||||
}
|
||||
|
||||
endpointIDs = append(endpointIDs, endpointID)
|
||||
}
|
||||
|
||||
err := store.EdgeGroup().Create(&portainer.EdgeGroup{
|
||||
if err := store.EdgeGroup().Create(&portainer.EdgeGroup{
|
||||
ID: 1,
|
||||
Name: "edgegroup-1",
|
||||
Endpoints: endpointIDs,
|
||||
})
|
||||
if err != nil {
|
||||
}); err != nil {
|
||||
t.Fatal("could not create edge group:", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -102,7 +102,6 @@ func Test_EndpointList_AgentVersion(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_endpointList_edgeFilter(t *testing.T) {
|
||||
|
||||
trustedEdgeAsync := portainer.Endpoint{ID: 1, UserTrusted: true, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: true}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment}
|
||||
untrustedEdgeAsync := portainer.Endpoint{ID: 2, UserTrusted: false, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: true}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment}
|
||||
regularUntrustedEdgeStandard := portainer.Endpoint{ID: 3, UserTrusted: false, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: false}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment}
|
||||
|
@ -227,8 +226,7 @@ func doEndpointListRequest(req *http.Request, h *Handler, is *assert.Assertions)
|
|||
}
|
||||
|
||||
resp := []portainer.Endpoint{}
|
||||
err = json.Unmarshal(body, &resp)
|
||||
if err != nil {
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -34,12 +34,10 @@ func (handler *Handler) endpointRegistriesList(w http.ResponseWriter, r *http.Re
|
|||
}
|
||||
|
||||
var registries []portainer.Registry
|
||||
err = handler.DataStore.ViewTx(func(tx dataservices.DataStoreTx) error {
|
||||
if err := handler.DataStore.ViewTx(func(tx dataservices.DataStoreTx) error {
|
||||
registries, err = handler.listRegistries(tx, r, portainer.EndpointID(endpointID))
|
||||
return err
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
}); err != nil {
|
||||
var httpErr *httperror.HandlerError
|
||||
if errors.As(err, &httpErr) {
|
||||
return httpErr
|
||||
|
@ -104,11 +102,9 @@ func (handler *Handler) filterKubernetesEndpointRegistries(r *http.Request, regi
|
|||
}
|
||||
|
||||
if namespaceParam != "" {
|
||||
authorized, err := handler.isNamespaceAuthorized(endpoint, namespaceParam, user.ID, memberships, isAdmin)
|
||||
if err != nil {
|
||||
if authorized, err := handler.isNamespaceAuthorized(endpoint, namespaceParam, user.ID, memberships, isAdmin); err != nil {
|
||||
return nil, httperror.NotFound("Unable to check for namespace authorization", err)
|
||||
}
|
||||
if !authorized {
|
||||
} else if !authorized {
|
||||
return nil, httperror.Forbidden("User is not authorized to use namespace", errors.New("user is not authorized to use namespace"))
|
||||
}
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
"github.com/portainer/portainer/api/http/handler/edgegroups"
|
||||
"github.com/portainer/portainer/api/internal/edge"
|
||||
"github.com/portainer/portainer/api/internal/endpointutils"
|
||||
"github.com/portainer/portainer/api/internal/unique"
|
||||
"github.com/portainer/portainer/api/slicesx"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
@ -254,6 +254,7 @@ func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId port
|
|||
if err != nil {
|
||||
return nil, errors.WithMessage(err, "Unable to retrieve edge group from the database")
|
||||
}
|
||||
|
||||
if edgeGroup.Dynamic {
|
||||
endpointIDs, err := edgegroups.GetEndpointsByTags(datastore, edgeGroup.TagIDs, edgeGroup.PartialMatch)
|
||||
if err != nil {
|
||||
|
@ -261,6 +262,7 @@ func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId port
|
|||
}
|
||||
edgeGroup.Endpoints = endpointIDs
|
||||
}
|
||||
|
||||
envIds = append(envIds, edgeGroup.Endpoints...)
|
||||
}
|
||||
|
||||
|
@ -275,7 +277,7 @@ func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId port
|
|||
envIds = envIds[:n]
|
||||
}
|
||||
|
||||
uniqueIds := unique.Unique(envIds)
|
||||
uniqueIds := slicesx.Unique(envIds)
|
||||
filteredEndpoints := filteredEndpointsByIds(endpoints, uniqueIds)
|
||||
|
||||
return filteredEndpoints, nil
|
||||
|
|
|
@ -5,8 +5,8 @@ import (
|
|||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/api/internal/slices"
|
||||
"github.com/portainer/portainer/api/internal/testhelpers"
|
||||
"github.com/portainer/portainer/api/slicesx"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -129,7 +129,7 @@ func Test_Filter_edgeFilter(t *testing.T) {
|
|||
func Test_Filter_excludeIDs(t *testing.T) {
|
||||
ids := []portainer.EndpointID{1, 2, 3, 4, 5, 6, 7, 8, 9}
|
||||
|
||||
environments := slices.Map(ids, func(id portainer.EndpointID) portainer.Endpoint {
|
||||
environments := slicesx.Map(ids, func(id portainer.EndpointID) portainer.Endpoint {
|
||||
return portainer.Endpoint{ID: id, GroupID: 1, Type: portainer.DockerEnvironment}
|
||||
})
|
||||
|
||||
|
|
|
@ -4,7 +4,8 @@ import (
|
|||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/internal/slices"
|
||||
"github.com/portainer/portainer/api/slicesx"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
@ -162,7 +163,7 @@ func TestSortEndpointsByField(t *testing.T) {
|
|||
}
|
||||
|
||||
func getEndpointIDs(environments []portainer.Endpoint) []portainer.EndpointID {
|
||||
return slices.Map(environments, func(environment portainer.Endpoint) portainer.EndpointID {
|
||||
return slicesx.Map(environments, func(environment portainer.Endpoint) portainer.EndpointID {
|
||||
return environment.ID
|
||||
})
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/internal/edge"
|
||||
"github.com/portainer/portainer/api/internal/endpointutils"
|
||||
"github.com/portainer/portainer/api/internal/set"
|
||||
"github.com/portainer/portainer/api/set"
|
||||
)
|
||||
|
||||
// updateEdgeRelations updates the edge stacks associated to an edge endpoint
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/internal/set"
|
||||
"github.com/portainer/portainer/api/set"
|
||||
)
|
||||
|
||||
func updateEnvironmentEdgeGroups(tx dataservices.DataStoreTx, newEdgeGroups []portainer.EdgeGroupID, environmentID portainer.EndpointID) (bool, error) {
|
||||
|
|
|
@ -10,7 +10,6 @@ import (
|
|||
)
|
||||
|
||||
func Test_updateEdgeGroups(t *testing.T) {
|
||||
|
||||
createGroups := func(store *datastore.Store, names []string) ([]portainer.EdgeGroup, error) {
|
||||
groups := make([]portainer.EdgeGroup, len(names))
|
||||
for index, name := range names {
|
||||
|
@ -21,8 +20,7 @@ func Test_updateEdgeGroups(t *testing.T) {
|
|||
Endpoints: make([]portainer.EndpointID, 0),
|
||||
}
|
||||
|
||||
err := store.EdgeGroup().Create(group)
|
||||
if err != nil {
|
||||
if err := store.EdgeGroup().Create(group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -42,6 +40,7 @@ func Test_updateEdgeGroups(t *testing.T) {
|
|||
return
|
||||
}
|
||||
}
|
||||
|
||||
is.Fail("expected endpoint to be in group")
|
||||
}
|
||||
}
|
||||
|
@ -52,6 +51,7 @@ func Test_updateEdgeGroups(t *testing.T) {
|
|||
for j, tag := range groups {
|
||||
if tag.Name == tagName {
|
||||
result[i] = groups[j]
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -88,6 +88,7 @@ func Test_updateEdgeGroups(t *testing.T) {
|
|||
}
|
||||
|
||||
expectedGroups := groupsByName(groups, testCase.groupsToApply)
|
||||
|
||||
expectedIDs := make([]portainer.EdgeGroupID, len(expectedGroups))
|
||||
for i, tag := range expectedGroups {
|
||||
expectedIDs[i] = tag.ID
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/internal/set"
|
||||
"github.com/portainer/portainer/api/set"
|
||||
)
|
||||
|
||||
// updateEnvironmentTags updates the tags associated to an environment
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue