From 1c55555ad03902175cbe9aa86584ca3ffe4dfc72 Mon Sep 17 00:00:00 2001 From: andres-portainer <91705312+andres-portainer@users.noreply.github.com> Date: Mon, 27 Apr 2026 12:32:44 -0300 Subject: [PATCH] chore(tests): increase code coverage BE-12877 (#2431) --- api/agent/version_test.go | 119 ++++++++++ api/backup/backup_test.go | 274 ++++++++++++++++++++++++ api/concurrent/concurrent_test.go | 149 +++++++++++++ api/logoutcontext/logoutcontext_test.go | 106 +++++++++ api/logs/log_test.go | 111 ++++++++++ api/platform/platform_test.go | 155 ++++++++++++++ api/set/set_test.go | 203 ++++++++++++++++++ api/url/url_test.go | 67 ++++++ 8 files changed, 1184 insertions(+) create mode 100644 api/agent/version_test.go create mode 100644 api/backup/backup_test.go create mode 100644 api/concurrent/concurrent_test.go create mode 100644 api/logoutcontext/logoutcontext_test.go create mode 100644 api/logs/log_test.go create mode 100644 api/platform/platform_test.go create mode 100644 api/set/set_test.go create mode 100644 api/url/url_test.go diff --git a/api/agent/version_test.go b/api/agent/version_test.go new file mode 100644 index 0000000000..ce51c2f5a1 --- /dev/null +++ b/api/agent/version_test.go @@ -0,0 +1,119 @@ +package agent + +import ( + "net/http" + "net/http/httptest" + "strconv" + "testing" + + portainer "github.com/portainer/portainer/api" + + "github.com/stretchr/testify/require" +) + +func tlsServer(t *testing.T, handler http.HandlerFunc) *httptest.Server { + t.Helper() + srv := httptest.NewTLSServer(handler) + t.Cleanup(srv.Close) + + return srv +} + +func TestGetAgentVersionAndPlatform_Success(t *testing.T) { + t.Parallel() + + srv := tlsServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(portainer.PortainerAgentHeader, "2.19.0") + w.Header().Set(portainer.HTTPResponseAgentPlatform, "1") + w.WriteHeader(http.StatusNoContent) + }) + + tlsCfg := srv.Client().Transport.(*http.Transport).TLSClientConfig + platform, version, err := GetAgentVersionAndPlatform(srv.URL, tlsCfg) + require.NoError(t, err) + require.Equal(t, portainer.AgentPlatformDocker, platform) + require.Equal(t, "2.19.0", version) +} + +func TestGetAgentVersionAndPlatform_NonOKStatus(t *testing.T) { + t.Parallel() + + srv := tlsServer(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + }) + + tlsCfg := srv.Client().Transport.(*http.Transport).TLSClientConfig + _, _, err := GetAgentVersionAndPlatform(srv.URL, tlsCfg) + require.Error(t, err) +} + +func TestGetAgentVersionAndPlatform_MissingVersionHeader(t *testing.T) { + t.Parallel() + + srv := tlsServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(portainer.HTTPResponseAgentPlatform, "1") + w.WriteHeader(http.StatusNoContent) + }) + + tlsCfg := srv.Client().Transport.(*http.Transport).TLSClientConfig + _, _, err := GetAgentVersionAndPlatform(srv.URL, tlsCfg) + require.Error(t, err) +} + +func TestGetAgentVersionAndPlatform_MissingPlatformHeader(t *testing.T) { + t.Parallel() + + srv := tlsServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(portainer.PortainerAgentHeader, "2.19.0") + w.WriteHeader(http.StatusNoContent) + }) + + tlsCfg := srv.Client().Transport.(*http.Transport).TLSClientConfig + _, _, err := GetAgentVersionAndPlatform(srv.URL, tlsCfg) + require.Error(t, err) +} + +func TestGetAgentVersionAndPlatform_InvalidPlatformZero(t *testing.T) { + t.Parallel() + + srv := tlsServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(portainer.PortainerAgentHeader, "2.19.0") + w.Header().Set(portainer.HTTPResponseAgentPlatform, "0") + w.WriteHeader(http.StatusNoContent) + }) + + tlsCfg := srv.Client().Transport.(*http.Transport).TLSClientConfig + _, _, err := GetAgentVersionAndPlatform(srv.URL, tlsCfg) + require.Error(t, err) +} + +func TestGetAgentVersionAndPlatform_NonNumericPlatform(t *testing.T) { + t.Parallel() + + srv := tlsServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(portainer.PortainerAgentHeader, "2.19.0") + w.Header().Set(portainer.HTTPResponseAgentPlatform, "docker") + w.WriteHeader(http.StatusNoContent) + }) + + tlsCfg := srv.Client().Transport.(*http.Transport).TLSClientConfig + _, _, err := GetAgentVersionAndPlatform(srv.URL, tlsCfg) + require.Error(t, err) +} + +func TestGetAgentVersionAndPlatform_PingPathAppended(t *testing.T) { + t.Parallel() + + var gotPath string + srv := tlsServer(t, func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + w.Header().Set(portainer.PortainerAgentHeader, "2.19.0") + w.Header().Set(portainer.HTTPResponseAgentPlatform, strconv.Itoa(int(portainer.AgentPlatformKubernetes))) + w.WriteHeader(http.StatusNoContent) + }) + + tlsCfg := srv.Client().Transport.(*http.Transport).TLSClientConfig + _, _, err := GetAgentVersionAndPlatform(srv.URL, tlsCfg) + require.NoError(t, err) + require.Equal(t, "/ping", gotPath) +} diff --git a/api/backup/backup_test.go b/api/backup/backup_test.go new file mode 100644 index 0000000000..43889df69c --- /dev/null +++ b/api/backup/backup_test.go @@ -0,0 +1,274 @@ +package backup + +import ( + "bytes" + "context" + "io" + "os" + "path/filepath" + "testing" + + "github.com/portainer/portainer/api/archive" + "github.com/portainer/portainer/api/crypto" + "github.com/portainer/portainer/api/datastore" + "github.com/portainer/portainer/api/filesystem" + "github.com/portainer/portainer/api/http/offlinegate" + "github.com/portainer/portainer/pkg/fips" + + "github.com/stretchr/testify/require" +) + +func init() { + fips.InitFIPS(false) +} + +func TestGetRestoreSourcePath_DBAtRoot(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + err := os.WriteFile(filesystem.JoinPaths(dir, "portainer.db"), []byte("db"), 0o600) + require.NoError(t, err) + + result, err := getRestoreSourcePath(dir) + require.NoError(t, err) + require.Equal(t, dir, result) +} + +func TestGetRestoreSourcePath_EncryptedDBAtRoot(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + err := os.WriteFile(filesystem.JoinPaths(dir, "portainer.edb"), []byte("db"), 0o600) + require.NoError(t, err) + + result, err := getRestoreSourcePath(dir) + require.NoError(t, err) + require.Equal(t, dir, result) +} + +func TestGetRestoreSourcePath_DBInSubdirectory(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + sub := filesystem.JoinPaths(dir, "backup-2024-01-01") + err := os.Mkdir(sub, 0o700) + require.NoError(t, err) + + err = os.WriteFile(filesystem.JoinPaths(sub, "portainer.db"), []byte("db"), 0o600) + require.NoError(t, err) + + result, err := getRestoreSourcePath(dir) + require.NoError(t, err) + require.Equal(t, sub, result) +} + +func TestGetRestoreSourcePath_NoDBFile(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + err := os.WriteFile(filesystem.JoinPaths(dir, "other.file"), []byte("data"), 0o600) + require.NoError(t, err) + + result, err := getRestoreSourcePath(dir) + require.NoError(t, err) + require.Equal(t, dir, result) +} + +func TestGetRestoreSourcePath_EmptyDir(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + + result, err := getRestoreSourcePath(dir) + require.NoError(t, err) + require.Equal(t, dir, result) +} + +func TestEncryptDecrypt_RoundTrip(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + plaintext := []byte("sensitive portainer backup data") + + srcPath := filesystem.JoinPaths(dir, "archive.tar.gz") + err := os.WriteFile(srcPath, plaintext, 0o600) + require.NoError(t, err) + + encryptedPath, err := encrypt(srcPath, "mysecretpassword") + require.NoError(t, err) + require.Equal(t, srcPath+".encrypted", encryptedPath) + + encryptedData, err := os.ReadFile(encryptedPath) + require.NoError(t, err) + + decryptedReader, err := crypto.AesDecrypt(bytes.NewReader(encryptedData), []byte("mysecretpassword")) + require.NoError(t, err) + + decrypted, err := io.ReadAll(decryptedReader) + require.NoError(t, err) + require.Equal(t, plaintext, decrypted) +} + +func TestEncryptDecrypt_WrongPassword(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + + srcPath := filesystem.JoinPaths(dir, "archive.tar.gz") + err := os.WriteFile(srcPath, []byte("data"), 0o600) + require.NoError(t, err) + + encryptedPath, err := encrypt(srcPath, "correctpassword") + require.NoError(t, err) + + encryptedData, err := os.ReadFile(encryptedPath) + require.NoError(t, err) + + _, err = crypto.AesDecrypt(bytes.NewReader(encryptedData), []byte("wrongpassword")) + require.Error(t, err) +} + +func TestCreateBackupArchive_NoPassword(t *testing.T) { + t.Parallel() + + _, store := datastore.MustNewTestStore(t, true, false) + storePath := store.GetConnection().GetStorePath() + gate := offlinegate.NewOfflineGate() + + archivePath, err := CreateBackupArchive("", gate, store, storePath) + require.NoError(t, err) + + f, err := os.Open(archivePath) + require.NoError(t, err) + t.Cleanup(func() { + err := f.Close() + require.NoError(t, err) + }) + + extractDir := t.TempDir() + err = archive.ExtractTarGz(f, extractDir) + require.NoError(t, err) + + dbFound := false + err = filepath.Walk(extractDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.Name() == "portainer.db" { + dbFound = true + } + + return nil + }) + require.NoError(t, err) + require.True(t, dbFound, "archive should contain portainer.db") +} + +func TestCreateBackupArchive_WithPassword(t *testing.T) { + t.Parallel() + + _, store := datastore.MustNewTestStore(t, true, false) + storePath := store.GetConnection().GetStorePath() + gate := offlinegate.NewOfflineGate() + + archivePath, err := CreateBackupArchive("backup-secret", gate, store, storePath) + require.NoError(t, err) + require.Contains(t, archivePath, ".encrypted") + + encryptedData, err := os.ReadFile(archivePath) + require.NoError(t, err) + + decryptedReader, err := crypto.AesDecrypt(bytes.NewReader(encryptedData), []byte("backup-secret")) + require.NoError(t, err) + + extractDir := t.TempDir() + err = archive.ExtractTarGz(decryptedReader, extractDir) + require.NoError(t, err) + + dbFound := false + err = filepath.Walk(extractDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.Name() == "portainer.db" { + dbFound = true + } + + return nil + }) + require.NoError(t, err) + require.True(t, dbFound, "decrypted archive should contain portainer.db") +} + +func TestRestoreArchive_NoPassword(t *testing.T) { + t.Parallel() + + _, store1 := datastore.MustNewTestStore(t, true, false) + storePath1 := store1.GetConnection().GetStorePath() + gate := offlinegate.NewOfflineGate() + + archivePath, err := CreateBackupArchive("", gate, store1, storePath1) + require.NoError(t, err) + + archiveData, err := os.ReadFile(archivePath) + require.NoError(t, err) + + _, store2 := datastore.MustNewTestStore(t, true, false) + storePath2 := store2.GetConnection().GetStorePath() + + ctx, cancel := context.WithCancel(t.Context()) + err = RestoreArchive(bytes.NewReader(archiveData), "", storePath2, gate, store2, cancel) + require.NoError(t, err) + + require.ErrorIs(t, ctx.Err(), context.Canceled) + + _, err = os.Stat(filesystem.JoinPaths(storePath2, "portainer.db")) + require.NoError(t, err) +} + +func TestRestoreArchive_WithPassword(t *testing.T) { + t.Parallel() + + _, store1 := datastore.MustNewTestStore(t, true, false) + storePath1 := store1.GetConnection().GetStorePath() + gate := offlinegate.NewOfflineGate() + + archivePath, err := CreateBackupArchive("restore-secret", gate, store1, storePath1) + require.NoError(t, err) + + archiveData, err := os.ReadFile(archivePath) + require.NoError(t, err) + + _, store2 := datastore.MustNewTestStore(t, true, false) + storePath2 := store2.GetConnection().GetStorePath() + + ctx, cancel := context.WithCancel(t.Context()) + err = RestoreArchive(bytes.NewReader(archiveData), "restore-secret", storePath2, gate, store2, cancel) + require.NoError(t, err) + + require.ErrorIs(t, ctx.Err(), context.Canceled) + + _, err = os.Stat(filesystem.JoinPaths(storePath2, "portainer.db")) + require.NoError(t, err) +} + +func TestRestoreArchive_WrongPassword(t *testing.T) { + t.Parallel() + + _, store1 := datastore.MustNewTestStore(t, true, false) + storePath1 := store1.GetConnection().GetStorePath() + gate := offlinegate.NewOfflineGate() + + archivePath, err := CreateBackupArchive("correct-password", gate, store1, storePath1) + require.NoError(t, err) + + archiveData, err := os.ReadFile(archivePath) + require.NoError(t, err) + + _, store2 := datastore.MustNewTestStore(t, true, false) + storePath2 := store2.GetConnection().GetStorePath() + + _, cancel := context.WithCancel(t.Context()) + err = RestoreArchive(bytes.NewReader(archiveData), "wrong-password", storePath2, gate, store2, cancel) + require.Error(t, err) +} diff --git a/api/concurrent/concurrent_test.go b/api/concurrent/concurrent_test.go new file mode 100644 index 0000000000..5809e3a8e9 --- /dev/null +++ b/api/concurrent/concurrent_test.go @@ -0,0 +1,149 @@ +package concurrent + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "testing/synctest" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRun_AllSucceed(t *testing.T) { + t.Parallel() + + fn1 := func(ctx context.Context) (any, error) { return "one", nil } + fn2 := func(ctx context.Context) (any, error) { return "two", nil } + fn3 := func(ctx context.Context) (any, error) { return "three", nil } + + results, err := Run(t.Context(), 0, fn1, fn2, fn3) + + require.NoError(t, err) + require.Len(t, results, 3) + + values := make([]string, 0, len(results)) + for _, r := range results { + values = append(values, r.Result.(string)) + } + require.ElementsMatch(t, []string{"one", "two", "three"}, values) +} + +func TestRun_OneError(t *testing.T) { + t.Parallel() + + sentinel := errors.New("task failed") + + fn1 := func(ctx context.Context) (any, error) { return "ok", nil } + fn2 := func(ctx context.Context) (any, error) { return nil, sentinel } + + _, err := Run(t.Context(), 0, fn1, fn2) + + require.ErrorIs(t, err, sentinel) +} + +func TestRun_NoTasks(t *testing.T) { + t.Parallel() + + results, err := Run(t.Context(), 0) + + require.NoError(t, err) + require.Empty(t, results) +} + +func TestRun_MaxConcurrency(t *testing.T) { + t.Parallel() + + const numTasks = 10 + var peak atomic.Int32 + var active atomic.Int32 + + task := func(ctx context.Context) (any, error) { + current := active.Add(1) + if current > peak.Load() { + peak.Store(current) + } + + time.Sleep(10 * time.Millisecond) + active.Add(-1) + + return nil, nil + } + + tasks := make([]Func, numTasks) + for i := range tasks { + tasks[i] = task + } + + synctest.Test(t, func(t *testing.T) { + results, err := Run(t.Context(), 3, tasks...) + require.NoError(t, err) + require.Len(t, results, numTasks) + require.LessOrEqual(t, peak.Load(), int32(3)) + }) +} + +func TestRun_ZeroConcurrencyUsesAllTasks(t *testing.T) { + t.Parallel() + + const numTasks = 5 + var peak atomic.Int32 + var active atomic.Int32 + + task := func(ctx context.Context) (any, error) { + current := active.Add(1) + if current > peak.Load() { + peak.Store(current) + } + + time.Sleep(20 * time.Millisecond) + active.Add(-1) + + return nil, nil + } + + tasks := make([]Func, numTasks) + for i := range tasks { + tasks[i] = task + } + + synctest.Test(t, func(t *testing.T) { + results, err := Run(t.Context(), 0, tasks...) + require.NoError(t, err) + require.Len(t, results, numTasks) + require.Equal(t, int32(numTasks), peak.Load()) + }) +} + +func TestRun_ContextCancelledBeforeStart(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(t.Context()) + cancel() + + called := atomic.Bool{} + fn := func(ctx context.Context) (any, error) { + called.Store(true) + return nil, ctx.Err() + } + + _, err := Run(ctx, 1, fn, fn, fn) + require.Error(t, err) +} + +func TestRun_ContextPassedToTasks(t *testing.T) { + t.Parallel() + + type key struct{} + ctx := context.WithValue(t.Context(), key{}, "testvalue") + + fn := func(ctx context.Context) (any, error) { + return ctx.Value(key{}), nil + } + + results, err := Run(ctx, 0, fn) + + require.NoError(t, err) + require.Equal(t, "testvalue", results[0].Result) +} diff --git a/api/logoutcontext/logoutcontext_test.go b/api/logoutcontext/logoutcontext_test.go new file mode 100644 index 0000000000..203d83b92b --- /dev/null +++ b/api/logoutcontext/logoutcontext_test.go @@ -0,0 +1,106 @@ +package logoutcontext + +import ( + "context" + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetContext_ReturnsActiveContext(t *testing.T) { + t.Parallel() + + token := "token-get-context-active" + defer Cancel(token) + + ctx := GetContext(token) + require.NoError(t, ctx.Err()) +} + +func TestCancel_CancelsContext(t *testing.T) { + t.Parallel() + + token := "token-cancel" + + ctx := GetContext(token) + require.NoError(t, ctx.Err()) + + Cancel(token) + + require.ErrorIs(t, ctx.Err(), context.Canceled) +} + +func TestCancel_RemovesService(t *testing.T) { + t.Parallel() + + token := "token-cancel-removes" + + first := GetContext(token) + Cancel(token) + + second := GetContext(token) + defer Cancel(token) + + require.ErrorIs(t, first.Err(), context.Canceled) + require.NoError(t, second.Err()) + require.NotEqual(t, first, second) +} + +func TestGetService_ReturnsSameServiceForSameToken(t *testing.T) { + t.Parallel() + + token := logoutToken("token-same-service") + defer RemoveService(token) + + s1 := GetService(token) + s2 := GetService(token) + + require.Same(t, s1, s2) +} + +func TestGetService_ReturnsDistinctServicesForDifferentTokens(t *testing.T) { + t.Parallel() + + tokenA := logoutToken("token-distinct-a") + tokenB := logoutToken("token-distinct-b") + defer RemoveService(tokenA) + defer RemoveService(tokenB) + + sA := GetService(tokenA) + sB := GetService(tokenB) + + require.NotSame(t, sA, sB) +} + +func TestGetService_ConcurrentAccess(t *testing.T) { + t.Parallel() + + const goroutines = 50 + token := logoutToken("token-concurrent") + defer RemoveService(token) + + var wg sync.WaitGroup + services := make([]*Service, goroutines) + + for i := range goroutines { + wg.Add(1) + go func(i int) { + defer wg.Done() + services[i] = GetService(token) + }(i) + } + + wg.Wait() + + for i := 1; i < goroutines; i++ { + require.Same(t, services[0], services[i]) + } +} + +func TestLogoutToken_AddsPrefix(t *testing.T) { + t.Parallel() + + result := logoutToken("abc123") + require.Equal(t, LogoutPrefix+"abc123", result) +} diff --git a/api/logs/log_test.go b/api/logs/log_test.go new file mode 100644 index 0000000000..e3034e8f3a --- /dev/null +++ b/api/logs/log_test.go @@ -0,0 +1,111 @@ +package logs + +import ( + "errors" + "testing" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + + "github.com/stretchr/testify/require" +) + +func saveGlobalLevel(t *testing.T) { + t.Helper() + orig := zerolog.GlobalLevel() + t.Cleanup(func() { zerolog.SetGlobalLevel(orig) }) +} + +func saveLogger(t *testing.T) { + t.Helper() + orig := log.Logger + t.Cleanup(func() { log.Logger = orig }) +} + +func TestSetLoggingLevel_Error(t *testing.T) { + saveGlobalLevel(t) + + SetLoggingLevel("ERROR") + require.Equal(t, zerolog.ErrorLevel, zerolog.GlobalLevel()) +} + +func TestSetLoggingLevel_Warn(t *testing.T) { + saveGlobalLevel(t) + + SetLoggingLevel("WARN") + require.Equal(t, zerolog.WarnLevel, zerolog.GlobalLevel()) +} + +func TestSetLoggingLevel_Info(t *testing.T) { + saveGlobalLevel(t) + + SetLoggingLevel("INFO") + require.Equal(t, zerolog.InfoLevel, zerolog.GlobalLevel()) +} + +func TestSetLoggingLevel_Debug(t *testing.T) { + saveGlobalLevel(t) + + SetLoggingLevel("DEBUG") + require.Equal(t, zerolog.DebugLevel, zerolog.GlobalLevel()) +} + +func TestSetLoggingLevel_UnknownLevelIsNoop(t *testing.T) { + saveGlobalLevel(t) + + zerolog.SetGlobalLevel(zerolog.InfoLevel) + SetLoggingLevel("TRACE") + require.Equal(t, zerolog.InfoLevel, zerolog.GlobalLevel()) +} + +func TestSetLoggingMode_Pretty(t *testing.T) { + saveLogger(t) + + SetLoggingMode("PRETTY") +} + +func TestSetLoggingMode_Nocolor(t *testing.T) { + saveLogger(t) + + SetLoggingMode("NOCOLOR") +} + +func TestSetLoggingMode_JSON(t *testing.T) { + saveLogger(t) + + SetLoggingMode("JSON") +} + +func TestSetLoggingMode_UnknownModeIsNoop(t *testing.T) { + saveLogger(t) + + SetLoggingMode("UNKNOWN") +} + +func TestFormatMessage_NonNil(t *testing.T) { + t.Parallel() + + require.Equal(t, "hello |", formatMessage("hello")) +} + +func TestFormatMessage_Nil(t *testing.T) { + t.Parallel() + + require.Empty(t, formatMessage(nil)) +} + +type stubCloser struct{ err error } + +func (s *stubCloser) Close() error { return s.err } + +func TestCloseAndLogErr_Success(t *testing.T) { + t.Parallel() + + CloseAndLogErr(&stubCloser{err: nil}) +} + +func TestCloseAndLogErr_Error(t *testing.T) { + t.Parallel() + + CloseAndLogErr(&stubCloser{err: errors.New("close failed")}) +} diff --git a/api/platform/platform_test.go b/api/platform/platform_test.go new file mode 100644 index 0000000000..be186212ee --- /dev/null +++ b/api/platform/platform_test.go @@ -0,0 +1,155 @@ +package platform + +import ( + "testing" + + portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/internal/testhelpers" + + "github.com/stretchr/testify/require" +) + +func TestDetermineContainerPlatform_Podman(t *testing.T) { + t.Setenv(PodmanMode, "1") + + require.Equal(t, PlatformPodman, DetermineContainerPlatform()) +} + +func TestDetermineContainerPlatform_Kubernetes(t *testing.T) { + t.Setenv(KubernetesServiceHost, "10.96.0.1") + + require.Equal(t, PlatformKubernetes, DetermineContainerPlatform()) +} + +func TestDetermineContainerPlatform_PodmanTakesPrecedenceOverKubernetes(t *testing.T) { + t.Setenv(PodmanMode, "1") + t.Setenv(KubernetesServiceHost, "10.96.0.1") + + require.Equal(t, PlatformPodman, DetermineContainerPlatform()) +} + +func TestCheckDockerEnvTypeForUpgrade_UnixSocket(t *testing.T) { + t.Parallel() + + endpoint := &portainer.Endpoint{URL: "unix:///var/run/docker.sock"} + require.Equal(t, PlatformDockerStandalone, checkDockerEnvTypeForUpgrade(endpoint)) +} + +func TestCheckDockerEnvTypeForUpgrade_Npipe(t *testing.T) { + t.Parallel() + + endpoint := &portainer.Endpoint{URL: "npipe:////./pipe/docker_engine", Type: portainer.DockerEnvironment} + require.Equal(t, PlatformDockerStandalone, checkDockerEnvTypeForUpgrade(endpoint)) +} + +func TestCheckDockerEnvTypeForUpgrade_Swarm(t *testing.T) { + t.Parallel() + + endpoint := &portainer.Endpoint{URL: "tcp://tasks.portainer_agent:9001"} + require.Equal(t, PlatformDockerSwarm, checkDockerEnvTypeForUpgrade(endpoint)) +} + +func TestCheckDockerEnvTypeForUpgrade_RemoteTCP(t *testing.T) { + t.Parallel() + + endpoint := &portainer.Endpoint{URL: "tcp://192.168.1.100:2376"} + require.Equal(t, ContainerPlatform(""), checkDockerEnvTypeForUpgrade(endpoint)) +} + +func TestDetectLocalEnvironment_UnsupportedPlatform(t *testing.T) { + t.Setenv(PodmanMode, "1") + t.Setenv(KubernetesServiceHost, "") + + ds := testhelpers.NewDatastore(testhelpers.WithEndpoints([]portainer.Endpoint{ + {ID: 1, Type: portainer.DockerEnvironment}, + })) + + endpoint, platform, err := detectLocalEnvironment(ds) + require.NoError(t, err) + require.Nil(t, endpoint) + require.Empty(t, platform) +} + +func TestDetectLocalEnvironment_NoEndpoints(t *testing.T) { + t.Setenv(KubernetesServiceHost, "10.96.0.1") + t.Setenv(PodmanMode, "") + + ds := testhelpers.NewDatastore(testhelpers.WithEndpoints([]portainer.Endpoint{})) + + endpoint, platform, err := detectLocalEnvironment(ds) + require.NoError(t, err) + require.Nil(t, endpoint) + require.Empty(t, platform) +} + +func TestDetectLocalEnvironment_KubernetesEndpointFound(t *testing.T) { + t.Setenv(KubernetesServiceHost, "10.96.0.1") + t.Setenv(PodmanMode, "") + + kube := portainer.Endpoint{ID: 1, Name: "local-k8s", Type: portainer.KubernetesLocalEnvironment} + ds := testhelpers.NewDatastore(testhelpers.WithEndpoints([]portainer.Endpoint{kube})) + + endpoint, platform, err := detectLocalEnvironment(ds) + require.NoError(t, err) + require.NotNil(t, endpoint) + require.Equal(t, portainer.EndpointID(1), endpoint.ID) + require.Equal(t, PlatformKubernetes, platform) +} + +func TestDetectLocalEnvironment_NoMatchingEndpointType(t *testing.T) { + t.Setenv(KubernetesServiceHost, "10.96.0.1") + t.Setenv(PodmanMode, "") + + docker := portainer.Endpoint{ID: 1, Type: portainer.DockerEnvironment} + ds := testhelpers.NewDatastore(testhelpers.WithEndpoints([]portainer.Endpoint{docker})) + + _, _, err := detectLocalEnvironment(ds) + require.ErrorIs(t, err, ErrNoLocalEnvironment) +} + +func TestService_GetPlatform(t *testing.T) { + t.Setenv(KubernetesServiceHost, "10.96.0.1") + t.Setenv(PodmanMode, "") + + kube := portainer.Endpoint{ID: 1, Type: portainer.KubernetesLocalEnvironment} + ds := testhelpers.NewDatastore(testhelpers.WithEndpoints([]portainer.Endpoint{kube})) + + svc := NewService(ds) + + platform, err := svc.GetPlatform() + require.NoError(t, err) + require.Equal(t, PlatformKubernetes, platform) +} + +func TestService_GetLocalEnvironment(t *testing.T) { + t.Setenv(KubernetesServiceHost, "10.96.0.1") + t.Setenv(PodmanMode, "") + + kube := portainer.Endpoint{ID: 1, Type: portainer.KubernetesLocalEnvironment} + ds := testhelpers.NewDatastore(testhelpers.WithEndpoints([]portainer.Endpoint{kube})) + + svc := NewService(ds) + + env, err := svc.GetLocalEnvironment() + require.NoError(t, err) + require.NotNil(t, env) + require.Equal(t, portainer.EndpointID(1), env.ID) +} + +func TestService_CachesLoadedEnvironment(t *testing.T) { + t.Setenv(KubernetesServiceHost, "10.96.0.1") + t.Setenv(PodmanMode, "") + + kube := portainer.Endpoint{ID: 1, Type: portainer.KubernetesLocalEnvironment} + ds := testhelpers.NewDatastore(testhelpers.WithEndpoints([]portainer.Endpoint{kube})) + + svc := NewService(ds) + + env1, err := svc.GetLocalEnvironment() + require.NoError(t, err) + + env2, err := svc.GetLocalEnvironment() + require.NoError(t, err) + + require.Same(t, env1, env2) +} diff --git a/api/set/set_test.go b/api/set/set_test.go new file mode 100644 index 0000000000..e1093405f6 --- /dev/null +++ b/api/set/set_test.go @@ -0,0 +1,203 @@ +package set + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAdd(t *testing.T) { + t.Parallel() + + s := make(Set[int]) + s.Add(1) + s.Add(2) + s.Add(2) + + require.Equal(t, 2, s.Len()) + require.True(t, s.Contains(1)) + require.True(t, s.Contains(2)) +} + +func TestContains(t *testing.T) { + t.Parallel() + + s := make(Set[string]) + s.Add("hello") + + require.True(t, s.Contains("hello")) + require.False(t, s.Contains("world")) +} + +func TestRemove(t *testing.T) { + t.Parallel() + + s := make(Set[int]) + s.Add(1) + s.Add(2) + s.Remove(1) + s.Remove(99) + + require.Equal(t, 1, s.Len()) + require.False(t, s.Contains(1)) + require.True(t, s.Contains(2)) +} + +func TestIsEmpty(t *testing.T) { + t.Parallel() + + s := make(Set[int]) + require.True(t, s.IsEmpty()) + + s.Add(1) + require.False(t, s.IsEmpty()) + + s.Remove(1) + require.True(t, s.IsEmpty()) +} + +func TestKeys(t *testing.T) { + t.Parallel() + + s := ToSet([]int{1, 2, 3}) + keys := s.Keys() + + require.Len(t, keys, 3) + require.ElementsMatch(t, []int{1, 2, 3}, keys) +} + +func TestCopy(t *testing.T) { + t.Parallel() + + original := ToSet([]string{"a", "b", "c"}) + copied := original.Copy() + + require.Equal(t, original.Len(), copied.Len()) + require.True(t, copied.Contains("a")) + require.True(t, copied.Contains("b")) + require.True(t, copied.Contains("c")) + + copied.Add("d") + require.False(t, original.Contains("d")) + + copied.Remove("a") + require.True(t, original.Contains("a")) +} + +func TestDifference(t *testing.T) { + t.Parallel() + + a := ToSet([]int{1, 2, 3, 4}) + b := ToSet([]int{3, 4, 5}) + + diff := a.Difference(b) + + require.Equal(t, 2, diff.Len()) + require.True(t, diff.Contains(1)) + require.True(t, diff.Contains(2)) + require.False(t, diff.Contains(3)) + require.False(t, diff.Contains(4)) +} + +func TestDifference_EmptySecond(t *testing.T) { + t.Parallel() + + a := ToSet([]int{1, 2, 3}) + b := make(Set[int]) + + diff := a.Difference(b) + + require.Equal(t, 3, diff.Len()) +} + +func TestDifference_EmptyFirst(t *testing.T) { + t.Parallel() + + a := make(Set[int]) + b := ToSet([]int{1, 2, 3}) + + diff := a.Difference(b) + + require.True(t, diff.IsEmpty()) +} + +func TestUnion(t *testing.T) { + t.Parallel() + + a := ToSet([]int{1, 2}) + b := ToSet([]int{2, 3}) + c := ToSet([]int{3, 4}) + + u := Union(a, b, c) + + require.Equal(t, 4, u.Len()) + require.True(t, u.Contains(1)) + require.True(t, u.Contains(2)) + require.True(t, u.Contains(3)) + require.True(t, u.Contains(4)) +} + +func TestUnion_NoSets(t *testing.T) { + t.Parallel() + + u := Union[int]() + require.True(t, u.IsEmpty()) +} + +func TestIntersection(t *testing.T) { + t.Parallel() + + a := ToSet([]int{1, 2, 3}) + b := ToSet([]int{2, 3, 4}) + c := ToSet([]int{3, 4, 5}) + + inter := Intersection(a, b, c) + + require.Equal(t, 1, inter.Len()) + require.True(t, inter.Contains(3)) +} + +func TestIntersection_NoOverlap(t *testing.T) { + t.Parallel() + + a := ToSet([]int{1, 2}) + b := ToSet([]int{3, 4}) + + inter := Intersection(a, b) + + require.True(t, inter.IsEmpty()) +} + +func TestIntersection_NoSets(t *testing.T) { + t.Parallel() + + inter := Intersection[int]() + require.True(t, inter.IsEmpty()) +} + +func TestIntersection_SingleSet(t *testing.T) { + t.Parallel() + + a := ToSet([]int{1, 2, 3}) + inter := Intersection(a) + + require.Equal(t, 3, inter.Len()) +} + +func TestToSet(t *testing.T) { + t.Parallel() + + keys := []string{"x", "y", "x"} + s := ToSet(keys) + + require.Equal(t, 2, s.Len()) + require.True(t, s.Contains("x")) + require.True(t, s.Contains("y")) +} + +func TestToSet_Empty(t *testing.T) { + t.Parallel() + + s := ToSet([]int{}) + require.True(t, s.IsEmpty()) +} diff --git a/api/url/url_test.go b/api/url/url_test.go new file mode 100644 index 0000000000..97eca24646 --- /dev/null +++ b/api/url/url_test.go @@ -0,0 +1,67 @@ +package url + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseURL_NoPrefix(t *testing.T) { + t.Parallel() + + u, err := ParseURL("192.168.1.1:9000") + require.NoError(t, err) + require.Equal(t, "192.168.1.1:9000", u.Host) +} + +func TestParseURL_HTTPPrefix(t *testing.T) { + t.Parallel() + + u, err := ParseURL("http://example.com:8080") + require.NoError(t, err) + require.Equal(t, "http", u.Scheme) + require.Equal(t, "example.com:8080", u.Host) +} + +func TestParseURL_HTTPSPrefix(t *testing.T) { + t.Parallel() + + u, err := ParseURL("https://example.com") + require.NoError(t, err) + require.Equal(t, "https", u.Scheme) + require.Equal(t, "example.com", u.Host) +} + +func TestParseURL_TCPPrefix(t *testing.T) { + t.Parallel() + + u, err := ParseURL("tcp://192.168.1.1:2376") + require.NoError(t, err) + require.Equal(t, "tcp", u.Scheme) + require.Equal(t, "192.168.1.1:2376", u.Host) +} + +func TestParseURL_SlashSlashPrefix(t *testing.T) { + t.Parallel() + + u, err := ParseURL("//192.168.1.1:2376") + require.NoError(t, err) + require.Equal(t, "192.168.1.1:2376", u.Host) +} + +func TestParseURL_UnixPrefix(t *testing.T) { + t.Parallel() + + u, err := ParseURL("unix:///var/run/docker.sock") + require.NoError(t, err) + require.Equal(t, "unix", u.Scheme) + require.Equal(t, "/var/run/docker.sock", u.Path) +} + +func TestParseURL_NpipePrefix(t *testing.T) { + t.Parallel() + + u, err := ParseURL("npipe:////./pipe/docker_engine") + require.NoError(t, err) + require.Equal(t, "npipe", u.Scheme) +}