fix(ecr): prevent deadlock on ECR token refresh during stack deployment [BE-12842] (#2564)

This commit is contained in:
Oscar Zhou
2026-05-07 08:34:19 +12:00
committed by GitHub
parent e7ec69708e
commit c3b0b9a2e0
8 changed files with 269 additions and 27 deletions
+4 -4
View File
@@ -8,8 +8,8 @@ import (
"time"
)
func (s *Service) GetEncodedAuthorizationToken() (token *string, expiry *time.Time, err error) {
getAuthorizationTokenOutput, err := s.client.GetAuthorizationToken(context.TODO(), nil)
func (s *Service) GetEncodedAuthorizationToken(ctx context.Context) (token *string, expiry *time.Time, err error) {
getAuthorizationTokenOutput, err := s.client.GetAuthorizationToken(ctx, nil)
if err != nil {
return
}
@@ -27,8 +27,8 @@ func (s *Service) GetEncodedAuthorizationToken() (token *string, expiry *time.Ti
return
}
func (s *Service) GetAuthorizationToken() (token *string, expiry *time.Time, err error) {
tokenEncodedStr, expiry, err := s.GetEncodedAuthorizationToken()
func (s *Service) GetAuthorizationToken(ctx context.Context) (token *string, expiry *time.Time, err error) {
tokenEncodedStr, expiry, err := s.GetEncodedAuthorizationToken(ctx)
if err != nil {
return
}
+11 -15
View File
@@ -66,7 +66,7 @@ func (manager *ComposeStackManager) Up(ctx context.Context, stack *portainer.Sta
EnvFilePath: envFilePath,
Host: url,
ProjectName: stack.Name,
Registries: portainerRegistriesToAuthConfigs(manager.dataStore, options.Registries),
Registries: portainerRegistriesToAuthConfigs(options.Registries),
},
ForceRecreate: options.ForceRecreate,
AbortOnContainerExit: options.AbortOnContainerExit,
@@ -98,7 +98,7 @@ func (manager *ComposeStackManager) Run(ctx context.Context, stack *portainer.St
EnvFilePath: envFilePath,
Host: url,
ProjectName: stack.Name,
Registries: portainerRegistriesToAuthConfigs(manager.dataStore, options.Registries),
Registries: portainerRegistriesToAuthConfigs(options.Registries),
},
Remove: options.Remove,
Args: options.Args,
@@ -147,7 +147,7 @@ func (manager *ComposeStackManager) Pull(ctx context.Context, stack *portainer.S
EnvFilePath: envFilePath,
Host: url,
ProjectName: stack.Name,
Registries: portainerRegistriesToAuthConfigs(manager.dataStore, options.Registries),
Registries: portainerRegistriesToAuthConfigs(options.Registries),
})
return errors.Wrap(err, "failed to pull images of the stack")
}
@@ -230,7 +230,12 @@ func copyConfigEnvVars(w io.Writer, envs []portainer.Pair) error {
return nil
}
func portainerRegistriesToAuthConfigs(tx dataservices.DataStoreTx, registries []portainer.Registry) []types.AuthConfig {
// portainerRegistriesToAuthConfigs converts registries to Docker auth configs.
// Callers must ensure ECR tokens are valid before calling this function (e.g. via
// registryutils.ValidateRegistriesECRTokens with a real DataStoreTx). This function
// intentionally performs no DB writes to avoid write-lock contention when called inside
// an active BoltDB write transaction.
func portainerRegistriesToAuthConfigs(registries []portainer.Registry) []types.AuthConfig {
var authConfigs []types.AuthConfig
for _, r := range registries {
@@ -243,7 +248,7 @@ func portainerRegistriesToAuthConfigs(tx dataservices.DataStoreTx, registries []
if r.Authentication {
var err error
ac.Username, ac.Password, err = getEffectiveRegUsernamePassword(tx, &r)
ac.Username, ac.Password, err = getEffectiveRegUsernamePassword(&r)
if err != nil {
continue
}
@@ -255,16 +260,7 @@ func portainerRegistriesToAuthConfigs(tx dataservices.DataStoreTx, registries []
return authConfigs
}
func getEffectiveRegUsernamePassword(tx dataservices.DataStoreTx, registry *portainer.Registry) (string, string, error) {
if err := registryutils.EnsureRegTokenValid(tx, registry); err != nil {
log.Warn().
Err(err).
Str("RegistryName", registry.Name).
Msg("Failed to validate registry token. Skip logging with this registry.")
return "", "", err
}
func getEffectiveRegUsernamePassword(registry *portainer.Registry) (string, string, error) {
username, password, err := registryutils.GetRegEffectiveCredential(registry)
if err != nil {
log.Warn().
+72
View File
@@ -4,6 +4,7 @@ import (
"io"
"os"
"testing"
"time"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/filesystem"
@@ -95,3 +96,74 @@ func Test_createEnvFile_mergesDefultAndInplaceEnvVars(t *testing.T) {
assert.Equal(t, []byte("VAR1=VAL1\nVAR2=VAL2\n\nVAR1=NEW_VAL1\nVAR3=VAL3\n"), content)
}
func Test_portainerRegistriesToAuthConfigs(t *testing.T) {
t.Parallel()
t.Run("returns empty slice for empty input", func(t *testing.T) {
t.Parallel()
result := portainerRegistriesToAuthConfigs([]portainer.Registry{})
require.Nil(t, result)
})
t.Run("uses registry URL, username and password for non-authenticated registry", func(t *testing.T) {
t.Parallel()
registries := []portainer.Registry{
{URL: "registry.example.com", Username: "user", Password: "pass", Authentication: false},
}
result := portainerRegistriesToAuthConfigs(registries)
require.Len(t, result, 1)
require.Equal(t, "registry.example.com", result[0].ServerAddress)
require.Equal(t, "user", result[0].Username)
require.Equal(t, "pass", result[0].Password)
})
t.Run("uses username and password for authenticated non-ECR registry", func(t *testing.T) {
t.Parallel()
registries := []portainer.Registry{
{URL: "registry.example.com", Username: "user", Password: "pass", Authentication: true, Type: portainer.CustomRegistry},
}
result := portainerRegistriesToAuthConfigs(registries)
require.Len(t, result, 1)
require.Equal(t, "user", result[0].Username)
require.Equal(t, "pass", result[0].Password)
})
t.Run("parses ECR access token for authenticated ECR registry with valid token", func(t *testing.T) {
t.Parallel()
registries := []portainer.Registry{
{
URL: "123456789.dkr.ecr.us-east-1.amazonaws.com",
Username: "AKIAIOSFODNN7EXAMPLE",
Password: "secretkey",
Authentication: true,
Type: portainer.EcrRegistry,
Ecr: portainer.EcrData{Region: "us-east-1"},
AccessToken: "AWS:ecr-password",
AccessTokenExpiry: time.Now().Add(time.Hour).Unix(),
},
}
result := portainerRegistriesToAuthConfigs(registries)
require.Len(t, result, 1)
require.Equal(t, "AWS", result[0].Username)
require.Equal(t, "ecr-password", result[0].Password)
})
t.Run("includes valid registries and skips ones with credential errors", func(t *testing.T) {
t.Parallel()
registries := []portainer.Registry{
{URL: "valid.example.com", Username: "user", Password: "pass", Authentication: false},
{
URL: "123456789.dkr.ecr.us-east-1.amazonaws.com",
Authentication: true,
Type: portainer.EcrRegistry,
Ecr: portainer.EcrData{Region: "us-east-1"},
AccessToken: "no-colon-token",
AccessTokenExpiry: time.Now().Add(time.Hour).Unix(),
},
}
result := portainerRegistriesToAuthConfigs(registries)
require.Len(t, result, 1)
require.Equal(t, "valid.example.com", result[0].ServerAddress)
})
}
+1 -1
View File
@@ -62,7 +62,7 @@ func (manager *SwarmStackManager) Login(ctx context.Context, registries []portai
for _, registry := range registries {
if registry.Authentication {
username, password, err := getEffectiveRegUsernamePassword(manager.dataStore, &registry)
username, password, err := getEffectiveRegUsernamePassword(&registry)
if err != nil {
continue
}
+40 -7
View File
@@ -1,6 +1,8 @@
package registryutils
import (
"context"
"fmt"
"time"
portainer "github.com/portainer/portainer/api"
@@ -14,22 +16,45 @@ func isRegTokenValid(registry *portainer.Registry) (valid bool) {
return registry.AccessToken != "" && registry.AccessTokenExpiry > time.Now().Unix()
}
func doGetRegToken(tx dataservices.DataStoreTx, registry *portainer.Registry) error {
func fetchRegToken(registry *portainer.Registry) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ecrClient := ecr.NewService(registry.Username, registry.Password, registry.Ecr.Region)
accessToken, expiryAt, err := ecrClient.GetAuthorizationToken()
accessToken, expiryAt, err := ecrClient.GetAuthorizationToken(ctx)
if err != nil {
return err
}
registry.AccessToken = *accessToken
registry.AccessTokenExpiry = expiryAt.Unix()
return nil
}
func doGetRegToken(tx dataservices.DataStoreTx, registry *portainer.Registry) error {
if err := fetchRegToken(registry); err != nil {
return err
}
return tx.Registry().Update(registry.ID, registry)
}
func parseRegToken(registry *portainer.Registry) (username, password string, err error) {
return ecr.NewService(registry.Username, registry.Password, registry.Ecr.Region).
ParseAuthorizationToken(registry.AccessToken)
// ValidateRegistriesECRTokens refreshes and persists ECR tokens for all registries that need it.
// Must be called with a real DataStoreTx (not a top-level DataStore) to avoid write-lock contention.
func ValidateRegistriesECRTokens(tx dataservices.DataStoreTx, registries []portainer.Registry) error {
for i := range registries {
reg := &registries[i]
if reg.Type != portainer.EcrRegistry {
continue
}
if isRegTokenValid(reg) {
continue
}
if err := doGetRegToken(tx, reg); err != nil {
return fmt.Errorf("ECR registry %q credentials are invalid or expired. Error: %w", reg.Name, err)
}
}
return nil
}
func EnsureRegTokenValid(tx dataservices.DataStoreTx, registry *portainer.Registry) error {
@@ -57,7 +82,15 @@ func GetRegEffectiveCredential(registry *portainer.Registry) (username, password
password = registry.Password
if registry.Type == portainer.EcrRegistry {
username, password, err = parseRegToken(registry)
// Fallback token refresh in case the upstream caller did not pre-validate the token.
if !isRegTokenValid(registry) {
if err := fetchRegToken(registry); err != nil {
return "", "", fmt.Errorf("ECR registry %q credentials are invalid or expired. Error: %w", registry.Name, err)
}
}
username, password, err = ecr.NewService(registry.Username, registry.Password, registry.Ecr.Region).
ParseAuthorizationToken(registry.AccessToken)
}
return
@@ -0,0 +1,131 @@
package registryutils_test
import (
"testing"
"time"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/internal/registryutils"
"github.com/stretchr/testify/require"
)
func newECRRegistry(id portainer.RegistryID, accessToken string, expiry int64) portainer.Registry {
return portainer.Registry{
ID: id,
Type: portainer.EcrRegistry,
Name: "test-ecr",
Username: "AKIAIOSFODNN7EXAMPLE",
Password: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
Ecr: portainer.EcrData{Region: "us-east-1"},
AccessToken: accessToken,
AccessTokenExpiry: expiry,
}
}
func TestValidateRegistriesECRTokens(t *testing.T) {
t.Parallel()
t.Run("skips non-ECR registries without error", func(t *testing.T) {
t.Parallel()
_, ds := datastore.MustNewTestStore(t, true, false)
registries := []portainer.Registry{
{ID: 1, Type: portainer.DockerHubRegistry, Name: "dockerhub"},
{ID: 2, Type: portainer.CustomRegistry, Name: "custom"},
}
require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error {
return registryutils.ValidateRegistriesECRTokens(tx, registries)
}))
})
t.Run("skips ECR registries with valid tokens", func(t *testing.T) {
t.Parallel()
_, ds := datastore.MustNewTestStore(t, true, false)
reg := newECRRegistry(1, "valid-token", time.Now().Add(time.Hour).Unix())
require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error {
return registryutils.ValidateRegistriesECRTokens(tx, []portainer.Registry{reg})
}))
})
t.Run("returns nil for empty registry list", func(t *testing.T) {
t.Parallel()
_, ds := datastore.MustNewTestStore(t, true, false)
require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error {
return registryutils.ValidateRegistriesECRTokens(tx, []portainer.Registry{})
}))
})
t.Run("returns error for ECR registry with missing token", func(t *testing.T) {
t.Parallel()
_, ds := datastore.MustNewTestStore(t, true, false)
reg := newECRRegistry(1, "", 0)
require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error {
return tx.Registry().Create(&reg)
}))
var validateErr error
_ = ds.UpdateTx(func(tx dataservices.DataStoreTx) error {
validateErr = registryutils.ValidateRegistriesECRTokens(tx, []portainer.Registry{reg})
return nil
})
require.Error(t, validateErr)
require.Contains(t, validateErr.Error(), "test-ecr")
})
t.Run("stops on first invalid ECR registry and includes its name in error", func(t *testing.T) {
t.Parallel()
_, ds := datastore.MustNewTestStore(t, true, false)
validECR := newECRRegistry(1, "valid-token", time.Now().Add(time.Hour).Unix())
invalidECR := newECRRegistry(2, "", 0)
invalidECR.Name = "invalid-ecr"
nonECR := portainer.Registry{ID: 3, Type: portainer.DockerHubRegistry}
require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error {
return tx.Registry().Create(&invalidECR)
}))
var validateErr error
_ = ds.UpdateTx(func(tx dataservices.DataStoreTx) error {
validateErr = registryutils.ValidateRegistriesECRTokens(tx, []portainer.Registry{validECR, invalidECR, nonECR})
return nil
})
require.Error(t, validateErr)
require.Contains(t, validateErr.Error(), "invalid-ecr")
})
}
func TestGetRegEffectiveCredential(t *testing.T) {
t.Parallel()
t.Run("returns username and password directly for non-ECR registry", func(t *testing.T) {
t.Parallel()
reg := &portainer.Registry{
Type: portainer.DockerHubRegistry,
Username: "user",
Password: "pass",
}
username, password, err := registryutils.GetRegEffectiveCredential(reg)
require.NoError(t, err)
require.Equal(t, "user", username)
require.Equal(t, "pass", password)
})
t.Run("parses ECR access token when token is valid", func(t *testing.T) {
t.Parallel()
reg := newECRRegistry(1, "AWS:ecr-password", time.Now().Add(time.Hour).Unix())
username, password, err := registryutils.GetRegEffectiveCredential(&reg)
require.NoError(t, err)
require.Equal(t, "AWS", username)
require.Equal(t, "ecr-password", password)
})
t.Run("returns error for ECR registry with missing token and invalid credentials", func(t *testing.T) {
t.Parallel()
reg := newECRRegistry(1, "", 0)
_, _, err := registryutils.GetRegEffectiveCredential(&reg)
require.Error(t, err)
require.Contains(t, err.Error(), "test-ecr")
})
}
@@ -7,6 +7,7 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/registryutils"
"github.com/portainer/portainer/api/stacks/stackutils"
"github.com/pkg/errors"
@@ -39,6 +40,10 @@ func CreateComposeStackDeploymentConfigTx(tx dataservices.DataStoreTx, securityC
filteredRegistries := security.FilterRegistries(registries, user, securityContext.UserMemberships, endpoint.ID)
if err := registryutils.ValidateRegistriesECRTokens(tx, filteredRegistries); err != nil {
return nil, err
}
config := &ComposeStackDeploymentConfig{
stack: stack,
endpoint: endpoint,
@@ -9,6 +9,7 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/registryutils"
"github.com/portainer/portainer/api/stacks/stackutils"
)
@@ -37,6 +38,10 @@ func CreateSwarmStackDeploymentConfigTx(tx dataservices.DataStoreTx, securityCon
filteredRegistries := security.FilterRegistries(registries, user, securityContext.UserMemberships, endpoint.ID)
if err := registryutils.ValidateRegistriesECRTokens(tx, filteredRegistries); err != nil {
return nil, err
}
config := &SwarmStackDeploymentConfig{
stack: stack,
endpoint: endpoint,