From 6a014d1e767070190fcb13a5cbc0743f15ef5ae3 Mon Sep 17 00:00:00 2001 From: Amir Raminfar Date: Mon, 25 May 2026 18:15:20 -0700 Subject: [PATCH] feat(cloud): accept container name or id in container-scoped tools (#4743) Co-authored-by: Claude Opus 4.7 (1M context) --- internal/cloud/client_test.go | 3 +- internal/cloud/tools.go | 16 +- internal/cloud/tools_actions.go | 30 ++- internal/cloud/tools_containers.go | 7 +- internal/cloud/tools_logs.go | 7 +- internal/cloud/tools_resolve.go | 180 ++++++++++++++++++ internal/cloud/tools_resolve_test.go | 274 +++++++++++++++++++++++++++ internal/cloud/tools_stream.go | 11 +- internal/cloud/tools_stream_test.go | 31 ++- internal/cloud/tools_test.go | 22 ++- 10 files changed, 537 insertions(+), 44 deletions(-) create mode 100644 internal/cloud/tools_resolve.go create mode 100644 internal/cloud/tools_resolve_test.go diff --git a/internal/cloud/client_test.go b/internal/cloud/client_test.go index afe8090e..53166b4e 100644 --- a/internal/cloud/client_test.go +++ b/internal/cloud/client_test.go @@ -128,9 +128,10 @@ func TestHandleRequest_CallTool_RestartContainer(t *testing.T) { mockClient := &MockClientService{} mockClient.On("ContainerAction", mock.Anything, mock.Anything, container.Restart).Return(nil) - cs := container_support.NewContainerService(mockClient, container.Container{ID: "abc123"}) + cs := container_support.NewContainerService(mockClient, container.Container{ID: "abc123", Name: "nginx", Host: "local"}) mockHost := &MockHostService{} + withResolver(mockHost, container.Container{ID: "abc123", Name: "nginx", Host: "local"}) mockHost.On("FindContainer", "local", "abc123", container.ContainerLabels(nil)).Return(cs, nil) client := &Client{ diff --git a/internal/cloud/tools.go b/internal/cloud/tools.go index b3a24635..c9cf0671 100644 --- a/internal/cloud/tools.go +++ b/internal/cloud/tools.go @@ -67,8 +67,8 @@ var ( Properties: map[string]paramProperty{}, }) - containerIDParam = paramProperty{Type: "string", Description: "The container ID (from find_containers)"} - hostIDParam = paramProperty{Type: "string", Description: "The host ID where the container is running (from find_containers)"} + containerIDParam = paramProperty{Type: "string", Description: "Container name or ID. You can pass the container name directly (as shown in logs, events, and find_containers) — it does not need to be the opaque ID. Resolved by exact name first, then ID, then a unique name substring. If the name matches more than one container the call fails with the list of candidates so you can disambiguate."} + hostIDParam = paramProperty{Type: "string", Description: "Host name or ID (from list_hosts or find_containers). Optional — omit it when the container name is unique across all hosts; supply it (name or ID) only to scope to a specific host when a name is ambiguous."} boolFalse = false targetedParams = mustSchema(paramSchema{ @@ -77,7 +77,7 @@ var ( "container_id": containerIDParam, "host_id": hostIDParam, }, - Required: []string{"container_id", "host_id"}, + Required: []string{"container_id"}, AdditionalProperties: &boolFalse, }) @@ -172,7 +172,7 @@ Examples: name == "die"; name == "oom"; name in ["die", "oom", "kill"]; name == "query": {Type: "string", Description: "Optional text search query (case-insensitive substring match)"}, "regex": {Type: "string", Description: "Optional regex pattern to match against log messages"}, }, - Required: []string{"container_id", "host_id"}, + Required: []string{"container_id"}, AdditionalProperties: &boolFalse, }) @@ -185,7 +185,7 @@ Examples: name == "die"; name == "oom"; name in ["die", "oom", "kill"]; name == "query": {Type: "string", Description: "Optional text search query (case-insensitive substring match)"}, "regex": {Type: "string", Description: "Optional regex pattern to match against log messages"}, }, - Required: []string{"container_id", "host_id"}, + Required: []string{"container_id"}, AdditionalProperties: &boolFalse, }) ) @@ -202,7 +202,7 @@ func AvailableTools(enableActions bool) []*pb.ToolDefinition { }, { Name: toolFindContainers, - Description: "Search for Docker containers by name, state, or health status. All parameters are optional. Returns container ID, name, image, state, health, and host. Use this before start/stop/restart actions to get the container ID and host.", + Description: "Search for Docker containers by name, state, or health status. All parameters are optional. Returns container ID, name, image, state, health, and host. The container-scoped tools (inspect/logs/start/stop/restart/remove/update) accept a name directly, so you usually don't need to look up the ID first — use this when you want to disambiguate a name that matches multiple containers.", ParametersJson: findContainerParams, Scope: pb.ToolScope_TOOL_SCOPE_INSTANCE, ReadOnly: true, @@ -230,14 +230,14 @@ func AvailableTools(enableActions bool) []*pb.ToolDefinition { }, { Name: toolFetchContainerLogs, - Description: "Fetch raw logs from a running Docker container. Requires container_id and host from find_containers. Optionally filter by time range, log level, text search, or regex pattern. Returns up to 100 matching log lines.", + Description: "Fetch raw logs from a running Docker container. Identify the container by name or ID via container_id; host_id is optional unless the name is ambiguous. Optionally filter by time range, log level, text search, or regex pattern. Returns up to 100 matching log lines.", ParametersJson: fetchLogsParams, Scope: pb.ToolScope_TOOL_SCOPE_CONTAINER, ReadOnly: true, }, { Name: toolStreamLogs, - Description: "Stream live logs from a running Docker container in real time. Requires container_id and host_id from find_containers. Optionally filter by log level, text search, or regex pattern. Streams continuously until cancelled.", + Description: "Stream live logs from a running Docker container in real time. Identify the container by name or ID via container_id; host_id is optional unless the name is ambiguous. Optionally filter by log level, text search, or regex pattern. Streams continuously until cancelled.", ParametersJson: streamLogsParams, Scope: pb.ToolScope_TOOL_SCOPE_CONTAINER, ReadOnly: true, diff --git a/internal/cloud/tools_actions.go b/internal/cloud/tools_actions.go index 6b91c4d9..d83c6502 100644 --- a/internal/cloud/tools_actions.go +++ b/internal/cloud/tools_actions.go @@ -25,14 +25,12 @@ func executeContainerAction(ctx context.Context, name string, argsJSON string, d return nil, err } - if args.ContainerID == "" { - return nil, fmt.Errorf("container_id is required") - } - if args.Host == "" { - return nil, fmt.Errorf("host is required") + hostID, containerID, err := resolveContainerRef(args.ContainerID, args.Host, deps) + if err != nil { + return nil, err } - cs, err := deps.HostService.FindContainer(args.Host, args.ContainerID, deps.Labels) + cs, err := deps.HostService.FindContainer(hostID, containerID, deps.Labels) if err != nil { return nil, fmt.Errorf("container not found: %w", err) } @@ -41,13 +39,13 @@ func executeContainerAction(ctx context.Context, name string, argsJSON string, d return nil, fmt.Errorf("action failed: %w", err) } - message := fmt.Sprintf("Successfully %s container %s.", pastTense(action), args.ContainerID) + message := fmt.Sprintf("Successfully %s container %s.", pastTense(action), cs.Container.Name) return &pb.CallToolResponse{ Success: true, Result: &pb.CallToolResponse_Action{Action: &pb.ActionResult{ Success: true, - ContainerId: args.ContainerID, + ContainerId: cs.Container.ID, Action: string(action), Message: message, }}, @@ -60,14 +58,12 @@ func executeUpdateContainer(ctx context.Context, argsJSON string, deps ToolDeps) return nil, fmt.Errorf("failed to parse arguments: %w", err) } - if args.ContainerID == "" { - return nil, fmt.Errorf("container_id is required") - } - if args.Host == "" { - return nil, fmt.Errorf("host is required") + hostID, containerID, err := resolveContainerRef(args.ContainerID, args.Host, deps) + if err != nil { + return nil, err } - cs, err := deps.HostService.FindContainer(args.Host, args.ContainerID, deps.Labels) + cs, err := deps.HostService.FindContainer(hostID, containerID, deps.Labels) if err != nil { return nil, fmt.Errorf("container not found: %w", err) } @@ -87,16 +83,16 @@ func executeUpdateContainer(ctx context.Context, argsJSON string, deps ToolDeps) return nil, fmt.Errorf("update failed: %w", updateErr) } - message := fmt.Sprintf("Successfully updated container %s by pulling the latest image and recreating it.", args.ContainerID) + message := fmt.Sprintf("Successfully updated container %s by pulling the latest image and recreating it.", cs.Container.Name) if !updated { - message = fmt.Sprintf("Container %s is already running the latest image. No update was needed.", args.ContainerID) + message = fmt.Sprintf("Container %s is already running the latest image. No update was needed.", cs.Container.Name) } return &pb.CallToolResponse{ Success: true, Result: &pb.CallToolResponse_Action{Action: &pb.ActionResult{ Success: true, - ContainerId: args.ContainerID, + ContainerId: cs.Container.ID, Action: "update", Message: message, }}, diff --git a/internal/cloud/tools_containers.go b/internal/cloud/tools_containers.go index 92c9434a..c0c41809 100644 --- a/internal/cloud/tools_containers.go +++ b/internal/cloud/tools_containers.go @@ -168,11 +168,12 @@ func executeInspectContainer(argsJSON string, deps ToolDeps) (*pb.CallToolRespon if err := json.Unmarshal([]byte(argsJSON), &args); err != nil { return nil, fmt.Errorf("failed to parse arguments: %w", err) } - if args.ContainerID == "" || args.Host == "" { - return nil, fmt.Errorf("container_id and host are required") + hostID, containerID, err := resolveContainerRef(args.ContainerID, args.Host, deps) + if err != nil { + return nil, err } - cs, err := deps.HostService.FindContainer(args.Host, args.ContainerID, deps.Labels) + cs, err := deps.HostService.FindContainer(hostID, containerID, deps.Labels) if err != nil { return nil, fmt.Errorf("container not found: %w", err) } diff --git a/internal/cloud/tools_logs.go b/internal/cloud/tools_logs.go index 8673fe7b..3a783ad4 100644 --- a/internal/cloud/tools_logs.go +++ b/internal/cloud/tools_logs.go @@ -27,11 +27,12 @@ func executeFetchContainerLogs(ctx context.Context, argsJSON string, deps ToolDe if err := json.Unmarshal([]byte(argsJSON), &args); err != nil { return nil, fmt.Errorf("failed to parse arguments: %w", err) } - if args.ContainerID == "" || args.Host == "" { - return nil, fmt.Errorf("container_id and host_id are required") + hostID, containerID, err := resolveContainerRef(args.ContainerID, args.Host, deps) + if err != nil { + return nil, err } - cs, err := deps.HostService.FindContainer(args.Host, args.ContainerID, deps.Labels) + cs, err := deps.HostService.FindContainer(hostID, containerID, deps.Labels) if err != nil { return nil, fmt.Errorf("container not found: %w", err) } diff --git a/internal/cloud/tools_resolve.go b/internal/cloud/tools_resolve.go new file mode 100644 index 00000000..e3d655df --- /dev/null +++ b/internal/cloud/tools_resolve.go @@ -0,0 +1,180 @@ +package cloud + +import ( + "fmt" + "strings" + + "github.com/amir20/dozzle/internal/container" +) + +// resolveContainerRef turns an LLM-supplied container reference (a name OR an +// id) plus an optional host reference (a name OR an id) into the concrete +// (hostID, containerID) pair that HostService.FindContainer expects. +// +// LLMs almost always have only the container name (it is what logs, events and +// listings surface), so every container-scoped tool funnels through here rather +// than passing the raw reference straight to FindContainer — which only matches +// exact ids on a known host. +// +// Matching is tiered, id-first so existing id-based callers behave exactly as +// before (this capability is purely additive — names are the new fallback): +// 1. exact id match (full id or short/prefix id) — the legacy path +// 2. exact name match (case-insensitive) +// 3. unique substring of the name +// +// The first tier that yields any candidates wins. An id match is always unique, +// so it is never treated as ambiguous. If a name tier yields more than one +// container the call fails with an error listing every candidate (name + id + +// host) so the caller can disambiguate — we NEVER silently pick one. This is +// what makes the write tools (stop/restart/remove/update) safe to drive by name. +// +// hostRef is optional: when empty the container must resolve unambiguously +// across all hosts; when supplied it scopes the search to that host (matched by +// id or name). +func resolveContainerRef(containerRef, hostRef string, deps ToolDeps) (hostID, containerID string, err error) { + containerRef = strings.TrimSpace(containerRef) + if containerRef == "" { + return "", "", fmt.Errorf("container_id is required") + } + + containers, errs := deps.HostService.ListAllContainers(deps.Labels) + logHostErrors(errs) + hostNames := buildHostNameMap(deps.HostService) + + // Scope to a host if one was supplied. The host reference may be an id or a + // name; an unknown host is an explicit error rather than a silent no-match. + hostRef = strings.TrimSpace(hostRef) + var scopedHostID string + if hostRef != "" { + scopedHostID, err = resolveHostRef(hostRef, deps) + if err != nil { + return "", "", err + } + filtered := containers[:0:0] + for _, c := range containers { + if c.Host == scopedHostID { + filtered = append(filtered, c) + } + } + containers = filtered + } + + // Tiered matching. Each tier collects candidates; the first non-empty tier + // decides the outcome. Id is checked first so a value that is a real + // container id resolves directly — identical to the legacy behavior — and is + // never confused with a name. + var exactID, exactName, substring []container.Container + for _, c := range containers { + switch { + case matchesID(c.ID, containerRef): + exactID = append(exactID, c) + case strings.EqualFold(c.Name, containerRef): + exactName = append(exactName, c) + case containsIgnoreCase(c.Name, containerRef): + substring = append(substring, c) + } + } + + for _, tier := range [][]container.Container{exactID, exactName, substring} { + switch len(tier) { + case 0: + continue + case 1: + return tier[0].Host, tier[0].ID, nil + default: + return "", "", ambiguousError(containerRef, hostRef, tier, hostNames) + } + } + + // Legacy fallback: when an explicit host was given (and resolved to a real + // host id) but the listing produced no match — e.g. a host returned a + // partial error and was omitted from ListAllContainers — pass the reference + // straight through to FindContainer's direct lookup, exactly as before this + // resolver existed. This guarantees id-based callers never regress. + if scopedHostID != "" { + return scopedHostID, containerRef, nil + } + + return "", "", fmt.Errorf("no container matching %q found across all connected hosts; call find_containers to list available containers", containerRef) +} + +// matchesID reports whether ref identifies the container id — either the full id +// or a short/prefix form (Docker ids are commonly referenced by their first 12 +// characters). Comparison is case-insensitive to tolerate sloppy input. +func matchesID(id, ref string) bool { + if id == "" { + return false + } + lid, lref := strings.ToLower(id), strings.ToLower(ref) + if lid == lref { + return true + } + // Treat ref as a prefix only when it is reasonably id-shaped to avoid a + // short generic string accidentally prefix-matching an id. + if len(lref) >= 12 && strings.HasPrefix(lid, lref) { + return true + } + return false +} + +// resolveHostRef resolves a host reference (id or name) to its host id. +func resolveHostRef(hostRef string, deps ToolDeps) (string, error) { + hosts := deps.HostService.Hosts() + var byName []container.Host + for _, h := range hosts { + if h.ID == hostRef { + return h.ID, nil + } + if strings.EqualFold(h.Name, hostRef) { + byName = append(byName, h) + } + } + switch len(byName) { + case 0: + return "", fmt.Errorf("no host matching %q found; call list_hosts to see available hosts", hostRef) + case 1: + return byName[0].ID, nil + default: + names := make([]string, len(byName)) + for i, h := range byName { + names[i] = fmt.Sprintf("%s (id %s)", h.Name, h.ID) + } + return "", fmt.Errorf("host name %q is ambiguous; matches: %s. Pass the host id instead", hostRef, strings.Join(names, "; ")) + } +} + +// ambiguousError builds an actionable error listing every candidate so the +// caller can re-issue the call unambiguously. The hint is tailored to the +// candidate set because the LLM reads it to choose its next action: when the +// candidates span multiple hosts, host_id disambiguates; when they all sit on +// one host, host_id is useless and only the exact id or full name will do. +func ambiguousError(containerRef, hostRef string, candidates []container.Container, hostNames map[string]string) error { + parts := make([]string, len(candidates)) + sameHost := true + for i, c := range candidates { + parts[i] = fmt.Sprintf("%s (id %s on host %s)", c.Name, shortID(c.ID), resolveHostName(c.Host, hostNames)) + if c.Host != candidates[0].Host { + sameHost = false + } + } + + var hint string + switch { + case hostRef != "" || sameHost: + // A host was already supplied, or every candidate is on the same host — + // scoping by host_id cannot narrow it further. + hint = "pass the exact container id or the full container name to disambiguate" + default: + hint = "pass host_id to scope to one host, or pass the exact container id" + } + return fmt.Errorf("%q matches multiple containers: %s. To act on the right one, %s", containerRef, strings.Join(parts, "; "), hint) +} + +// shortID trims a full docker id to its conventional 12-character form for +// display in errors. +func shortID(id string) string { + if len(id) > 12 { + return id[:12] + } + return id +} diff --git a/internal/cloud/tools_resolve_test.go b/internal/cloud/tools_resolve_test.go new file mode 100644 index 00000000..9fe0223d --- /dev/null +++ b/internal/cloud/tools_resolve_test.go @@ -0,0 +1,274 @@ +package cloud + +import ( + "context" + "testing" + + "github.com/amir20/dozzle/internal/container" + container_support "github.com/amir20/dozzle/internal/support/container" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// resolverDeps builds ToolDeps whose HostService resolves against the given +// containers. Hosts are derived from the containers' Host fields, optionally +// augmented with explicit hosts (so a host id and a human name can differ). +func resolverDeps(containers []container.Container, hosts ...container.Host) ToolDeps { + m := &MockHostService{} + m.On("ListAllContainers", container.ContainerLabels(nil)).Return(containers, nil).Maybe() + if len(hosts) == 0 { + seen := map[string]bool{} + for _, c := range containers { + if c.Host != "" && !seen[c.Host] { + seen[c.Host] = true + hosts = append(hosts, container.Host{ID: c.Host, Name: c.Host}) + } + } + } + m.On("Hosts").Return(hosts).Maybe() + return ToolDeps{HostService: m} +} + +func TestResolveContainerRef_ByName(t *testing.T) { + deps := resolverDeps([]container.Container{ + {ID: "abc123def456", Name: "nginx", Host: "local"}, + {ID: "fff999", Name: "redis", Host: "local"}, + }) + + host, id, err := resolveContainerRef("nginx", "", deps) + assert.NoError(t, err) + assert.Equal(t, "local", host) + assert.Equal(t, "abc123def456", id) +} + +func TestResolveContainerRef_ByNameCaseInsensitive(t *testing.T) { + deps := resolverDeps([]container.Container{ + {ID: "abc123def456", Name: "NginX", Host: "local"}, + }) + + host, id, err := resolveContainerRef("nginx", "", deps) + assert.NoError(t, err) + assert.Equal(t, "local", host) + assert.Equal(t, "abc123def456", id) +} + +func TestResolveContainerRef_ByFullID(t *testing.T) { + deps := resolverDeps([]container.Container{ + {ID: "abc123def456", Name: "nginx", Host: "local"}, + }) + + host, id, err := resolveContainerRef("abc123def456", "", deps) + assert.NoError(t, err) + assert.Equal(t, "local", host) + assert.Equal(t, "abc123def456", id) +} + +func TestResolveContainerRef_ByShortIDPrefix(t *testing.T) { + deps := resolverDeps([]container.Container{ + {ID: "abc123def4567890", Name: "nginx", Host: "local"}, + }) + + // 12-char short id form + host, id, err := resolveContainerRef("abc123def456", "", deps) + assert.NoError(t, err) + assert.Equal(t, "local", host) + assert.Equal(t, "abc123def4567890", id) +} + +func TestResolveContainerRef_BySubstring(t *testing.T) { + deps := resolverDeps([]container.Container{ + {ID: "id1", Name: "my-app-frontend", Host: "local"}, + {ID: "id2", Name: "database", Host: "local"}, + }) + + host, id, err := resolveContainerRef("frontend", "", deps) + assert.NoError(t, err) + assert.Equal(t, "local", host) + assert.Equal(t, "id1", id) +} + +func TestResolveContainerRef_ExactNameBeatsSubstring(t *testing.T) { + // "api" matches "api" exactly and "api-gateway" as a substring. Exact wins, + // so this is NOT ambiguous. + deps := resolverDeps([]container.Container{ + {ID: "id1", Name: "api", Host: "local"}, + {ID: "id2", Name: "api-gateway", Host: "local"}, + }) + + host, id, err := resolveContainerRef("api", "", deps) + assert.NoError(t, err) + assert.Equal(t, "local", host) + assert.Equal(t, "id1", id) +} + +func TestResolveContainerRef_AmbiguousNameAcrossHosts(t *testing.T) { + // Same name on two different hosts, no host supplied → ambiguous, must list + // candidates and must NOT silently pick one. + deps := resolverDeps([]container.Container{ + {ID: "id1", Name: "nginx", Host: "host-a"}, + {ID: "id2", Name: "nginx", Host: "host-b"}, + }) + + _, _, err := resolveContainerRef("nginx", "", deps) + assert.Error(t, err) + assert.Contains(t, err.Error(), "matches multiple containers") + assert.Contains(t, err.Error(), "id1") + assert.Contains(t, err.Error(), "id2") + assert.Contains(t, err.Error(), "host-a") + assert.Contains(t, err.Error(), "host-b") + assert.Contains(t, err.Error(), "host_id") +} + +func TestResolveContainerRef_AmbiguousSubstring(t *testing.T) { + // Both candidates are on the same host, so host_id cannot disambiguate — the + // hint must steer the LLM to the exact id / full name, not a useless retry + // with host_id. + deps := resolverDeps([]container.Container{ + {ID: "id1", Name: "app-frontend", Host: "local"}, + {ID: "id2", Name: "app-backend", Host: "local"}, + }) + + _, _, err := resolveContainerRef("app", "", deps) + assert.Error(t, err) + assert.Contains(t, err.Error(), "matches multiple containers") + assert.Contains(t, err.Error(), "exact container id or the full container name") + assert.NotContains(t, err.Error(), "host_id") +} + +func TestResolveContainerRef_HostDisambiguates(t *testing.T) { + // Same name on two hosts; supplying the host id resolves cleanly. + deps := resolverDeps([]container.Container{ + {ID: "id1", Name: "nginx", Host: "host-a"}, + {ID: "id2", Name: "nginx", Host: "host-b"}, + }) + + host, id, err := resolveContainerRef("nginx", "host-b", deps) + assert.NoError(t, err) + assert.Equal(t, "host-b", host) + assert.Equal(t, "id2", id) +} + +func TestResolveContainerRef_HostByName(t *testing.T) { + // Host referenced by its human name rather than its id. + deps := resolverDeps( + []container.Container{ + {ID: "id1", Name: "nginx", Host: "h-a"}, + {ID: "id2", Name: "nginx", Host: "h-b"}, + }, + container.Host{ID: "h-a", Name: "server-a"}, + container.Host{ID: "h-b", Name: "server-b"}, + ) + + host, id, err := resolveContainerRef("nginx", "server-b", deps) + assert.NoError(t, err) + assert.Equal(t, "h-b", host) + assert.Equal(t, "id2", id) +} + +func TestResolveContainerRef_UnknownHost(t *testing.T) { + deps := resolverDeps([]container.Container{ + {ID: "id1", Name: "nginx", Host: "local"}, + }) + + _, _, err := resolveContainerRef("nginx", "nope", deps) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no host matching") +} + +func TestResolveContainerRef_NotFound(t *testing.T) { + deps := resolverDeps([]container.Container{ + {ID: "id1", Name: "nginx", Host: "local"}, + }) + + _, _, err := resolveContainerRef("does-not-exist", "", deps) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no container matching") + assert.Contains(t, err.Error(), "find_containers") +} + +func TestResolveContainerRef_EmptyContainer(t *testing.T) { + deps := resolverDeps([]container.Container{ + {ID: "id1", Name: "nginx", Host: "local"}, + }) + + _, _, err := resolveContainerRef("", "", deps) + assert.Error(t, err) + assert.Contains(t, err.Error(), "container_id is required") +} + +func TestResolveContainerRef_LegacyIDWithHostFallsThrough(t *testing.T) { + // host exists but the container is absent from the listing (e.g. a partial + // host error). With an explicit host the resolver falls through to the + // direct (hostID, ref) lookup — preserving the legacy id path exactly. + deps := resolverDeps( + []container.Container{}, + container.Host{ID: "local", Name: "local"}, + ) + + host, id, err := resolveContainerRef("abc123", "local", deps) + assert.NoError(t, err) + assert.Equal(t, "local", host) + assert.Equal(t, "abc123", id) +} + +func TestResolveContainerRef_IDBeatsName(t *testing.T) { + // A pathological case: one container's id equals another's name. Id-first + // ordering means the id match wins and the call is unambiguous — existing + // id-based callers keep working identically. + deps := resolverDeps([]container.Container{ + {ID: "shared", Name: "alpha", Host: "local"}, + {ID: "other", Name: "shared", Host: "local"}, + }) + + host, id, err := resolveContainerRef("shared", "", deps) + assert.NoError(t, err) + assert.Equal(t, "local", host) + assert.Equal(t, "shared", id) +} + +// --- End-to-end tests through ExecuteTool proving the resolver is wired in --- + +func TestExecuteTool_InspectContainer_ByName(t *testing.T) { + mockHost := &MockHostService{} + withResolver(mockHost, container.Container{ID: "abc123def456", Name: "nginx", Host: "local"}) + cs := container_support.NewContainerService(&MockClientService{}, container.Container{ID: "abc123def456", Name: "nginx", Host: "local"}) + mockHost.On("FindContainer", "local", "abc123def456", container.ContainerLabels(nil)).Return(cs, nil) + + // Pass the NAME in container_id and omit host_id entirely. + resp := ExecuteTool(context.Background(), "inspect_container", `{"container_id":"nginx"}`, ToolDeps{HostService: mockHost}) + assert.True(t, resp.Success) + assert.Equal(t, "nginx", resp.GetInspectContainer().Name) + mockHost.AssertCalled(t, "FindContainer", "local", "abc123def456", container.ContainerLabels(nil)) +} + +func TestExecuteTool_RestartContainer_AmbiguousName_NoSilentPick(t *testing.T) { + // Write tool with an ambiguous name across hosts must NOT act — it must + // return the candidate list and never call FindContainer/ContainerAction. + mockClient := &MockClientService{} + mockHost := &MockHostService{} + withResolver(mockHost, + container.Container{ID: "id1", Name: "nginx", Host: "host-a"}, + container.Container{ID: "id2", Name: "nginx", Host: "host-b"}, + ) + + resp := ExecuteTool(context.Background(), "restart_container", `{"container_id":"nginx"}`, ToolDeps{HostService: mockHost, EnableActions: true}) + assert.False(t, resp.Success) + assert.Contains(t, resp.Error, "matches multiple containers") + mockHost.AssertNotCalled(t, "FindContainer", mock.Anything, mock.Anything, mock.Anything) + mockClient.AssertNotCalled(t, "ContainerAction", mock.Anything, mock.Anything, mock.Anything) +} + +func TestExecuteTool_RestartContainer_ByName_HostInferred(t *testing.T) { + mockClient := &MockClientService{} + mockClient.On("ContainerAction", mock.Anything, mock.Anything, container.Restart).Return(nil) + cs := container_support.NewContainerService(mockClient, container.Container{ID: "id1", Name: "nginx", Host: "host-a"}) + + mockHost := &MockHostService{} + withResolver(mockHost, container.Container{ID: "id1", Name: "nginx", Host: "host-a"}) + mockHost.On("FindContainer", "host-a", "id1", container.ContainerLabels(nil)).Return(cs, nil) + + resp := ExecuteTool(context.Background(), "restart_container", `{"container_id":"nginx"}`, ToolDeps{HostService: mockHost, EnableActions: true}) + assert.True(t, resp.Success) + assert.Equal(t, "id1", resp.GetAction().ContainerId) + mockClient.AssertCalled(t, "ContainerAction", mock.Anything, mock.Anything, container.Restart) +} diff --git a/internal/cloud/tools_stream.go b/internal/cloud/tools_stream.go index 41fd8b9b..20e80a23 100644 --- a/internal/cloud/tools_stream.go +++ b/internal/cloud/tools_stream.go @@ -21,8 +21,8 @@ func parseStreamArgs(argsJSON string) (*fetchLogsArgs, *regexp.Regexp, error) { if err := json.Unmarshal([]byte(argsJSON), &args); err != nil { return nil, nil, fmt.Errorf("failed to parse arguments: %w", err) } - if args.ContainerID == "" || args.Host == "" { - return nil, nil, fmt.Errorf("container_id and host_id are required") + if args.ContainerID == "" { + return nil, nil, fmt.Errorf("container_id is required") } var re *regexp.Regexp @@ -68,7 +68,12 @@ func executeStreamLogs(ctx context.Context, requestID string, argsJSON string, d return err } - cs, err := deps.HostService.FindContainer(args.Host, args.ContainerID, deps.Labels) + hostID, containerID, err := resolveContainerRef(args.ContainerID, args.Host, deps) + if err != nil { + return err + } + + cs, err := deps.HostService.FindContainer(hostID, containerID, deps.Labels) if err != nil { return fmt.Errorf("container not found: %w", err) } diff --git a/internal/cloud/tools_stream_test.go b/internal/cloud/tools_stream_test.go index c8727efa..54fd04a7 100644 --- a/internal/cloud/tools_stream_test.go +++ b/internal/cloud/tools_stream_test.go @@ -23,10 +23,18 @@ func TestParseStreamArgs_Valid(t *testing.T) { assert.NotNil(t, re) } -func TestParseStreamArgs_MissingRequired(t *testing.T) { - _, _, err := parseStreamArgs(`{"container_id":"abc"}`) +func TestParseStreamArgs_HostOptional(t *testing.T) { + // host_id is optional now — only container_id is required. + args, _, err := parseStreamArgs(`{"container_id":"abc"}`) + assert.NoError(t, err) + assert.Equal(t, "abc", args.ContainerID) + assert.Empty(t, args.Host) +} + +func TestParseStreamArgs_MissingContainer(t *testing.T) { + _, _, err := parseStreamArgs(`{"host_id":"host1"}`) assert.Error(t, err) - assert.Contains(t, err.Error(), "container_id and host_id are required") + assert.Contains(t, err.Error(), "container_id is required") } func TestParseStreamArgs_InvalidJSON(t *testing.T) { @@ -228,8 +236,9 @@ func TestExecuteStreamLogs_BasicFlow(t *testing.T) { return nil } - cs := container_support.NewContainerService(mockClient, container.Container{ID: "abc123", Name: "test-container"}) + cs := container_support.NewContainerService(mockClient, container.Container{ID: "abc123", Name: "test-container", Host: "host1"}) mockHost := &MockHostService{} + withResolver(mockHost, container.Container{ID: "abc123", Name: "test-container", Host: "host1"}) mockHost.On("FindContainer", "host1", "abc123", container.ContainerLabels(nil)).Return(cs, nil) var mu sync.Mutex @@ -265,8 +274,9 @@ func TestExecuteStreamLogs_WithLevelFilter(t *testing.T) { return nil } - cs := container_support.NewContainerService(mockClient, container.Container{ID: "abc123", Name: "test-container"}) + cs := container_support.NewContainerService(mockClient, container.Container{ID: "abc123", Name: "test-container", Host: "host1"}) mockHost := &MockHostService{} + withResolver(mockHost, container.Container{ID: "abc123", Name: "test-container", Host: "host1"}) mockHost.On("FindContainer", "host1", "abc123", container.ContainerLabels(nil)).Return(cs, nil) var mu sync.Mutex @@ -306,8 +316,9 @@ func TestExecuteStreamLogs_CancelContext(t *testing.T) { return ctx.Err() } - cs := container_support.NewContainerService(mockClient, container.Container{ID: "abc123", Name: "test-container"}) + cs := container_support.NewContainerService(mockClient, container.Container{ID: "abc123", Name: "test-container", Host: "host1"}) mockHost := &MockHostService{} + withResolver(mockHost, container.Container{ID: "abc123", Name: "test-container", Host: "host1"}) mockHost.On("FindContainer", "host1", "abc123", container.ContainerLabels(nil)).Return(cs, nil) var mu sync.Mutex @@ -353,7 +364,12 @@ func TestExecuteStreamLogs_InvalidArgs(t *testing.T) { } func TestExecuteStreamLogs_ContainerNotFound(t *testing.T) { + // host1 exists but the requested container is absent from the listing. With + // an explicit host the resolver falls through to FindContainer's direct + // lookup (legacy path), which reports the not-found error. mockHost := &MockHostService{} + mockHost.On("ListAllContainers", container.ContainerLabels(nil)).Return([]container.Container{}, nil) + mockHost.On("Hosts").Return([]container.Host{{ID: "host1", Name: "host1"}}) mockHost.On("FindContainer", "host1", "missing", container.ContainerLabels(nil)).Return(nil, assert.AnError) send := func(resp *pb.ToolResponse) error { return nil } @@ -372,8 +388,9 @@ func TestExecuteStreamLogs_BatchingAt50(t *testing.T) { return nil } - cs := container_support.NewContainerService(mockClient, container.Container{ID: "abc123", Name: "test-container"}) + cs := container_support.NewContainerService(mockClient, container.Container{ID: "abc123", Name: "test-container", Host: "host1"}) mockHost := &MockHostService{} + withResolver(mockHost, container.Container{ID: "abc123", Name: "test-container", Host: "host1"}) mockHost.On("FindContainer", mock.Anything, mock.Anything, mock.Anything).Return(cs, nil) var mu sync.Mutex diff --git a/internal/cloud/tools_test.go b/internal/cloud/tools_test.go index c368e1ad..ec05f632 100644 --- a/internal/cloud/tools_test.go +++ b/internal/cloud/tools_test.go @@ -94,6 +94,22 @@ func (m *MockHostService) Hosts() []container.Host { return args.Get(0).([]container.Host) } +// withResolver wires up the ListAllContainers + Hosts mocks that +// resolveContainerRef needs, so container-scoped tool tests can drive the +// resolver. Derives the host list from the containers' Host fields. +func withResolver(m *MockHostService, containers ...container.Container) { + seen := map[string]bool{} + hosts := []container.Host{} + for _, c := range containers { + if c.Host != "" && !seen[c.Host] { + seen[c.Host] = true + hosts = append(hosts, container.Host{ID: c.Host, Name: c.Host}) + } + } + m.On("ListAllContainers", container.ContainerLabels(nil)).Return(containers, nil).Maybe() + m.On("Hosts").Return(hosts).Maybe() +} + type MockClientService struct { mock.Mock } @@ -180,9 +196,10 @@ func TestExecuteTool_RestartContainer(t *testing.T) { mockClient := &MockClientService{} mockClient.On("ContainerAction", mock.Anything, mock.Anything, container.Restart).Return(nil) - cs := container_support.NewContainerService(mockClient, container.Container{ID: "abc123"}) + cs := container_support.NewContainerService(mockClient, container.Container{ID: "abc123", Name: "nginx", Host: "local"}) mockHost := &MockHostService{} + withResolver(mockHost, container.Container{ID: "abc123", Name: "nginx", Host: "local"}) mockHost.On("FindContainer", "local", "abc123", container.ContainerLabels(nil)).Return(cs, nil) argsJSON := `{"container_id": "abc123", "host_id": "local"}` @@ -201,9 +218,10 @@ func TestExecuteTool_RemoveContainer(t *testing.T) { mockClient := &MockClientService{} mockClient.On("ContainerAction", mock.Anything, mock.Anything, container.Remove).Return(nil) - cs := container_support.NewContainerService(mockClient, container.Container{ID: "abc123"}) + cs := container_support.NewContainerService(mockClient, container.Container{ID: "abc123", Name: "nginx", Host: "local"}) mockHost := &MockHostService{} + withResolver(mockHost, container.Container{ID: "abc123", Name: "nginx", Host: "local"}) mockHost.On("FindContainer", "local", "abc123", container.ContainerLabels(nil)).Return(cs, nil) argsJSON := `{"container_id": "abc123", "host_id": "local"}`