From 207bfd99af4b308a1609f3439317fbd83145f106 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Wed, 18 Feb 2026 13:53:04 +0300 Subject: [PATCH] COCOS-525-487 - Refactor attestation and atls (#562) * Refactor attestation handling to remove quoteprovider dependency - Removed references to quoteprovider in various files, replacing them with vtpm where necessary. - Updated function signatures and implementations to use SEVNonce instead of quoteprovider.Nonce. - Introduced new vtpm package to handle SEV-related attestation logic, including fetching and verifying attestation reports. - Adjusted tests to reflect changes in the attestation logic and ensure compatibility with the new structure. - Deleted the now redundant quoteprovider/sev_test.go file. Signed-off-by: Sammy Oina * fix: Add veraison/go-cose dependency to go.mod Signed-off-by: Sammy Oina * feat: Introduce TLS package for enhanced security configuration and refactor client code to utilize new TLS utilities Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- agent/api/grpc/requests.go | 3 +- agent/api/grpc/server.go | 9 +- agent/api/grpc/server_test.go | 19 +- agent/api/logging.go | 3 +- agent/api/metrics.go | 3 +- agent/service.go | 5 +- agent/service_test.go | 7 +- algorithm.lin-reg-py.enc | Bin 0 -> 3552 bytes cli/attestation.go | 7 +- cli/attestation_test.go | 17 +- cmd/attestation-service/main.go | 3 +- encryption.key | 1 + kbs-admin.key | 3 + kbs-admin.pub | 3 + pkg/attestation/azure/snp.go | 7 +- pkg/attestation/quoteprovider/sev_test.go | 377 ------------------ .../{quoteprovider => vtpm}/sev.go | 57 +-- pkg/attestation/vtpm/vtpm.go | 9 +- pkg/attestation/vtpm/vtpm_test.go | 17 +- pkg/clients/doc.go | 2 +- pkg/clients/grpc/agent/agent.go | 3 +- pkg/clients/grpc/connect_test.go | 19 +- pkg/clients/grpc/grpc.go | 22 +- pkg/clients/http/client.go | 18 +- pkg/clients/http/client_test.go | 9 +- pkg/sdk/agent_test.go | 11 +- pkg/tls/doc.go | 6 + pkg/{clients => tls}/tls.go | 31 +- pkg/{clients => tls}/tls_test.go | 142 +++---- 29 files changed, 222 insertions(+), 591 deletions(-) create mode 100644 algorithm.lin-reg-py.enc create mode 100644 encryption.key create mode 100644 kbs-admin.key create mode 100644 kbs-admin.pub delete mode 100644 pkg/attestation/quoteprovider/sev_test.go rename pkg/attestation/{quoteprovider => vtpm}/sev.go (78%) create mode 100644 pkg/tls/doc.go rename pkg/{clients => tls}/tls.go (80%) rename pkg/{clients => tls}/tls_test.go (77%) diff --git a/agent/api/grpc/requests.go b/agent/api/grpc/requests.go index 36ac45c1..6363d48b 100644 --- a/agent/api/grpc/requests.go +++ b/agent/api/grpc/requests.go @@ -6,7 +6,6 @@ import ( "errors" "github.com/ultravioletrs/cocos/pkg/attestation" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" ) @@ -42,7 +41,7 @@ func (req resultReq) validate() error { } type attestationReq struct { - TeeNonce [quoteprovider.Nonce]byte + TeeNonce [vtpm.SEVNonce]byte VtpmNonce [vtpm.Nonce]byte AttType attestation.PlatformType } diff --git a/agent/api/grpc/server.go b/agent/api/grpc/server.go index 9159d1fb..f7f9b685 100644 --- a/agent/api/grpc/server.go +++ b/agent/api/grpc/server.go @@ -14,7 +14,6 @@ import ( "github.com/go-kit/kit/transport/grpc" "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/pkg/attestation" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" @@ -134,7 +133,7 @@ func encodeResultResponse(_ context.Context, response any) (any, error) { func validateNonce(nonce []byte, maxLen int, target any) error { if len(nonce) > maxLen { switch maxLen { - case quoteprovider.Nonce: + case vtpm.SEVNonce: return ErrTEENonceLength case vtpm.Nonce: return ErrVTPMNonceLength @@ -144,7 +143,7 @@ func validateNonce(nonce []byte, maxLen int, target any) error { } switch t := target.(type) { - case *[quoteprovider.Nonce]byte: + case *[vtpm.SEVNonce]byte: copy(t[:], nonce) case *[vtpm.Nonce]byte: copy(t[:], nonce) @@ -156,10 +155,10 @@ func validateNonce(nonce []byte, maxLen int, target any) error { func decodeAttestationRequest(_ context.Context, grpcReq any) (any, error) { req := grpcReq.(*agent.AttestationRequest) - var reportData [quoteprovider.Nonce]byte + var reportData [vtpm.SEVNonce]byte var nonce [vtpm.Nonce]byte - if err := validateNonce(req.TeeNonce, quoteprovider.Nonce, &reportData); err != nil { + if err := validateNonce(req.TeeNonce, vtpm.SEVNonce, &reportData); err != nil { return nil, err } diff --git a/agent/api/grpc/server_test.go b/agent/api/grpc/server_test.go index e788ff20..cf00dd39 100644 --- a/agent/api/grpc/server_test.go +++ b/agent/api/grpc/server_test.go @@ -12,7 +12,6 @@ import ( "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/agent/mocks" "github.com/ultravioletrs/cocos/pkg/attestation" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "google.golang.org/grpc" "google.golang.org/grpc/metadata" @@ -229,7 +228,7 @@ func TestAttestation(t *testing.T) { return len(resp.File) > 0 })).Return(nil).Once() - reportData := [quoteprovider.Nonce]byte{} + reportData := [vtpm.SEVNonce]byte{} vtpmNonce := [vtpm.Nonce]byte{} attestationType := attestation.SNP mockService.On("Attestation", mock.Anything, reportData, vtpmNonce, attestationType).Return(attestationData, nil) @@ -298,8 +297,8 @@ func TestValidateNonce(t *testing.T) { }{ { name: "valid TEE nonce", - nonce: make([]byte, quoteprovider.Nonce), - maxLen: quoteprovider.Nonce, + nonce: make([]byte, vtpm.SEVNonce), + maxLen: vtpm.SEVNonce, shouldError: false, }, { @@ -310,8 +309,8 @@ func TestValidateNonce(t *testing.T) { }, { name: "TEE nonce too long", - nonce: make([]byte, quoteprovider.Nonce+1), - maxLen: quoteprovider.Nonce, + nonce: make([]byte, vtpm.SEVNonce+1), + maxLen: vtpm.SEVNonce, shouldError: true, expectedErr: ErrTEENonceLength, }, @@ -326,8 +325,8 @@ func TestValidateNonce(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if tt.maxLen == quoteprovider.Nonce { - var target [quoteprovider.Nonce]byte + if tt.maxLen == vtpm.SEVNonce { + var target [vtpm.SEVNonce]byte err := validateNonce(tt.nonce, tt.maxLen, &target) if tt.shouldError { assert.Error(t, err) @@ -388,7 +387,7 @@ func TestEncodeResultResponse(t *testing.T) { } func TestDecodeAttestationRequest(t *testing.T) { - teeNonce := make([]byte, quoteprovider.Nonce) + teeNonce := make([]byte, vtpm.SEVNonce) vtpmNonce := make([]byte, vtpm.Nonce) req := &agent.AttestationRequest{ @@ -406,7 +405,7 @@ func TestDecodeAttestationRequest(t *testing.T) { func TestDecodeAttestationRequestWithInvalidNonce(t *testing.T) { // Test with TEE nonce too long - teeNonce := make([]byte, quoteprovider.Nonce+1) + teeNonce := make([]byte, vtpm.SEVNonce+1) req := &agent.AttestationRequest{TeeNonce: teeNonce} _, err := decodeAttestationRequest(context.Background(), req) diff --git a/agent/api/logging.go b/agent/api/logging.go index 30aa77b2..3d9fc839 100644 --- a/agent/api/logging.go +++ b/agent/api/logging.go @@ -13,7 +13,6 @@ import ( "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/pkg/attestation" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" ) @@ -105,7 +104,7 @@ func (lm *loggingMiddleware) Result(ctx context.Context) (response []byte, err e return lm.svc.Result(ctx) } -func (lm *loggingMiddleware) Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) (response []byte, err error) { +func (lm *loggingMiddleware) Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) (response []byte, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method Attestation took %s to complete", time.Since(begin)) if err != nil { diff --git a/agent/api/metrics.go b/agent/api/metrics.go index 4aa2557f..b4dc1f2e 100644 --- a/agent/api/metrics.go +++ b/agent/api/metrics.go @@ -12,7 +12,6 @@ import ( "github.com/go-kit/kit/metrics" "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/pkg/attestation" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" ) @@ -91,7 +90,7 @@ func (ms *metricsMiddleware) Result(ctx context.Context) ([]byte, error) { return ms.svc.Result(ctx) } -func (ms *metricsMiddleware) Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) { +func (ms *metricsMiddleware) Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) { defer func(begin time.Time) { ms.counter.With("method", "attestation").Add(1) ms.latency.With("method", "attestation").Observe(time.Since(begin).Seconds()) diff --git a/agent/service.go b/agent/service.go index 84db1f70..b5573561 100644 --- a/agent/service.go +++ b/agent/service.go @@ -21,7 +21,6 @@ import ( "github.com/ultravioletrs/cocos/agent/statemachine" "github.com/ultravioletrs/cocos/internal" "github.com/ultravioletrs/cocos/pkg/attestation" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation" runner_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner" @@ -120,7 +119,7 @@ type Service interface { Algo(ctx context.Context, algorithm Algorithm) error Data(ctx context.Context, dataset Dataset) error Result(ctx context.Context) ([]byte, error) - Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) + Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) IMAMeasurements(ctx context.Context) ([]byte, []byte, error) AzureAttestationToken(ctx context.Context, nonce [vtpm.Nonce]byte) ([]byte, error) State() string @@ -418,7 +417,7 @@ func (as *agentService) Result(ctx context.Context) ([]byte, error) { return as.result, as.runError } -func (as *agentService) Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) { +func (as *agentService) Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) { rawQuote, err := as.attestationClient.GetAttestation(ctx, reportData, nonce, attType) if err != nil { return []byte{}, errors.Wrap(ErrAttestationFailed, err) diff --git a/agent/service_test.go b/agent/service_test.go index 62f7c9ed..3796751c 100644 --- a/agent/service_test.go +++ b/agent/service_test.go @@ -24,7 +24,6 @@ import ( "github.com/ultravioletrs/cocos/agent/statemachine" smmocks "github.com/ultravioletrs/cocos/agent/statemachine/mocks" "github.com/ultravioletrs/cocos/pkg/attestation" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" runnermocks "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner/mocks" "golang.org/x/crypto/sha3" @@ -336,7 +335,7 @@ func TestAttestation(t *testing.T) { cases := []struct { name string - reportData [quoteprovider.Nonce]byte + reportData [vtpm.SEVNonce]byte nonce [vtpm.Nonce]byte rawQuote []uint8 platform attestation.PlatformType @@ -457,8 +456,8 @@ func TestAzureAttestationToken(t *testing.T) { } } -func generateReportData() [quoteprovider.Nonce]byte { - bytes := make([]byte, quoteprovider.Nonce) +func generateReportData() [vtpm.SEVNonce]byte { + bytes := make([]byte, vtpm.SEVNonce) _, err := rand.Read(bytes) if err != nil { log.Fatalf("Failed to generate random bytes: %v", err) diff --git a/algorithm.lin-reg-py.enc b/algorithm.lin-reg-py.enc new file mode 100644 index 0000000000000000000000000000000000000000..3e2a8f205a34a05517e4ccedbb0a5c2055c61f2c GIT binary patch literal 3552 zcmV<64IlDK2utmi|BQk-AkP3fyA))G!9KJcsr)eDz`b_$=?X$#c~v=hD38Eaj!!@l zCmgvDLKyK+vF7-@H1Pqr_;68jB`g+Ff;@1L-}nO}6bJR&$swtLphOdfr`}1a zCT($kxgMOY1#oQNh1AmFd@c=^l?__WKD54-tH{kbNx53EJb%=>CE1@vcY)nAbE9dV zNez+cs`XUExy*Dg5-pK4jOHMsVEX(=W|UJy zbp+y!v+sCDbl)N6VI?&m6gG8c)&T{RmF)CEoo%go$bXUJEg+N~B#Zdo3jqC3i@igJ z**RzPPHGHm)(tJ%sXXMT5;3=+PpK5^PJZpn)&}q8{T*(-Qqk@#y=f+T{pq+WVV+cf zpniy`4x~M;vMEhxywFLqF_!?G_t239EP~E6my96Z4#nyw;C8K2?psg&_O6+S#?MmB zYY(#4df^aV%SACsL$$g)8ETzfVvOH&l#smtt1&PT_YUVF+}xJU{!7}*-^(R!%$1e; zMmH@H$QR@nd$lgmS9Db5YATU8wZ0L5Gf(rCoW9*Y;lA}yO}?zT`%}F9 zXC;*CI;5`AIL!`nf57pNI4Ja zuvv()MZPSnMswF(F-`DEs}nqr9M)3G`=GeXQyjHBXYG3y5L0*D1Xr&{E;LoQ;Yj%S zWq>8hdlqS1kJ>C1$l`pw;vNezE6uZ}L?hU3}-aW`X?9%s{;k}5+}WjWl26>A;hH)e@~J*swXA}W~NH%kEPh80DFb;%IsF6 zb;po449;i46cwp8pS%U7JLOKUHHBWxx%4wlY0FnIH{=u9a?StX4T*hI} zjL}Er1OaJa{S`~tW3fSDJm=UdrFcD{0+A72NX3-X+7^M9Da5b{&{sUxbP`ufz&nl^ zg7t>dcgXDxA*b=$P0ek8w}H!>0Tv%UoHA(Gu>-GNsia_kwux15dY%<~JikHh!BGvv@=>wLpbJmV!#?PUj zAYwX5+xpUBCZ4;WYL54nTpP)a>Q@uZLI~cRMq3l`c-n?Ego+~ld;+K&tau=+NH87@ z^Gzx+o-vlHj+1_N`}gyq{}h!_j@Xa)FilI>lNz`_g3}wWSe*1WnaAw&Lf3P-&oYtinsFvE{<*0B8&WN z^^urZ){3$NPSgAH+&9D>;{?^kei==8PG3UCEgIUqs`Lb=!&s#otQq>@!v`H4J45?3 z2UCWCZ;T-vL#1gNR28fU??#eGTV(9Mt!L`GJSp##NBdO3t06Y~5$5u4<)6Dbu|H5Q z>T=@E6B>_F%_aKrm4q2L6WPb$S3sj70*V+ufHZL@+cwpuTmr#TeB$n{#4hkPlpslvc!Bu zh#|Cg!h~Z(kKSq6^q1FY9XdwqFga6|-(wt#0^lzepP61v1QQKhiq2H=H+NvO4Wflq zvRgyCBq>>iob5=Gk3X7{7%w6z7W@l5su80Glb6%>z)RVv3IMK@h#omB<%%(6Q-d!F zDk>jTXoW#Vuk0nER_kwwa;|@3==QDF3*kNg04L0ss;o^)m1F8GI6%@zQdqi}YqV$P z!$&PX*Xg1RWGlx&#i5W76)n`bih~e@1-T?G$$l`j0Br({VdR4hW{DT zS{wOJZSM$ztB)jNarz-l+xFq5m2pj<@bi8({&%!tcJ&fg+*z7jaA&wmR>s@0!`jKN zXt^p7U4@I1C)&NQ?=O=nUF#}BA*Av2w3S24z_O!5H}BoTgR$dsM;o8y$$prW{tSX2 z!P&TSjOs#B=t*pUZU4DUR;;tAa8nBS01r%C#@MZvI;nlN)$>llJ6I75^(@4^U?w%# z4}Nm-p|DL&OB~_yBV-BMZD(E|I{mGIlz`ejO4{8vTiP8hj++y`WKU=n}oDiGnJ{f5SusJ>t@SgMx8> z`3f0)oFp9Rbqfg~&?0SIKV~Zx98?Y$3=oJ=QYA-2W|v515kk8!*`zG3vL3Ra%5U({ zOn>H&={l$|Taitr&hHCqYzEP0V;3>7cX)AiGv)_9LSCq9b* z98~@v#OGI#)X{=MS%aJv(w}deWHW;%OhW_9-Mqk@)W~ILf%vQa^=dDH(hv94T2)BU`w4b`*M+t{xbLgM+v>}Y40S>-pv3b-@&@bN6yf$hhc_X| zVH%)&=Qqu7bp2~K$9VK9fX7xbsJpSHEU4oFK2^l&bKy|4FG?FXnWt3~xWA;1klMJETd4K* zcMZr5IZ8DlXWh0JT%#u=?RekoFgVjlrj?-CH{S0j-vQ;dZbuU|4*+uNh<>5JtnNOr ztFO_$2;dM9mmEvCR!JjnA;#7mVm1mp?p1q0oxNt%*y7UwF|sq5gHnglER!R0n4
    M@%>0@f%C0)mH?~7FtK8ew$ibH&Gx`nu}19823t77@2>9 zKypgN??U@oQ`L1Xni4xjipz4Xs5AwZg=C#L!ZLmzR+gh4H}`mWMfqN5^~a^5(SsG* z6&Ig!8-dTKKpNJjzkcaG5nbzY-?OH_?t^v(-Q4nM&l9pn;#wOC#4wfgVUsLA quoteprovider.Nonce { - msg := color.New(color.FgRed).Sprintf("nonce must be a hex encoded string of length lesser or equal %d bytes ❌ ", quoteprovider.Nonce) + if len(teeNonce) > vtpm.SEVNonce { + msg := color.New(color.FgRed).Sprintf("nonce must be a hex encoded string of length lesser or equal %d bytes ❌ ", vtpm.SEVNonce) cmd.Println(msg) return } diff --git a/cli/attestation_test.go b/cli/attestation_test.go index 2379b439..3eb4439e 100644 --- a/cli/attestation_test.go +++ b/cli/attestation_test.go @@ -21,7 +21,6 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" mmocks "github.com/ultravioletrs/cocos/pkg/attestation/cmdconfig/mocks" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "github.com/ultravioletrs/cocos/pkg/sdk/mocks" ) @@ -37,8 +36,8 @@ func TestNewAttestationCmd(t *testing.T) { var buf bytes.Buffer cmd.SetOut(&buf) - reportData := bytes.Repeat([]byte{0x01}, quoteprovider.Nonce) - mockSDK.On("Attestation", mock.Anything, [quoteprovider.Nonce]byte(reportData), mock.Anything).Return(nil) + reportData := bytes.Repeat([]byte{0x01}, vtpm.SEVNonce) + mockSDK.On("Attestation", mock.Anything, [vtpm.SEVNonce]byte(reportData), mock.Anything).Return(nil) cmd.SetArgs([]string{hex.EncodeToString(reportData)}) err := cmd.Execute() @@ -50,7 +49,7 @@ func TestNewGetAttestationCmd(t *testing.T) { validattestation, err := os.ReadFile("../attestation.bin") require.NoError(t, err) - teeNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce)) + teeNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce)) vtpmNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce)) tokenNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce)) @@ -184,7 +183,7 @@ func TestNewGetAttestationCmd(t *testing.T) { var buf bytes.Buffer cmd.SetOut(&buf) - mockSDK.On("Attestation", mock.Anything, [quoteprovider.Nonce]byte(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce)), [vtpm.Nonce]byte(bytes.Repeat([]byte{0x00}, vtpm.Nonce)), mock.Anything, mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) { + mockSDK.On("Attestation", mock.Anything, [vtpm.SEVNonce]byte(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce)), [vtpm.Nonce]byte(bytes.Repeat([]byte{0x00}, vtpm.Nonce)), mock.Anything, mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) { _, err := args.Get(4).(*os.File).Write(tc.mockResponse) require.NoError(t, err) }) @@ -891,7 +890,7 @@ func TestGetAttestationCmdEdgeCases(t *testing.T) { }, { name: "TEE nonce too large", - args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce+1))}, + args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce+1))}, setupMock: func(sdk *mocks.SDK) { }, expectedErr: "nonce must be a hex encoded string of length lesser or equal 64 bytes", @@ -912,7 +911,7 @@ func TestGetAttestationCmdEdgeCases(t *testing.T) { }, { name: "successful TDX attestation", - args: []string{"tdx", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))}, + args: []string{"tdx", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce))}, setupMock: func(sdk *mocks.SDK) { sdk.On("Attestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil).Run(func(args mock.Arguments) { @@ -925,7 +924,7 @@ func TestGetAttestationCmdEdgeCases(t *testing.T) { }, { name: "file creation error", - args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))}, + args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce))}, setupMock: func(sdk *mocks.SDK) { }, expectedErr: "Error creating attestation file", @@ -1380,7 +1379,7 @@ func TestContextCancellation(t *testing.T) { cmd.SetOut(&buf) cmd.SetErr(&buf) - teeNonceHex := hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce)) + teeNonceHex := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce)) cmd.SetArgs([]string{"snp", "--tee", teeNonceHex}) err := cmd.Execute() diff --git a/cmd/attestation-service/main.go b/cmd/attestation-service/main.go index 9e2488b8..5c9ff527 100644 --- a/cmd/attestation-service/main.go +++ b/cmd/attestation-service/main.go @@ -18,7 +18,6 @@ import ( "github.com/ultravioletrs/cocos/pkg/attestation" "github.com/ultravioletrs/cocos/pkg/attestation/azure" "github.com/ultravioletrs/cocos/pkg/attestation/eat" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/tdx" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "golang.org/x/sync/errgroup" @@ -89,7 +88,7 @@ func main() { } if ccPlatform == attestation.SNP || ccPlatform == attestation.SNPvTPM { - if err := quoteprovider.FetchCertificates(uint(cfg.Vmpl)); err != nil { + if err := vtpm.FetchSEVCertificates(uint(cfg.Vmpl)); err != nil { logger.Error(fmt.Sprintf("failed to fetch certificates: %s", err)) exitCode = 1 return diff --git a/encryption.key b/encryption.key new file mode 100644 index 00000000..840cae66 --- /dev/null +++ b/encryption.key @@ -0,0 +1 @@ +bbf3a1198ee889f77a227fe01e329864fd6a37a2d23135ea8e2c5a2ebc07f0d3 diff --git a/kbs-admin.key b/kbs-admin.key new file mode 100644 index 00000000..62418798 --- /dev/null +++ b/kbs-admin.key @@ -0,0 +1,3 @@ +-----BEGIN PRIVATE KEY----- +MC4CAQAwBQYDK2VwBCIEICgJcXfNueGCu8jFFNGBXm9r25OGBEc0OEqCUVjyI4fY +-----END PRIVATE KEY----- diff --git a/kbs-admin.pub b/kbs-admin.pub new file mode 100644 index 00000000..8fc02406 --- /dev/null +++ b/kbs-admin.pub @@ -0,0 +1,3 @@ +-----BEGIN PUBLIC KEY----- +MCowBQYDK2VwAyEAPbPOfwsJkxpNBluGOg/lgNVE/o0AEM7J11wvkXvHXSw= +-----END PUBLIC KEY----- diff --git a/pkg/attestation/azure/snp.go b/pkg/attestation/azure/snp.go index d2ab3dad..a7ee9707 100644 --- a/pkg/attestation/azure/snp.go +++ b/pkg/attestation/azure/snp.go @@ -21,7 +21,6 @@ import ( "github.com/google/go-tpm-tools/proto/attest" "github.com/ultravioletrs/cocos/pkg/attestation" "github.com/ultravioletrs/cocos/pkg/attestation/eat" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "google.golang.org/protobuf/proto" ) @@ -130,7 +129,7 @@ func (a verifier) VerifTeeAttestation(report []byte, teeNonce []byte) error { return errors.Wrap(fmt.Errorf("failed to convert TEE report to proto"), err) } - return quoteprovider.VerifyAttestationReportTLS(attestationReport, teeNonce, a.Policy) + return vtpm.VerifySEVAttestationReportTLS(attestationReport, teeNonce, a.Policy) } func (a verifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error { @@ -148,7 +147,7 @@ func (a verifier) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce [] } snpReport := quote.GetSevSnpAttestation() - if err = quoteprovider.VerifyAttestationReportTLS(snpReport, nil, a.Policy); err != nil { + if err = vtpm.VerifySEVAttestationReportTLS(snpReport, nil, a.Policy); err != nil { return fmt.Errorf("failed to verify vTPM attestation report: %w", err) } @@ -266,7 +265,7 @@ func GenerateAttestationPolicy(token, product string, policy uint64) (*attestati return nil, fmt.Errorf("failed to decode reportID: %w", err) } - sevSnpProduct := quoteprovider.GetProductName(product) + sevSnpProduct := vtpm.GetSEVProductName(product) return &attestation.Config{ Config: &check.Config{ diff --git a/pkg/attestation/quoteprovider/sev_test.go b/pkg/attestation/quoteprovider/sev_test.go deleted file mode 100644 index 975cb10f..00000000 --- a/pkg/attestation/quoteprovider/sev_test.go +++ /dev/null @@ -1,377 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 - -//go:build !embed - -package quoteprovider - -import ( - "os" - "path" - "testing" - - "github.com/google/go-sev-guest/proto/check" - "github.com/google/go-sev-guest/proto/sevsnp" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestFillInAttestationLocal(t *testing.T) { - originalHome := os.Getenv("HOME") - defer func() { - os.Setenv("HOME", originalHome) - }() - - tempDir, err := os.MkdirTemp("", "test_home") - require.NoError(t, err) - defer os.RemoveAll(tempDir) - - os.Setenv("HOME", tempDir) - - cocosDir := path.Join(tempDir, cocosDirectory, sevSnpProductMilan) - err = os.MkdirAll(cocosDir, 0o755) - require.NoError(t, err) - - bundleContent := []byte("mock ASK ARK bundle") - bundlePath := path.Join(cocosDir, arkAskBundleName) - err = os.WriteFile(bundlePath, bundleContent, 0o644) - require.NoError(t, err) - - config := &check.Config{ - RootOfTrust: &check.RootOfTrust{ - ProductLine: sevSnpProductMilan, - }, - Policy: &check.Policy{}, - } - - tests := []struct { - name string - attestation *sevsnp.Attestation - setupFunc func() - expectedError bool - errorContains string - }{ - { - name: "Empty attestation - creates new chain", - attestation: &sevsnp.Attestation{ - CertificateChain: nil, - }, - setupFunc: func() {}, - expectedError: true, - errorContains: "could not find ASK or ASVK PEM block; could not find ARK PEM block", - }, - { - name: "Attestation with existing chain - no changes needed", - attestation: &sevsnp.Attestation{ - CertificateChain: &sevsnp.CertificateChain{ - AskCert: []byte("existing ASK cert"), - ArkCert: []byte("existing ARK cert"), - }, - }, - setupFunc: func() {}, - expectedError: false, - }, - { - name: "Attestation with empty chain - tries to load from file", - attestation: &sevsnp.Attestation{ - CertificateChain: &sevsnp.CertificateChain{}, - }, - setupFunc: func() {}, - expectedError: true, - errorContains: "could not find ASK or ASVK PEM block; could not find ARK PEM block", - }, - { - name: "No bundle file exists - no error", - attestation: &sevsnp.Attestation{ - CertificateChain: &sevsnp.CertificateChain{}, - }, - setupFunc: func() { - os.Remove(bundlePath) - }, - expectedError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - os.Setenv("HOME", tempDir) - if _, err := os.Stat(bundlePath); os.IsNotExist(err) { - if err := os.WriteFile(bundlePath, bundleContent, 0o644); err != nil { - t.Fatalf("Failed to write bundle file: %v", err) - } - } - - tt.setupFunc() - - err := fillInAttestationLocal(tt.attestation, config) - - if tt.expectedError { - assert.Error(t, err) - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) - } - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestGetProductName(t *testing.T) { - tests := []struct { - name string - product string - expected sevsnp.SevProduct_SevProductName - }{ - { - name: "Milan product", - product: sevSnpProductMilan, - expected: sevsnp.SevProduct_SEV_PRODUCT_MILAN, - }, - { - name: "Genoa product", - product: sevSnpProductGenoa, - expected: sevsnp.SevProduct_SEV_PRODUCT_GENOA, - }, - { - name: "Unknown product", - product: "UnknownProduct", - expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN, - }, - { - name: "Empty product", - product: "", - expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN, - }, - { - name: "Case sensitive - milan lowercase", - product: "milan", - expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := GetProductName(tt.product) - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestVerifyReport(t *testing.T) { - tests := []struct { - name string - attestation *sevsnp.Attestation - config *check.Config - expectedError bool - errorContains string - }{ - { - name: "Invalid product line", - attestation: &sevsnp.Attestation{ - CertificateChain: &sevsnp.CertificateChain{}, - }, - config: &check.Config{ - RootOfTrust: &check.RootOfTrust{ - ProductLine: "InvalidProduct", - }, - Policy: &check.Policy{}, - }, - expectedError: true, - errorContains: "product name must be", - }, - { - name: "Valid Milan product line", - attestation: &sevsnp.Attestation{ - CertificateChain: &sevsnp.CertificateChain{ - AskCert: []byte("mock ask cert"), - ArkCert: []byte("mock ark cert"), - }, - }, - config: &check.Config{ - RootOfTrust: &check.RootOfTrust{ - ProductLine: sevSnpProductMilan, - }, - Policy: &check.Policy{}, - }, - expectedError: true, - errorContains: "attestation verification failed", - }, - { - name: "Valid Genoa product line", - attestation: &sevsnp.Attestation{ - CertificateChain: &sevsnp.CertificateChain{ - AskCert: []byte("mock ask cert"), - ArkCert: []byte("mock ark cert"), - }, - }, - config: &check.Config{ - RootOfTrust: &check.RootOfTrust{ - ProductLine: sevSnpProductGenoa, - }, - Policy: &check.Policy{}, - }, - expectedError: true, - errorContains: "attestation verification failed", - }, - { - name: "Config with existing product policy", - attestation: &sevsnp.Attestation{ - CertificateChain: &sevsnp.CertificateChain{ - AskCert: []byte("mock ask cert"), - ArkCert: []byte("mock ark cert"), - }, - }, - config: &check.Config{ - RootOfTrust: &check.RootOfTrust{ - ProductLine: sevSnpProductMilan, - }, - Policy: &check.Policy{ - Product: &sevsnp.SevProduct{ - Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN, - }, - }, - }, - expectedError: true, - errorContains: "attestation verification failed", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := verifyReport(tt.attestation, tt.config) - - if tt.expectedError { - assert.Error(t, err) - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) - } - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestValidateReport(t *testing.T) { - tests := []struct { - name string - attestation *sevsnp.Attestation - config *check.Config - expectedError bool - errorContains string - }{ - { - name: "Basic validation test", - attestation: &sevsnp.Attestation{ - CertificateChain: &sevsnp.CertificateChain{}, - }, - config: &check.Config{ - Policy: &check.Policy{ - Policy: 196608, - }, - }, - expectedError: true, - errorContains: "attestation validation failed", - }, - { - name: "Validation with report data", - attestation: &sevsnp.Attestation{ - CertificateChain: &sevsnp.CertificateChain{}, - }, - config: &check.Config{ - Policy: &check.Policy{ - Policy: 196608, - ReportData: []byte("test report datatest report datatest report datatest report data"), - }, - }, - expectedError: true, - errorContains: "attestation validation failed", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validateReport(tt.attestation, tt.config) - - if tt.expectedError { - assert.Error(t, err) - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) - } - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestFetchAttestation(t *testing.T) { - tests := []struct { - name string - reportData []byte - vmpl uint - expectedError bool - errorContains string - }{ - { - name: "Report data too large", - reportData: make([]byte, Nonce+1), - vmpl: 0, - expectedError: true, - errorContains: "could not get quote provider", - }, - { - name: "Valid report data size", - reportData: make([]byte, 32), - vmpl: 0, - expectedError: true, - errorContains: "could not get quote provider", - }, - { - name: "Maximum valid report data size", - reportData: make([]byte, Nonce), - vmpl: 1, - expectedError: true, - errorContains: "could not get quote provider", - }, - { - name: "Empty report data", - reportData: []byte{}, - vmpl: 0, - expectedError: true, - errorContains: "could not get quote provider", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := FetchAttestation(tt.reportData, tt.vmpl) - - if tt.expectedError { - assert.Error(t, err) - assert.Empty(t, result) - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) - } - } else { - assert.NoError(t, err) - assert.NotEmpty(t, result) - } - }) - } -} - -func TestGetLeveledQuoteProvider(t *testing.T) { - t.Run("GetLeveledQuoteProvider call", func(t *testing.T) { - provider, err := GetLeveledQuoteProvider() - - if err != nil { - assert.Error(t, err) - assert.Nil(t, provider) - } else { - assert.NoError(t, err) - assert.NotNil(t, provider) - } - }) -} diff --git a/pkg/attestation/quoteprovider/sev.go b/pkg/attestation/vtpm/sev.go similarity index 78% rename from pkg/attestation/quoteprovider/sev.go rename to pkg/attestation/vtpm/sev.go index 5dcf437c..b23b2ded 100644 --- a/pkg/attestation/quoteprovider/sev.go +++ b/pkg/attestation/vtpm/sev.go @@ -3,7 +3,7 @@ //go:build !embed -package quoteprovider +package vtpm import ( "crypto/rand" @@ -32,7 +32,7 @@ const ( cocosDirectory = "/cocos" arkAskBundleName = "ask_ark.pem" vcekName = "vcek.pem" - Nonce = 64 + SEVNonce = 64 sevSnpProductMilan = "Milan" sevSnpProductGenoa = "Genoa" ) @@ -43,9 +43,9 @@ var ( ) var ( - ErrProductLine = errors.New(fmt.Sprintf("product name must be %s or %s", sevSnpProductMilan, sevSnpProductGenoa)) - ErrAttVerification = errors.New("attestation verification failed") - errAttValidation = errors.New("attestation validation failed") + ErrSEVProductLine = errors.New(fmt.Sprintf("product name must be %s or %s", sevSnpProductMilan, sevSnpProductGenoa)) + ErrSEVAttVerification = errors.New("attestation verification failed") + errSEVAttValidation = errors.New("attestation validation failed") ) func fillInAttestationLocal(attestation *sevsnp.Attestation, cfg *check.Config) error { @@ -77,16 +77,17 @@ func fillInAttestationLocal(attestation *sevsnp.Attestation, cfg *check.Config) return nil } +// verifyReport verifies the SEV-SNP attestation report. func verifyReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error { sopts, err := verify.RootOfTrustToOptions(cfg.RootOfTrust) if err != nil { - return fmt.Errorf("failed to get root of trust options: %v", errors.Wrap(ErrAttVerification, err)) + return fmt.Errorf("failed to get root of trust options: %v", errors.Wrap(ErrSEVAttVerification, err)) } if cfg.Policy.Product == nil { - productName := GetProductName(cfg.RootOfTrust.ProductLine) + productName := GetSEVProductName(cfg.RootOfTrust.ProductLine) if productName == sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN { - return ErrProductLine + return ErrSEVProductLine } sopts.Product = &sevsnp.SevProduct{ @@ -107,30 +108,33 @@ func verifyReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error { } if err := verify.SnpAttestation(attestationPB, sopts); err != nil { - return errors.Wrap(ErrAttVerification, err) + return errors.Wrap(ErrSEVAttVerification, err) } return nil } +// validateReport validates the SEV-SNP attestation report against policy. func validateReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error { opts, err := validate.PolicyToOptions(cfg.Policy) if err != nil { - return fmt.Errorf("failed to get policy for validation: %v", errors.Wrap(ErrAttVerification, err)) + return fmt.Errorf("failed to get policy for validation: %v", errors.Wrap(ErrSEVAttVerification, err)) } if err = validate.SnpAttestation(attestationPB, opts); err != nil { - return errors.Wrap(errAttValidation, err) + return errors.Wrap(errSEVAttValidation, err) } return nil } -func GetLeveledQuoteProvider() (client.LeveledQuoteProvider, error) { +// getLeveledQuoteProvider returns a leveled quote provider for SEV-SNP. +func getLeveledQuoteProvider() (client.LeveledQuoteProvider, error) { return client.GetLeveledQuoteProvider() } -func VerifyAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData []byte, policy *attestation.Config) error { +// VerifySEVAttestationReportTLS verifies a SEV-SNP attestation report for TLS (exported for azure package). +func VerifySEVAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData []byte, policy *attestation.Config) error { config := policy.Config // Certificate chain is populated based on the extra data that is appended to the SEV-SNP attestation report. @@ -141,10 +145,11 @@ func VerifyAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData [] config.Policy.ReportData = reportData[:] } - return VerifyAndValidate(attestationPB, config) + return verifySEVAndValidate(attestationPB, config) } -func VerifyAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) error { +// verifySEVAndValidate performs both verification and validation of a SEV-SNP attestation. +func verifySEVAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) error { logger.Init("", false, false, io.Discard) if err := verifyReport(attestationPB, cfg); err != nil { @@ -158,15 +163,16 @@ func VerifyAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) err return nil } -func FetchAttestation(reportDataSlice []byte, vmpl uint) ([]byte, error) { - var reportData [Nonce]byte +// fetchSEVAttestation fetches a SEV-SNP attestation report. +func fetchSEVAttestation(reportDataSlice []byte, vmpl uint) ([]byte, error) { + var reportData [SEVNonce]byte - qp, err := GetLeveledQuoteProvider() + qp, err := getLeveledQuoteProvider() if err != nil { return []byte{}, fmt.Errorf("could not get quote provider") } - if len(reportData) > Nonce { + if len(reportData) > SEVNonce { return []byte{}, fmt.Errorf("attestation report size mismatch") } copy(reportData[:], reportDataSlice) @@ -206,7 +212,8 @@ func FetchAttestation(reportDataSlice []byte, vmpl uint) ([]byte, error) { return result, nil } -func GetProductName(product string) sevsnp.SevProduct_SevProductName { +// GetSEVProductName maps a product string to a SEV product name. +func GetSEVProductName(product string) sevsnp.SevProduct_SevProductName { switch product { case sevSnpProductMilan: return sevsnp.SevProduct_SEV_PRODUCT_MILAN @@ -217,6 +224,7 @@ func GetProductName(product string) sevsnp.SevProduct_SevProductName { } } +// derToPem converts DER-encoded certificate to PEM format. func derToPem(der []byte) []byte { // Try to parse to make sure it's a certificate if _, err := x509.ParseCertificate(der); err != nil { @@ -226,15 +234,16 @@ func derToPem(der []byte) []byte { return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) } -func FetchCertificates(vmpl uint) error { - var reportData [Nonce]byte +// FetchSEVCertificates fetches SEV-SNP certificates from KDS. +func FetchSEVCertificates(vmpl uint) error { + var reportData [SEVNonce]byte - qp, err := GetLeveledQuoteProvider() + qp, err := getLeveledQuoteProvider() if err != nil { return fmt.Errorf("could not get quote provider") } - if len(reportData) > Nonce { + if len(reportData) > SEVNonce { return fmt.Errorf("attestation report size mismatch") } diff --git a/pkg/attestation/vtpm/vtpm.go b/pkg/attestation/vtpm/vtpm.go index 827eca5e..a777d671 100644 --- a/pkg/attestation/vtpm/vtpm.go +++ b/pkg/attestation/vtpm/vtpm.go @@ -25,7 +25,6 @@ import ( "github.com/google/go-tpm/tpmutil" "github.com/ultravioletrs/cocos/pkg/attestation" "github.com/ultravioletrs/cocos/pkg/attestation/eat" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "golang.org/x/crypto/sha3" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/prototext" @@ -120,7 +119,7 @@ func (v provider) Attestation(teeNonce []byte, vTpmNonce []byte) ([]byte, error) } func (v provider) TeeAttestation(teeNonce []byte) ([]byte, error) { - return quoteprovider.FetchAttestation(teeNonce, v.vmpl) + return fetchSEVAttestation(teeNonce, v.vmpl) } func (v provider) VTpmAttestation(vTpmNonce []byte) ([]byte, error) { @@ -171,7 +170,7 @@ func (v verifier) VerifTeeAttestation(report []byte, teeNonce []byte) error { } attestationReport := sevsnp.Attestation{Report: attestReport, CertificateChain: nil} - return quoteprovider.VerifyAttestationReportTLS(&attestationReport, teeNonce, v.Policy) + return VerifySEVAttestationReportTLS(&attestationReport, teeNonce, v.Policy) } func (v verifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error { @@ -234,7 +233,7 @@ func VTPMVerify(quote []byte, teeNonce []byte, vtpmNonce []byte, writer io.Write attestData := sha3.Sum512(nonce) - if err := quoteprovider.VerifyAttestationReportTLS(attestation.GetSevSnpAttestation(), attestData[:], policy); err != nil { + if err := VerifySEVAttestationReportTLS(attestation.GetSevSnpAttestation(), attestData[:], policy); err != nil { return fmt.Errorf("failed to verify TEE attestation report: %v", err) } @@ -336,7 +335,7 @@ func addTEEAttestation(attestation *attest.Attestation, nonce []byte, vmpl uint) attestData := sha3.Sum512(teeNonce) - rawTeeAttestation, err := quoteprovider.FetchAttestation(attestData[:], vmpl) + rawTeeAttestation, err := fetchSEVAttestation(attestData[:], vmpl) if err != nil { return fmt.Errorf("failed to fetch TEE attestation report: %v", err) } diff --git a/pkg/attestation/vtpm/vtpm_test.go b/pkg/attestation/vtpm/vtpm_test.go index 955664e9..16043a48 100644 --- a/pkg/attestation/vtpm/vtpm_test.go +++ b/pkg/attestation/vtpm/vtpm_test.go @@ -22,12 +22,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ultravioletrs/cocos/pkg/attestation" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "google.golang.org/protobuf/encoding/protojson" ) -const sevSnpProductMilan = "Milan" - var policy = attestation.Config{Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &attestation.PcrConfig{}} type mockTPM struct { @@ -679,13 +676,13 @@ func TestVerifyAttestationReportMalformedSignature(t *testing.T) { name: "Valid attestation, distorted signature", attestationReport: attestationPB, reportData: reportData, - err: quoteprovider.ErrAttVerification, + err: ErrSEVAttVerification, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy) + err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy) assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err)) }) } @@ -713,13 +710,13 @@ func TestVerifyAttestationReportUnknownProduct(t *testing.T) { name: "Valid attestation, unknown product", attestationReport: attestationPB, reportData: reportData, - err: quoteprovider.ErrProductLine, + err: ErrSEVProductLine, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy) + err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy) assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err)) }) } @@ -752,7 +749,7 @@ func TestVerifyAttestationReportSuccess(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy) + err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy) assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err)) }) } @@ -780,13 +777,13 @@ func TestVerifyAttestationReportMalformedPolicy(t *testing.T) { name: "Valid attestation, malformed policy (measurement)", attestationReport: attestationPB, reportData: reportData, - err: quoteprovider.ErrAttVerification, + err: ErrSEVAttVerification, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy) + err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy) assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err)) }) } diff --git a/pkg/clients/doc.go b/pkg/clients/doc.go index 1b921bd5..e949c6e2 100644 --- a/pkg/clients/doc.go +++ b/pkg/clients/doc.go @@ -2,5 +2,5 @@ // SPDX-License-Identifier: Apache-2.0 // Package clients contains the domain concept definitions needed to support -// HTTP/gRPC Client functionality. +// client configuration for HTTP/gRPC connections. package clients diff --git a/pkg/clients/grpc/agent/agent.go b/pkg/clients/grpc/agent/agent.go index 6408282c..c98d5643 100644 --- a/pkg/clients/grpc/agent/agent.go +++ b/pkg/clients/grpc/agent/agent.go @@ -9,6 +9,7 @@ import ( "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/pkg/clients" "github.com/ultravioletrs/cocos/pkg/clients/grpc" + "github.com/ultravioletrs/cocos/pkg/tls" grpchealth "google.golang.org/grpc/health/grpc_health_v1" ) @@ -21,7 +22,7 @@ func NewAgentClient(ctx context.Context, cfg clients.AttestedClientConfig) (grpc return nil, nil, err } - if client.Secure() != clients.WithMATLS.String() && client.Secure() != clients.WithATLS.String() && client.Secure() != clients.WithTLS.String() { + if client.Secure() != tls.WithMATLS.String() && client.Secure() != tls.WithATLS.String() && client.Secure() != tls.WithTLS.String() { health := grpchealth.NewHealthClient(client.Connection()) resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{ Service: "agent", diff --git a/pkg/clients/grpc/connect_test.go b/pkg/clients/grpc/connect_test.go index 2a3d38a4..f3fe8e7b 100644 --- a/pkg/clients/grpc/connect_test.go +++ b/pkg/clients/grpc/connect_test.go @@ -22,6 +22,7 @@ import ( "github.com/ultravioletrs/cocos/pkg/attestation" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "github.com/ultravioletrs/cocos/pkg/clients" + "github.com/ultravioletrs/cocos/pkg/tls" ) func TestNewClient(t *testing.T) { @@ -119,7 +120,7 @@ func TestNewClient(t *testing.T) { ServerCAFile: "nonexistent.pem", }, wantErr: true, - err: clients.ErrFailedToLoadRootCA, + err: tls.ErrFailedToLoadRootCA, }, { name: "Fail with invalid ClientCert", @@ -130,7 +131,7 @@ func TestNewClient(t *testing.T) { ClientKey: clientKeyFile, }, wantErr: true, - err: clients.ErrFailedToLoadClientCertKey, + err: tls.ErrFailedToLoadClientCertKey, }, { name: "Fail with invalid ClientKey", @@ -141,7 +142,7 @@ func TestNewClient(t *testing.T) { ClientKey: "nonexistent.pem", }, wantErr: true, - err: clients.ErrFailedToLoadClientCertKey, + err: tls.ErrFailedToLoadClientCertKey, }, } @@ -169,32 +170,32 @@ func TestNewClient(t *testing.T) { func TestClientSecure(t *testing.T) { tests := []struct { name string - secure clients.Security + secure tls.Security expected string }{ { name: "Without TLS", - secure: clients.WithoutTLS, + secure: tls.WithoutTLS, expected: "without TLS", }, { name: "With TLS", - secure: clients.WithTLS, + secure: tls.WithTLS, expected: "with TLS", }, { name: "With mTLS", - secure: clients.WithMTLS, + secure: tls.WithMTLS, expected: "with mTLS", }, { name: "With aTLS", - secure: clients.WithATLS, + secure: tls.WithATLS, expected: "with aTLS", }, { name: "With maTLS", - secure: clients.WithMATLS, + secure: tls.WithMATLS, expected: "with maTLS", }, } diff --git a/pkg/clients/grpc/grpc.go b/pkg/clients/grpc/grpc.go index 76806c49..b68e7414 100644 --- a/pkg/clients/grpc/grpc.go +++ b/pkg/clients/grpc/grpc.go @@ -6,6 +6,7 @@ package grpc import ( "github.com/absmach/supermq/pkg/errors" "github.com/ultravioletrs/cocos/pkg/clients" + "github.com/ultravioletrs/cocos/pkg/tls" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -26,7 +27,7 @@ type Client interface { type client struct { *grpc.ClientConn cfg clients.ClientConfiguration - security clients.Security + security tls.Security } var _ Client = (*client)(nil) @@ -59,14 +60,19 @@ func (c *client) Connection() *grpc.ClientConn { return c.ClientConn } -func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, clients.Security, error) { +func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, tls.Security, error) { opts := []grpc.DialOption{ grpc.WithStatsHandler(otelgrpc.NewClientHandler()), } - security := clients.WithoutTLS + security := tls.WithoutTLS if agcfg, ok := cfg.(clients.AttestedClientConfig); ok && agcfg.AttestedTLS { - result, err := clients.LoadATLSConfig(agcfg) + result, err := tls.LoadATLSConfig( + agcfg.AttestationPolicy, + agcfg.ServerCAFile, + agcfg.ClientCert, + agcfg.ClientKey, + ) if err != nil { return nil, security, err } @@ -90,13 +96,13 @@ func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, clients.Securit return conn, security, nil } -func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, clients.Security, error) { - result, err := clients.LoadBasicTLSConfig(serverCAFile, clientCert, clientKey) +func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, tls.Security, error) { + result, err := tls.LoadBasicConfig(serverCAFile, clientCert, clientKey) if err != nil { - return nil, clients.WithoutTLS, err + return nil, tls.WithoutTLS, err } - if result.Security == clients.WithoutTLS || result.Config == nil { + if result.Security == tls.WithoutTLS || result.Config == nil { return insecure.NewCredentials(), result.Security, nil } diff --git a/pkg/clients/http/client.go b/pkg/clients/http/client.go index c1f9759f..97fd3933 100644 --- a/pkg/clients/http/client.go +++ b/pkg/clients/http/client.go @@ -8,6 +8,7 @@ import ( "time" "github.com/ultravioletrs/cocos/pkg/clients" + "github.com/ultravioletrs/cocos/pkg/tls" ) type Client interface { @@ -19,7 +20,7 @@ type Client interface { type client struct { transport *http.Transport cfg clients.ClientConfiguration - security clients.Security + security tls.Security } var _ Client = (*client)(nil) @@ -49,17 +50,22 @@ func (c *client) Timeout() time.Duration { return c.cfg.Config().Timeout } -func createTransport(cfg clients.ClientConfiguration) (*http.Transport, clients.Security, error) { +func createTransport(cfg clients.ClientConfiguration) (*http.Transport, tls.Security, error) { transport := &http.Transport{ MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, } - security := clients.WithoutTLS + security := tls.WithoutTLS if agcfg, ok := cfg.(*clients.AttestedClientConfig); ok && agcfg.AttestedTLS { - result, err := clients.LoadATLSConfig(*agcfg) + result, err := tls.LoadATLSConfig( + agcfg.AttestationPolicy, + agcfg.ServerCAFile, + agcfg.ClientCert, + agcfg.ClientKey, + ) if err != nil { return nil, security, err } @@ -69,12 +75,12 @@ func createTransport(cfg clients.ClientConfiguration) (*http.Transport, clients. } else { conf := cfg.Config() - result, err := clients.LoadBasicTLSConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey) + result, err := tls.LoadBasicConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey) if err != nil { return nil, security, err } - if result.Security != clients.WithoutTLS { + if result.Security != tls.WithoutTLS { transport.TLSClientConfig = result.Config } diff --git a/pkg/clients/http/client_test.go b/pkg/clients/http/client_test.go index 85bbd90f..3686b8c9 100644 --- a/pkg/clients/http/client_test.go +++ b/pkg/clients/http/client_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/ultravioletrs/cocos/pkg/clients" + "github.com/ultravioletrs/cocos/pkg/tls" ) func TestConfig_Configuration(t *testing.T) { @@ -144,7 +145,7 @@ func TestClient_Secure(t *testing.T) { URL: "http://localhost:8080", Timeout: 30 * time.Second, }, - expected: clients.WithoutTLS.String(), + expected: tls.WithoutTLS.String(), }, } @@ -183,7 +184,7 @@ func TestCreateTransport_DefaultSettings(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, transport) - assert.Equal(t, clients.WithoutTLS, security) + assert.Equal(t, tls.WithoutTLS, security) assert.Equal(t, 100, transport.MaxIdleConns) assert.Equal(t, 90*time.Second, transport.IdleConnTimeout) assert.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout) @@ -205,7 +206,7 @@ func TestCreateTransport_ATLSError(t *testing.T) { assert.Error(t, err) assert.Nil(t, transport) - assert.Equal(t, clients.WithoutTLS, security) + assert.Equal(t, tls.WithoutTLS, security) assert.Contains(t, err.Error(), "failed to stat attestation policy") } @@ -220,7 +221,7 @@ func TestCreateTransport_BasicTLSError(t *testing.T) { assert.Error(t, err) assert.Nil(t, transport) - assert.Equal(t, clients.WithoutTLS, security) + assert.Equal(t, tls.WithoutTLS, security) assert.Contains(t, err.Error(), "failed to load root ca file") } diff --git a/pkg/sdk/agent_test.go b/pkg/sdk/agent_test.go index f20a4ddc..5d19530b 100644 --- a/pkg/sdk/agent_test.go +++ b/pkg/sdk/agent_test.go @@ -19,7 +19,6 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/ultravioletrs/cocos/agent" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "github.com/ultravioletrs/cocos/pkg/sdk" "golang.org/x/crypto/sha3" @@ -388,7 +387,7 @@ func TestAttestation(t *testing.T) { cases := []struct { name string userKey any - reportData [quoteprovider.Nonce]byte + reportData [vtpm.SEVNonce]byte nonce [vtpm.Nonce]byte response *agent.AttestationResponse svcRes []byte @@ -397,7 +396,7 @@ func TestAttestation(t *testing.T) { { name: "fetch attestation report successfully", userKey: resultConsumerKey, - reportData: [quoteprovider.Nonce]byte(reportData), + reportData: [vtpm.SEVNonce]byte(reportData), nonce: [vtpm.Nonce]byte(nonce), response: &agent.AttestationResponse{ File: report, @@ -408,7 +407,7 @@ func TestAttestation(t *testing.T) { { name: "fetch attestation report with different key type", userKey: resultConsumer1Key, - reportData: [quoteprovider.Nonce]byte(reportData), + reportData: [vtpm.SEVNonce]byte(reportData), nonce: [vtpm.Nonce]byte(nonce), response: &agent.AttestationResponse{ File: report, @@ -419,7 +418,7 @@ func TestAttestation(t *testing.T) { { name: "failed to fetch attestation report", userKey: resultConsumerKey, - reportData: [quoteprovider.Nonce]byte(reportData), + reportData: [vtpm.SEVNonce]byte(reportData), nonce: [vtpm.Nonce]byte(nonce), response: &agent.AttestationResponse{ File: []byte{}, @@ -429,7 +428,7 @@ func TestAttestation(t *testing.T) { { name: "invalid report data", userKey: resultConsumerKey, - reportData: [quoteprovider.Nonce]byte{}, + reportData: [vtpm.SEVNonce]byte{}, nonce: [vtpm.Nonce]byte(nonce), response: &agent.AttestationResponse{ File: []byte{}, diff --git a/pkg/tls/doc.go b/pkg/tls/doc.go new file mode 100644 index 00000000..fce8e36e --- /dev/null +++ b/pkg/tls/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +// Package tls provides TLS configuration utilities for client connections. +// It supports standard TLS, mutual TLS (mTLS), and attested TLS (aTLS). +package tls diff --git a/pkg/clients/tls.go b/pkg/tls/tls.go similarity index 80% rename from pkg/clients/tls.go rename to pkg/tls/tls.go index 7538b736..51bd4d8d 100644 --- a/pkg/clients/tls.go +++ b/pkg/tls/tls.go @@ -1,7 +1,7 @@ // Copyright (c) Ultraviolet // SPDX-License-Identifier: Apache-2.0 -package clients +package tls import ( "crypto/rand" @@ -53,20 +53,20 @@ var ( errAttestationPolicyIrregular = errors.New("attestation policy file is not a regular file") ) -// TLSResult contains the result of TLS configuration. -type TLSResult struct { +// Result contains the result of TLS configuration. +type Result struct { Config *tls.Config Security Security } -// LoadBasicTLSConfig loads standard TLS configuration (TLS/mTLS). -func LoadBasicTLSConfig(serverCAFile, clientCert, clientKey string) (*TLSResult, error) { +// LoadBasicConfig loads standard TLS configuration (TLS/mTLS). +func LoadBasicConfig(serverCAFile, clientCert, clientKey string) (*Result, error) { tlsConfig := &tls.Config{} security := WithoutTLS // If no TLS configuration is provided, return nil config (no TLS) if serverCAFile == "" && clientCert == "" && clientKey == "" { - return &TLSResult{Config: nil, Security: security}, nil + return &Result{Config: nil, Security: security}, nil } if serverCAFile != "" { @@ -96,14 +96,15 @@ func LoadBasicTLSConfig(serverCAFile, clientCert, clientKey string) (*TLSResult, security = WithMTLS } - return &TLSResult{Config: tlsConfig, Security: security}, nil + return &Result{Config: tlsConfig, Security: security}, nil } // LoadATLSConfig configures Attested TLS. -func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) { +// Parameters are passed individually to avoid circular dependencies with the clients package. +func LoadATLSConfig(attestationPolicy, serverCAFile, clientCert, clientKey string) (*Result, error) { security := WithATLS - info, err := os.Stat(cfg.AttestationPolicy) + info, err := os.Stat(attestationPolicy) if err != nil { return nil, errors.Wrap(errors.New("failed to stat attestation policy file"), err) } @@ -112,12 +113,12 @@ func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) { return nil, errAttestationPolicyIrregular } - attestation.AttestationPolicyPath = cfg.AttestationPolicy + attestation.AttestationPolicyPath = attestationPolicy var rootCAs *x509.CertPool - if cfg.ServerCAFile != "" { - rootCAs, err = loadRootCAs(cfg.ServerCAFile) + if serverCAFile != "" { + rootCAs, err = loadRootCAs(serverCAFile) if err != nil { return nil, err } @@ -141,8 +142,8 @@ func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) { }, } - if cfg.ClientCert != "" || cfg.ClientKey != "" { - certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey) + if clientCert != "" || clientKey != "" { + certificate, err := tls.LoadX509KeyPair(clientCert, clientKey) if err != nil { return nil, errors.Wrap(ErrFailedToLoadClientCertKey, err) } @@ -150,7 +151,7 @@ func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) { tlsConfig.Certificates = []tls.Certificate{certificate} } - return &TLSResult{Config: tlsConfig, Security: security}, nil + return &Result{Config: tlsConfig, Security: security}, nil } // loadRootCAs loads root CA certificates from a file. diff --git a/pkg/clients/tls_test.go b/pkg/tls/tls_test.go similarity index 77% rename from pkg/clients/tls_test.go rename to pkg/tls/tls_test.go index b9e82a66..d37d3e98 100644 --- a/pkg/clients/tls_test.go +++ b/pkg/tls/tls_test.go @@ -1,7 +1,7 @@ // Copyright (c) Ultraviolet // SPDX-License-Identifier: Apache-2.0 -package clients +package tls import ( "crypto/rand" @@ -64,7 +64,7 @@ func TestSecurity_String(t *testing.T) { } } -func TestLoadBasicTLSConfig(t *testing.T) { +func TestLoadBasicConfig(t *testing.T) { // Create temporary directory for test files tmpDir := t.TempDir() @@ -164,7 +164,7 @@ func TestLoadBasicTLSConfig(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := LoadBasicTLSConfig(tt.serverCAFile, tt.clientCert, tt.clientKey) + result, err := LoadBasicConfig(tt.serverCAFile, tt.clientCert, tt.clientKey) if tt.expectError { assert.Error(t, err) @@ -201,100 +201,86 @@ func TestLoadATLSConfig(t *testing.T) { require.NoError(t, os.WriteFile(policyFile, []byte(`{"policy": "test"}`), 0o644)) tests := []struct { - name string - config AttestedClientConfig - expectedSec Security - expectError bool - errorMsg string + name string + attestationPolicy string + serverCAFile string + clientCert string + clientKey string + expectedSec Security + expectError bool + errorMsg string }{ { - name: "ValidATLSConfig", - config: AttestedClientConfig{ - StandardClientConfig: StandardClientConfig{ - ServerCAFile: "", - }, - AttestationPolicy: policyFile, - ProductName: "test-product", - }, - expectedSec: WithATLS, - expectError: false, + name: "ValidATLSConfig", + attestationPolicy: policyFile, + serverCAFile: "", + clientCert: "", + clientKey: "", + expectedSec: WithATLS, + expectError: false, }, { - name: "ValidMATLSConfig", - config: AttestedClientConfig{ - StandardClientConfig: StandardClientConfig{ - ServerCAFile: caFile, - }, - AttestationPolicy: policyFile, - ProductName: "test-product", - }, - expectedSec: WithMATLS, - expectError: false, + name: "ValidMATLSConfig", + attestationPolicy: policyFile, + serverCAFile: caFile, + clientCert: "", + clientKey: "", + expectedSec: WithMATLS, + expectError: false, }, { - name: "ValidATLSWithClientCert", - config: AttestedClientConfig{ - StandardClientConfig: StandardClientConfig{ - ClientCert: certFile, - ClientKey: keyFile, - }, - AttestationPolicy: policyFile, - ProductName: "test-product", - }, - expectedSec: WithATLS, - expectError: false, + name: "ValidATLSWithClientCert", + attestationPolicy: policyFile, + serverCAFile: "", + clientCert: certFile, + clientKey: keyFile, + expectedSec: WithATLS, + expectError: false, }, { - name: "NonexistentPolicyFile", - config: AttestedClientConfig{ - AttestationPolicy: filepath.Join(tmpDir, "nonexistent.json"), - ProductName: "test-product", - }, - expectedSec: WithoutTLS, - expectError: true, - errorMsg: "failed to stat attestation policy file", + name: "NonexistentPolicyFile", + attestationPolicy: filepath.Join(tmpDir, "nonexistent.json"), + serverCAFile: "", + clientCert: "", + clientKey: "", + expectedSec: WithoutTLS, + expectError: true, + errorMsg: "failed to stat attestation policy file", }, { - name: "PolicyFileIsDirectory", - config: AttestedClientConfig{ - AttestationPolicy: tmpDir, // Directory instead of file - ProductName: "test-product", - }, - expectedSec: WithoutTLS, - expectError: true, - errorMsg: "attestation policy file is not a regular file", + name: "PolicyFileIsDirectory", + attestationPolicy: tmpDir, // Directory instead of file + serverCAFile: "", + clientCert: "", + clientKey: "", + expectedSec: WithoutTLS, + expectError: true, + errorMsg: "attestation policy file is not a regular file", }, { - name: "InvalidCAFile", - config: AttestedClientConfig{ - StandardClientConfig: StandardClientConfig{ - ServerCAFile: filepath.Join(tmpDir, "nonexistent.crt"), - }, - AttestationPolicy: policyFile, - ProductName: "test-product", - }, - expectedSec: WithoutTLS, - expectError: true, - errorMsg: "failed to read certificate file", + name: "InvalidCAFile", + attestationPolicy: policyFile, + serverCAFile: filepath.Join(tmpDir, "nonexistent.crt"), + clientCert: "", + clientKey: "", + expectedSec: WithoutTLS, + expectError: true, + errorMsg: "failed to read certificate file", }, { - name: "InvalidClientCert", - config: AttestedClientConfig{ - StandardClientConfig: StandardClientConfig{ - ClientCert: filepath.Join(tmpDir, "nonexistent.crt"), - ClientKey: keyFile, - }, - AttestationPolicy: policyFile, - ProductName: "test-product", - }, - expectedSec: WithoutTLS, - expectError: true, + name: "InvalidClientCert", + attestationPolicy: policyFile, + serverCAFile: "", + clientCert: filepath.Join(tmpDir, "nonexistent.crt"), + clientKey: keyFile, + expectedSec: WithoutTLS, + expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := LoadATLSConfig(tt.config) + result, err := LoadATLSConfig(tt.attestationPolicy, tt.serverCAFile, tt.clientCert, tt.clientKey) if tt.expectError { assert.Error(t, err)