Skip to content

Commit 36379a5

Browse files
committed
wip: spin egress container for each mcp server
Closes: #124
1 parent 90c962b commit 36379a5

File tree

13 files changed

+332
-20
lines changed

13 files changed

+332
-20
lines changed

cmd/thv/app/rm.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"github.com/spf13/cobra"
77

88
"github.com/stacklok/toolhive/pkg/lifecycle"
9+
"github.com/stacklok/toolhive/pkg/logger"
910
)
1011

1112
var rmCmd = &cobra.Command{
@@ -37,9 +38,33 @@ func rmCmdFunc(cmd *cobra.Command, args []string) error {
3738
}
3839

3940
// Delete container.
40-
if err := manager.DeleteContainer(ctx, containerName, rmForce); err != nil {
41+
if err := manager.DeleteContainer(ctx, containerName, rmForce, true); err != nil {
4142
return fmt.Errorf("failed to delete container: %v", err)
4243
}
4344

45+
// Delete associated egress container.
46+
egressContainerName := containerName + "-egress"
47+
if err := manager.DeleteContainer(ctx, egressContainerName, rmForce, false); err != nil {
48+
// just log the error and continue
49+
logger.Warnf("failed to delete egress container %q: %v", egressContainerName, err)
50+
}
51+
52+
// Delete networks if there are no containers using them.
53+
toolHiveContainers, err := manager.ListContainers(ctx, listAll)
54+
if err != nil {
55+
return fmt.Errorf("failed to list containers: %v", err)
56+
}
57+
fmt.Println("ToolHive containers:", toolHiveContainers)
58+
59+
if len(toolHiveContainers) == 0 {
60+
// remove networks
61+
if err := manager.DeleteNetwork(ctx, "toolhive-internal"); err != nil {
62+
return fmt.Errorf("failed to delete network: %v", err)
63+
}
64+
if err := manager.DeleteNetwork(ctx, "toolhive-external"); err != nil {
65+
return fmt.Errorf("failed to delete network: %v", err)
66+
}
67+
}
68+
4469
return nil
4570
}

cmd/thv/app/run.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,11 @@ func runCmdFunc(cmd *cobra.Command, args []string) error {
245245
return fmt.Errorf("failed to retrieve or pull image: %v", err)
246246
}
247247

248+
// pull the egress image if it is not already pulled
249+
if err := pullImage(ctx, config.EgressImage, rt); err != nil {
250+
return fmt.Errorf("failed to retrieve or pull egress image: %v", err)
251+
}
252+
248253
// Configure the RunConfig with transport, ports, permissions, etc.
249254
if err := configureRunConfig(config, runTransport, runPort, runTargetPort, runEnv); err != nil {
250255
return err

cmd/thv/app/stop.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,16 @@ func stopCmdFunc(cmd *cobra.Command, args []string) error {
4646
}
4747
}
4848

49+
// Stop associated egress container
50+
egressContainerName := containerName + "-egress"
51+
err = manager.StopContainer(ctx, egressContainerName)
52+
if err != nil {
53+
if errors.Is(err, lifecycle.ErrContainerNotFound) {
54+
logger.Infof("Egress container %s is not running", egressContainerName)
55+
} else {
56+
return fmt.Errorf("failed to stop egress container %q: %w", egressContainerName, err)
57+
}
58+
}
59+
4960
return nil
5061
}

