NOISSUE - Add internal tests (#266)

* add internal tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix linter

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix race conditions

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* remove all races

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2024-10-09 21:01:11 +03:00
committed by GitHub
parent db7f3c7a4b
commit 18aa8ba785
21 changed files with 1013 additions and 455 deletions
-62
View File
@@ -1,62 +0,0 @@
#!/bin/bash
# Set your default values for sudo and sev
sudo_option=false
sev_option=false
# Parse command line arguments
while [[ $# -gt 0 ]]; do
key="$1"
case $key in
--sudo)
sudo_option=true
shift
;;
--sev)
sev_option=true
shift
;;
*)
echo "Unknown option: $key"
exit 1
;;
esac
done
build_qemu_command() {
local qemu_command="/usr/bin/qemu-system-x86_64 -enable-kvm -machine q35 -cpu EPYC -smp 4,maxcpus=64 -m 2048M,slots=5,maxmem=30G -drive if=pflash,format=raw,unit=0,file=$MANAGER_QEMU_OVMF_CODE_FILE,readonly=on -drive if=pflash,format=raw,unit=1,file=img/OVMF_VARS.fd -device virtio-scsi-pci,id=scsi,disable-legacy=on,iommu_platform=true -drive file=img/focal-server-cloudimg-amd64.img,if=none,id=disk0,format=qcow2 -device scsi-hd,drive=disk0 -netdev user,id=vmnic,hostfwd=tcp::2222-:22,hostfwd=tcp::9301-:9031,hostfwd=tcp::7020-:7002 -device virtio-net-pci,disable-legacy=on,iommu_platform=true,netdev=vmnic,romfile= -nographic -monitor pty"
if [ "$sev_option" = true ]; then
qemu_command="$qemu_command -object sev-guest,id=sev0,cbitpos=51,reduced-phys-bits=1 -machine memory-encryption=sev0"
fi
echo "$qemu_command"
}
if [ ! -f "img/OVMF_VARS.fd" ]; then
cp "$MANAGER_QEMU_OVMF_VARS_FILE" "img/OVMF_VARS.fd"
echo "Copied $MANAGER_QEMU_OVMF_VARS_FILE to img/OVMF_VARS.fd"
else
echo "img/OVMF_VARS.fd already exists. No need to copy."
fi
echo "Launching VM ..."
qemu_command=$(build_qemu_command)
echo "$qemu_command"
echo "Mapping CTRL-C to CTRL-]"
stty intr ^]
if [ "$sudo_option" = true ]; then
# Split the command and arguments into an array; << operator is known as a "here string"
IFS=" " read -r -a qemu_command_array <<< "$qemu_command"
# Treat each element in the array as a separate word, preserving spaces within each element
sudo "${qemu_command_array[@]}"
else
$qemu_command
fi
# Restore the mapping
stty intr ^c
-62
View File
@@ -1,62 +0,0 @@
<domain type="kvm">
<name>QEmu-alpine-standard-x86_64</name>
<uuid>c7a5fdbd-cdaf-9455-926a-d65c16db1809</uuid>
<metadata>
<libosinfo:libosinfo xmlns:libosinfo="http://libosinfo.org/xmlns/libvirt/domain/1.0">
<libosinfo:os id="http://alpinelinux.org/alpinelinux/3.15"/>
</libosinfo:libosinfo>
</metadata>
<memory unit="KiB">4194304</memory>
<currentMemory unit="KiB">4194304</currentMemory>
<vcpu placement="static">1</vcpu>
<os>
<type arch="x86_64" machine="q35">hvm</type>
<bootmenu enable="yes"/>
<loader readonly="yes" type="pflash">/usr/share/OVMF/OVMF_CODE.fd</loader>
<nvram template='/usr/share/OVMF/OVMF_VARS.fd'>./img/OVMF_VARS.fd</nvram>
<!-- <boot dev='hd'/> -->
</os>
<features>
<acpi/>
<apic/>
<vmport state="off"/>
</features>
<cpu mode="host-passthrough" check="none" migratable="on"/>
<clock offset="utc">
<timer name="rtc" tickpolicy="catchup"/>
<timer name="pit" tickpolicy="delay"/>
<timer name="hpet" present="no"/>
</clock>
<on_poweroff>destroy</on_poweroff>
<on_reboot>restart</on_reboot>
<on_crash>destroy</on_crash>
<pm>
<suspend-to-mem enabled="no"/>
<suspend-to-disk enabled="no"/>
</pm>
<devices>
<emulator>/usr/bin/qemu-system-x86_64</emulator>
<disk type="file" device="disk">
<driver name="qemu" type="qcow2" discard="unmap"/>
<source file="./img/focal-server-cloudimg-amd64.qcow2"/>
<target dev="vda" bus="virtio"/>
<address type="pci" domain="0x0000" bus="0x04" slot="0x00" function="0x0"/>
<boot order="1"/>
</disk>
<graphics type="spice" autoport="yes">
<listen type="address"/>
<image compression="off"/>
</graphics>
<video>
<model type="qxl" ram="65536" vram="65536" vgamem="16384" heads="1" primary="yes"/>
<address type="pci" domain="0x0000" bus="0x00" slot="0x01" function="0x0"/>
</video>
<interface type="network">
<mac address="52:54:00:03:7b:5f"/>
<source network="default"/>
<model type="virtio"/>
<address type="pci" domain="0x0000" bus="0x01" slot="0x00" function="0x0"/>
</interface>
</devices>
</domain>
-6
View File
@@ -1,6 +0,0 @@
<pool type="dir">
<name>virtimages</name>
<target>
<path>./img</path>
</target>
</pool>
-15
View File
@@ -1,15 +0,0 @@
<volume>
<name>boot.img</name>
<allocation>0</allocation>
<capacity unit="G">1</capacity>
<target>
<format type="qcow2"/>
<path>./img/boot.img</path>
<permissions>
<owner>107</owner>
<group>107</group>
<mode>0744</mode>
<label>virt_image_t</label>
</permissions>
</target>
</volume>
-64
View File
@@ -1,64 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package internal
import (
"bytes"
"fmt"
"io"
"os"
"os/exec"
"strings"
)
// ExeShCmdStdout executes a shell command capturing the standard output.
func ExeShCmdStdout(command string, args ...string) (string, error) {
var stdoutBuf, stderrBuf bytes.Buffer
cmd := exec.Command(command, args...)
// Capture stdout and stderr using buffers
cmd.Stdout = io.MultiWriter(&stdoutBuf, os.Stdout)
cmd.Stderr = io.MultiWriter(&stderrBuf, os.Stderr)
err := cmd.Run()
if err != nil {
return "", fmt.Errorf("error executing command '%s': %s", cmd.String(), err)
}
return stdoutBuf.String(), nil
}
// ExtractCmdAndArgs extracts the command and its arguments from the output string.
func ExtractCmdAndArgs(cmdLine string, sudo bool) (string, []string) {
lines := strings.Split(cmdLine, "\n")
if len(lines) == 0 {
return "", nil
}
parts := strings.Fields(lines[0])
if len(parts) == 0 {
return "", nil
}
if sudo {
parts = append([]string{"sudo"}, parts...)
}
cmd := parts[0]
args := parts[1:]
return cmd, args
}
// RunCmdOutput runs the specified command and returns its standard output as a string.
func RunCmdOutput(command string, args ...string) (string, error) {
cmd := exec.Command(command, args...)
output, err := cmd.Output()
if err != nil {
return "", fmt.Errorf("error executing command '%s': %s", cmd.String(), err)
}
return string(output), nil
}
+136
View File
@@ -0,0 +1,136 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package internal
import (
"bytes"
"encoding/hex"
"os"
"path/filepath"
"testing"
)
func TestCopyFile(t *testing.T) {
tempDir, err := os.MkdirTemp("", "copyfile_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tempDir)
srcPath := filepath.Join(tempDir, "source.txt")
content := []byte("Hello, World!")
if err := os.WriteFile(srcPath, content, 0o644); err != nil {
t.Fatalf("Failed to create source file: %v", err)
}
dstPath := filepath.Join(tempDir, "destination.txt")
if err := CopyFile(srcPath, dstPath); err != nil {
t.Fatalf("CopyFile failed: %v", err)
}
copiedContent, err := os.ReadFile(dstPath)
if err != nil {
t.Fatalf("Failed to read destination file: %v", err)
}
if !bytes.Equal(content, copiedContent) {
t.Errorf("Copied content does not match original. Got %s, want %s", copiedContent, content)
}
}
func TestDeleteFilesInDir(t *testing.T) {
tempDir, err := os.MkdirTemp("", "deletefiles_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tempDir)
filenames := []string{"file1.txt", "file2.txt", "file3.txt"}
for _, filename := range filenames {
filepath := filepath.Join(tempDir, filename)
if err := os.WriteFile(filepath, []byte("test"), 0o644); err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
}
if err := DeleteFilesInDir(tempDir); err != nil {
t.Fatalf("DeleteFilesInDir failed: %v", err)
}
remainingFiles, err := os.ReadDir(tempDir)
if err != nil {
t.Fatalf("Failed to read directory: %v", err)
}
if len(remainingFiles) != 0 {
t.Errorf("Directory not empty after deletion. %d files remain", len(remainingFiles))
}
}
func TestChecksum(t *testing.T) {
tempDir, err := os.MkdirTemp("", "checksum_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tempDir)
filePath := filepath.Join(tempDir, "test.txt")
content := []byte("Hello, World!")
if err := os.WriteFile(filePath, content, 0o644); err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
checksum, err := Checksum(filePath)
if err != nil {
t.Fatalf("Checksum failed: %v", err)
}
expectedChecksum, _ := hex.DecodeString("1af17a664e3fa8e419b8ba05c2a173169df76162a5a286e0c405b460d478f7ef")
if !bytes.Equal(checksum, expectedChecksum) {
t.Errorf("File checksum mismatch. Got %x, want %x", checksum, expectedChecksum)
}
dirPath := filepath.Join(tempDir, "testdir")
if err := os.Mkdir(dirPath, 0o755); err != nil {
t.Fatalf("Failed to create test directory: %v", err)
}
if err := os.WriteFile(filepath.Join(dirPath, "file1.txt"), []byte("File 1"), 0o644); err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
if err := os.WriteFile(filepath.Join(dirPath, "file2.txt"), []byte("File 2"), 0o644); err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
dirChecksum, err := Checksum(dirPath)
if err != nil {
t.Fatalf("Directory Checksum failed: %v", err)
}
if len(dirChecksum) != 32 { // SHA3-256 produces a 32-byte hash
t.Errorf("Unexpected directory checksum length. Got %d bytes, want 32 bytes", len(dirChecksum))
}
}
func TestChecksumHex(t *testing.T) {
tempFile, err := os.CreateTemp("", "checksumhex_test")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tempFile.Name())
content := []byte("Hello, World!")
if _, err := tempFile.Write(content); err != nil {
t.Fatalf("Failed to write to test file: %v", err)
}
tempFile.Close()
checksumHex, err := ChecksumHex(tempFile.Name())
if err != nil {
t.Fatalf("ChecksumHex failed: %v", err)
}
expectedChecksumHex := "1af17a664e3fa8e419b8ba05c2a173169df76162a5a286e0c405b460d478f7ef"
if checksumHex != expectedChecksumHex {
t.Errorf("ChecksumHex mismatch. Got %s, want %s", checksumHex, expectedChecksumHex)
}
}
-45
View File
@@ -1,45 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package libvirt
import (
"fmt"
"log"
"log/slog"
"net"
"time"
"github.com/digitalocean/go-libvirt"
)
func Connect(logger *slog.Logger) *libvirt.Libvirt {
// This dials libvirt on the local machine, but you can substitute the first
// two parameters with "tcp", "<ip address>:<port>" to connect to libvirt on
// a remote machine.
c, err := net.DialTimeout("unix", "/var/run/libvirt/libvirt-sock", 2*time.Second)
if err != nil {
log.Fatalf("failed to dial libvirt: %v", err)
}
l := libvirt.New(c)
if err := l.Connect(); err != nil {
log.Fatalf("failed to connect: %v", err)
}
v, err := l.Version()
if err != nil {
logger.Error(fmt.Sprintf("failed to retrieve libvirt version: %v", err))
}
logger.Info(fmt.Sprintf("Retrieved libvirt version: %s", v))
domains, err := l.Domains()
if err != nil {
logger.Error(fmt.Sprintf("failed to retrieve domains: %v", err))
}
for _, d := range domains {
logger.Info(fmt.Sprintf("%d\t%s\t%x\n", d.ID, d.Name, d.UUID))
}
return l
}
-167
View File
@@ -1,167 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package libvirt
import (
"context"
"errors"
"fmt"
"os"
"regexp"
"strings"
golibvirt "github.com/digitalocean/go-libvirt"
)
var re = regexp.MustCompile(`'([^']*)'`)
func CreateDomain(ctx context.Context, libvirt *golibvirt.Libvirt, poolXML, volXML, domXML string) (string, error) {
wd, err := os.Getwd()
if err != nil {
return "", err
}
poolStr, err := readXMLFile(poolXML, "pool.xml")
if err != nil {
return "", err
}
poolStr = replaceSubstring(poolStr, "./", wd+"/")
volStr, err := readXMLFile(volXML, "vol.xml")
if err != nil {
return "", err
}
volStr = replaceSubstring(volStr, "./", wd+"/")
domStr, err := readXMLFile(domXML, "dom.xml")
if err != nil {
return "", err
}
domStr = replaceSubstring(domStr, "./", wd+"/")
dom, err := createDomain(libvirt, poolStr, volStr, domStr)
if err != nil {
return "", fmt.Errorf("failed to create domain: %s", err)
}
return dom.Name, nil
}
func createDomain(libvirtConn *golibvirt.Libvirt, poolXML, volXML, domXML string) (golibvirt.Domain, error) {
pool, err := libvirtConn.StoragePoolCreateXML(poolXML, 0)
_ = pool
if err != nil {
lvErr := err.(golibvirt.Error)
if lvErr.Code == 9 {
name, err := entityName(lvErr.Message)
if err != nil {
return golibvirt.Domain{}, err
}
pool, err = libvirtConn.StoragePoolLookupByName(name)
if err != nil {
return golibvirt.Domain{}, err
}
goto pool_exists
}
return golibvirt.Domain{}, err
}
pool_exists:
_, err = libvirtConn.StorageVolCreateXML(pool, volXML, 0)
if err != nil {
lvErr := err.(golibvirt.Error)
if lvErr.Code == 90 {
name, err := entityName(lvErr.Message)
if err != nil {
return golibvirt.Domain{}, err
}
_, err = libvirtConn.StorageVolLookupByName(pool, name)
if err != nil {
return golibvirt.Domain{}, err
}
goto vol_exists
}
return golibvirt.Domain{}, err
}
vol_exists:
dom, err := libvirtConn.DomainDefineXMLFlags(domXML, 0)
if err != nil {
return golibvirt.Domain{}, err
}
err = libvirtConn.DomainCreate(dom)
if err != nil {
lvErr := err.(golibvirt.Error)
if lvErr.Code == 55 {
return dom, nil
}
return golibvirt.Domain{}, err
}
// extra flags; not used yet, so callers should always pass 0
current, err := libvirtConn.DomainSnapshotCurrent(dom, 0)
if err != nil {
lvErr := err.(golibvirt.Error)
if lvErr.Code == 72 {
return dom, nil
}
return golibvirt.Domain{}, err
}
err = libvirtConn.DomainRevertToSnapshot(current, uint32(golibvirt.DomainSnapshotRevertRunning))
if err != nil {
return golibvirt.Domain{}, err
}
return dom, nil
}
func entityName(msg string) (string, error) {
match := re.FindStringSubmatch(msg)
if len(match) < 1 {
return "", errors.New("entity not found")
}
return match[1], nil
}
func readXMLFile(filename, defaultFilename string) (string, error) {
if filename == "" {
filename = "./xml/" + defaultFilename
}
xmlBytes, err := os.ReadFile(filename)
if err != nil {
return "", fmt.Errorf("failed to read XML file: %s", err)
}
return string(xmlBytes), nil
}
func replaceSubstring(xml, substring, replacement string) string {
// Split the file text into lines
lines := strings.Split(xml, "\n")
// Create a variable to hold the resulting string
var result strings.Builder
// Iterate over each line
for _, line := range lines {
// Replace the substring with the replacement
newLine := strings.ReplaceAll(line, substring, replacement)
// Append the modified line to the resulting string
result.WriteString(newLine)
result.WriteString("\n")
}
return result.String()
}
+2
View File
@@ -27,6 +27,8 @@ type handler struct {
stopRetry chan struct{}
}
//go:generate mockery --name io.Writer --output ./mocks --filename io_writer.go
func NewProtoHandler(conn io.Writer, opts *slog.HandlerOptions, cmpID string) slog.Handler {
if opts == nil {
opts = &slog.HandlerOptions{}
+83
View File
@@ -0,0 +1,83 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package logger
import (
"context"
"io"
"log/slog"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
type failedWriter struct{}
func (f *failedWriter) Write(p []byte) (n int, err error) {
return 0, io.ErrUnexpectedEOF
}
// TestNewProtoHandler tests the initialization of the ProtoHandler.
func TestNewProtoHandler(t *testing.T) {
handler := NewProtoHandler(io.Discard, nil, "testCmpID")
assert.NotNil(t, handler, "Handler should not be nil")
}
// TestHandleMessageSuccess tests the handling of a message when the write succeeds.
func TestHandleMessageSuccess(t *testing.T) {
handler := NewProtoHandler(io.Discard, nil, "testCmpID")
record := slog.Record{
Time: time.Now(),
Message: "Test message",
Level: slog.LevelInfo,
}
err := handler.Handle(context.Background(), record)
assert.NoError(t, err, "Handle should not return an error")
}
// TestHandleMessageFailure tests the caching mechanism when the write fails.
func TestHandleMessageFailure(t *testing.T) {
protohandler := NewProtoHandler(&failedWriter{}, nil, "testCmpID")
record := slog.Record{
Time: time.Now(),
Message: "Test message",
Level: slog.LevelInfo,
}
err := protohandler.Handle(context.Background(), record)
assert.NoError(t, err, "Handle should not return an error even when write fails")
assert.NotEmpty(t, protohandler.(*handler).CachedMessages(), "Cached messages should not be empty")
}
// TestEnabled tests that the handler enables logging based on level.
func TestEnabled(t *testing.T) {
handler := NewProtoHandler(io.Discard, nil, "testCmpID")
assert.True(t, handler.Enabled(context.Background(), slog.LevelInfo), "Logging should be enabled for LevelInfo")
assert.False(t, handler.Enabled(context.Background(), slog.LevelDebug), "Logging should be disabled for LevelDebug by default")
}
// TestPeriodicRetry stops retry after close.
func TestCloseStopsRetry(t *testing.T) {
mockWriter := io.Discard
handler := NewProtoHandler(mockWriter, nil, "testCmpID").(*handler)
time.Sleep(2 * time.Second)
err := handler.Close()
assert.NoError(t, err, "Close should not return an error")
time.Sleep(1 * time.Second) // Ensure no retry after close
}
// Utility function to retrieve cached messages.
func (h *handler) CachedMessages() [][]byte {
h.mutex.Lock()
defer h.mutex.Unlock()
return h.cachedMessages
}
+1 -1
View File
@@ -1,5 +1,5 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Package server contains the HTTP, gRPC and CoAP server implementation.
// Package server contains the gRPC server implementation.
package server
+19 -17
View File
@@ -18,6 +18,7 @@ import (
"math/big"
"net"
"os"
"strings"
"time"
"github.com/google/go-sev-guest/client"
@@ -205,26 +206,27 @@ func loadCertFile(certFile string) ([]byte, error) {
func loadX509KeyPair(certfile, keyfile string) (tls.Certificate, error) {
var cert, key []byte
var err error
if _, err = os.Stat(certfile); err == nil {
cert, err = os.ReadFile(certfile)
if err != nil {
return tls.Certificate{}, err
readFileOrData := func(input string) ([]byte, error) {
if len(input) < 1000 && !strings.Contains(input, "\n") {
data, err := os.ReadFile(input)
if err == nil {
return data, nil
}
}
} else if os.IsNotExist(err) {
cert = []byte(certfile)
} else {
return tls.Certificate{}, err
return []byte(input), nil
}
if _, err := os.Stat(keyfile); err == nil {
key, err = os.ReadFile(keyfile)
if err != nil {
return tls.Certificate{}, err
}
} else if os.IsNotExist(err) {
key = []byte(keyfile)
} else {
return tls.Certificate{}, err
cert, err = readFileOrData(certfile)
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to read cert: %v", err)
}
key, err = readFileOrData(keyfile)
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to read key: %v", err)
}
return tls.X509KeyPair(cert, key)
}
+250
View File
@@ -0,0 +1,250 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package grpc
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"log/slog"
"math/big"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
authmocks "github.com/ultravioletrs/cocos/agent/mocks"
"github.com/ultravioletrs/cocos/agent/quoteprovider/mocks"
"github.com/ultravioletrs/cocos/internal/server"
"google.golang.org/grpc"
"google.golang.org/grpc/test/bufconn"
)
const bufSize = 1024 * 1024
var lis *bufconn.Listener
func init() {
lis = bufconn.Listen(bufSize)
}
func TestNew(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
config := server.Config{
Host: "localhost",
Port: "50051",
}
logger := slog.Default()
qp := new(mocks.QuoteProvider)
authSvc := new(authmocks.Authenticator)
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc)
assert.NotNil(t, srv)
assert.IsType(t, &Server{}, srv)
}
func TestServerStart(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
config := server.Config{
Host: "localhost",
Port: "0",
}
buf := &ThreadSafeBuffer{}
logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug}))
qp := new(mocks.QuoteProvider)
authSvc := new(authmocks.Authenticator)
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc)
var wg sync.WaitGroup
wg.Add(1)
go func() {
wg.Done()
err := srv.Start()
assert.NoError(t, err)
}()
wg.Wait()
time.Sleep(100 * time.Millisecond)
cancel()
assert.Contains(t, buf.String(), "TestServer service gRPC server listening at localhost:0 without TLS")
}
func TestServerStartWithTLS(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cert, key, err := generateSelfSignedCert()
assert.NoError(t, err)
config := server.Config{
Host: "localhost",
Port: "0",
CertFile: string(cert),
KeyFile: string(key),
}
logBuffer := &ThreadSafeBuffer{}
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
qp := new(mocks.QuoteProvider)
authSvc := new(authmocks.Authenticator)
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc)
var wg sync.WaitGroup
wg.Add(1)
go func() {
wg.Done()
err := srv.Start()
assert.NoError(t, err)
}()
wg.Wait()
time.Sleep(200 * time.Millisecond)
cancel()
time.Sleep(200 * time.Millisecond)
logContent := logBuffer.String()
fmt.Println(logContent)
assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS")
}
func TestServerStartWithAttestedTLS(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
config := server.Config{
Host: "localhost",
Port: "0",
AttestedTLS: true,
}
logBuffer := &ThreadSafeBuffer{}
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
qp := new(mocks.QuoteProvider)
authSvc := new(authmocks.Authenticator)
qp.On("GetRawQuote", mock.Anything).Return([]byte("mock-quote"), nil)
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc)
var wg sync.WaitGroup
wg.Add(1)
go func() {
wg.Done()
err := srv.Start()
assert.NoError(t, err)
}()
wg.Wait()
time.Sleep(100 * time.Millisecond)
cancel()
time.Sleep(100 * time.Millisecond)
logContent := logBuffer.String()
assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with Attested TLS")
qp.AssertExpectations(t)
}
func TestServerStop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
config := server.Config{
Host: "localhost",
Port: "0",
}
buf := &ThreadSafeBuffer{}
logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug}))
qp := new(mocks.QuoteProvider)
authSvc := new(authmocks.Authenticator)
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc)
go func() {
err := srv.Start()
assert.NoError(t, err)
}()
time.Sleep(100 * time.Millisecond)
cancel()
time.Sleep(100 * time.Millisecond)
err := srv.Stop()
assert.NoError(t, err)
assert.Contains(t, buf.String(), "TestServer gRPC service shutdown at localhost:0")
}
func generateSelfSignedCert() ([]byte, []byte, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}
cert, err := generateSelfSignedCertFromKey(key)
if err != nil {
return nil, nil, err
}
return cert, pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}), nil
}
func generateSelfSignedCertFromKey(key *rsa.PrivateKey) ([]byte, error) {
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1, 0, 0),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}), nil
}
type ThreadSafeBuffer struct {
buffer strings.Builder
mu sync.Mutex
}
func (b *ThreadSafeBuffer) Write(p []byte) (n int, err error) {
b.mu.Lock()
defer b.mu.Unlock()
return b.buffer.Write(p)
}
func (b *ThreadSafeBuffer) String() string {
b.mu.Lock()
defer b.mu.Unlock()
return b.buffer.String()
}
+60
View File
@@ -0,0 +1,60 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
package mocks
import mock "github.com/stretchr/testify/mock"
// Server is an autogenerated mock type for the Server type
type Server struct {
mock.Mock
}
// Start provides a mock function with given fields:
func (_m *Server) Start() error {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Start")
}
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}
// Stop provides a mock function with given fields:
func (_m *Server) Stop() error {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Stop")
}
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}
// NewServer creates a new instance of Server. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewServer(t interface {
mock.TestingT
Cleanup(func())
}) *Server {
mock := &Server{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
+1
View File
@@ -11,6 +11,7 @@ import (
"syscall"
)
//go:generate mockery --name Server --output ./mocks --filename server.go
type Server interface {
Start() error
Stop() error
+138
View File
@@ -0,0 +1,138 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package server
import (
"context"
"errors"
"log/slog"
"os"
"syscall"
"testing"
"time"
"github.com/ultravioletrs/cocos/internal/server/mocks"
)
func TestStopAllServer(t *testing.T) {
server1 := new(mocks.Server)
server2 := new(mocks.Server)
server1.On("Stop").Return(nil)
server2.On("Stop").Return(errors.New("failed to stop"))
tests := []struct {
name string
servers []Server
expectedError bool
}{
{
name: "All servers stop successfully",
servers: []Server{
server1,
server1,
},
expectedError: false,
},
{
name: "One server fails to stop",
servers: []Server{
server1,
server2,
},
expectedError: true,
},
{
name: "No servers",
servers: []Server{},
expectedError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := stopAllServer(tt.servers...)
if (err != nil) != tt.expectedError {
t.Errorf("stopAllServer() error = %v, expectedError %v", err, tt.expectedError)
}
})
}
}
func TestStopHandler(t *testing.T) {
mockServer := new(mocks.Server)
mockServer.On("Stop").Return(nil)
tests := []struct {
name string
setupFunc func() (context.Context, context.CancelFunc, *slog.Logger, string, []Server)
triggerSignal bool
expectedError bool
expectCanceled bool
}{
{
name: "Graceful shutdown on signal",
setupFunc: func() (context.Context, context.CancelFunc, *slog.Logger, string, []Server) {
ctx, cancel := context.WithCancel(context.Background())
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
return ctx, cancel, logger, "test", []Server{mockServer}
},
triggerSignal: true,
expectedError: false,
expectCanceled: true,
},
{
name: "Context canceled",
setupFunc: func() (context.Context, context.CancelFunc, *slog.Logger, string, []Server) {
ctx, cancel := context.WithCancel(context.Background())
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()
return ctx, cancel, logger, "test", []Server{mockServer}
},
triggerSignal: false,
expectedError: false,
expectCanceled: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel, logger, svcName, servers := tt.setupFunc()
defer cancel()
errChan := make(chan error)
go func() {
errChan <- StopHandler(ctx, cancel, logger, svcName, servers...)
}()
if tt.triggerSignal {
// Simulate SIGINT
go func() {
time.Sleep(100 * time.Millisecond)
err := syscall.Kill(syscall.Getpid(), syscall.SIGINT)
if err != nil {
t.Errorf("failed to send signal: %v", err)
}
}()
}
select {
case err := <-errChan:
if (err != nil) != tt.expectedError {
t.Errorf("StopHandler() error = %v, expectedError %v", err, tt.expectedError)
}
case <-time.After(2 * time.Second):
t.Error("StopHandler() timed out")
}
if tt.expectCanceled {
select {
case <-ctx.Done():
// Context was canceled as expected
default:
t.Error("Context was not canceled")
}
}
})
}
}
+7 -10
View File
@@ -38,7 +38,7 @@ type AckWriter struct {
wg sync.WaitGroup
}
func NewAckWriter(conn net.Conn) *AckWriter {
func NewAckWriter(conn net.Conn) io.WriteCloser {
aw := &AckWriter{
conn: conn,
pendingMessages: make(chan *Message, maxConcurrent),
@@ -52,14 +52,6 @@ func NewAckWriter(conn net.Conn) *AckWriter {
return aw
}
func (aw *AckWriter) WriteProto(msg proto.Message) (int, error) {
data, err := proto.Marshal(msg)
if err != nil {
return 0, fmt.Errorf("error marshaling protobuf message: %v", err)
}
return aw.Write(data)
}
func (aw *AckWriter) Write(p []byte) (int, error) {
if len(p) > maxMessageSize {
return 0, fmt.Errorf("message size exceeds maximum allowed size of %d bytes", maxMessageSize)
@@ -176,11 +168,16 @@ func (aw *AckWriter) Close() error {
return aw.conn.Close()
}
type Reader interface {
Read() ([]byte, error)
ReadProto(msg proto.Message) error
}
type AckReader struct {
conn net.Conn
}
func NewAckReader(conn net.Conn) *AckReader {
func NewAckReader(conn net.Conn) Reader {
return &AckReader{
conn: conn,
}
+205
View File
@@ -0,0 +1,205 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package vsock
import (
"bytes"
"encoding/binary"
"errors"
"io"
"net"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/ultravioletrs/cocos/pkg/manager"
"google.golang.org/protobuf/proto"
)
// MockConn implements net.Conn for testing purposes.
type MockConn struct {
ReadData []byte
WrittenData []byte
ReadErr error
WriteErr error
closed bool
mu sync.Mutex
}
func (m *MockConn) Read(b []byte) (n int, err error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.closed {
return 0, io.EOF
}
if len(m.ReadData) == 0 {
return 0, io.EOF // Ensure we handle this case more predictably
}
if m.ReadErr != nil {
return 0, m.ReadErr
}
n = copy(b, m.ReadData)
m.ReadData = m.ReadData[n:]
return n, nil
}
func (m *MockConn) Write(b []byte) (n int, err error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.closed {
return 0, errors.New("connection closed")
}
if m.WriteErr != nil {
return 0, m.WriteErr
}
m.WrittenData = append(m.WrittenData, b...)
return len(b), nil
}
func (m *MockConn) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
m.closed = true
return nil
}
// Implement other net.Conn methods with empty implementations.
func (m *MockConn) LocalAddr() net.Addr { return nil }
func (m *MockConn) RemoteAddr() net.Addr { return nil }
func (m *MockConn) SetDeadline(t time.Time) error { return nil }
func (m *MockConn) SetReadDeadline(t time.Time) error { return nil }
func (m *MockConn) SetWriteDeadline(t time.Time) error { return nil }
func TestAckReader_Read(t *testing.T) {
tests := []struct {
name string
data []byte
wantErr bool
}{
{"Valid message", []byte("Hello, World!"), false},
{"Empty message", []byte{}, false},
{"Message at max size", make([]byte, maxMessageSize), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockConn := &MockConn{}
ar := NewAckReader(mockConn)
// Prepare mock data
messageID := uint32(1)
messageLen := uint32(len(tt.data))
mockData := make([]byte, 8+len(tt.data))
binary.LittleEndian.PutUint32(mockData[:4], messageID)
binary.LittleEndian.PutUint32(mockData[4:8], messageLen)
copy(mockData[8:], tt.data)
mockConn.ReadData = mockData
data, err := ar.Read()
if (err != nil) != tt.wantErr {
t.Errorf("AckReader.Read() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if !bytes.Equal(data, tt.data) {
t.Errorf("AckReader.Read() got = %v, want %v", data, tt.data)
}
// Check if ACK was sent
if len(mockConn.WrittenData) != 4 {
t.Errorf("AckReader.Read() did not send ACK")
} else {
ackID := binary.LittleEndian.Uint32(mockConn.WrittenData)
if ackID != messageID {
t.Errorf("AckReader.Read() sent wrong ACK ID, got %d, want %d", ackID, messageID)
}
}
}
})
}
}
func TestAckReader_ReadProto(t *testing.T) {
tests := []struct {
name string
msg *manager.ClientStreamMessage
wantErr bool
}{
{"Valid proto message", &manager.ClientStreamMessage{}, false},
{"Empty proto message", &manager.ClientStreamMessage{}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockConn := &MockConn{}
ar := NewAckReader(mockConn)
// Prepare mock data
protoData, _ := proto.Marshal(tt.msg)
messageID := uint32(1)
messageLen := uint32(len(protoData))
mockData := make([]byte, 8+len(protoData))
binary.LittleEndian.PutUint32(mockData[:4], messageID)
binary.LittleEndian.PutUint32(mockData[4:8], messageLen)
copy(mockData[8:], protoData)
mockConn.ReadData = mockData
receivedMsg := &manager.ClientStreamMessage{}
err := ar.ReadProto(receivedMsg)
if (err != nil) != tt.wantErr {
t.Errorf("AckReader.ReadProto() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if receivedMsg.Message != tt.msg.Message {
t.Errorf("AckReader.ReadProto() got = %v, want %v", receivedMsg, tt.msg)
}
// Check if ACK was sent
if len(mockConn.WrittenData) != 4 {
t.Errorf("AckReader.ReadProto() did not send ACK")
} else {
ackID := binary.LittleEndian.Uint32(mockConn.WrittenData)
if ackID != messageID {
t.Errorf("AckReader.ReadProto() sent wrong ACK ID, got %d, want %d", ackID, messageID)
}
}
}
})
}
}
func TestNewAckWriter(t *testing.T) {
mockConn := &MockConn{}
writer := NewAckWriter(mockConn)
if _, ok := writer.(io.Writer); !ok {
t.Errorf("NewAckWriter() did not return an io.Writer")
}
}
func TestNewAckReader(t *testing.T) {
mockConn := &MockConn{}
reader := NewAckReader(mockConn)
assert.NotNil(t, reader)
}
func TestAckWriter_Close(t *testing.T) {
mockConn := &MockConn{}
aw := NewAckWriter(mockConn)
err := aw.Close()
if err != nil {
t.Errorf("AckWriter.Close() error = %v, wantErr %v", err, nil)
}
if !mockConn.closed {
t.Errorf("AckWriter.Close() did not close the connection")
}
}
+106
View File
@@ -0,0 +1,106 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package internal
import (
"os"
"path/filepath"
"testing"
)
func TestZipDirectoryToMemory(t *testing.T) {
tempDir, err := os.MkdirTemp("", "zip_test")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
testFiles := map[string]string{
"file1.txt": "Content of file 1",
"file2.txt": "Content of file 2",
"subdir/file3.txt": "Content of file 3 in subdirectory",
}
for path, content := range testFiles {
fullPath := filepath.Join(tempDir, path)
err := os.MkdirAll(filepath.Dir(fullPath), 0o755)
if err != nil {
t.Fatalf("Failed to create directory: %v", err)
}
err = os.WriteFile(fullPath, []byte(content), 0o644)
if err != nil {
t.Fatalf("Failed to write test file: %v", err)
}
}
zipData, err := ZipDirectoryToMemory(tempDir)
if err != nil {
t.Fatalf("ZipDirectoryToMemory failed: %v", err)
}
if len(zipData) == 0 {
t.Error("Zip data is empty")
}
unzipDir, err := os.MkdirTemp("", "unzip_test")
if err != nil {
t.Fatalf("Failed to create temp directory for unzip: %v", err)
}
defer os.RemoveAll(unzipDir)
err = UnzipFromMemory(zipData, unzipDir)
if err != nil {
t.Fatalf("UnzipFromMemory failed: %v", err)
}
for path, expectedContent := range testFiles {
fullPath := filepath.Join(unzipDir, path)
content, err := os.ReadFile(fullPath)
if err != nil {
t.Errorf("Failed to read unzipped file %s: %v", path, err)
continue
}
if string(content) != expectedContent {
t.Errorf("Content mismatch for file %s. Expected: %s, Got: %s", path, expectedContent, string(content))
}
}
}
func TestZipDirectoryToMemory_EmptyDirectory(t *testing.T) {
tempDir, err := os.MkdirTemp("", "empty_zip_test")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
zipData, err := ZipDirectoryToMemory(tempDir)
if err != nil {
t.Fatalf("ZipDirectoryToMemory failed on empty directory: %v", err)
}
if len(zipData) == 0 {
t.Error("Zip data is empty for an empty directory")
}
}
func TestUnzipFromMemory_InvalidZipData(t *testing.T) {
invalidZipData := []byte("This is not a valid zip file")
tempDir, err := os.MkdirTemp("", "invalid_unzip_test")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
err = UnzipFromMemory(invalidZipData, tempDir)
if err == nil {
t.Error("UnzipFromMemory should fail with invalid zip data")
}
}
func TestZipDirectoryToMemory_NonExistentDirectory(t *testing.T) {
nonExistentDir := "/path/to/non/existent/directory"
_, err := ZipDirectoryToMemory(nonExistentDir)
if err == nil {
t.Error("ZipDirectoryToMemory should fail with non-existent directory")
}
}
-2
View File
@@ -146,8 +146,6 @@ func TestHandleConnection(t *testing.T) {
receivedMsg := <-ms.eventsChan
assert.Equal(t, msg.GetAgentEvent().EventType, receivedMsg.GetAgentEvent().EventType)
assert.Equal(t, msg.GetAgentEvent().ComputationId, receivedMsg.GetAgentEvent().ComputationId)
mockConn.AssertExpectations(t)
}
func TestReportBrokenConnection(t *testing.T) {
+5 -4
View File
@@ -22,6 +22,7 @@ import (
"github.com/ultravioletrs/cocos/pkg/sdk"
"golang.org/x/crypto/sha3"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
)
@@ -33,7 +34,7 @@ var (
)
func TestAlgo(t *testing.T) {
conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
conn, err := grpc.NewClient("passthrough://bufnet", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(bufDialer))
if err != nil {
t.Fatalf("Failed to dial bufnet: %v", err)
}
@@ -120,7 +121,7 @@ func TestAlgo(t *testing.T) {
}
func TestData(t *testing.T) {
conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
conn, err := grpc.NewClient("passthrough://bufnet", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(bufDialer))
if err != nil {
t.Fatalf("Failed to dial bufnet: %v", err)
}
@@ -217,7 +218,7 @@ func TestData(t *testing.T) {
}
func TestResult(t *testing.T) {
conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
conn, err := grpc.NewClient("passthrough://bufnet", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(bufDialer))
if err != nil {
t.Fatalf("Failed to dial bufnet: %v", err)
}
@@ -318,7 +319,7 @@ func TestAttestation(t *testing.T) {
0x05, 0x06, 0x07, 0x08,
}
conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
conn, err := grpc.NewClient("passthrough://bufnet", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(bufDialer))
if err != nil {
t.Fatalf("Failed to dial bufnet: %v", err)
}