diff --git a/cmd/template-builder/main.go b/cmd/template-builder/main.go index f4c507a..8729b05 100644 --- a/cmd/template-builder/main.go +++ b/cmd/template-builder/main.go @@ -289,7 +289,8 @@ func runBuild(ctx context.Context, cfg buildConfig) error { emitUser("system", "Saving template") snapPath := filepath.Join(snapDir, "vmstate.snap") memPath := filepath.Join(snapDir, "mem.snap") - if err := vm.CreateSnapshot(socketPath, snapPath, memPath, snapDir); err != nil { + // Flatten: per-sandbox restores skip apply_delta. Safe here — base isn't shared yet. + if err := vm.CreateSnapshot(socketPath, snapPath, memPath, snapDir, vm.SnapshotFlatten); err != nil { return fmt.Errorf("snapshot: %w", err) } emitInternal("system", "snapshot captured") diff --git a/internal/vm/fc/models/snapshot_create_params.go b/internal/vm/fc/models/snapshot_create_params.go index 608f9b3..d58459e 100644 --- a/internal/vm/fc/models/snapshot_create_params.go +++ b/internal/vm/fc/models/snapshot_create_params.go @@ -20,6 +20,9 @@ type SnapshotCreateParams struct { // Directory for block device delta files. When set, overlay block devices write delta files (containing only dirty blocks) into this directory, named {drive_id}.delta. BlockDeltaDir string `json:"block_delta_dir,omitempty"` + // If true, bake each overlay's dirty blocks into its base.ext4 and zero the side-car bitmap. Only safe at template-creation time. Requires block_delta_dir. + Flatten bool `json:"flatten,omitempty"` + // Path to the file that will contain the guest memory. // Required: true MemFilePath *string `json:"mem_file_path"` diff --git a/internal/vm/firecracker.go b/internal/vm/firecracker.go index 6b1223b..68c703f 100644 --- a/internal/vm/firecracker.go +++ b/internal/vm/firecracker.go @@ -231,10 +231,27 @@ func StartInstance(socketPath string) error { // Snapshot operations // --------------------------------------------------------------------------- +// SnapshotMode controls per-disk flatten behavior at snapshot creation. +type SnapshotMode int + +const ( + // SnapshotNormal: leave overlay deltas as-is. Sandboxes restored from + // this snapshot replay the delta into a per-VM overlay on create. + SnapshotNormal SnapshotMode = iota + // SnapshotFlatten: bake each overlay's dirty blocks into base.ext4 and + // zero the side-car bitmap. Sandboxes restored from this snapshot skip + // apply_delta. Only safe when the base isn't shared with other live VMs. + SnapshotFlatten +) + // CreateSnapshot pauses the VM and creates a full snapshot. Non-empty // blockDeltaDir tells the forked engine to also emit .delta files // containing dirty blocks — required to create sandboxes from this template. -func CreateSnapshot(socketPath, snapshotPath, memPath, blockDeltaDir string) error { +// mode=SnapshotFlatten bakes those deltas into base.ext4 (see SnapshotMode). +func CreateSnapshot(socketPath, snapshotPath, memPath, blockDeltaDir string, mode SnapshotMode) error { + if mode == SnapshotFlatten && blockDeltaDir == "" { + return fmt.Errorf("SnapshotFlatten requires non-empty blockDeltaDir") + } fc := newFCClient(socketPath) ctx := context.Background() @@ -254,6 +271,7 @@ func CreateSnapshot(socketPath, snapshotPath, memPath, blockDeltaDir string) err MemFilePath: &memPath, SnapshotType: models.SnapshotCreateParamsSnapshotTypeFull, BlockDeltaDir: blockDeltaDir, + Flatten: mode == SnapshotFlatten, }, }); err != nil { return fmt.Errorf("create snapshot: %w", err) diff --git a/internal/vm/firecracker_test.go b/internal/vm/firecracker_test.go new file mode 100644 index 0000000..63d3576 --- /dev/null +++ b/internal/vm/firecracker_test.go @@ -0,0 +1,142 @@ +package vm + +import ( + "encoding/json" + "io" + "net" + "net/http" + "path/filepath" + "strings" + "sync" + "testing" + "time" +) + +// TestCreateSnapshot_FlattenFieldInJSONBody asserts that the Go-side enum +// (SnapshotNormal/SnapshotFlatten) serializes the `flatten` field correctly +// in the JSON body sent to firecracker. The SnapshotNormal cases use the +// substring check (not unmarshal) so an accidentally-dropped `omitempty` +// would surface as `"flatten": false` appearing in the body — unmarshal +// alone would silently accept that as the zero value. +func TestCreateSnapshot_FlattenFieldInJSONBody(t *testing.T) { + cases := []struct { + name string + mode SnapshotMode + blockDeltaDir string + assertBody func(t *testing.T, body []byte) + }{ + { + name: "normal_mode_with_empty_delta_dir_omits_flatten", + mode: SnapshotNormal, + blockDeltaDir: "", + assertBody: func(t *testing.T, body []byte) { + if strings.Contains(string(body), "flatten") { + t.Errorf("flatten field must be omitted for SnapshotNormal; body=%s", string(body)) + } + }, + }, + { + name: "normal_mode_with_delta_dir_omits_flatten", + mode: SnapshotNormal, + blockDeltaDir: "/tmp/delta", + assertBody: func(t *testing.T, body []byte) { + if strings.Contains(string(body), "flatten") { + t.Errorf("flatten field must be omitted for SnapshotNormal; body=%s", string(body)) + } + }, + }, + { + name: "flatten_mode_sends_true", + mode: SnapshotFlatten, + blockDeltaDir: "/tmp/delta", + assertBody: func(t *testing.T, body []byte) { + var decoded struct { + Flatten bool `json:"flatten"` + } + if err := json.Unmarshal(body, &decoded); err != nil { + t.Fatalf("unmarshal body: %v (body=%s)", err, string(body)) + } + if !decoded.Flatten { + t.Errorf("flatten=%v, want true (body=%s)", decoded.Flatten, string(body)) + } + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + socketPath := filepath.Join(t.TempDir(), "fc.sock") + + var ( + bodyMu sync.Mutex + capturedBody []byte + ) + ln, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("listen unix: %v", err) + } + defer ln.Close() + + srv := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPatch && r.URL.Path == "/vm": + w.WriteHeader(http.StatusNoContent) + case r.Method == http.MethodPut && r.URL.Path == "/snapshot/create": + b, _ := io.ReadAll(r.Body) + bodyMu.Lock() + capturedBody = b + bodyMu.Unlock() + w.WriteHeader(http.StatusNoContent) + default: + t.Errorf("unexpected request: %s %s", r.Method, r.URL.Path) + w.WriteHeader(http.StatusNotFound) + } + }), + } + go srv.Serve(ln) + defer srv.Close() + waitForUnixSocket(t, socketPath) + + if err := CreateSnapshot(socketPath, "/tmp/snap", "/tmp/mem", tc.blockDeltaDir, tc.mode); err != nil { + t.Fatalf("CreateSnapshot: %v", err) + } + + bodyMu.Lock() + body := capturedBody + bodyMu.Unlock() + if body == nil { + t.Fatal("snapshot/create handler never invoked") + } + tc.assertBody(t, body) + }) + } +} + +// Client-side guard: SnapshotFlatten with empty blockDeltaDir is rejected +// before any RPC. +func TestCreateSnapshot_FlattenRequiresBlockDeltaDir(t *testing.T) { + err := CreateSnapshot("/dev/null/unused", "/tmp/snap", "/tmp/mem", "", SnapshotFlatten) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "blockDeltaDir") { + t.Errorf("unexpected error: %v", err) + } +} + +// waitForUnixSocket blocks until the listener at socketPath accepts a connection +// or the deadline elapses. Avoids the race where CreateSnapshot dials before +// http.Server.Serve has installed its handler in the accept loop. +func waitForUnixSocket(t *testing.T, socketPath string) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if conn, err := net.Dial("unix", socketPath); err == nil { + _ = conn.Close() + return + } + time.Sleep(5 * time.Millisecond) + } + t.Fatalf("server at %s never became ready", socketPath) +} diff --git a/internal/vm/manager.go b/internal/vm/manager.go index 6e1062e..66c5f94 100644 --- a/internal/vm/manager.go +++ b/internal/vm/manager.go @@ -539,7 +539,7 @@ func (m *Manager) PauseVM(ctx context.Context, vmID, snapshotDir string) (snapsh memPath = filepath.Join(snapshotDir, "mem.snap") log.Info().Str("snapshot_path", snapshotPath).Msg("pausing VM — creating snapshot") - if err := CreateSnapshot(inst.SocketPath, snapshotPath, memPath, ""); err != nil { + if err := CreateSnapshot(inst.SocketPath, snapshotPath, memPath, "", SnapshotNormal); err != nil { return "", "", m.handleVMError(vmID, fmt.Errorf("create snapshot: %w", err)) } @@ -673,7 +673,7 @@ func (m *Manager) CreateVMSnapshot(ctx context.Context, vmID, snapshotDir string snapshotPath = filepath.Join(snapshotDir, "vmstate.snap") memPath = filepath.Join(snapshotDir, "mem.snap") - if err := CreateSnapshot(inst.SocketPath, snapshotPath, memPath, ""); err != nil { + if err := CreateSnapshot(inst.SocketPath, snapshotPath, memPath, "", SnapshotNormal); err != nil { return "", "", fmt.Errorf("create snapshot: %w", err) }