1
0
Fork 0
mirror of https://github.com/portainer/portainer.git synced 2025-07-24 15:59:41 +02:00

chore(code): reduce the code duplication EE-7278 (#11969)

This commit is contained in:
andres-portainer 2024-06-26 18:14:22 -03:00 committed by GitHub
parent 39bdfa4512
commit 9ee092aa5e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
85 changed files with 520 additions and 618 deletions

View file

@ -4,7 +4,7 @@ import (
"net/http"
"github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/set"
"github.com/portainer/portainer/api/set"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/response"
)

View file

@ -73,8 +73,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
payload.GroupID = groupID
var tagIDs []portainer.TagID
err = request.RetrieveMultiPartFormJSONValue(r, "TagIds", &tagIDs, true)
if err != nil {
if err := request.RetrieveMultiPartFormJSONValue(r, "TagIds", &tagIDs, true); err != nil {
return errors.New("invalid TagIds parameter")
}
payload.TagIDs = tagIDs
@ -96,6 +95,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
if err != nil {
return errors.New("invalid CA certificate file. Ensure that the file is uploaded correctly")
}
payload.TLSCACertFile = caCert
}
@ -110,6 +110,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
if err != nil {
return errors.New("invalid key file. Ensure that the file is uploaded correctly")
}
payload.TLSKeyFile = key
}
}
@ -120,6 +121,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
if err != nil {
return errors.New("invalid Azure application ID")
}
payload.AzureApplicationID = azureApplicationID
azureTenantID, err := request.RetrieveMultiPartFormValue(r, "AzureTenantID", false)
@ -139,6 +141,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
if err != nil || strings.EqualFold("", strings.Trim(endpointURL, " ")) {
return errors.New("URL cannot be empty")
}
payload.URL = endpointURL
publicURL, _ := request.RetrieveMultiPartFormValue(r, "PublicURL", true)
@ -156,10 +159,10 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
}
gpus := make([]portainer.Pair, 0)
err = request.RetrieveMultiPartFormJSONValue(r, "Gpus", &gpus, true)
if err != nil {
if err := request.RetrieveMultiPartFormJSONValue(r, "Gpus", &gpus, true); err != nil {
return errors.New("invalid Gpus parameter")
}
payload.Gpus = gpus
edgeCheckinInterval, _ := request.RetrieveNumericMultiPartFormValue(r, "EdgeCheckinInterval", true)
@ -206,8 +209,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
// @router /endpoints [post]
func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
payload := &endpointCreatePayload{}
err := payload.Validate(r)
if err != nil {
if err := payload.Validate(r); err != nil {
return httperror.BadRequest("Invalid request payload", err)
}
@ -268,8 +270,7 @@ func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) *
)
}
err = handler.DataStore.EndpointRelation().Create(relationObject)
if err != nil {
if err := handler.DataStore.EndpointRelation().Create(relationObject); err != nil {
return httperror.InternalServerError("Unable to persist the relation object inside the database", err)
}
@ -278,6 +279,7 @@ func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) *
func (handler *Handler) createEndpoint(tx dataservices.DataStoreTx, payload *endpointCreatePayload) (*portainer.Endpoint, *httperror.HandlerError) {
var err error
switch payload.EndpointCreationType {
case azureEnvironment:
return handler.createAzureEndpoint(tx, payload)
@ -329,8 +331,7 @@ func (handler *Handler) createAzureEndpoint(tx dataservices.DataStoreTx, payload
}
httpClient := client.NewHTTPClient()
_, err := httpClient.ExecuteAzureAuthenticationRequest(&credentials)
if err != nil {
if _, err := httpClient.ExecuteAzureAuthenticationRequest(&credentials); err != nil {
return nil, httperror.InternalServerError("Unable to authenticate against Azure", err)
}
@ -352,8 +353,7 @@ func (handler *Handler) createAzureEndpoint(tx dataservices.DataStoreTx, payload
Kubernetes: portainer.KubernetesDefault(),
}
err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint)
if err != nil {
if err := handler.saveEndpointAndUpdateAuthorizations(tx, endpoint); err != nil {
return nil, httperror.InternalServerError("An error occurred while trying to create the environment", err)
}
@ -405,8 +405,7 @@ func (handler *Handler) createEdgeAgentEndpoint(tx dataservices.DataStoreTx, pay
endpoint.EdgeID = edgeID.String()
}
err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint)
if err != nil {
if err := handler.saveEndpointAndUpdateAuthorizations(tx, endpoint); err != nil {
return nil, httperror.InternalServerError("An error occurred while trying to create the environment", err)
}
@ -443,8 +442,7 @@ func (handler *Handler) createUnsecuredEndpoint(tx dataservices.DataStoreTx, pay
Kubernetes: portainer.KubernetesDefault(),
}
err := handler.snapshotAndPersistEndpoint(tx, endpoint)
if err != nil {
if err := handler.snapshotAndPersistEndpoint(tx, endpoint); err != nil {
return nil, err
}
@ -478,8 +476,7 @@ func (handler *Handler) createKubernetesEndpoint(tx dataservices.DataStoreTx, pa
Kubernetes: portainer.KubernetesDefault(),
}
err := handler.snapshotAndPersistEndpoint(tx, endpoint)
if err != nil {
if err := handler.snapshotAndPersistEndpoint(tx, endpoint); err != nil {
return nil, err
}
@ -510,13 +507,11 @@ func (handler *Handler) createTLSSecuredEndpoint(tx dataservices.DataStoreTx, pa
endpoint.Agent.Version = agentVersion
err := handler.storeTLSFiles(endpoint, payload)
if err != nil {
if err := handler.storeTLSFiles(endpoint, payload); err != nil {
return nil, err
}
err = handler.snapshotAndPersistEndpoint(tx, endpoint)
if err != nil {
if err := handler.snapshotAndPersistEndpoint(tx, endpoint); err != nil {
return nil, err
}
@ -524,17 +519,16 @@ func (handler *Handler) createTLSSecuredEndpoint(tx dataservices.DataStoreTx, pa
}
func (handler *Handler) snapshotAndPersistEndpoint(tx dataservices.DataStoreTx, endpoint *portainer.Endpoint) *httperror.HandlerError {
err := handler.SnapshotService.SnapshotEndpoint(endpoint)
if err != nil {
if err := handler.SnapshotService.SnapshotEndpoint(endpoint); err != nil {
if (endpoint.Type == portainer.AgentOnDockerEnvironment && strings.Contains(err.Error(), "Invalid request signature")) ||
(endpoint.Type == portainer.AgentOnKubernetesEnvironment && strings.Contains(err.Error(), "unknown")) {
err = errors.New("agent already paired with another Portainer instance")
}
return httperror.InternalServerError("Unable to initiate communications with environment", err)
}
err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint)
if err != nil {
if err := handler.saveEndpointAndUpdateAuthorizations(tx, endpoint); err != nil {
return httperror.InternalServerError("An error occurred while trying to create the environment", err)
}
@ -555,16 +549,14 @@ func (handler *Handler) saveEndpointAndUpdateAuthorizations(tx dataservices.Data
AllowStackManagementForRegularUsers: true,
}
err := tx.Endpoint().Create(endpoint)
if err != nil {
if err := tx.Endpoint().Create(endpoint); err != nil {
return err
}
for _, tagID := range endpoint.TagIDs {
err = tx.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) {
if err := tx.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) {
tag.Endpoints[endpoint.ID] = true
})
if err != nil {
}); err != nil {
return err
}
}
@ -580,22 +572,26 @@ func (handler *Handler) storeTLSFiles(endpoint *portainer.Endpoint, payload *end
if err != nil {
return httperror.InternalServerError("Unable to persist TLS CA certificate file on disk", err)
}
endpoint.TLSConfig.TLSCACertPath = caCertPath
}
if !payload.TLSSkipClientVerify {
certPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileCert, payload.TLSCertFile)
if err != nil {
return httperror.InternalServerError("Unable to persist TLS certificate file on disk", err)
}
endpoint.TLSConfig.TLSCertPath = certPath
keyPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileKey, payload.TLSKeyFile)
if err != nil {
return httperror.InternalServerError("Unable to persist TLS key file on disk", err)
}
endpoint.TLSConfig.TLSKeyPath = keyPath
if payload.TLSSkipClientVerify {
return nil
}
certPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileCert, payload.TLSCertFile)
if err != nil {
return httperror.InternalServerError("Unable to persist TLS certificate file on disk", err)
}
endpoint.TLSConfig.TLSCertPath = certPath
keyPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileKey, payload.TLSKeyFile)
if err != nil {
return httperror.InternalServerError("Unable to persist TLS key file on disk", err)
}
endpoint.TLSConfig.TLSKeyPath = keyPath
return nil
}

View file

@ -30,24 +30,22 @@ func TestEndpointDeleteEdgeGroupsConcurrently(t *testing.T) {
for i := 0; i < endpointsCount; i++ {
endpointID := portainer.EndpointID(i) + 1
err := store.Endpoint().Create(&portainer.Endpoint{
if err := store.Endpoint().Create(&portainer.Endpoint{
ID: endpointID,
Name: "env-" + strconv.Itoa(int(endpointID)),
Type: portainer.EdgeAgentOnDockerEnvironment,
})
if err != nil {
}); err != nil {
t.Fatal("could not create endpoint:", err)
}
endpointIDs = append(endpointIDs, endpointID)
}
err := store.EdgeGroup().Create(&portainer.EdgeGroup{
if err := store.EdgeGroup().Create(&portainer.EdgeGroup{
ID: 1,
Name: "edgegroup-1",
Endpoints: endpointIDs,
})
if err != nil {
}); err != nil {
t.Fatal("could not create edge group:", err)
}

View file

@ -102,7 +102,6 @@ func Test_EndpointList_AgentVersion(t *testing.T) {
}
func Test_endpointList_edgeFilter(t *testing.T) {
trustedEdgeAsync := portainer.Endpoint{ID: 1, UserTrusted: true, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: true}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment}
untrustedEdgeAsync := portainer.Endpoint{ID: 2, UserTrusted: false, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: true}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment}
regularUntrustedEdgeStandard := portainer.Endpoint{ID: 3, UserTrusted: false, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: false}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment}
@ -227,8 +226,7 @@ func doEndpointListRequest(req *http.Request, h *Handler, is *assert.Assertions)
}
resp := []portainer.Endpoint{}
err = json.Unmarshal(body, &resp)
if err != nil {
if err := json.Unmarshal(body, &resp); err != nil {
return nil, err
}

View file

@ -34,12 +34,10 @@ func (handler *Handler) endpointRegistriesList(w http.ResponseWriter, r *http.Re
}
var registries []portainer.Registry
err = handler.DataStore.ViewTx(func(tx dataservices.DataStoreTx) error {
if err := handler.DataStore.ViewTx(func(tx dataservices.DataStoreTx) error {
registries, err = handler.listRegistries(tx, r, portainer.EndpointID(endpointID))
return err
})
if err != nil {
}); err != nil {
var httpErr *httperror.HandlerError
if errors.As(err, &httpErr) {
return httpErr
@ -104,11 +102,9 @@ func (handler *Handler) filterKubernetesEndpointRegistries(r *http.Request, regi
}
if namespaceParam != "" {
authorized, err := handler.isNamespaceAuthorized(endpoint, namespaceParam, user.ID, memberships, isAdmin)
if err != nil {
if authorized, err := handler.isNamespaceAuthorized(endpoint, namespaceParam, user.ID, memberships, isAdmin); err != nil {
return nil, httperror.NotFound("Unable to check for namespace authorization", err)
}
if !authorized {
} else if !authorized {
return nil, httperror.Forbidden("User is not authorized to use namespace", errors.New("user is not authorized to use namespace"))
}

View file

@ -13,7 +13,7 @@ import (
"github.com/portainer/portainer/api/http/handler/edgegroups"
"github.com/portainer/portainer/api/internal/edge"
"github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/api/internal/unique"
"github.com/portainer/portainer/api/slicesx"
"github.com/portainer/portainer/pkg/libhttp/request"
"github.com/pkg/errors"
@ -254,6 +254,7 @@ func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId port
if err != nil {
return nil, errors.WithMessage(err, "Unable to retrieve edge group from the database")
}
if edgeGroup.Dynamic {
endpointIDs, err := edgegroups.GetEndpointsByTags(datastore, edgeGroup.TagIDs, edgeGroup.PartialMatch)
if err != nil {
@ -261,6 +262,7 @@ func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId port
}
edgeGroup.Endpoints = endpointIDs
}
envIds = append(envIds, edgeGroup.Endpoints...)
}
@ -275,7 +277,7 @@ func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId port
envIds = envIds[:n]
}
uniqueIds := unique.Unique(envIds)
uniqueIds := slicesx.Unique(envIds)
filteredEndpoints := filteredEndpointsByIds(endpoints, uniqueIds)
return filteredEndpoints, nil

View file

@ -5,8 +5,8 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/internal/slices"
"github.com/portainer/portainer/api/internal/testhelpers"
"github.com/portainer/portainer/api/slicesx"
"github.com/stretchr/testify/assert"
)
@ -129,7 +129,7 @@ func Test_Filter_edgeFilter(t *testing.T) {
func Test_Filter_excludeIDs(t *testing.T) {
ids := []portainer.EndpointID{1, 2, 3, 4, 5, 6, 7, 8, 9}
environments := slices.Map(ids, func(id portainer.EndpointID) portainer.Endpoint {
environments := slicesx.Map(ids, func(id portainer.EndpointID) portainer.Endpoint {
return portainer.Endpoint{ID: id, GroupID: 1, Type: portainer.DockerEnvironment}
})

View file

@ -4,7 +4,8 @@ import (
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/internal/slices"
"github.com/portainer/portainer/api/slicesx"
"github.com/stretchr/testify/assert"
)
@ -162,7 +163,7 @@ func TestSortEndpointsByField(t *testing.T) {
}
func getEndpointIDs(environments []portainer.Endpoint) []portainer.EndpointID {
return slices.Map(environments, func(environment portainer.Endpoint) portainer.EndpointID {
return slicesx.Map(environments, func(environment portainer.Endpoint) portainer.EndpointID {
return environment.ID
})
}

View file

@ -6,7 +6,7 @@ import (
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/edge"
"github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/api/internal/set"
"github.com/portainer/portainer/api/set"
)
// updateEdgeRelations updates the edge stacks associated to an edge endpoint

View file

@ -6,7 +6,7 @@ import (
"github.com/pkg/errors"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/set"
"github.com/portainer/portainer/api/set"
)
func updateEnvironmentEdgeGroups(tx dataservices.DataStoreTx, newEdgeGroups []portainer.EdgeGroupID, environmentID portainer.EndpointID) (bool, error) {

View file

@ -10,7 +10,6 @@ import (
)
func Test_updateEdgeGroups(t *testing.T) {
createGroups := func(store *datastore.Store, names []string) ([]portainer.EdgeGroup, error) {
groups := make([]portainer.EdgeGroup, len(names))
for index, name := range names {
@ -21,8 +20,7 @@ func Test_updateEdgeGroups(t *testing.T) {
Endpoints: make([]portainer.EndpointID, 0),
}
err := store.EdgeGroup().Create(group)
if err != nil {
if err := store.EdgeGroup().Create(group); err != nil {
return nil, err
}
@ -42,6 +40,7 @@ func Test_updateEdgeGroups(t *testing.T) {
return
}
}
is.Fail("expected endpoint to be in group")
}
}
@ -52,6 +51,7 @@ func Test_updateEdgeGroups(t *testing.T) {
for j, tag := range groups {
if tag.Name == tagName {
result[i] = groups[j]
break
}
}
@ -88,6 +88,7 @@ func Test_updateEdgeGroups(t *testing.T) {
}
expectedGroups := groupsByName(groups, testCase.groupsToApply)
expectedIDs := make([]portainer.EdgeGroupID, len(expectedGroups))
for i, tag := range expectedGroups {
expectedIDs[i] = tag.ID

View file

@ -4,7 +4,7 @@ import (
"github.com/pkg/errors"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/set"
"github.com/portainer/portainer/api/set"
)
// updateEnvironmentTags updates the tags associated to an environment