pkg/api/v1/servers.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ func (s *ServerRoutes) deleteServer(w http.ResponseWriter, r *http.Request) {
145145
ctx := r.Context()
146146
name := chi.URLParam(r, "name")
147147
forceDelete := r.URL.Query().Get("force") == "true"
148-
err := s.manager.DeleteContainer(ctx, name, forceDelete)
148+
err := s.manager.DeleteContainer(ctx, name, forceDelete, true)
149149
if err != nil {
150150
if errors.Is(err, lifecycle.ErrContainerNotFound) {
151151
http.Error(w, "Server not found", http.StatusNotFound)

pkg/container/docker/client.go

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ func NewClient(ctx context.Context) (*Client, error) {
8080

8181
c, err := NewClientWithSocketPath(ctx, socketPath, runtimeType)
8282
if err != nil {
83-
logger.Debugf("Failed to create client for %s: %v", sp, err)
8483
lastErr = err
84+
logger.Debugf("Failed to create client for %s: %v", sp, err)
8585
continue
8686
}
8787

@@ -989,6 +989,11 @@ func (c *Client) getPermissionConfigFromProfile(
989989
return nil, fmt.Errorf("unsupported transport type: %s", transportType)
990990
}
991991

992+
// Add necessary capabilities for egress containers
993+
if profile.Name == permissions.ProfileEgress {
994+
config.CapAdd = append(config.CapAdd, "CAP_SETUID", "CAP_SETGID")
995+
}
996+
992997
return config, nil
993998
}
994999

@@ -1334,3 +1339,54 @@ func (c *Client) handleExistingContainer(
13341339
// Container was removed and needs to be recreated
13351340
return false, nil
13361341
}
1342+
1343+
// CreateNetwork creates a network following configuration.
1344+
func (c *Client) CreateNetwork(
1345+
ctx context.Context,
1346+
name string,
1347+
labels map[string]string,
1348+
internal bool,
1349+
) (string, error) {
1350+
// Check if the network already exists
1351+
networks, err := c.client.NetworkList(ctx, network.ListOptions{
1352+
Filters: filters.NewArgs(filters.Arg("name", name)),
1353+
})
1354+
if err != nil {
1355+
return "", fmt.Errorf("failed to list networks: %w", err)
1356+
}
1357+
if len(networks) > 0 {
1358+
// Network already exists, return its ID
1359+
return networks[0].ID, nil
1360+
}
1361+
1362+
networkCreate := network.CreateOptions{
1363+
Driver: "bridge",
1364+
Internal: internal,
1365+
Labels: labels,
1366+
}
1367+
1368+
resp, err := c.client.NetworkCreate(ctx, name, networkCreate)
1369+
if err != nil {
1370+
return "", err
1371+
}
1372+
return resp.ID, nil
1373+
}
1374+
1375+
// DeleteNetwork deletes a network by name.
1376+
func (c *Client) DeleteNetwork(ctx context.Context, name string) error {
1377+
// find the network by name
1378+
networks, err := c.client.NetworkList(ctx, network.ListOptions{
1379+
Filters: filters.NewArgs(filters.Arg("name", name)),
1380+
})
1381+
if err != nil {
1382+
return err
1383+
}
1384+
if len(networks) == 0 {
1385+
return fmt.Errorf("network %s not found", name)
1386+
}
1387+
1388+
if err := c.client.NetworkRemove(ctx, networks[0].ID); err != nil {
1389+
return fmt.Errorf("failed to remove network %s: %w", name, err)
1390+
}
1391+
return nil
1392+
}

pkg/container/kubernetes/client.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ import (
1515
backoff "github.com/cenkalti/backoff/v4"
1616
appsv1 "k8s.io/api/apps/v1"
1717
corev1 "k8s.io/api/core/v1"
18+
networkingv1 "k8s.io/api/networking/v1"
19+
1820
"k8s.io/apimachinery/pkg/api/errors"
1921
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2022
"k8s.io/apimachinery/pkg/util/intstr"
@@ -567,6 +569,85 @@ func (*Client) StopContainer(_ context.Context, _ string) error {
567569
return nil
568570
}
569571

572+
func (c *Client) CreateNetwork(ctx context.Context, name string, labels map[string]string, internal bool) (string, error) {
573+
namespace := getCurrentNamespace()
574+
575+
// Check if the NetworkPolicy already exists
576+
_, err := c.client.NetworkingV1().NetworkPolicies(namespace).Get(ctx, name, metav1.GetOptions{})
577+
if err == nil {
578+
return name, nil // NetworkPolicy already exists
579+
}
580+
if !errors.IsNotFound(err) {
581+
return "", fmt.Errorf("failed to check if NetworkPolicy exists: %w", err)
582+
}
583+
584+
// Define the NetworkPolicy spec based on the 'internal' flag
585+
policyTypes := []networkingv1.PolicyType{networkingv1.PolicyTypeIngress}
586+
var ingressRules []networkingv1.NetworkPolicyIngressRule
587+
588+
if internal {
589+
// Restrict ingress to pods with the same labels
590+
ingressRules = []networkingv1.NetworkPolicyIngressRule{
591+
{
592+
From: []networkingv1.NetworkPolicyPeer{
593+
{
594+
PodSelector: &metav1.LabelSelector{
595+
MatchLabels: labels,
596+
},
597+
},
598+
},
599+
},
600+
}
601+
} else {
602+
// Allow all ingress traffic
603+
ingressRules = []networkingv1.NetworkPolicyIngressRule{
604+
{
605+
From: []networkingv1.NetworkPolicyPeer{
606+
{
607+
NamespaceSelector: &metav1.LabelSelector{},
608+
},
609+
},
610+
},
611+
}
612+
}
613+
614+
// Create the NetworkPolicy object
615+
policy := &networkingv1.NetworkPolicy{
616+
ObjectMeta: metav1.ObjectMeta{
617+
Name: name,
618+
Namespace: namespace,
619+
Labels: labels,
620+
},
621+
Spec: networkingv1.NetworkPolicySpec{
622+
PodSelector: metav1.LabelSelector{
623+
MatchLabels: labels,
624+
},
625+
PolicyTypes: policyTypes,
626+
Ingress: ingressRules,
627+
},
628+
}
629+
630+
// Create the NetworkPolicy in Kubernetes
631+
_, err = c.client.NetworkingV1().NetworkPolicies(namespace).Create(ctx, policy, metav1.CreateOptions{})
632+
if err != nil {
633+
return "", fmt.Errorf("failed to create NetworkPolicy: %w", err)
634+
}
635+
636+
return name, nil
637+
}
638+
639+
func (c *Client) DeleteNetwork(ctx context.Context, name string) error {
640+
namespace := getCurrentNamespace() // Implement this function to retrieve the current namespace
641+
642+
err := c.client.NetworkingV1().NetworkPolicies(namespace).Delete(ctx, name, metav1.DeleteOptions{})
643+
if err != nil {
644+
return fmt.Errorf("failed to delete network policy %q in namespace %q: %w", name, namespace, err)
645+
}
646+
647+
fmt.Printf("Network policy %q in namespace %q has been deleted successfully.\n", name, namespace)
648+
return nil
649+
}
650+
570651
// waitForStatefulSetReady waits for a statefulset to be ready using the watch API
571652
func waitForStatefulSetReady(ctx context.Context, clientset kubernetes.Interface, namespace, name string) error {
572653
// Create a field selector to watch only this specific statefulset

pkg/container/runtime/types.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ type Runtime interface {
8888

8989
// BuildImage builds a Docker image from a Dockerfile in the specified context directory
9090
BuildImage(ctx context.Context, contextDir, imageName string) error
91+
92+
// CreateNetwork creates a network
93+
CreateNetwork(ctx context.Context, networkName string, labels map[string]string, internal bool) (string, error)
94+
95+
// DeleteNetwork deletes a network
96+
DeleteNetwork(ctx context.Context, networkName string) error
9197
}
9298

9399
// Monitor defines the interface for container monitoring

pkg/labels/labels.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ func AddStandardLabels(labels map[string]string, containerName, containerBaseNam
4343
labels[LabelToolType] = "mcp"
4444
}
4545

46+
// AddNetworkLabels adds network-related labels to a network
47+
func AddNetworkLabels(labels map[string]string, networkName string) {
48+
labels[LabelEnabled] = "true"
49+
labels[LabelName] = networkName
50+
}
51+
4652
// FormatToolHiveFilter formats a filter for ToolHive containers
4753
func FormatToolHiveFilter() string {
4854
return fmt.Sprintf("%s=true", LabelEnabled)

pkg/lifecycle/manager.go

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ type Manager interface {
3030
// ListContainers lists all ToolHive-managed containers.
3131
ListContainers(ctx context.Context, listAll bool) ([]rt.ContainerInfo, error)
3232
// DeleteContainer deletes a container and its associated proxy process.
33-
DeleteContainer(ctx context.Context, name string, forceDelete bool) error
33+
DeleteContainer(ctx context.Context, name string, forceDelete bool, removeConfig bool) error
34+
// DeleteNetwork deletes a network.
35+
DeleteNetwork(ctx context.Context, name string) error
3436
// StopContainer stops a container and its associated proxy process.
3537
StopContainer(ctx context.Context, name string) error
3638
// RunContainer runs a container in the foreground.
@@ -86,7 +88,7 @@ func (d *defaultManager) ListContainers(ctx context.Context, listAll bool) ([]rt
8688
return toolHiveContainers, nil
8789
}
8890

89-
func (d *defaultManager) DeleteContainer(ctx context.Context, name string, forceDelete bool) error {
91+
func (d *defaultManager) DeleteContainer(ctx context.Context, name string, forceDelete bool, removeConfig bool) error {
9092
// We need several fields from the container struct for deletion.
9193
container, err := d.findContainerByName(ctx, name)
9294
if err != nil {
@@ -114,24 +116,26 @@ func (d *defaultManager) DeleteContainer(ctx context.Context, name string, force
114116
return fmt.Errorf("failed to remove container: %v", err)
115117
}
116118

117-
// Get the base name from the container labels
118-
baseName := labels.GetContainerBaseName(containerLabels)
119-
if baseName != "" {
120-
// Delete the saved state if it exists
121-
if err := runner.DeleteSavedConfig(ctx, baseName); err != nil {
122-
logger.Warnf("Warning: Failed to delete saved state: %v", err)
123-
} else {
124-
logger.Infof("Saved state for %s removed", baseName)
119+
if removeConfig {
120+
// Get the base name from the container labels
121+
baseName := labels.GetContainerBaseName(containerLabels)
122+
if baseName != "" {
123+
// Delete the saved state if it exists
124+
if err := runner.DeleteSavedConfig(ctx, baseName); err != nil {
125+
logger.Warnf("Warning: Failed to delete saved state: %v", err)
126+
} else {
127+
logger.Infof("Saved state for %s removed", baseName)
128+
}
125129
}
126-
}
127130

128-
logger.Infof("Container %s removed", name)
131+
logger.Infof("Container %s removed", name)
129132

130-
if shouldRemoveClientConfig() {
131-
if err := removeClientConfigurations(name); err != nil {
132-
logger.Warnf("Warning: Failed to remove client configurations: %v", err)
133-
} else {
134-
logger.Infof("Client configurations for %s removed", name)
133+
if shouldRemoveClientConfig() {
134+
if err := removeClientConfigurations(name); err != nil {
135+
logger.Warnf("Warning: Failed to remove client configurations: %v", err)
136+
} else {
137+
logger.Infof("Client configurations for %s removed", name)
138+
}
135139
}
136140
}
137141

@@ -445,6 +449,16 @@ func (d *defaultManager) stopContainer(ctx context.Context, containerID, contain
445449
return nil
446450
}
447451

452+
func (d *defaultManager) DeleteNetwork(ctx context.Context, name string) error {
453+
// Remove the network
454+
logger.Infof("Removing network %s...", name)
455+
if err := d.runtime.DeleteNetwork(ctx, name); err != nil {
456+
return fmt.Errorf("failed to remove network: %v", err)
457+
}
458+
459+
return nil
460+
}
461+
448462
func shouldRemoveClientConfig() bool {
449463
c := config.GetConfig()
450464
return len(c.Clients.RegisteredClients) > 0 || c.Clients.AutoDiscovery

0 commit comments

Comments
 (0)