1
0
Fork 0
mirror of https://github.com/portainer/portainer.git synced 2025-07-24 15:59:41 +02:00

feat(api): ping the endpoint at creation time (#1817)

This commit is contained in:
Anthony Lapenna 2018-04-16 13:19:24 +02:00 committed by GitHub
parent 031b428e0c
commit 05d6abf57b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 270 additions and 109 deletions

View file

@ -1,7 +1,13 @@
package handler
import (
"bytes"
"crypto/tls"
"strings"
"time"
"github.com/portainer/portainer"
"github.com/portainer/portainer/crypto"
httperror "github.com/portainer/portainer/http/error"
"github.com/portainer/portainer/http/proxy"
"github.com/portainer/portainer/http/security"
@ -56,15 +62,6 @@ func NewEndpointHandler(bouncer *security.RequestBouncer, authorizeEndpointManag
}
type (
postEndpointsRequest struct {
Name string `valid:"required"`
URL string `valid:"required"`
PublicURL string `valid:"-"`
TLS bool
TLSSkipVerify bool
TLSSkipClientVerify bool
}
postEndpointsResponse struct {
ID int `json:"Id"`
}
@ -82,6 +79,18 @@ type (
TLSSkipVerify bool `valid:"-"`
TLSSkipClientVerify bool `valid:"-"`
}
postEndpointPayload struct {
name string
url string
publicURL string
useTLS bool
skipTLSServerVerification bool
skipTLSClientVerification bool
caCert []byte
cert []byte
key []byte
}
)
// handleGetEndpoints handles GET requests on /endpoints
@ -107,32 +116,48 @@ func (handler *EndpointHandler) handleGetEndpoints(w http.ResponseWriter, r *htt
encodeJSON(w, filteredEndpoints, handler.Logger)
}
// handlePostEndpoints handles POST requests on /endpoints
func (handler *EndpointHandler) handlePostEndpoints(w http.ResponseWriter, r *http.Request) {
if !handler.authorizeEndpointManagement {
httperror.WriteErrorResponse(w, ErrEndpointManagementDisabled, http.StatusServiceUnavailable, handler.Logger)
return
func sendPingRequest(host string, tlsConfig *tls.Config) error {
transport := &http.Transport{}
scheme := "http"
if tlsConfig != nil {
transport.TLSClientConfig = tlsConfig
scheme = "https"
}
var req postEndpointsRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httperror.WriteErrorResponse(w, ErrInvalidJSON, http.StatusBadRequest, handler.Logger)
return
client := &http.Client{
Timeout: time.Second * 3,
Transport: transport,
}
_, err := govalidator.ValidateStruct(req)
pingOperationURL := strings.Replace(host, "tcp://", scheme+"://", 1) + "/_ping"
_, err := client.Get(pingOperationURL)
if err != nil {
httperror.WriteErrorResponse(w, ErrInvalidRequestFormat, http.StatusBadRequest, handler.Logger)
return
return err
}
return nil
}
func (handler *EndpointHandler) createTLSSecuredEndpoint(payload *postEndpointPayload) (*portainer.Endpoint, error) {
tlsConfig, err := crypto.CreateTLSConfig(payload.caCert, payload.cert, payload.key, payload.skipTLSClientVerification, payload.skipTLSServerVerification)
if err != nil {
return nil, err
}
err = sendPingRequest(payload.url, tlsConfig)
if err != nil {
return nil, err
}
endpoint := &portainer.Endpoint{
Name: req.Name,
URL: req.URL,
PublicURL: req.PublicURL,
Name: payload.name,
URL: payload.url,
PublicURL: payload.publicURL,
TLSConfig: portainer.TLSConfiguration{
TLS: req.TLS,
TLSSkipVerify: req.TLSSkipVerify,
TLS: payload.useTLS,
TLSSkipVerify: payload.skipTLSServerVerification,
},
AuthorizedUsers: []portainer.UserID{},
AuthorizedTeams: []portainer.TeamID{},
@ -141,30 +166,147 @@ func (handler *EndpointHandler) handlePostEndpoints(w http.ResponseWriter, r *ht
err = handler.EndpointService.CreateEndpoint(endpoint)
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusInternalServerError, handler.Logger)
return nil, err
}
folder := strconv.Itoa(int(endpoint.ID))
if !payload.skipTLSServerVerification {
r := bytes.NewReader(payload.caCert)
// TODO: review the API exposed by the FileService to store
// a file from a byte slice and return the path to the stored file instead
// of using multiple legacy calls (StoreTLSFile, GetPathForTLSFile) here.
err = handler.FileService.StoreTLSFile(folder, portainer.TLSFileCA, r)
if err != nil {
handler.EndpointService.DeleteEndpoint(endpoint.ID)
return nil, err
}
caCertPath, _ := handler.FileService.GetPathForTLSFile(folder, portainer.TLSFileCA)
endpoint.TLSConfig.TLSCACertPath = caCertPath
}
if !payload.skipTLSClientVerification {
r := bytes.NewReader(payload.cert)
err = handler.FileService.StoreTLSFile(folder, portainer.TLSFileCert, r)
if err != nil {
handler.EndpointService.DeleteEndpoint(endpoint.ID)
return nil, err
}
certPath, _ := handler.FileService.GetPathForTLSFile(folder, portainer.TLSFileCert)
endpoint.TLSConfig.TLSCertPath = certPath
r = bytes.NewReader(payload.key)
err = handler.FileService.StoreTLSFile(folder, portainer.TLSFileKey, r)
if err != nil {
handler.EndpointService.DeleteEndpoint(endpoint.ID)
return nil, err
}
keyPath, _ := handler.FileService.GetPathForTLSFile(folder, portainer.TLSFileKey)
endpoint.TLSConfig.TLSKeyPath = keyPath
}
err = handler.EndpointService.UpdateEndpoint(endpoint.ID, endpoint)
if err != nil {
return nil, err
}
return endpoint, nil
}
func (handler *EndpointHandler) createUnsecuredEndpoint(payload *postEndpointPayload) (*portainer.Endpoint, error) {
if !strings.HasPrefix(payload.url, "unix://") {
err := sendPingRequest(payload.url, nil)
if err != nil {
return nil, err
}
}
endpoint := &portainer.Endpoint{
Name: payload.name,
URL: payload.url,
PublicURL: payload.publicURL,
TLSConfig: portainer.TLSConfiguration{
TLS: false,
},
AuthorizedUsers: []portainer.UserID{},
AuthorizedTeams: []portainer.TeamID{},
Extensions: []portainer.EndpointExtension{},
}
err := handler.EndpointService.CreateEndpoint(endpoint)
if err != nil {
return nil, err
}
return endpoint, nil
}
func (handler *EndpointHandler) createEndpoint(payload *postEndpointPayload) (*portainer.Endpoint, error) {
if payload.useTLS {
return handler.createTLSSecuredEndpoint(payload)
}
return handler.createUnsecuredEndpoint(payload)
}
func convertPostEndpointRequestToPayload(r *http.Request) (*postEndpointPayload, error) {
payload := &postEndpointPayload{}
payload.name = r.FormValue("Name")
payload.url = r.FormValue("URL")
payload.publicURL = r.FormValue("PublicURL")
if payload.name == "" || payload.url == "" {
return nil, ErrInvalidRequestFormat
}
payload.useTLS = r.FormValue("TLS") == "true"
if payload.useTLS {
payload.skipTLSServerVerification = r.FormValue("TLSSkipVerify") == "true"
payload.skipTLSClientVerification = r.FormValue("TLSSkipClientVerify") == "true"
if !payload.skipTLSServerVerification {
caCert, err := getUploadedFileContent(r, "TLSCACertFile")
if err != nil {
return nil, err
}
payload.caCert = caCert
}
if !payload.skipTLSClientVerification {
cert, err := getUploadedFileContent(r, "TLSCertFile")
if err != nil {
return nil, err
}
payload.cert = cert
key, err := getUploadedFileContent(r, "TLSKeyFile")
if err != nil {
return nil, err
}
payload.key = key
}
}
return payload, nil
}
// handlePostEndpoints handles POST requests on /endpoints
func (handler *EndpointHandler) handlePostEndpoints(w http.ResponseWriter, r *http.Request) {
if !handler.authorizeEndpointManagement {
httperror.WriteErrorResponse(w, ErrEndpointManagementDisabled, http.StatusServiceUnavailable, handler.Logger)
return
}
if req.TLS {
folder := strconv.Itoa(int(endpoint.ID))
payload, err := convertPostEndpointRequestToPayload(r)
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusBadRequest, handler.Logger)
return
}
if !req.TLSSkipVerify {
caCertPath, _ := handler.FileService.GetPathForTLSFile(folder, portainer.TLSFileCA)
endpoint.TLSConfig.TLSCACertPath = caCertPath
}
if !req.TLSSkipClientVerify {
certPath, _ := handler.FileService.GetPathForTLSFile(folder, portainer.TLSFileCert)
endpoint.TLSConfig.TLSCertPath = certPath
keyPath, _ := handler.FileService.GetPathForTLSFile(folder, portainer.TLSFileKey)
endpoint.TLSConfig.TLSKeyPath = keyPath
}
err = handler.EndpointService.UpdateEndpoint(endpoint.ID, endpoint)
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusInternalServerError, handler.Logger)
return
}
endpoint, err := handler.createEndpoint(payload)
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusInternalServerError, handler.Logger)
return
}
encodeJSON(w, &postEndpointsResponse{ID: int(endpoint.ID)}, handler.Logger)