diff --git a/api/dataservices/base.go b/api/dataservices/base.go index 04af70b02..18839b60f 100644 --- a/api/dataservices/base.go +++ b/api/dataservices/base.go @@ -10,7 +10,7 @@ type BaseCRUD[T any, I constraints.Integer] interface { Create(element *T) error Read(ID I) (*T, error) Exists(ID I) (bool, error) - ReadAll() ([]T, error) + ReadAll(predicates ...func(T) bool) ([]T, error) Update(ID I, element *T) error Delete(ID I) error } @@ -56,12 +56,13 @@ func (service BaseDataService[T, I]) Exists(ID I) (bool, error) { return exists, err } -func (service BaseDataService[T, I]) ReadAll() ([]T, error) { +// ReadAll retrieves all the elements that satisfy all the provided predicates. +func (service BaseDataService[T, I]) ReadAll(predicates ...func(T) bool) ([]T, error) { var collection = make([]T, 0) return collection, service.Connection.ViewTx(func(tx portainer.Transaction) error { var err error - collection, err = service.Tx(tx).ReadAll() + collection, err = service.Tx(tx).ReadAll(predicates...) return err }) diff --git a/api/dataservices/base_test.go b/api/dataservices/base_test.go new file mode 100644 index 000000000..e97a09963 --- /dev/null +++ b/api/dataservices/base_test.go @@ -0,0 +1,92 @@ +package dataservices + +import ( + "strconv" + "testing" + + portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/slicesx" + + "github.com/stretchr/testify/require" +) + +type testObject struct { + ID int + Value int +} + +type mockConnection struct { + store map[int]testObject + + portainer.Connection +} + +func (m mockConnection) UpdateObject(bucket string, key []byte, value interface{}) error { + obj := value.(*testObject) + + m.store[obj.ID] = *obj + + return nil +} + +func (m mockConnection) GetAll(bucketName string, obj any, appendFn func(o any) (any, error)) error { + for _, v := range m.store { + if _, err := appendFn(&v); err != nil { + return err + } + } + + return nil +} + +func (m mockConnection) UpdateTx(fn func(portainer.Transaction) error) error { + return fn(m) +} + +func (m mockConnection) ViewTx(fn func(portainer.Transaction) error) error { + return fn(m) +} + +func (m mockConnection) ConvertToKey(v int) []byte { + return []byte(strconv.Itoa(v)) +} + +func TestReadAll(t *testing.T) { + service := BaseDataService[testObject, int]{ + Bucket: "testBucket", + Connection: mockConnection{store: make(map[int]testObject)}, + } + + data := []testObject{ + {ID: 1, Value: 1}, + {ID: 2, Value: 2}, + {ID: 3, Value: 3}, + {ID: 4, Value: 4}, + {ID: 5, Value: 5}, + } + + for _, item := range data { + err := service.Update(item.ID, &item) + require.NoError(t, err) + } + + // ReadAll without predicates + result, err := service.ReadAll() + require.NoError(t, err) + + expected := append([]testObject{}, data...) + + require.ElementsMatch(t, expected, result) + + // ReadAll with predicates + hasLowID := func(obj testObject) bool { return obj.ID < 3 } + isEven := func(obj testObject) bool { return obj.Value%2 == 0 } + + result, err = service.ReadAll(hasLowID, isEven) + require.NoError(t, err) + + expected = slicesx.Filter(expected, hasLowID) + expected = slicesx.Filter(expected, isEven) + + require.ElementsMatch(t, expected, result) +} diff --git a/api/dataservices/base_tx.go b/api/dataservices/base_tx.go index d9915b64c..5d7e7eee0 100644 --- a/api/dataservices/base_tx.go +++ b/api/dataservices/base_tx.go @@ -34,13 +34,32 @@ func (service BaseDataServiceTx[T, I]) Exists(ID I) (bool, error) { return service.Tx.KeyExists(service.Bucket, identifier) } -func (service BaseDataServiceTx[T, I]) ReadAll() ([]T, error) { +// ReadAll retrieves all the elements that satisfy all the provided predicates. +func (service BaseDataServiceTx[T, I]) ReadAll(predicates ...func(T) bool) ([]T, error) { var collection = make([]T, 0) + if len(predicates) == 0 { + return collection, service.Tx.GetAll( + service.Bucket, + new(T), + AppendFn(&collection), + ) + } + + filterFn := func(element T) bool { + for _, p := range predicates { + if !p(element) { + return false + } + } + + return true + } + return collection, service.Tx.GetAll( service.Bucket, new(T), - AppendFn(&collection), + FilterFn(&collection, filterFn), ) } diff --git a/api/internal/testhelpers/datastore.go b/api/internal/testhelpers/datastore.go index 392f21e97..19254f540 100644 --- a/api/internal/testhelpers/datastore.go +++ b/api/internal/testhelpers/datastore.go @@ -7,6 +7,7 @@ import ( "github.com/portainer/portainer/api/database" "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/dataservices/errors" + "github.com/portainer/portainer/api/slicesx" ) var _ dataservices.DataStore = &testDatastore{} @@ -152,8 +153,17 @@ type stubUserService struct { users []portainer.User } -func (s *stubUserService) BucketName() string { return "users" } -func (s *stubUserService) ReadAll() ([]portainer.User, error) { return s.users, nil } +func (s *stubUserService) BucketName() string { return "users" } +func (s *stubUserService) ReadAll(predicates ...func(portainer.User) bool) ([]portainer.User, error) { + filtered := s.users + + for _, p := range predicates { + filtered = slicesx.Filter(filtered, p) + } + + return filtered, nil +} + func (s *stubUserService) UsersByRole(role portainer.UserRole) ([]portainer.User, error) { return s.users, nil } @@ -171,8 +181,16 @@ type stubEdgeJobService struct { jobs []portainer.EdgeJob } -func (s *stubEdgeJobService) BucketName() string { return "edgejobs" } -func (s *stubEdgeJobService) ReadAll() ([]portainer.EdgeJob, error) { return s.jobs, nil } +func (s *stubEdgeJobService) BucketName() string { return "edgejobs" } +func (s *stubEdgeJobService) ReadAll(predicates ...func(portainer.EdgeJob) bool) ([]portainer.EdgeJob, error) { + filtered := s.jobs + + for _, p := range predicates { + filtered = slicesx.Filter(filtered, p) + } + + return filtered, nil +} // WithEdgeJobs option will instruct testDatastore to return provided jobs func WithEdgeJobs(js []portainer.EdgeJob) datastoreOption { @@ -362,8 +380,14 @@ func (s *stubStacksService) Read(ID portainer.StackID) (*portainer.Stack, error) return nil, errors.ErrObjectNotFound } -func (s *stubStacksService) ReadAll() ([]portainer.Stack, error) { - return s.stacks, nil +func (s *stubStacksService) ReadAll(predicates ...func(portainer.Stack) bool) ([]portainer.Stack, error) { + filtered := s.stacks + + for _, p := range predicates { + filtered = slicesx.Filter(filtered, p) + } + + return filtered, nil } func (s *stubStacksService) StacksByEndpointID(endpointID portainer.EndpointID) ([]portainer.Stack, error) {