From 5ee6efb145cecbe87a6e8fe147cc87d9d7addca3 Mon Sep 17 00:00:00 2001 From: Dakota Walsh <101994734+dakota-portainer@users.noreply.github.com> Date: Thu, 1 Feb 2024 11:41:32 +1300 Subject: [PATCH] fix(backup): restore over network share EE-6578 (#11044) --- api/backup/backup.go | 10 ++-------- api/dataservices/interface.go | 4 +--- api/datastore/backup.go | 11 +++++++++-- api/datastore/backup_test.go | 6 +++--- api/datastore/datastore.go | 2 +- api/datastore/migrate_data.go | 3 +-- api/datastore/migrate_data_test.go | 6 +++--- api/internal/testhelpers/datastore.go | 18 +++++++++++++++--- 8 files changed, 35 insertions(+), 25 deletions(-) diff --git a/api/backup/backup.go b/api/backup/backup.go index 1fd50565c..8670ae738 100644 --- a/api/backup/backup.go +++ b/api/backup/backup.go @@ -82,14 +82,8 @@ func CreateBackupArchive(password string, gate *offlinegate.OfflineGate, datasto } func backupDb(backupDirPath string, datastore dataservices.DataStore) error { - backupWriter, err := os.Create(filepath.Join(backupDirPath, "portainer.db")) - if err != nil { - return err - } - if err = datastore.BackupTo(backupWriter); err != nil { - return err - } - return backupWriter.Close() + _, err := datastore.Backup(filepath.Join(backupDirPath, "portainer.db")) + return err } func encrypt(path string, passphrase string) (string, error) { diff --git a/api/dataservices/interface.go b/api/dataservices/interface.go index d2e81053b..22e233ced 100644 --- a/api/dataservices/interface.go +++ b/api/dataservices/interface.go @@ -1,8 +1,6 @@ package dataservices import ( - "io" - portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/database/models" ) @@ -46,7 +44,7 @@ type ( MigrateData() error Rollback(force bool) error CheckCurrentEdition() error - BackupTo(w io.Writer) error + Backup(path string) (string, error) Export(filename string) (err error) DataStoreTx diff --git a/api/datastore/backup.go b/api/datastore/backup.go index 5c198f775..17388b3e6 100644 --- a/api/datastore/backup.go +++ b/api/datastore/backup.go @@ -9,12 +9,19 @@ import ( "github.com/rs/zerolog/log" ) -func (store *Store) Backup() (string, error) { +// Backup takes an optional output path and creates a backup of the database. +// The database connection is stopped before running the backup to avoid any +// corruption and if a path is not given a default is used. +// The path or an error are returned. +func (store *Store) Backup(path string) (string, error) { if err := store.createBackupPath(); err != nil { return "", err } backupFilename := store.backupFilename() + if path != "" { + backupFilename = path + } log.Info().Str("from", store.connection.GetDatabaseFilePath()).Str("to", backupFilename).Msgf("Backing up database") // Close the store before backing up @@ -69,7 +76,7 @@ func (store *Store) RestoreFromFile(backupFilename string) error { func (store *Store) createBackupPath() error { backupDir := path.Join(store.connection.GetStorePath(), "backups") if exists, _ := store.fileService.FileExists(backupDir); !exists { - if err := os.MkdirAll(backupDir, 0700); err != nil { + if err := os.MkdirAll(backupDir, 0o700); err != nil { return fmt.Errorf("unable to create backup folder: %w", err) } } diff --git a/api/datastore/backup_test.go b/api/datastore/backup_test.go index b09c3cf89..98c8e5829 100644 --- a/api/datastore/backup_test.go +++ b/api/datastore/backup_test.go @@ -39,7 +39,7 @@ func TestBackup(t *testing.T) { SchemaVersion: portainer.APIVersion, } store.VersionService.UpdateVersion(&v) - store.Backup() + store.Backup("") if !isFileExist(backupFileName) { t.Errorf("Expect backup file to be created %s", backupFileName) @@ -55,7 +55,7 @@ func TestRestore(t *testing.T) { updateEdition(store, portainer.PortainerCE) updateVersion(store, "2.4") - store.Backup() + store.Backup("") updateVersion(store, "2.16") testVersion(store, "2.16", t) store.Restore() @@ -68,7 +68,7 @@ func TestRestore(t *testing.T) { // override and set initial db version and edition updateEdition(store, portainer.PortainerCE) updateVersion(store, "2.4") - store.Backup() + store.Backup("") updateVersion(store, "2.14") updateVersion(store, "2.16") testVersion(store, "2.16", t) diff --git a/api/datastore/datastore.go b/api/datastore/datastore.go index ceebbd0b0..643a90108 100644 --- a/api/datastore/datastore.go +++ b/api/datastore/datastore.go @@ -31,7 +31,7 @@ func (store *Store) Open() (newStore bool, err error) { } if encryptionReq { - backupFilename, err := store.Backup() + backupFilename, err := store.Backup("") if err != nil { return false, fmt.Errorf("failed to backup database prior to encrypting: %w", err) } diff --git a/api/datastore/migrate_data.go b/api/datastore/migrate_data.go index 3cbfe2521..f430051e8 100644 --- a/api/datastore/migrate_data.go +++ b/api/datastore/migrate_data.go @@ -40,7 +40,7 @@ func (store *Store) MigrateData() error { } // before we alter anything in the DB, create a backup - _, err = store.Backup() + _, err = store.Backup("") if err != nil { return errors.Wrap(err, "while backing up database") } @@ -131,7 +131,6 @@ func (store *Store) FailSafeMigrate(migrator *migrator.Migrator, version *models // Rollback to a pre-upgrade backup copy/snapshot of portainer.db func (store *Store) connectionRollback(force bool) error { - if !force { confirmed, err := cli.Confirm("Are you sure you want to rollback your database to the previous backup?") if err != nil || !confirmed { diff --git a/api/datastore/migrate_data_test.go b/api/datastore/migrate_data_test.go index 6e0662a2b..447090501 100644 --- a/api/datastore/migrate_data_test.go +++ b/api/datastore/migrate_data_test.go @@ -165,7 +165,7 @@ func TestRollback(t *testing.T) { _, store := MustNewTestStore(t, false, false) store.VersionService.UpdateVersion(&v) - _, err := store.Backup() + _, err := store.Backup("") if err != nil { log.Fatal().Err(err).Msg("") } @@ -199,7 +199,7 @@ func TestRollback(t *testing.T) { _, store := MustNewTestStore(t, true, false) store.VersionService.UpdateVersion(&v) - _, err := store.Backup() + _, err := store.Backup("") if err != nil { log.Fatal().Err(err).Msg("") } @@ -305,7 +305,7 @@ func migrateDBTestHelper(t *testing.T, srcPath, wantPath string, overrideInstanc os.WriteFile( gotPath, gotJSON, - 0600, + 0o600, ) t.Errorf( "migrate data from %s to %s failed\nwrote migrated input to %s\nmismatch (-want +got):\n%s", diff --git a/api/internal/testhelpers/datastore.go b/api/internal/testhelpers/datastore.go index c684213bd..fdc5b5bd3 100644 --- a/api/internal/testhelpers/datastore.go +++ b/api/internal/testhelpers/datastore.go @@ -1,7 +1,6 @@ package testhelpers import ( - "io" "time" portainer "github.com/portainer/portainer/api" @@ -37,7 +36,7 @@ type testDatastore struct { pendingActionsService dataservices.PendingActionsService } -func (d *testDatastore) BackupTo(io.Writer) error { return nil } +func (d *testDatastore) Backup(path string) (string, error) { return "", nil } func (d *testDatastore) Open() (bool, error) { return false, nil } func (d *testDatastore) Init() error { return nil } func (d *testDatastore) Close() error { return nil } @@ -57,9 +56,11 @@ func (d *testDatastore) EndpointGroup() dataservices.EndpointGroupService { re func (d *testDatastore) FDOProfile() dataservices.FDOProfileService { return d.fdoProfile } + func (d *testDatastore) EndpointRelation() dataservices.EndpointRelationService { return d.endpointRelation } + func (d *testDatastore) HelmUserRepository() dataservices.HelmUserRepositoryService { return d.helmUserRepository } @@ -94,6 +95,7 @@ func (d *testDatastore) IsErrObjectNotFound(e error) bool { func (d *testDatastore) Export(filename string) (err error) { return nil } + func (d *testDatastore) Import(filename string) (err error) { return nil } @@ -119,10 +121,12 @@ func (s *stubSettingsService) BucketName() string { return "settings" } func (s *stubSettingsService) Settings() (*portainer.Settings, error) { return s.settings, nil } + func (s *stubSettingsService) UpdateSettings(settings *portainer.Settings) error { s.settings = settings return nil } + func WithSettingsService(settings *portainer.Settings) datastoreOption { return func(d *testDatastore) { d.settings = &stubSettingsService{ @@ -162,15 +166,19 @@ func (s *stubEdgeJobService) ReadAll() ([]portainer.EdgeJob, error) { return s.j func (s *stubEdgeJobService) Read(ID portainer.EdgeJobID) (*portainer.EdgeJob, error) { return nil, nil } + func (s *stubEdgeJobService) Create(edgeJob *portainer.EdgeJob) error { return nil } + func (s *stubEdgeJobService) CreateWithID(ID portainer.EdgeJobID, edgeJob *portainer.EdgeJob) error { return nil } + func (s *stubEdgeJobService) Update(ID portainer.EdgeJobID, edgeJob *portainer.EdgeJob) error { return nil } + func (s *stubEdgeJobService) UpdateEdgeJobFunc(ID portainer.EdgeJobID, updateFunc func(edgeJob *portainer.EdgeJob)) error { return nil } @@ -192,6 +200,7 @@ func (s *stubEndpointRelationService) BucketName() string { return "endpoint_rel func (s *stubEndpointRelationService) EndpointRelations() ([]portainer.EndpointRelation, error) { return s.relations, nil } + func (s *stubEndpointRelationService) EndpointRelation(ID portainer.EndpointID) (*portainer.EndpointRelation, error) { for _, relation := range s.relations { if relation.EndpointID == ID { @@ -201,9 +210,11 @@ func (s *stubEndpointRelationService) EndpointRelation(ID portainer.EndpointID) return nil, errors.ErrObjectNotFound } + func (s *stubEndpointRelationService) Create(EndpointRelation *portainer.EndpointRelation) error { return nil } + func (s *stubEndpointRelationService) UpdateEndpointRelation(ID portainer.EndpointID, relation *portainer.EndpointRelation) error { for i, r := range s.relations { if r.EndpointID == ID { @@ -213,6 +224,7 @@ func (s *stubEndpointRelationService) UpdateEndpointRelation(ID portainer.Endpoi return nil } + func (s *stubEndpointRelationService) DeleteEndpointRelation(ID portainer.EndpointID) error { return nil } @@ -307,7 +319,7 @@ func (s *stubEndpointService) GetNextIdentifier() int { } func (s *stubEndpointService) EndpointsByTeamID(teamID portainer.TeamID) ([]portainer.Endpoint, error) { - var endpoints = make([]portainer.Endpoint, 0) + endpoints := make([]portainer.Endpoint, 0) for _, e := range s.endpoints { for t := range e.TeamAccessPolicies {