mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
NOISSUE - Add internal tests (#266)
* add internal tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix linter Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix race conditions Signed-off-by: Sammy Oina <sammyoina@gmail.com> * remove all races Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
db7f3c7a4b
commit
18aa8ba785
@@ -1,62 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Set your default values for sudo and sev
|
||||
sudo_option=false
|
||||
sev_option=false
|
||||
|
||||
# Parse command line arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
key="$1"
|
||||
|
||||
case $key in
|
||||
--sudo)
|
||||
sudo_option=true
|
||||
shift
|
||||
;;
|
||||
--sev)
|
||||
sev_option=true
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $key"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
build_qemu_command() {
|
||||
local qemu_command="/usr/bin/qemu-system-x86_64 -enable-kvm -machine q35 -cpu EPYC -smp 4,maxcpus=64 -m 2048M,slots=5,maxmem=30G -drive if=pflash,format=raw,unit=0,file=$MANAGER_QEMU_OVMF_CODE_FILE,readonly=on -drive if=pflash,format=raw,unit=1,file=img/OVMF_VARS.fd -device virtio-scsi-pci,id=scsi,disable-legacy=on,iommu_platform=true -drive file=img/focal-server-cloudimg-amd64.img,if=none,id=disk0,format=qcow2 -device scsi-hd,drive=disk0 -netdev user,id=vmnic,hostfwd=tcp::2222-:22,hostfwd=tcp::9301-:9031,hostfwd=tcp::7020-:7002 -device virtio-net-pci,disable-legacy=on,iommu_platform=true,netdev=vmnic,romfile= -nographic -monitor pty"
|
||||
|
||||
if [ "$sev_option" = true ]; then
|
||||
qemu_command="$qemu_command -object sev-guest,id=sev0,cbitpos=51,reduced-phys-bits=1 -machine memory-encryption=sev0"
|
||||
fi
|
||||
|
||||
echo "$qemu_command"
|
||||
}
|
||||
|
||||
if [ ! -f "img/OVMF_VARS.fd" ]; then
|
||||
cp "$MANAGER_QEMU_OVMF_VARS_FILE" "img/OVMF_VARS.fd"
|
||||
echo "Copied $MANAGER_QEMU_OVMF_VARS_FILE to img/OVMF_VARS.fd"
|
||||
else
|
||||
echo "img/OVMF_VARS.fd already exists. No need to copy."
|
||||
fi
|
||||
|
||||
echo "Launching VM ..."
|
||||
|
||||
qemu_command=$(build_qemu_command)
|
||||
echo "$qemu_command"
|
||||
|
||||
echo "Mapping CTRL-C to CTRL-]"
|
||||
stty intr ^]
|
||||
|
||||
if [ "$sudo_option" = true ]; then
|
||||
# Split the command and arguments into an array; << operator is known as a "here string"
|
||||
IFS=" " read -r -a qemu_command_array <<< "$qemu_command"
|
||||
# Treat each element in the array as a separate word, preserving spaces within each element
|
||||
sudo "${qemu_command_array[@]}"
|
||||
else
|
||||
$qemu_command
|
||||
fi
|
||||
|
||||
# Restore the mapping
|
||||
stty intr ^c
|
||||
@@ -1,62 +0,0 @@
|
||||
<domain type="kvm">
|
||||
<name>QEmu-alpine-standard-x86_64</name>
|
||||
<uuid>c7a5fdbd-cdaf-9455-926a-d65c16db1809</uuid>
|
||||
<metadata>
|
||||
<libosinfo:libosinfo xmlns:libosinfo="http://libosinfo.org/xmlns/libvirt/domain/1.0">
|
||||
<libosinfo:os id="http://alpinelinux.org/alpinelinux/3.15"/>
|
||||
</libosinfo:libosinfo>
|
||||
</metadata>
|
||||
<memory unit="KiB">4194304</memory>
|
||||
<currentMemory unit="KiB">4194304</currentMemory>
|
||||
<vcpu placement="static">1</vcpu>
|
||||
<os>
|
||||
<type arch="x86_64" machine="q35">hvm</type>
|
||||
<bootmenu enable="yes"/>
|
||||
<loader readonly="yes" type="pflash">/usr/share/OVMF/OVMF_CODE.fd</loader>
|
||||
<nvram template='/usr/share/OVMF/OVMF_VARS.fd'>./img/OVMF_VARS.fd</nvram>
|
||||
<!-- <boot dev='hd'/> -->
|
||||
</os>
|
||||
<features>
|
||||
<acpi/>
|
||||
<apic/>
|
||||
<vmport state="off"/>
|
||||
</features>
|
||||
<cpu mode="host-passthrough" check="none" migratable="on"/>
|
||||
<clock offset="utc">
|
||||
<timer name="rtc" tickpolicy="catchup"/>
|
||||
<timer name="pit" tickpolicy="delay"/>
|
||||
<timer name="hpet" present="no"/>
|
||||
</clock>
|
||||
<on_poweroff>destroy</on_poweroff>
|
||||
<on_reboot>restart</on_reboot>
|
||||
<on_crash>destroy</on_crash>
|
||||
<pm>
|
||||
<suspend-to-mem enabled="no"/>
|
||||
<suspend-to-disk enabled="no"/>
|
||||
</pm>
|
||||
<devices>
|
||||
<emulator>/usr/bin/qemu-system-x86_64</emulator>
|
||||
<disk type="file" device="disk">
|
||||
<driver name="qemu" type="qcow2" discard="unmap"/>
|
||||
<source file="./img/focal-server-cloudimg-amd64.qcow2"/>
|
||||
<target dev="vda" bus="virtio"/>
|
||||
<address type="pci" domain="0x0000" bus="0x04" slot="0x00" function="0x0"/>
|
||||
<boot order="1"/>
|
||||
</disk>
|
||||
<graphics type="spice" autoport="yes">
|
||||
<listen type="address"/>
|
||||
<image compression="off"/>
|
||||
</graphics>
|
||||
<video>
|
||||
<model type="qxl" ram="65536" vram="65536" vgamem="16384" heads="1" primary="yes"/>
|
||||
<address type="pci" domain="0x0000" bus="0x00" slot="0x01" function="0x0"/>
|
||||
</video>
|
||||
<interface type="network">
|
||||
<mac address="52:54:00:03:7b:5f"/>
|
||||
<source network="default"/>
|
||||
<model type="virtio"/>
|
||||
<address type="pci" domain="0x0000" bus="0x01" slot="0x00" function="0x0"/>
|
||||
</interface>
|
||||
</devices>
|
||||
|
||||
</domain>
|
||||
@@ -1,6 +0,0 @@
|
||||
<pool type="dir">
|
||||
<name>virtimages</name>
|
||||
<target>
|
||||
<path>./img</path>
|
||||
</target>
|
||||
</pool>
|
||||
@@ -1,15 +0,0 @@
|
||||
<volume>
|
||||
<name>boot.img</name>
|
||||
<allocation>0</allocation>
|
||||
<capacity unit="G">1</capacity>
|
||||
<target>
|
||||
<format type="qcow2"/>
|
||||
<path>./img/boot.img</path>
|
||||
<permissions>
|
||||
<owner>107</owner>
|
||||
<group>107</group>
|
||||
<mode>0744</mode>
|
||||
<label>virt_image_t</label>
|
||||
</permissions>
|
||||
</target>
|
||||
</volume>
|
||||
@@ -1,64 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package internal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ExeShCmdStdout executes a shell command capturing the standard output.
|
||||
func ExeShCmdStdout(command string, args ...string) (string, error) {
|
||||
var stdoutBuf, stderrBuf bytes.Buffer
|
||||
|
||||
cmd := exec.Command(command, args...)
|
||||
|
||||
// Capture stdout and stderr using buffers
|
||||
cmd.Stdout = io.MultiWriter(&stdoutBuf, os.Stdout)
|
||||
cmd.Stderr = io.MultiWriter(&stderrBuf, os.Stderr)
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error executing command '%s': %s", cmd.String(), err)
|
||||
}
|
||||
|
||||
return stdoutBuf.String(), nil
|
||||
}
|
||||
|
||||
// ExtractCmdAndArgs extracts the command and its arguments from the output string.
|
||||
func ExtractCmdAndArgs(cmdLine string, sudo bool) (string, []string) {
|
||||
lines := strings.Split(cmdLine, "\n")
|
||||
if len(lines) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
parts := strings.Fields(lines[0])
|
||||
if len(parts) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if sudo {
|
||||
parts = append([]string{"sudo"}, parts...)
|
||||
}
|
||||
|
||||
cmd := parts[0]
|
||||
args := parts[1:]
|
||||
|
||||
return cmd, args
|
||||
}
|
||||
|
||||
// RunCmdOutput runs the specified command and returns its standard output as a string.
|
||||
func RunCmdOutput(command string, args ...string) (string, error) {
|
||||
cmd := exec.Command(command, args...)
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error executing command '%s': %s", cmd.String(), err)
|
||||
}
|
||||
|
||||
return string(output), nil
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package internal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCopyFile(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "copyfile_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
srcPath := filepath.Join(tempDir, "source.txt")
|
||||
content := []byte("Hello, World!")
|
||||
if err := os.WriteFile(srcPath, content, 0o644); err != nil {
|
||||
t.Fatalf("Failed to create source file: %v", err)
|
||||
}
|
||||
|
||||
dstPath := filepath.Join(tempDir, "destination.txt")
|
||||
if err := CopyFile(srcPath, dstPath); err != nil {
|
||||
t.Fatalf("CopyFile failed: %v", err)
|
||||
}
|
||||
|
||||
copiedContent, err := os.ReadFile(dstPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read destination file: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(content, copiedContent) {
|
||||
t.Errorf("Copied content does not match original. Got %s, want %s", copiedContent, content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteFilesInDir(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "deletefiles_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
filenames := []string{"file1.txt", "file2.txt", "file3.txt"}
|
||||
for _, filename := range filenames {
|
||||
filepath := filepath.Join(tempDir, filename)
|
||||
if err := os.WriteFile(filepath, []byte("test"), 0o644); err != nil {
|
||||
t.Fatalf("Failed to create test file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := DeleteFilesInDir(tempDir); err != nil {
|
||||
t.Fatalf("DeleteFilesInDir failed: %v", err)
|
||||
}
|
||||
|
||||
remainingFiles, err := os.ReadDir(tempDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read directory: %v", err)
|
||||
}
|
||||
|
||||
if len(remainingFiles) != 0 {
|
||||
t.Errorf("Directory not empty after deletion. %d files remain", len(remainingFiles))
|
||||
}
|
||||
}
|
||||
|
||||
func TestChecksum(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "checksum_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
filePath := filepath.Join(tempDir, "test.txt")
|
||||
content := []byte("Hello, World!")
|
||||
if err := os.WriteFile(filePath, content, 0o644); err != nil {
|
||||
t.Fatalf("Failed to create test file: %v", err)
|
||||
}
|
||||
|
||||
checksum, err := Checksum(filePath)
|
||||
if err != nil {
|
||||
t.Fatalf("Checksum failed: %v", err)
|
||||
}
|
||||
|
||||
expectedChecksum, _ := hex.DecodeString("1af17a664e3fa8e419b8ba05c2a173169df76162a5a286e0c405b460d478f7ef")
|
||||
if !bytes.Equal(checksum, expectedChecksum) {
|
||||
t.Errorf("File checksum mismatch. Got %x, want %x", checksum, expectedChecksum)
|
||||
}
|
||||
|
||||
dirPath := filepath.Join(tempDir, "testdir")
|
||||
if err := os.Mkdir(dirPath, 0o755); err != nil {
|
||||
t.Fatalf("Failed to create test directory: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dirPath, "file1.txt"), []byte("File 1"), 0o644); err != nil {
|
||||
t.Fatalf("Failed to create test file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dirPath, "file2.txt"), []byte("File 2"), 0o644); err != nil {
|
||||
t.Fatalf("Failed to create test file: %v", err)
|
||||
}
|
||||
|
||||
dirChecksum, err := Checksum(dirPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Directory Checksum failed: %v", err)
|
||||
}
|
||||
|
||||
if len(dirChecksum) != 32 { // SHA3-256 produces a 32-byte hash
|
||||
t.Errorf("Unexpected directory checksum length. Got %d bytes, want 32 bytes", len(dirChecksum))
|
||||
}
|
||||
}
|
||||
|
||||
func TestChecksumHex(t *testing.T) {
|
||||
tempFile, err := os.CreateTemp("", "checksumhex_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
content := []byte("Hello, World!")
|
||||
if _, err := tempFile.Write(content); err != nil {
|
||||
t.Fatalf("Failed to write to test file: %v", err)
|
||||
}
|
||||
tempFile.Close()
|
||||
|
||||
checksumHex, err := ChecksumHex(tempFile.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("ChecksumHex failed: %v", err)
|
||||
}
|
||||
|
||||
expectedChecksumHex := "1af17a664e3fa8e419b8ba05c2a173169df76162a5a286e0c405b460d478f7ef"
|
||||
if checksumHex != expectedChecksumHex {
|
||||
t.Errorf("ChecksumHex mismatch. Got %s, want %s", checksumHex, expectedChecksumHex)
|
||||
}
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package libvirt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/digitalocean/go-libvirt"
|
||||
)
|
||||
|
||||
func Connect(logger *slog.Logger) *libvirt.Libvirt {
|
||||
// This dials libvirt on the local machine, but you can substitute the first
|
||||
// two parameters with "tcp", "<ip address>:<port>" to connect to libvirt on
|
||||
// a remote machine.
|
||||
c, err := net.DialTimeout("unix", "/var/run/libvirt/libvirt-sock", 2*time.Second)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to dial libvirt: %v", err)
|
||||
}
|
||||
|
||||
l := libvirt.New(c)
|
||||
if err := l.Connect(); err != nil {
|
||||
log.Fatalf("failed to connect: %v", err)
|
||||
}
|
||||
|
||||
v, err := l.Version()
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to retrieve libvirt version: %v", err))
|
||||
}
|
||||
logger.Info(fmt.Sprintf("Retrieved libvirt version: %s", v))
|
||||
|
||||
domains, err := l.Domains()
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to retrieve domains: %v", err))
|
||||
}
|
||||
|
||||
for _, d := range domains {
|
||||
logger.Info(fmt.Sprintf("%d\t%s\t%x\n", d.ID, d.Name, d.UUID))
|
||||
}
|
||||
|
||||
return l
|
||||
}
|
||||
@@ -1,167 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package libvirt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
golibvirt "github.com/digitalocean/go-libvirt"
|
||||
)
|
||||
|
||||
var re = regexp.MustCompile(`'([^']*)'`)
|
||||
|
||||
func CreateDomain(ctx context.Context, libvirt *golibvirt.Libvirt, poolXML, volXML, domXML string) (string, error) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
poolStr, err := readXMLFile(poolXML, "pool.xml")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
poolStr = replaceSubstring(poolStr, "./", wd+"/")
|
||||
|
||||
volStr, err := readXMLFile(volXML, "vol.xml")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
volStr = replaceSubstring(volStr, "./", wd+"/")
|
||||
|
||||
domStr, err := readXMLFile(domXML, "dom.xml")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
domStr = replaceSubstring(domStr, "./", wd+"/")
|
||||
|
||||
dom, err := createDomain(libvirt, poolStr, volStr, domStr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create domain: %s", err)
|
||||
}
|
||||
|
||||
return dom.Name, nil
|
||||
}
|
||||
|
||||
func createDomain(libvirtConn *golibvirt.Libvirt, poolXML, volXML, domXML string) (golibvirt.Domain, error) {
|
||||
pool, err := libvirtConn.StoragePoolCreateXML(poolXML, 0)
|
||||
_ = pool
|
||||
if err != nil {
|
||||
lvErr := err.(golibvirt.Error)
|
||||
if lvErr.Code == 9 {
|
||||
name, err := entityName(lvErr.Message)
|
||||
if err != nil {
|
||||
return golibvirt.Domain{}, err
|
||||
}
|
||||
pool, err = libvirtConn.StoragePoolLookupByName(name)
|
||||
if err != nil {
|
||||
return golibvirt.Domain{}, err
|
||||
}
|
||||
|
||||
goto pool_exists
|
||||
}
|
||||
|
||||
return golibvirt.Domain{}, err
|
||||
}
|
||||
pool_exists:
|
||||
|
||||
_, err = libvirtConn.StorageVolCreateXML(pool, volXML, 0)
|
||||
if err != nil {
|
||||
lvErr := err.(golibvirt.Error)
|
||||
if lvErr.Code == 90 {
|
||||
name, err := entityName(lvErr.Message)
|
||||
if err != nil {
|
||||
return golibvirt.Domain{}, err
|
||||
}
|
||||
_, err = libvirtConn.StorageVolLookupByName(pool, name)
|
||||
if err != nil {
|
||||
return golibvirt.Domain{}, err
|
||||
}
|
||||
|
||||
goto vol_exists
|
||||
}
|
||||
|
||||
return golibvirt.Domain{}, err
|
||||
}
|
||||
|
||||
vol_exists:
|
||||
|
||||
dom, err := libvirtConn.DomainDefineXMLFlags(domXML, 0)
|
||||
if err != nil {
|
||||
return golibvirt.Domain{}, err
|
||||
}
|
||||
|
||||
err = libvirtConn.DomainCreate(dom)
|
||||
if err != nil {
|
||||
lvErr := err.(golibvirt.Error)
|
||||
if lvErr.Code == 55 {
|
||||
return dom, nil
|
||||
}
|
||||
|
||||
return golibvirt.Domain{}, err
|
||||
}
|
||||
|
||||
// extra flags; not used yet, so callers should always pass 0
|
||||
current, err := libvirtConn.DomainSnapshotCurrent(dom, 0)
|
||||
if err != nil {
|
||||
lvErr := err.(golibvirt.Error)
|
||||
if lvErr.Code == 72 {
|
||||
return dom, nil
|
||||
}
|
||||
|
||||
return golibvirt.Domain{}, err
|
||||
}
|
||||
|
||||
err = libvirtConn.DomainRevertToSnapshot(current, uint32(golibvirt.DomainSnapshotRevertRunning))
|
||||
if err != nil {
|
||||
return golibvirt.Domain{}, err
|
||||
}
|
||||
|
||||
return dom, nil
|
||||
}
|
||||
|
||||
func entityName(msg string) (string, error) {
|
||||
match := re.FindStringSubmatch(msg)
|
||||
if len(match) < 1 {
|
||||
return "", errors.New("entity not found")
|
||||
}
|
||||
|
||||
return match[1], nil
|
||||
}
|
||||
|
||||
func readXMLFile(filename, defaultFilename string) (string, error) {
|
||||
if filename == "" {
|
||||
filename = "./xml/" + defaultFilename
|
||||
}
|
||||
|
||||
xmlBytes, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read XML file: %s", err)
|
||||
}
|
||||
|
||||
return string(xmlBytes), nil
|
||||
}
|
||||
|
||||
func replaceSubstring(xml, substring, replacement string) string {
|
||||
// Split the file text into lines
|
||||
lines := strings.Split(xml, "\n")
|
||||
|
||||
// Create a variable to hold the resulting string
|
||||
var result strings.Builder
|
||||
|
||||
// Iterate over each line
|
||||
for _, line := range lines {
|
||||
// Replace the substring with the replacement
|
||||
newLine := strings.ReplaceAll(line, substring, replacement)
|
||||
|
||||
// Append the modified line to the resulting string
|
||||
result.WriteString(newLine)
|
||||
result.WriteString("\n")
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
@@ -27,6 +27,8 @@ type handler struct {
|
||||
stopRetry chan struct{}
|
||||
}
|
||||
|
||||
//go:generate mockery --name io.Writer --output ./mocks --filename io_writer.go
|
||||
|
||||
func NewProtoHandler(conn io.Writer, opts *slog.HandlerOptions, cmpID string) slog.Handler {
|
||||
if opts == nil {
|
||||
opts = &slog.HandlerOptions{}
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type failedWriter struct{}
|
||||
|
||||
func (f *failedWriter) Write(p []byte) (n int, err error) {
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
// TestNewProtoHandler tests the initialization of the ProtoHandler.
|
||||
func TestNewProtoHandler(t *testing.T) {
|
||||
handler := NewProtoHandler(io.Discard, nil, "testCmpID")
|
||||
|
||||
assert.NotNil(t, handler, "Handler should not be nil")
|
||||
}
|
||||
|
||||
// TestHandleMessageSuccess tests the handling of a message when the write succeeds.
|
||||
func TestHandleMessageSuccess(t *testing.T) {
|
||||
handler := NewProtoHandler(io.Discard, nil, "testCmpID")
|
||||
record := slog.Record{
|
||||
Time: time.Now(),
|
||||
Message: "Test message",
|
||||
Level: slog.LevelInfo,
|
||||
}
|
||||
|
||||
err := handler.Handle(context.Background(), record)
|
||||
|
||||
assert.NoError(t, err, "Handle should not return an error")
|
||||
}
|
||||
|
||||
// TestHandleMessageFailure tests the caching mechanism when the write fails.
|
||||
func TestHandleMessageFailure(t *testing.T) {
|
||||
protohandler := NewProtoHandler(&failedWriter{}, nil, "testCmpID")
|
||||
record := slog.Record{
|
||||
Time: time.Now(),
|
||||
Message: "Test message",
|
||||
Level: slog.LevelInfo,
|
||||
}
|
||||
|
||||
err := protohandler.Handle(context.Background(), record)
|
||||
|
||||
assert.NoError(t, err, "Handle should not return an error even when write fails")
|
||||
assert.NotEmpty(t, protohandler.(*handler).CachedMessages(), "Cached messages should not be empty")
|
||||
}
|
||||
|
||||
// TestEnabled tests that the handler enables logging based on level.
|
||||
func TestEnabled(t *testing.T) {
|
||||
handler := NewProtoHandler(io.Discard, nil, "testCmpID")
|
||||
|
||||
assert.True(t, handler.Enabled(context.Background(), slog.LevelInfo), "Logging should be enabled for LevelInfo")
|
||||
assert.False(t, handler.Enabled(context.Background(), slog.LevelDebug), "Logging should be disabled for LevelDebug by default")
|
||||
}
|
||||
|
||||
// TestPeriodicRetry stops retry after close.
|
||||
func TestCloseStopsRetry(t *testing.T) {
|
||||
mockWriter := io.Discard
|
||||
|
||||
handler := NewProtoHandler(mockWriter, nil, "testCmpID").(*handler)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
err := handler.Close()
|
||||
|
||||
assert.NoError(t, err, "Close should not return an error")
|
||||
time.Sleep(1 * time.Second) // Ensure no retry after close
|
||||
}
|
||||
|
||||
// Utility function to retrieve cached messages.
|
||||
func (h *handler) CachedMessages() [][]byte {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
return h.cachedMessages
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package server contains the HTTP, gRPC and CoAP server implementation.
|
||||
// Package server contains the gRPC server implementation.
|
||||
package server
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-sev-guest/client"
|
||||
@@ -205,26 +206,27 @@ func loadCertFile(certFile string) ([]byte, error) {
|
||||
func loadX509KeyPair(certfile, keyfile string) (tls.Certificate, error) {
|
||||
var cert, key []byte
|
||||
var err error
|
||||
if _, err = os.Stat(certfile); err == nil {
|
||||
cert, err = os.ReadFile(certfile)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
|
||||
readFileOrData := func(input string) ([]byte, error) {
|
||||
if len(input) < 1000 && !strings.Contains(input, "\n") {
|
||||
data, err := os.ReadFile(input)
|
||||
if err == nil {
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
} else if os.IsNotExist(err) {
|
||||
cert = []byte(certfile)
|
||||
} else {
|
||||
return tls.Certificate{}, err
|
||||
return []byte(input), nil
|
||||
}
|
||||
if _, err := os.Stat(keyfile); err == nil {
|
||||
key, err = os.ReadFile(keyfile)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
} else if os.IsNotExist(err) {
|
||||
key = []byte(keyfile)
|
||||
} else {
|
||||
return tls.Certificate{}, err
|
||||
|
||||
cert, err = readFileOrData(certfile)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("failed to read cert: %v", err)
|
||||
}
|
||||
|
||||
key, err = readFileOrData(keyfile)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("failed to read key: %v", err)
|
||||
}
|
||||
|
||||
return tls.X509KeyPair(cert, key)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,250 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/big"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
authmocks "github.com/ultravioletrs/cocos/agent/mocks"
|
||||
"github.com/ultravioletrs/cocos/agent/quoteprovider/mocks"
|
||||
"github.com/ultravioletrs/cocos/internal/server"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
)
|
||||
|
||||
const bufSize = 1024 * 1024
|
||||
|
||||
var lis *bufconn.Listener
|
||||
|
||||
func init() {
|
||||
lis = bufconn.Listen(bufSize)
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
config := server.Config{
|
||||
Host: "localhost",
|
||||
Port: "50051",
|
||||
}
|
||||
logger := slog.Default()
|
||||
qp := new(mocks.QuoteProvider)
|
||||
authSvc := new(authmocks.Authenticator)
|
||||
|
||||
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc)
|
||||
|
||||
assert.NotNil(t, srv)
|
||||
assert.IsType(t, &Server{}, srv)
|
||||
}
|
||||
|
||||
func TestServerStart(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
config := server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
}
|
||||
buf := &ThreadSafeBuffer{}
|
||||
logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
qp := new(mocks.QuoteProvider)
|
||||
authSvc := new(authmocks.Authenticator)
|
||||
|
||||
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
wg.Done()
|
||||
err := srv.Start()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
cancel()
|
||||
|
||||
assert.Contains(t, buf.String(), "TestServer service gRPC server listening at localhost:0 without TLS")
|
||||
}
|
||||
|
||||
func TestServerStartWithTLS(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
cert, key, err := generateSelfSignedCert()
|
||||
assert.NoError(t, err)
|
||||
|
||||
config := server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
CertFile: string(cert),
|
||||
KeyFile: string(key),
|
||||
}
|
||||
|
||||
logBuffer := &ThreadSafeBuffer{}
|
||||
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
qp := new(mocks.QuoteProvider)
|
||||
authSvc := new(authmocks.Authenticator)
|
||||
|
||||
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
wg.Done()
|
||||
err := srv.Start()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
cancel()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
logContent := logBuffer.String()
|
||||
fmt.Println(logContent)
|
||||
assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS")
|
||||
}
|
||||
|
||||
func TestServerStartWithAttestedTLS(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
config := server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
AttestedTLS: true,
|
||||
}
|
||||
|
||||
logBuffer := &ThreadSafeBuffer{}
|
||||
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
qp := new(mocks.QuoteProvider)
|
||||
authSvc := new(authmocks.Authenticator)
|
||||
qp.On("GetRawQuote", mock.Anything).Return([]byte("mock-quote"), nil)
|
||||
|
||||
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
wg.Done()
|
||||
err := srv.Start()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
cancel()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
logContent := logBuffer.String()
|
||||
assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with Attested TLS")
|
||||
|
||||
qp.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestServerStop(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
config := server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
}
|
||||
buf := &ThreadSafeBuffer{}
|
||||
logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
qp := new(mocks.QuoteProvider)
|
||||
authSvc := new(authmocks.Authenticator)
|
||||
|
||||
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc)
|
||||
|
||||
go func() {
|
||||
err := srv.Start()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
cancel()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err := srv.Stop()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Contains(t, buf.String(), "TestServer gRPC service shutdown at localhost:0")
|
||||
}
|
||||
|
||||
func generateSelfSignedCert() ([]byte, []byte, error) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cert, err := generateSelfSignedCertFromKey(key)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return cert, pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}), nil
|
||||
}
|
||||
|
||||
func generateSelfSignedCertFromKey(key *rsa.PrivateKey) ([]byte, error) {
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}), nil
|
||||
}
|
||||
|
||||
type ThreadSafeBuffer struct {
|
||||
buffer strings.Builder
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (b *ThreadSafeBuffer) Write(p []byte) (n int, err error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.buffer.Write(p)
|
||||
}
|
||||
|
||||
func (b *ThreadSafeBuffer) String() string {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.buffer.String()
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
// Code generated by mockery v2.43.2. DO NOT EDIT.
|
||||
|
||||
package mocks
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// Server is an autogenerated mock type for the Server type
|
||||
type Server struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Start provides a mock function with given fields:
|
||||
func (_m *Server) Start() error {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Start")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Stop provides a mock function with given fields:
|
||||
func (_m *Server) Stop() error {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Stop")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// NewServer creates a new instance of Server. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewServer(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Server {
|
||||
mock := &Server{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
//go:generate mockery --name Server --output ./mocks --filename server.go
|
||||
type Server interface {
|
||||
Start() error
|
||||
Stop() error
|
||||
|
||||
@@ -0,0 +1,138 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/internal/server/mocks"
|
||||
)
|
||||
|
||||
func TestStopAllServer(t *testing.T) {
|
||||
server1 := new(mocks.Server)
|
||||
server2 := new(mocks.Server)
|
||||
server1.On("Stop").Return(nil)
|
||||
server2.On("Stop").Return(errors.New("failed to stop"))
|
||||
tests := []struct {
|
||||
name string
|
||||
servers []Server
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "All servers stop successfully",
|
||||
servers: []Server{
|
||||
server1,
|
||||
server1,
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "One server fails to stop",
|
||||
servers: []Server{
|
||||
server1,
|
||||
server2,
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "No servers",
|
||||
servers: []Server{},
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := stopAllServer(tt.servers...)
|
||||
if (err != nil) != tt.expectedError {
|
||||
t.Errorf("stopAllServer() error = %v, expectedError %v", err, tt.expectedError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopHandler(t *testing.T) {
|
||||
mockServer := new(mocks.Server)
|
||||
mockServer.On("Stop").Return(nil)
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func() (context.Context, context.CancelFunc, *slog.Logger, string, []Server)
|
||||
triggerSignal bool
|
||||
expectedError bool
|
||||
expectCanceled bool
|
||||
}{
|
||||
{
|
||||
name: "Graceful shutdown on signal",
|
||||
setupFunc: func() (context.Context, context.CancelFunc, *slog.Logger, string, []Server) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
return ctx, cancel, logger, "test", []Server{mockServer}
|
||||
},
|
||||
triggerSignal: true,
|
||||
expectedError: false,
|
||||
expectCanceled: true,
|
||||
},
|
||||
{
|
||||
name: "Context canceled",
|
||||
setupFunc: func() (context.Context, context.CancelFunc, *slog.Logger, string, []Server) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
return ctx, cancel, logger, "test", []Server{mockServer}
|
||||
},
|
||||
triggerSignal: false,
|
||||
expectedError: false,
|
||||
expectCanceled: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx, cancel, logger, svcName, servers := tt.setupFunc()
|
||||
defer cancel()
|
||||
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
errChan <- StopHandler(ctx, cancel, logger, svcName, servers...)
|
||||
}()
|
||||
|
||||
if tt.triggerSignal {
|
||||
// Simulate SIGINT
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err := syscall.Kill(syscall.Getpid(), syscall.SIGINT)
|
||||
if err != nil {
|
||||
t.Errorf("failed to send signal: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
if (err != nil) != tt.expectedError {
|
||||
t.Errorf("StopHandler() error = %v, expectedError %v", err, tt.expectedError)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("StopHandler() timed out")
|
||||
}
|
||||
|
||||
if tt.expectCanceled {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Context was canceled as expected
|
||||
default:
|
||||
t.Error("Context was not canceled")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -38,7 +38,7 @@ type AckWriter struct {
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewAckWriter(conn net.Conn) *AckWriter {
|
||||
func NewAckWriter(conn net.Conn) io.WriteCloser {
|
||||
aw := &AckWriter{
|
||||
conn: conn,
|
||||
pendingMessages: make(chan *Message, maxConcurrent),
|
||||
@@ -52,14 +52,6 @@ func NewAckWriter(conn net.Conn) *AckWriter {
|
||||
return aw
|
||||
}
|
||||
|
||||
func (aw *AckWriter) WriteProto(msg proto.Message) (int, error) {
|
||||
data, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error marshaling protobuf message: %v", err)
|
||||
}
|
||||
return aw.Write(data)
|
||||
}
|
||||
|
||||
func (aw *AckWriter) Write(p []byte) (int, error) {
|
||||
if len(p) > maxMessageSize {
|
||||
return 0, fmt.Errorf("message size exceeds maximum allowed size of %d bytes", maxMessageSize)
|
||||
@@ -176,11 +168,16 @@ func (aw *AckWriter) Close() error {
|
||||
return aw.conn.Close()
|
||||
}
|
||||
|
||||
type Reader interface {
|
||||
Read() ([]byte, error)
|
||||
ReadProto(msg proto.Message) error
|
||||
}
|
||||
|
||||
type AckReader struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func NewAckReader(conn net.Conn) *AckReader {
|
||||
func NewAckReader(conn net.Conn) Reader {
|
||||
return &AckReader{
|
||||
conn: conn,
|
||||
}
|
||||
|
||||
@@ -0,0 +1,205 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package vsock
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
// MockConn implements net.Conn for testing purposes.
|
||||
type MockConn struct {
|
||||
ReadData []byte
|
||||
WrittenData []byte
|
||||
ReadErr error
|
||||
WriteErr error
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (m *MockConn) Read(b []byte) (n int, err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.closed {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if len(m.ReadData) == 0 {
|
||||
return 0, io.EOF // Ensure we handle this case more predictably
|
||||
}
|
||||
if m.ReadErr != nil {
|
||||
return 0, m.ReadErr
|
||||
}
|
||||
n = copy(b, m.ReadData)
|
||||
m.ReadData = m.ReadData[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (m *MockConn) Write(b []byte) (n int, err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.closed {
|
||||
return 0, errors.New("connection closed")
|
||||
}
|
||||
if m.WriteErr != nil {
|
||||
return 0, m.WriteErr
|
||||
}
|
||||
m.WrittenData = append(m.WrittenData, b...)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (m *MockConn) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Implement other net.Conn methods with empty implementations.
|
||||
func (m *MockConn) LocalAddr() net.Addr { return nil }
|
||||
func (m *MockConn) RemoteAddr() net.Addr { return nil }
|
||||
func (m *MockConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (m *MockConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (m *MockConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
func TestAckReader_Read(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{"Valid message", []byte("Hello, World!"), false},
|
||||
{"Empty message", []byte{}, false},
|
||||
{"Message at max size", make([]byte, maxMessageSize), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockConn := &MockConn{}
|
||||
ar := NewAckReader(mockConn)
|
||||
|
||||
// Prepare mock data
|
||||
messageID := uint32(1)
|
||||
messageLen := uint32(len(tt.data))
|
||||
mockData := make([]byte, 8+len(tt.data))
|
||||
binary.LittleEndian.PutUint32(mockData[:4], messageID)
|
||||
binary.LittleEndian.PutUint32(mockData[4:8], messageLen)
|
||||
copy(mockData[8:], tt.data)
|
||||
mockConn.ReadData = mockData
|
||||
|
||||
data, err := ar.Read()
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("AckReader.Read() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr {
|
||||
if !bytes.Equal(data, tt.data) {
|
||||
t.Errorf("AckReader.Read() got = %v, want %v", data, tt.data)
|
||||
}
|
||||
|
||||
// Check if ACK was sent
|
||||
if len(mockConn.WrittenData) != 4 {
|
||||
t.Errorf("AckReader.Read() did not send ACK")
|
||||
} else {
|
||||
ackID := binary.LittleEndian.Uint32(mockConn.WrittenData)
|
||||
if ackID != messageID {
|
||||
t.Errorf("AckReader.Read() sent wrong ACK ID, got %d, want %d", ackID, messageID)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAckReader_ReadProto(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *manager.ClientStreamMessage
|
||||
wantErr bool
|
||||
}{
|
||||
{"Valid proto message", &manager.ClientStreamMessage{}, false},
|
||||
{"Empty proto message", &manager.ClientStreamMessage{}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockConn := &MockConn{}
|
||||
ar := NewAckReader(mockConn)
|
||||
|
||||
// Prepare mock data
|
||||
protoData, _ := proto.Marshal(tt.msg)
|
||||
messageID := uint32(1)
|
||||
messageLen := uint32(len(protoData))
|
||||
mockData := make([]byte, 8+len(protoData))
|
||||
binary.LittleEndian.PutUint32(mockData[:4], messageID)
|
||||
binary.LittleEndian.PutUint32(mockData[4:8], messageLen)
|
||||
copy(mockData[8:], protoData)
|
||||
mockConn.ReadData = mockData
|
||||
|
||||
receivedMsg := &manager.ClientStreamMessage{}
|
||||
err := ar.ReadProto(receivedMsg)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("AckReader.ReadProto() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr {
|
||||
if receivedMsg.Message != tt.msg.Message {
|
||||
t.Errorf("AckReader.ReadProto() got = %v, want %v", receivedMsg, tt.msg)
|
||||
}
|
||||
|
||||
// Check if ACK was sent
|
||||
if len(mockConn.WrittenData) != 4 {
|
||||
t.Errorf("AckReader.ReadProto() did not send ACK")
|
||||
} else {
|
||||
ackID := binary.LittleEndian.Uint32(mockConn.WrittenData)
|
||||
if ackID != messageID {
|
||||
t.Errorf("AckReader.ReadProto() sent wrong ACK ID, got %d, want %d", ackID, messageID)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAckWriter(t *testing.T) {
|
||||
mockConn := &MockConn{}
|
||||
writer := NewAckWriter(mockConn)
|
||||
|
||||
if _, ok := writer.(io.Writer); !ok {
|
||||
t.Errorf("NewAckWriter() did not return an io.Writer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAckReader(t *testing.T) {
|
||||
mockConn := &MockConn{}
|
||||
reader := NewAckReader(mockConn)
|
||||
|
||||
assert.NotNil(t, reader)
|
||||
}
|
||||
|
||||
func TestAckWriter_Close(t *testing.T) {
|
||||
mockConn := &MockConn{}
|
||||
aw := NewAckWriter(mockConn)
|
||||
|
||||
err := aw.Close()
|
||||
if err != nil {
|
||||
t.Errorf("AckWriter.Close() error = %v, wantErr %v", err, nil)
|
||||
}
|
||||
|
||||
if !mockConn.closed {
|
||||
t.Errorf("AckWriter.Close() did not close the connection")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package internal
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestZipDirectoryToMemory(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "zip_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
testFiles := map[string]string{
|
||||
"file1.txt": "Content of file 1",
|
||||
"file2.txt": "Content of file 2",
|
||||
"subdir/file3.txt": "Content of file 3 in subdirectory",
|
||||
}
|
||||
|
||||
for path, content := range testFiles {
|
||||
fullPath := filepath.Join(tempDir, path)
|
||||
err := os.MkdirAll(filepath.Dir(fullPath), 0o755)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create directory: %v", err)
|
||||
}
|
||||
err = os.WriteFile(fullPath, []byte(content), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
zipData, err := ZipDirectoryToMemory(tempDir)
|
||||
if err != nil {
|
||||
t.Fatalf("ZipDirectoryToMemory failed: %v", err)
|
||||
}
|
||||
|
||||
if len(zipData) == 0 {
|
||||
t.Error("Zip data is empty")
|
||||
}
|
||||
|
||||
unzipDir, err := os.MkdirTemp("", "unzip_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory for unzip: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(unzipDir)
|
||||
|
||||
err = UnzipFromMemory(zipData, unzipDir)
|
||||
if err != nil {
|
||||
t.Fatalf("UnzipFromMemory failed: %v", err)
|
||||
}
|
||||
|
||||
for path, expectedContent := range testFiles {
|
||||
fullPath := filepath.Join(unzipDir, path)
|
||||
content, err := os.ReadFile(fullPath)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to read unzipped file %s: %v", path, err)
|
||||
continue
|
||||
}
|
||||
if string(content) != expectedContent {
|
||||
t.Errorf("Content mismatch for file %s. Expected: %s, Got: %s", path, expectedContent, string(content))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestZipDirectoryToMemory_EmptyDirectory(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "empty_zip_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
zipData, err := ZipDirectoryToMemory(tempDir)
|
||||
if err != nil {
|
||||
t.Fatalf("ZipDirectoryToMemory failed on empty directory: %v", err)
|
||||
}
|
||||
|
||||
if len(zipData) == 0 {
|
||||
t.Error("Zip data is empty for an empty directory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnzipFromMemory_InvalidZipData(t *testing.T) {
|
||||
invalidZipData := []byte("This is not a valid zip file")
|
||||
tempDir, err := os.MkdirTemp("", "invalid_unzip_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
err = UnzipFromMemory(invalidZipData, tempDir)
|
||||
if err == nil {
|
||||
t.Error("UnzipFromMemory should fail with invalid zip data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestZipDirectoryToMemory_NonExistentDirectory(t *testing.T) {
|
||||
nonExistentDir := "/path/to/non/existent/directory"
|
||||
_, err := ZipDirectoryToMemory(nonExistentDir)
|
||||
if err == nil {
|
||||
t.Error("ZipDirectoryToMemory should fail with non-existent directory")
|
||||
}
|
||||
}
|
||||
@@ -146,8 +146,6 @@ func TestHandleConnection(t *testing.T) {
|
||||
receivedMsg := <-ms.eventsChan
|
||||
assert.Equal(t, msg.GetAgentEvent().EventType, receivedMsg.GetAgentEvent().EventType)
|
||||
assert.Equal(t, msg.GetAgentEvent().ComputationId, receivedMsg.GetAgentEvent().ComputationId)
|
||||
|
||||
mockConn.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestReportBrokenConnection(t *testing.T) {
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/ultravioletrs/cocos/pkg/sdk"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
@@ -33,7 +34,7 @@ var (
|
||||
)
|
||||
|
||||
func TestAlgo(t *testing.T) {
|
||||
conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
|
||||
conn, err := grpc.NewClient("passthrough://bufnet", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(bufDialer))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial bufnet: %v", err)
|
||||
}
|
||||
@@ -120,7 +121,7 @@ func TestAlgo(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestData(t *testing.T) {
|
||||
conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
|
||||
conn, err := grpc.NewClient("passthrough://bufnet", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(bufDialer))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial bufnet: %v", err)
|
||||
}
|
||||
@@ -217,7 +218,7 @@ func TestData(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResult(t *testing.T) {
|
||||
conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
|
||||
conn, err := grpc.NewClient("passthrough://bufnet", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(bufDialer))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial bufnet: %v", err)
|
||||
}
|
||||
@@ -318,7 +319,7 @@ func TestAttestation(t *testing.T) {
|
||||
0x05, 0x06, 0x07, 0x08,
|
||||
}
|
||||
|
||||
conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
|
||||
conn, err := grpc.NewClient("passthrough://bufnet", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(bufDialer))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial bufnet: %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user