diff --git a/cmd/vmd/main.go b/cmd/vmd/main.go index e3e51a8..b02f621 100644 --- a/cmd/vmd/main.go +++ b/cmd/vmd/main.go @@ -248,6 +248,9 @@ func main() { // ---- VM manager ---- maxRestores, _ := strconv.Atoi(envOrDefault("VMD_MAX_CONCURRENT_RESTORES", "100")) + uffdEnabled := envOrDefault("VMD_UFFD_ENABLED", "true") != "false" + uffdPrefetchEnabled := envOrDefault("VMD_UFFD_PREFETCH_ENABLED", "true") != "false" + uffdRecordMaxSeconds, _ := strconv.Atoi(envOrDefault("VMD_UFFD_RECORD_MAX_SECONDS", "10")) mgr, err := vm.NewManager(vm.ManagerConfig{ FirecrackerBin: cfg.FirecrackerBin, @@ -260,6 +263,9 @@ func main() { BoxdBinaryPath: cfg.BoxdBinaryPath, HostInterface: cfg.HostInterface, MaxConcurrentRestores: maxRestores, + UffdEnabled: uffdEnabled, + UffdPrefetchEnabled: uffdPrefetchEnabled, + UffdRecordMaxSeconds: uffdRecordMaxSeconds, }, netMgr, log) if err != nil { log.Fatal().Err(err).Msg("failed to initialize VM manager") diff --git a/internal/vm/build.go b/internal/vm/build.go index 38a1075..8dd450f 100644 --- a/internal/vm/build.go +++ b/internal/vm/build.go @@ -167,6 +167,23 @@ func (m *Manager) buildTemplateSync(ctx context.Context, buildVMID string, req B return nil, fmt.Errorf("read build meta: %w", err) } + // Best-effort: a missing access.log just means sandboxes fall back + // to sequential prefetch. The "build-" prefix must remain so isBuildVM + // skips persistence + reconciler for this throwaway VM. + if m.cfg.UffdEnabled && m.cfg.UffdPrefetchEnabled { + recordingVMID := "build-record-" + req.TemplateID + accessLogPath := filepath.Join(snapshotDir, "access.log") + recCfg := VMConfig{ + VCPU: req.VCPU, + MemoryMiB: req.MemoryMiB, + BasePath: result.BasePath, + DeltaDir: snapshotDir, + } + if recErr := m.RecordAccessPattern(ctx, recordingVMID, result.SnapshotPath, result.MemFilePath, accessLogPath, recCfg, nil); recErr != nil { + log.Warn().Err(recErr).Msg("access-pattern recording failed (sandbox will fall back to sequential prefetch)") + } + } + log.Info().Dur("total", time.Since(buildStart)).Msg("template build complete") return result, nil } diff --git a/internal/vm/firecracker.go b/internal/vm/firecracker.go index c95dc38..6b1223b 100644 --- a/internal/vm/firecracker.go +++ b/internal/vm/firecracker.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "strings" + "time" httptransport "github.com/go-openapi/runtime/client" "github.com/go-openapi/strfmt" @@ -326,3 +327,43 @@ func RestoreSnapshotWithOverrides(socketPath, snapshotPath, memPath, ifaceID, ta } return nil } + +// RestoreSnapshotUffdWithOverrides is the UFFD-backend variant of +// RestoreSnapshotWithOverrides. Instead of pointing Firecracker at the +// mem.snap file (File backend, which synchronously reads and CRC64-verifies +// the entire snapshot before returning), it points Firecracker at a Unix +// domain socket served by our in-process UFFD handler. +// +// LoadSnapshot returns in milliseconds because pages are not read upfront. +// As the guest touches pages, the kernel forwards faults to our handler, +// which serves them from a memory-mapped mem.snap. +// +// uffdSocketPath must be a bound Unix socket; the caller is responsible for +// starting the handler goroutine before invoking this function. +func RestoreSnapshotUffdWithOverrides(socketPath, snapshotPath, uffdSocketPath, ifaceID, tapDevice, blockDeltaDir string) error { + // Bound LoadSnapshot so a hung Firecracker doesn't wedge vmd. + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + fc := newFCClient(socketPath) + if _, err := fc.Operations.LoadSnapshot(&operations.LoadSnapshotParams{ + Context: ctx, + Body: &models.SnapshotLoadParams{ + SnapshotPath: &snapshotPath, + MemBackend: &models.MemoryBackend{ + BackendType: strPtr(models.MemoryBackendBackendTypeUffd), + BackendPath: &uffdSocketPath, + }, + ResumeVM: true, + NetworkOverrides: []*models.NetworkOverride{ + {IfaceID: &ifaceID, HostDevName: &tapDevice}, + }, + BlockDeltaDir: blockDeltaDir, + }, + }); err != nil { + if isTornSnapshotErr(err) { + return fmt.Errorf("load snapshot (uffd): %w: %v", ErrTornSnapshot, err) + } + return fmt.Errorf("load snapshot (uffd): %w", err) + } + return nil +} diff --git a/internal/vm/manager.go b/internal/vm/manager.go index 7ffe244..5af5f07 100644 --- a/internal/vm/manager.go +++ b/internal/vm/manager.go @@ -20,6 +20,7 @@ import ( "google.golang.org/grpc/status" "github.com/superserve-ai/sandbox/internal/network" + "github.com/superserve-ai/sandbox/internal/vm/uffd" pb "github.com/superserve-ai/sandbox/proto/boxdpb" ) @@ -88,6 +89,12 @@ type VMInstance struct { CreatedAt time.Time Metadata map[string]string + // uffdCancel cancels the per-VM UFFD handler goroutine. nil for VMs + // not using UFFD (cold boots, template builds, in-place resumes). + uffdCancel context.CancelFunc + // uffdHandler is the per-VM UFFD handler; nil for non-UFFD VMs. + uffdHandler *uffd.Handler + mu sync.RWMutex } @@ -122,6 +129,22 @@ type ManagerConfig struct { // prevent a spike of concurrent sandbox creates from saturating host // file I/O, netns setup, and Firecracker boots. 0 → default 100. MaxConcurrentRestores int + + // UffdEnabled gates the UFFD lazy-restore path. false → fresh + // restores fall back to the File memory backend (synchronous CRC64), + // same as in-place resume. Default true; flip to false as an ops + // circuit breaker without redeploying. + UffdEnabled bool + + // UffdPrefetchEnabled turns on background prefetch in the UFFD + // handler. When true, the handler walks mem.snap after the handshake + // and pre-copies pages into guest memory so the first exec doesn't + // stall on cold pages. Ignored when UffdEnabled is false. + UffdPrefetchEnabled bool + + // UffdRecordMaxSeconds caps how long RecordAccessPattern waits for the + // guest's fault rate to settle before giving up. 0 → 10s default. + UffdRecordMaxSeconds int } // --------------------------------------------------------------------------- @@ -441,6 +464,8 @@ func (m *Manager) DestroyVM(ctx context.Context, vmID string, force bool) error log := m.log.With().Str("vm_id", vmID).Logger() log.Info().Bool("force", force).Msg("destroying VM") + m.cancelUffdHandler(inst) + // Stop the systemd unit if one exists — this is the path for sandbox // VMs launched via startFirecrackerViaSystemd. if err := stopUnit(ctx, systemdUnitName(vmID)); err != nil { @@ -746,6 +771,14 @@ func (m *Manager) assertUnderVMSnapshotDir(vmID, p string) error { // RestoreVMSnapshot boots a VM from a previously captured snapshot. func (m *Manager) RestoreVMSnapshot(ctx context.Context, vmID, snapshotPath, memPath string, resourceLimits VMConfig, netCfg *network.Config, +) (*VMInstance, error) { + return m.restoreVMSnapshot(ctx, vmID, snapshotPath, memPath, resourceLimits, netCfg, nil) +} + +// restoreVMSnapshot is the implementation. The recorder argument is +// non-nil only for template-build access-pattern recording. +func (m *Manager) restoreVMSnapshot(ctx context.Context, vmID, snapshotPath, memPath string, + resourceLimits VMConfig, netCfg *network.Config, recorder *uffd.Recorder, ) (*VMInstance, error) { log := m.log.With().Str("vm_id", vmID).Logger() @@ -843,6 +876,7 @@ func (m *Manager) RestoreVMSnapshot(ctx context.Context, vmID, snapshotPath, mem vmDir := filepath.Join(m.cfg.RunDir, vmID) socketPath := filepath.Join(vmDir, "firecracker.sock") + uffdSocketPath := filepath.Join(vmDir, "uffd.sock") // Publish all the network/disk/socket fields before starting // Firecracker so the in-memory view is consistent for concurrent @@ -856,8 +890,21 @@ func (m *Manager) RestoreVMSnapshot(ctx context.Context, vmID, snapshotPath, mem inst.SocketPath = socketPath inst.mu.Unlock() + // inPlace resume always uses File backend; UffdEnabled=false is the + // ops circuit breaker that forces fresh restores onto File too. + useUffd := !inPlace && m.cfg.UffdEnabled + if useUffd { + if err := m.startUffdHandler(ctx, inst, uffdSocketPath, memPath, recorder); err != nil { + m.netMgr.CleanupVM(vmID) + m.cleanupRunDir(vmID) + m.setStatus(vmID, StatusError) + return nil, fmt.Errorf("start uffd handler: %w", err) + } + } + pid, startErr := m.startFirecrackerViaSystemd(ctx, vmID, socketPath, diskPath, resourceLimits.BasePath, nsName) if startErr != nil { + m.cancelUffdHandler(inst) if !inPlace { m.netMgr.CleanupVM(vmID) } @@ -872,12 +919,20 @@ func (m *Manager) RestoreVMSnapshot(ctx context.Context, vmID, snapshotPath, mem log.Info().Msg("restoring snapshot") var restoreErr error - if inPlace { + switch { + case useUffd: + restoreErr = RestoreSnapshotUffdWithOverrides(socketPath, snapshotPath, uffdSocketPath, "eth0", tapDevice, plan.deltaDir) + case inPlace: restoreErr = RestoreSnapshot(socketPath, snapshotPath, memPath, plan.deltaDir) - } else { + default: + // UFFD disabled but fresh restore — File backend with network overrides. restoreErr = RestoreSnapshotWithOverrides(socketPath, snapshotPath, memPath, "eth0", tapDevice, plan.deltaDir) } if restoreErr != nil { + // Firecracker is already running; stop the unit before other + // cleanup or it leaks. See stopUnitDuringRestoreError comment. + m.stopUnitDuringRestoreError(vmID) + m.cancelUffdHandler(inst) if !inPlace { m.netMgr.CleanupVM(vmID) } @@ -891,7 +946,28 @@ func (m *Manager) RestoreVMSnapshot(ctx context.Context, vmID, snapshotPath, mem return nil, fmt.Errorf("restore snapshot: %w", restoreErr) } + // LoadSnapshot doesn't ack the UFFD handshake — Firecracker can + // return success while our handler silently failed, which would + // leave the guest hanging on first page fault. + if useUffd { + inst.mu.RLock() + uffdH := inst.uffdHandler + inst.mu.RUnlock() + if uffdH != nil { + if hsErr := uffdH.WaitHandshake(ctx); hsErr != nil { + m.stopUnitDuringRestoreError(vmID) + m.cancelUffdHandler(inst) + m.netMgr.CleanupVM(vmID) + m.cleanupRunDir(vmID) + m.setStatus(vmID, StatusError) + return nil, fmt.Errorf("uffd handshake: %w", hsErr) + } + } + } + if err := m.waitForBoxd(ctx, hostIP, 5*time.Second); err != nil { + m.stopUnitDuringRestoreError(vmID) + m.cancelUffdHandler(inst) if !inPlace { m.netMgr.CleanupVM(vmID) } @@ -1336,6 +1412,191 @@ func (m *Manager) cleanupRunDir(dirName string) { } } +// stopUnitDuringRestoreError stops the per-VM systemd unit when a restore +// aborts after Firecracker started. Uses a fresh context because the +// caller's gRPC ctx is often already cancelled (deadline exceeded under +// load). Without this, the firecracker process leaks. +func (m *Manager) stopUnitDuringRestoreError(vmID string) { + cleanupCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := stopUnit(cleanupCtx, systemdUnitName(vmID)); err != nil { + m.log.Warn().Err(err).Str("vm_id", vmID).Msg("systemctl stop failed during restore error cleanup") + } + removeUnitDropIn(vmID) +} + +// startUffdHandler binds the per-VM UFFD socket, opens mem.snap, and +// launches the handler goroutine. reqCtx is the caller's request ctx; +// its deadline applies to the handshake. The serve loop uses an +// independent ctx so it outlives the request. recorder, when non-nil, +// hooks OnPageFault for template-build recording. +func (m *Manager) startUffdHandler(reqCtx context.Context, inst *VMInstance, uffdSocketPath, memSnapPath string, recorder *uffd.Recorder) error { + accessLogPath := filepath.Join(filepath.Dir(memSnapPath), "access.log") + cfg := uffd.Config{ + SocketPath: uffdSocketPath, + MemSnapPath: memSnapPath, + AccessLogPath: accessLogPath, + Logger: m.log.With().Str("vm_id", inst.ID).Logger(), + } + if recorder != nil { + cfg.OnPageFault = recorder.Record + cfg.PrefetchEnabled = false + } else { + cfg.PrefetchEnabled = m.cfg.UffdPrefetchEnabled + } + h := uffd.New(cfg) + if err := h.Start(); err != nil { + return err + } + lifetimeCtx, lifetimeCancel := context.WithCancel(context.Background()) + inst.mu.Lock() + inst.uffdCancel = lifetimeCancel + inst.uffdHandler = h + inst.mu.Unlock() + + vmID := inst.ID + go func() { + defer func() { + if r := recover(); r != nil { + m.log.Error().Interface("panic", r).Str("vm_id", vmID).Msg("uffd handler panicked") + } + }() + defer h.Close() + if err := h.AcceptHandshake(reqCtx); err != nil { + if !errors.Is(err, context.Canceled) { + m.log.Error().Err(err).Str("vm_id", vmID).Msg("uffd handshake failed") + } + return + } + if err := h.Serve(lifetimeCtx); err != nil && !errors.Is(err, context.Canceled) { + m.log.Error().Err(err).Str("vm_id", vmID).Msg("uffd handler exited with error") + } + }() + return nil +} + +const ( + recordTickInterval = 100 * time.Millisecond + recordMinDuration = 500 * time.Millisecond + recordQuietTicks = 3 + recordDefaultMax = 10 * time.Second +) + +// RecordAccessPattern restores the snapshot under UFFD-recording mode, +// waits for the guest's page-fault rate to settle, and writes the +// observed page-access order to outputPath. +func (m *Manager) RecordAccessPattern(ctx context.Context, vmID, snapshotPath, memPath, outputPath string, + resourceLimits VMConfig, netCfg *network.Config, +) error { + if _, err := os.Stat(outputPath); err == nil { + m.log.Info().Str("path", outputPath).Msg("access log already exists, skipping recording") + return nil + } + + recorder := uffd.NewRecorder() + inst, err := m.restoreVMSnapshot(ctx, vmID, snapshotPath, memPath, resourceLimits, netCfg, recorder) + if err != nil { + return fmt.Errorf("restore for recording: %w", err) + } + defer func() { + _ = m.DestroyVM(context.Background(), inst.ID, true) + }() + + inst.mu.RLock() + handler := inst.uffdHandler + inst.mu.RUnlock() + if handler == nil { + return errors.New("UFFD handler missing after restore-for-recording") + } + + maxDuration := recordDefaultMax + if m.cfg.UffdRecordMaxSeconds > 0 { + maxDuration = time.Duration(m.cfg.UffdRecordMaxSeconds) * time.Second + } + + settled, elapsed, totalFaults := m.waitFaultsSettle(ctx, handler, maxDuration) + if errors.Is(ctx.Err(), context.Canceled) || errors.Is(ctx.Err(), context.DeadlineExceeded) { + return ctx.Err() + } + + pages := recorder.Len() + mkEvent := func() *zerolog.Event { + ev := m.log.Info() + if !settled { + ev = m.log.Warn() + } + return ev.Str("template_vm", vmID).Bool("settled", settled). + Dur("elapsed", elapsed).Uint64("total_faults", totalFaults).Int("pages", pages) + } + + if pages == 0 { + // Missing access.log → sequential prefetch fallback, which is a + // cleaner signal than a zero-byte log. + mkEvent().Msg("access pattern recording produced no pages; access.log not written") + return nil + } + + if err := recorder.Flush(outputPath, false); err != nil { + return fmt.Errorf("flush access log: %w", err) + } + mkEvent().Str("path", outputPath).Msg("access pattern recorded") + return nil +} + +// waitFaultsSettle returns settled=true once the guest has produced +// activity then stayed quiet for recordQuietTicks past +// recordMinDuration. settled=false means maxDuration or ctx tripped. +func (m *Manager) waitFaultsSettle(ctx context.Context, h *uffd.Handler, maxDuration time.Duration) (settled bool, elapsed time.Duration, totalFaults uint64) { + start := time.Now() + ticker := time.NewTicker(recordTickInterval) + defer ticker.Stop() + deadline := time.NewTimer(maxDuration) + defer deadline.Stop() + + var lastFaults uint64 + seenActivity := false + quietCount := 0 + + for { + select { + case <-ctx.Done(): + return false, time.Since(start), h.Stats().FaultsServed.Load() + case <-deadline.C: + return false, time.Since(start), h.Stats().FaultsServed.Load() + case <-ticker.C: + cur := h.Stats().FaultsServed.Load() + delta := cur - lastFaults + lastFaults = cur + if delta > 0 { + seenActivity = true + quietCount = 0 + continue + } + if !seenActivity { + continue + } + if time.Since(start) < recordMinDuration { + continue + } + quietCount++ + if quietCount >= recordQuietTicks { + return true, time.Since(start), cur + } + } + } +} + +// cancelUffdHandler is idempotent and a no-op for non-UFFD VMs. +func (m *Manager) cancelUffdHandler(inst *VMInstance) { + inst.mu.Lock() + cancel := inst.uffdCancel + inst.uffdCancel = nil + inst.mu.Unlock() + if cancel != nil { + cancel() + } +} + // startFirecrackerColdBoot launches Firecracker inside a network namespace, // configures it, and boots the kernel. Used by the template build pipeline // (coldBootFromRootfs) to boot the throwaway VM that we snapshot into a diff --git a/internal/vm/manager_test.go b/internal/vm/manager_test.go index 694eb78..fa27efd 100644 --- a/internal/vm/manager_test.go +++ b/internal/vm/manager_test.go @@ -5,6 +5,11 @@ import ( "os" "path/filepath" "testing" + "time" + + "github.com/rs/zerolog" + + "github.com/superserve-ai/sandbox/internal/vm/uffd" ) // TestPlanRestore pins the restore-decision behavior across the four input @@ -374,3 +379,97 @@ func TestDeleteSnapshotFiles_NoSnapshotDirConfigured_Rejected(t *testing.T) { t.Error("expected rejection when SnapshotDir is unconfigured") } } + +func newTestManager() *Manager { + return &Manager{log: zerolog.Nop()} +} + +func TestWaitFaultsSettle_Converges(t *testing.T) { + h := uffd.New(uffd.Config{Logger: zerolog.Nop()}) + stop := make(chan struct{}) + go func() { + ticker := time.NewTicker(20 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-stop: + return + case <-ticker.C: + h.Stats().FaultsServed.Add(100) + } + } + }() + time.AfterFunc(150*time.Millisecond, func() { close(stop) }) + + mgr := newTestManager() + settled, elapsed, total := mgr.waitFaultsSettle(context.Background(), h, 5*time.Second) + if !settled { + t.Errorf("expected settled=true, got false (elapsed=%v total=%d)", elapsed, total) + } + if elapsed > 2*time.Second { + t.Errorf("settle took too long: %v", elapsed) + } + if total == 0 { + t.Error("expected non-zero total faults") + } +} + +func TestWaitFaultsSettle_NoActivityHitsCeiling(t *testing.T) { + h := uffd.New(uffd.Config{Logger: zerolog.Nop()}) + mgr := newTestManager() + settled, elapsed, total := mgr.waitFaultsSettle(context.Background(), h, 300*time.Millisecond) + if settled { + t.Errorf("expected settled=false (no activity), got true") + } + if total != 0 { + t.Errorf("expected zero total faults, got %d", total) + } + if elapsed < 250*time.Millisecond { + t.Errorf("ceiling hit too early: %v", elapsed) + } +} + +func TestWaitFaultsSettle_NeverQuietHitsCeiling(t *testing.T) { + h := uffd.New(uffd.Config{Logger: zerolog.Nop()}) + stop := make(chan struct{}) + defer close(stop) + go func() { + ticker := time.NewTicker(30 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-stop: + return + case <-ticker.C: + h.Stats().FaultsServed.Add(1) + } + } + }() + + mgr := newTestManager() + settled, elapsed, total := mgr.waitFaultsSettle(context.Background(), h, 400*time.Millisecond) + if settled { + t.Errorf("expected settled=false (continuous activity), got true") + } + if total == 0 { + t.Error("expected non-zero total faults") + } + if elapsed < 350*time.Millisecond { + t.Errorf("ceiling tripped too early: %v", elapsed) + } +} + +func TestWaitFaultsSettle_CtxCancel(t *testing.T) { + h := uffd.New(uffd.Config{Logger: zerolog.Nop()}) + ctx, cancel := context.WithCancel(context.Background()) + time.AfterFunc(100*time.Millisecond, cancel) + + mgr := newTestManager() + settled, elapsed, _ := mgr.waitFaultsSettle(ctx, h, 30*time.Second) + if settled { + t.Errorf("expected settled=false on cancel") + } + if elapsed > 500*time.Millisecond { + t.Errorf("cancel did not propagate promptly: %v", elapsed) + } +} diff --git a/internal/vm/reconciler.go b/internal/vm/reconciler.go index f35f82d..a91550f 100644 --- a/internal/vm/reconciler.go +++ b/internal/vm/reconciler.go @@ -255,13 +255,16 @@ func (r *Reconciler) runOnce(ctx context.Context) { } } - // Drift 3: systemd has a unit, DB says the sandbox is deleted or has - // no row at all. This is an orphan — stop the unit + clean up. + // Drift 3: systemd unit active, DB says deleted/failed/missing. + // `failed` catches restores whose forward-path cleanup didn't run + // (e.g., gRPC ctx fired mid-LoadSnapshot and our work continued + // after the caller gave up). if dbSandboxes != nil { for id := range active { sb, known := dbSandboxes[id] deleted := known && sb.Sandbox.Status == db.SandboxStatusDeleted - if known && !deleted { + failed := known && sb.Sandbox.Status == db.SandboxStatusFailed + if known && !deleted && !failed { continue } if !r.gracePeriodElapsed("orphan:"+id, now) { @@ -276,6 +279,9 @@ func (r *Reconciler) runOnce(ctx context.Context) { if deleted { reason = "systemd unit for soft-deleted sandbox" kind = "systemd_active_db_deleted" + } else if failed { + reason = "systemd unit for failed sandbox" + kind = "systemd_active_db_failed" } log.Warn().Str("vm_id", id).Str("drift", kind).Msg("orphan systemd unit — stopping") if err := stopUnit(ctx, systemdUnitName(id)); err != nil { diff --git a/internal/vm/uffd/handler.go b/internal/vm/uffd/handler.go new file mode 100644 index 0000000..68c5a41 --- /dev/null +++ b/internal/vm/uffd/handler.go @@ -0,0 +1,491 @@ +package uffd + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net" + "os" + "path/filepath" + "sync" + "sync/atomic" + "syscall" + "time" + "unsafe" + + "github.com/rs/zerolog" + "golang.org/x/sync/errgroup" + "golang.org/x/sys/unix" +) + +// defaultFaultWorkers caps concurrent in-flight UFFDIO_COPY ioctls per +// VM. Bounded so handler teardown via g.Wait doesn't stall on a wedged +// worker. +const defaultFaultWorkers = 256 + +// uffdClosed is the sentinel value of Handler.uffdFd before the fd is +// received and after it has been closed. +const uffdClosed = ^uintptr(0) + +// GuestRegionMapping is the JSON Firecracker sends over the UDS alongside +// the userfaultfd fd. Field tags must match +// GuestRegionUffdMapping in src/vmm/src/persist.rs of our firecracker fork. +type GuestRegionMapping struct { + BaseHostVirtAddr uint64 `json:"base_host_virt_addr"` + Size uint64 `json:"size"` + Offset uint64 `json:"offset"` + PageSize uint64 `json:"page_size"` + PageSizeKib uint64 `json:"page_size_kib,omitempty"` +} + +func (r *GuestRegionMapping) contains(faultAddr uint64) bool { + return faultAddr >= r.BaseHostVirtAddr && faultAddr < r.BaseHostVirtAddr+r.Size +} + +type Config struct { + SocketPath string + MemSnapPath string + FaultWorkers int // 0 -> defaultFaultWorkers + + // AccessLogPath, when set and present on disk, supplies a recorded + // page-access order from template build time. The prefetcher replays + // it. Missing file → sequential fallback. + AccessLogPath string + + // PrefetchEnabled controls whether the prefetcher runs at all. + PrefetchEnabled bool + + // OnPageFault, if non-nil, is invoked for each PageFault event with + // the backing-file offset of the page about to be served. Used by + // template-builder to record access patterns. Must be fast (called + // on the fault hot path). + OnPageFault func(offset uint64) + + Logger zerolog.Logger +} + +type Stats struct { + FaultsServed atomic.Uint64 + BytesServed atomic.Uint64 + CopyErrors atomic.Uint64 + UnknownEvents atomic.Uint64 + RemoveEvents atomic.Uint64 + EagainRetries atomic.Uint64 + Eexist atomic.Uint64 + PrefetchedPages atomic.Uint64 + PrefetchSkipped atomic.Uint64 // EEXIST during prefetch (page already faulted by guest) +} + +// Handler binds a UDS, receives a userfaultfd fd from Firecracker, then +// serves page faults from mem.snap for the life of the VM. One handler +// per VM. +type Handler struct { + cfg Config + log zerolog.Logger + stats Stats + source *Source + listener *net.UnixListener + + // fdMu (RLock per ioctl, Lock around close) prevents an ioctl from + // racing with close + kernel fd-reuse. uffdFd is atomic so the + // lock-free serveLoop read is sound under Go's memory model. + fdMu sync.RWMutex + uffdFd atomic.Uintptr // uffdClosed when not held + + mappings []GuestRegionMapping + pageSize uint64 + + // handshakeDone carries the AcceptHandshake outcome to WaitHandshake. + handshakeDone chan error + + // prefetchWg lets Close wait for the prefetch goroutine before munmap. + prefetchWg sync.WaitGroup + + closed atomic.Bool +} + +func New(cfg Config) *Handler { + if cfg.FaultWorkers <= 0 { + cfg.FaultWorkers = defaultFaultWorkers + } + h := &Handler{ + cfg: cfg, + log: cfg.Logger.With().Str("component", "uffd").Logger(), + handshakeDone: make(chan error, 1), + } + h.uffdFd.Store(uffdClosed) + return h +} + +func (h *Handler) Stats() *Stats { return &h.stats } + +// Start binds the Unix socket and mmaps mem.snap. After it returns, +// Firecracker's LoadSnapshot can be invoked safely. +func (h *Handler) Start() error { + // UFFD handler runs BEFORE startFirecrackerViaSystemd, which is + // where the per-VM rundir is normally created. + if err := os.MkdirAll(filepath.Dir(h.cfg.SocketPath), 0o755); err != nil { + return fmt.Errorf("mkdir socket dir: %w", err) + } + _ = os.Remove(h.cfg.SocketPath) + + src, err := OpenSource(h.cfg.MemSnapPath) + if err != nil { + return fmt.Errorf("open source: %w", err) + } + + listener, err := net.ListenUnix("unix", &net.UnixAddr{Name: h.cfg.SocketPath, Net: "unix"}) + if err != nil { + _ = src.Close() + return fmt.Errorf("listen %s: %w", h.cfg.SocketPath, err) + } + if err := os.Chmod(h.cfg.SocketPath, 0o600); err != nil { + h.log.Warn().Err(err).Msg("chmod socket failed") + } + + h.source = src + h.listener = listener + return nil +} + + +// AcceptHandshake blocks on the UDS listener until Firecracker connects +// and sends the UFFD fd + region mappings. ctx supplies the handshake +// deadline — kept separate from Serve's lifetime ctx so a slow handshake +// under burst doesn't share fate with the page-fault loop. The outcome +// is published to WaitHandshake. +func (h *Handler) AcceptHandshake(ctx context.Context) error { + if h.listener == nil || h.source == nil { + err := errors.New("AcceptHandshake called before Start (or after Close)") + h.publishHandshake(err) + return err + } + if err := h.acceptAndReceive(ctx); err != nil { + err = fmt.Errorf("receive handshake: %w", err) + h.publishHandshake(err) + return err + } + h.publishHandshake(nil) + return nil +} + +func (h *Handler) publishHandshake(err error) { + select { + case h.handshakeDone <- err: + default: + } +} + +// WaitHandshake returns the AcceptHandshake outcome, blocking until it +// completes or ctx is cancelled. Idempotent. +func (h *Handler) WaitHandshake(ctx context.Context) error { + select { + case err := <-h.handshakeDone: + h.publishHandshake(err) + return err + case <-ctx.Done(): + return ctx.Err() + } +} + +// Serve runs the page-fault loop (and prefetcher, if enabled) until ctx +// is cancelled or the UFFD fd is closed. AcceptHandshake must have +// returned nil first. +func (h *Handler) Serve(ctx context.Context) error { + if h.cfg.PrefetchEnabled { + h.prefetchWg.Add(1) + go func() { + defer h.prefetchWg.Done() + h.runPrefetch(ctx) + }() + } + return h.serveLoop(ctx) +} + +func (h *Handler) Close() error { + if !h.closed.CompareAndSwap(false, true) { + return nil + } + var firstErr error + if h.listener != nil { + if err := h.listener.Close(); err != nil && firstErr == nil { + firstErr = fmt.Errorf("close listener: %w", err) + } + } + if h.cfg.SocketPath != "" { + _ = os.Remove(h.cfg.SocketPath) + } + h.closeFd() + // Drain prefetch before munmap. Fault workers are already joined by + // serveLoop's g.Wait; only the bare prefetch goroutine needs this. + h.prefetchWg.Wait() + if h.source != nil { + if err := h.source.Close(); err != nil && firstErr == nil { + firstErr = fmt.Errorf("close source: %w", err) + } + h.source = nil + } + return firstErr +} + +// closeFd closes the userfaultfd under the write lock so no in-flight +// ioctl can be operating on the fd when the kernel frees the slot +// (preventing fd-number reuse from redirecting a worker's ioctl to an +// unrelated open). Idempotent. +func (h *Handler) closeFd() { + h.fdMu.Lock() + defer h.fdMu.Unlock() + fd := h.uffdFd.Load() + if fd == uffdClosed { + return + } + _ = unix.Close(int(fd)) + h.uffdFd.Store(uffdClosed) +} + +func (h *Handler) acceptAndReceive(ctx context.Context) error { + // SetDeadline is the only way to make AcceptUnix interruptible; the + // watchdog below sets it to the past on ctx cancel. + if deadline, ok := ctx.Deadline(); ok { + _ = h.listener.SetDeadline(deadline) + } else { + _ = h.listener.SetDeadline(time.Now().Add(30 * time.Second)) + } + + stop := make(chan struct{}) + defer close(stop) + go func() { + select { + case <-ctx.Done(): + _ = h.listener.SetDeadline(time.Now()) + case <-stop: + } + }() + + conn, err := h.listener.AcceptUnix() + if err != nil { + return fmt.Errorf("accept: %w", err) + } + defer conn.Close() + + body := make([]byte, 4096) + oob := make([]byte, unix.CmsgSpace(4)) + + n, oobn, _, _, err := conn.ReadMsgUnix(body, oob) + if err != nil { + return fmt.Errorf("recvmsg: %w", err) + } + if n == 0 { + return errors.New("empty message from firecracker") + } + + cmsgs, err := unix.ParseSocketControlMessage(oob[:oobn]) + if err != nil { + return fmt.Errorf("parse cmsg: %w", err) + } + // Firecracker sends exactly one fd; cap defensively. + uffdFd := -1 + for _, cm := range cmsgs { + fds, err := unix.ParseUnixRights(&cm) + if err != nil { + continue + } + if len(fds) > 1 { + for _, fd := range fds { + _ = unix.Close(fd) + } + return fmt.Errorf("expected 1 fd in SCM_RIGHTS, got %d", len(fds)) + } + if len(fds) == 1 { + if uffdFd != -1 { + _ = unix.Close(fds[0]) + return errors.New("multiple SCM_RIGHTS messages; expected one") + } + uffdFd = fds[0] + } + } + if uffdFd == -1 { + return errors.New("no UFFD fd received in SCM_RIGHTS") + } + + // Firecracker creates UFFD with O_NONBLOCK; clear it so unix.Read + // blocks (cancellation closes the fd → EBADF wakes the read). + if flags, err := unix.FcntlInt(uintptr(uffdFd), unix.F_GETFL, 0); err == nil { + _, _ = unix.FcntlInt(uintptr(uffdFd), unix.F_SETFL, flags&^unix.O_NONBLOCK) + } + + var mappings []GuestRegionMapping + if err := json.Unmarshal(body[:n], &mappings); err != nil { + _ = unix.Close(uffdFd) + return fmt.Errorf("unmarshal mappings: %w (body=%q)", err, string(body[:n])) + } + if len(mappings) == 0 { + _ = unix.Close(uffdFd) + return errors.New("firecracker sent zero mappings") + } + pageSize := mappings[0].PageSize + if pageSize == 0 { + pageSize = mappings[0].PageSizeKib + } + if pageSize == 0 { + _ = unix.Close(uffdFd) + return fmt.Errorf("invalid page size in mappings: %+v", mappings[0]) + } + + h.uffdFd.Store(uintptr(uffdFd)) + h.mappings = mappings + h.pageSize = pageSize + + h.log.Info(). + Int("regions", len(mappings)). + Uint64("page_size", pageSize). + Int("uffd_fd", uffdFd). + Msg("uffd handshake complete") + + return nil +} + +func (h *Handler) serveLoop(ctx context.Context) error { + g, gctx := errgroup.WithContext(ctx) + g.SetLimit(h.cfg.FaultWorkers) + + // Closing the UFFD fd is the only way to wake a blocking read on it, + // so bridge context cancellation to a close here. + go func() { + <-gctx.Done() + h.closeFd() + }() + + var msgBuf [32]byte + for { + if err := gctx.Err(); err != nil { + return g.Wait() + } + fd := h.uffdFd.Load() + if fd == uffdClosed { + return g.Wait() + } + n, err := unix.Read(int(fd), msgBuf[:]) + if err != nil { + if errors.Is(err, syscall.EINTR) { + continue + } + // EBADF: cancellation goroutine closed the fd. Exit cleanly. + if errors.Is(err, syscall.EBADF) { + return g.Wait() + } + return fmt.Errorf("read uffd_msg: %w", err) + } + if n == 0 { + return g.Wait() // Firecracker exited. + } + if n < int(unsafe.Sizeof(uffdMsg{})) { + h.log.Warn().Int("bytes", n).Msg("short uffd_msg read; skipping") + continue + } + msg := *(*uffdMsg)(unsafe.Pointer(&msgBuf[0])) + switch msg.Event { + case UFFD_EVENT_PAGEFAULT: + pf := msg.asPageFault() + h.servePagefault(g, pf.Address, h.pageSize) + case UFFD_EVENT_REMOVE: + h.stats.RemoveEvents.Add(1) + rm := msg.asRemove() + h.handleRemove(rm.Start, rm.End) + default: + h.stats.UnknownEvents.Add(1) + h.log.Debug().Uint8("event", msg.Event).Msg("ignoring unknown uffd event") + } + } +} + +// servePagefault dispatches a single fault to the worker pool. The pool +// blocks if all FaultWorkers slots are busy — natural backpressure. +func (h *Handler) servePagefault(g *errgroup.Group, faultAddr, pageSize uint64) { + g.Go(func() error { + pageAddr := faultAddr &^ (pageSize - 1) + + region := h.findRegion(pageAddr) + if region == nil { + h.log.Error().Uint64("addr", faultAddr).Msg("page fault address outside all regions; killing handler") + return fmt.Errorf("address %#x outside guest regions", faultAddr) + } + + offset := region.Offset + (pageAddr - region.BaseHostVirtAddr) + if h.cfg.OnPageFault != nil { + h.cfg.OnPageFault(offset) + } + srcPtr, err := h.source.PagePointer(offset, pageSize) + if err != nil { + h.stats.CopyErrors.Add(1) + return fmt.Errorf("source page lookup: %w", err) + } + + h.fdMu.RLock() + fd := h.uffdFd.Load() + if fd == uffdClosed { + h.fdMu.RUnlock() + return nil + } + _, copyErr := ioctlCopy(fd, pageAddr, srcPtr, pageSize, 0) + h.fdMu.RUnlock() + if copyErr != nil { + // EEXIST: another fault on the same page raced ahead of us; + // it's already mapped. Benign. + if errors.Is(copyErr, syscall.EEXIST) { + h.stats.Eexist.Add(1) + return nil + } + // EAGAIN: a REMOVE event is pending in the kernel queue; + // all ioctls return EAGAIN until it's drained. The kernel + // re-fires the fault after we handle the REMOVE. + if errors.Is(copyErr, syscall.EAGAIN) { + h.stats.EagainRetries.Add(1) + return nil + } + h.stats.CopyErrors.Add(1) + return fmt.Errorf("UFFDIO_COPY at %#x: %w", pageAddr, copyErr) + } + h.stats.FaultsServed.Add(1) + h.stats.BytesServed.Add(pageSize) + return nil + }) +} + +// handleRemove responds to balloon-device removals by zeroing the range. +// Without this, subsequent faults on these pages would resolve to stale +// mem.snap contents instead of zero pages. +func (h *Handler) handleRemove(start, end uint64) { + if end <= start { + return + } + length := end - start + h.fdMu.RLock() + fd := h.uffdFd.Load() + if fd == uffdClosed { + h.fdMu.RUnlock() + return + } + err := ioctlZeropage(fd, start, length) + h.fdMu.RUnlock() + if err != nil { + // EAGAIN: another REMOVE event is queued ahead of us — the next + // REMOVE will cover this range, so skipping the zero is correct. + if errors.Is(err, syscall.EAGAIN) { + h.stats.EagainRetries.Add(1) + return + } + h.log.Warn().Err(err).Uint64("start", start).Uint64("end", end).Msg("UFFDIO_ZEROPAGE failed") + } +} + +func (h *Handler) findRegion(faultAddr uint64) *GuestRegionMapping { + for i := range h.mappings { + if h.mappings[i].contains(faultAddr) { + return &h.mappings[i] + } + } + return nil +} + diff --git a/internal/vm/uffd/handler_test.go b/internal/vm/uffd/handler_test.go new file mode 100644 index 0000000..1e2f39e --- /dev/null +++ b/internal/vm/uffd/handler_test.go @@ -0,0 +1,158 @@ +package uffd + +import ( + "context" + "encoding/json" + "errors" + "testing" + "time" + "unsafe" + + "github.com/rs/zerolog" +) + +// firecracker-produced JSON from src/vmm/src/persist.rs:725-768. Verified +// against the upstream struct definition (GuestRegionUffdMapping) and the +// reference example handler in src/firecracker/examples/uffd/uffd_utils.rs. +const sampleMappingsJSON = `[ + {"base_host_virt_addr": 140737488355328, "size": 268435456, "offset": 0, "page_size": 4096, "page_size_kib": 4096} +]` + +func TestParseGuestRegionMappings(t *testing.T) { + var mappings []GuestRegionMapping + if err := json.Unmarshal([]byte(sampleMappingsJSON), &mappings); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if len(mappings) != 1 { + t.Fatalf("got %d mappings, want 1", len(mappings)) + } + m := mappings[0] + if m.BaseHostVirtAddr != 140737488355328 { + t.Errorf("BaseHostVirtAddr = %d, want 140737488355328", m.BaseHostVirtAddr) + } + if m.Size != 268435456 { + t.Errorf("Size = %d, want 268435456 (256 MiB)", m.Size) + } + if m.Offset != 0 { + t.Errorf("Offset = %d, want 0", m.Offset) + } + if m.PageSize != 4096 { + t.Errorf("PageSize = %d, want 4096", m.PageSize) + } +} + +func TestGuestRegionMapping_Contains(t *testing.T) { + r := &GuestRegionMapping{ + BaseHostVirtAddr: 0x1000_0000, + Size: 0x1000_0000, + } + cases := []struct { + name string + addr uint64 + want bool + }{ + {"start of region", 0x1000_0000, true}, + {"middle of region", 0x1800_0000, true}, + {"last byte", 0x1FFF_FFFF, true}, + {"first byte after region", 0x2000_0000, false}, + {"before region", 0x0FFF_FFFF, false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := r.contains(tc.addr); got != tc.want { + t.Errorf("contains(%#x) = %v, want %v", tc.addr, got, tc.want) + } + }) + } +} + +func TestHandler_FindRegion(t *testing.T) { + h := &Handler{ + mappings: []GuestRegionMapping{ + {BaseHostVirtAddr: 0x1000_0000, Size: 0x1000_0000}, + {BaseHostVirtAddr: 0x3000_0000, Size: 0x2000_0000}, + }, + } + if r := h.findRegion(0x1500_0000); r == nil || r.BaseHostVirtAddr != 0x1000_0000 { + t.Errorf("findRegion(0x1500_0000): got %+v, want region 1", r) + } + if r := h.findRegion(0x4000_0000); r == nil || r.BaseHostVirtAddr != 0x3000_0000 { + t.Errorf("findRegion(0x4000_0000): got %+v, want region 2", r) + } + if r := h.findRegion(0x2500_0000); r != nil { + t.Errorf("findRegion(0x2500_0000) = %+v, want nil (gap between regions)", r) + } +} + +func TestIoctlNumbers(t *testing.T) { + // These constants are computed from 's _IOWR/_IOR + // macros and should never change. Pinning them here so a code edit + // can't silently break the kernel ABI. + if UFFDIO_COPY != 0xC028AA03 { + t.Errorf("UFFDIO_COPY = %#x, want 0xC028AA03", UFFDIO_COPY) + } + if UFFDIO_ZEROPAGE != 0xC020AA04 { + t.Errorf("UFFDIO_ZEROPAGE = %#x, want 0xC020AA04", UFFDIO_ZEROPAGE) + } + if UFFDIO_UNREGISTER != 0x8010AA01 { + t.Errorf("UFFDIO_UNREGISTER = %#x, want 0x8010AA01", UFFDIO_UNREGISTER) + } +} + +func TestUffdMsgSize(t *testing.T) { + // uffd_msg is 32 bytes packed in the kernel ABI. Reading 32 bytes + // from the UFFD fd is correct iff the Go struct also weighs 32 bytes. + var msg uffdMsg + want := uintptr(32) + if got := unsafe.Sizeof(msg); got != want { + t.Errorf("sizeof(uffdMsg) = %d, want %d", got, want) + } +} + +func TestHandler_CloseIdempotent(t *testing.T) { + h := New(Config{Logger: zerolog.Nop()}) + if err := h.Close(); err != nil { + t.Errorf("first Close: %v", err) + } + // Second call must be a no-op — closed CAS guards against double-Munmap + // and double-close of listener/fd. + if err := h.Close(); err != nil { + t.Errorf("second Close: %v", err) + } +} + +func TestHandler_WaitHandshake_Success(t *testing.T) { + h := New(Config{Logger: zerolog.Nop()}) + h.publishHandshake(nil) + if err := h.WaitHandshake(context.Background()); err != nil { + t.Errorf("WaitHandshake: %v", err) + } + // Idempotent: re-publish lets a second caller see the same outcome. + if err := h.WaitHandshake(context.Background()); err != nil { + t.Errorf("second WaitHandshake: %v", err) + } +} + +func TestHandler_WaitHandshake_Error(t *testing.T) { + h := New(Config{Logger: zerolog.Nop()}) + want := errors.New("boom") + h.publishHandshake(want) + got := h.WaitHandshake(context.Background()) + if !errors.Is(got, want) { + t.Errorf("WaitHandshake = %v, want %v", got, want) + } +} + +func TestHandler_WaitHandshake_CtxCancel(t *testing.T) { + h := New(Config{Logger: zerolog.Nop()}) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + start := time.Now() + err := h.WaitHandshake(ctx) + if !errors.Is(err, context.Canceled) { + t.Errorf("WaitHandshake = %v, want context.Canceled", err) + } + if time.Since(start) > 100*time.Millisecond { + t.Errorf("cancel didn't propagate promptly: %v", time.Since(start)) + } +} diff --git a/internal/vm/uffd/ioctl.go b/internal/vm/uffd/ioctl.go new file mode 100644 index 0000000..a2dcba7 --- /dev/null +++ b/internal/vm/uffd/ioctl.go @@ -0,0 +1,126 @@ +// Package uffd implements a userfaultfd page-fault handler for Firecracker +// snapshot restore. The handler receives a userfaultfd file descriptor from +// Firecracker over a Unix domain socket, then serves page faults by copying +// pages from a memory snapshot file into the guest's address space. +// +// One handler per VM. Lifecycle is tied to the VM's gRPC restore call — +// the handler goroutine is spawned before LoadSnapshot is issued and +// terminates when the VM is destroyed or the UFFD fd closes. +package uffd + +import ( + "fmt" + "syscall" + "unsafe" +) + +// From , verified against Linux 6.8. _IO[RW] +// encoding: (dir << 30) | (sizeof << 16) | (type << 8) | nr, with +// type = 0xAA. Pinned by TestIoctlNumbers so a struct resize can't +// silently break the kernel ABI. +const ( + UFFDIO_COPY uintptr = 0xC028AA03 + UFFDIO_ZEROPAGE uintptr = 0xC020AA04 + UFFDIO_UNREGISTER uintptr = 0x8010AA01 +) + +const ( + UFFD_EVENT_PAGEFAULT uint8 = 0x12 + UFFD_EVENT_FORK uint8 = 0x13 + UFFD_EVENT_REMAP uint8 = 0x14 + UFFD_EVENT_REMOVE uint8 = 0x15 + UFFD_EVENT_UNMAP uint8 = 0x16 +) + +const ( + UFFDIO_COPY_MODE_DONTWAKE uint64 = 1 << 0 + UFFDIO_COPY_MODE_WP uint64 = 1 << 1 +) + +type uffdioRange struct { + Start uint64 + Len uint64 +} + +type uffdioCopy struct { + Dst uint64 + Src uint64 + Len uint64 + Mode uint64 + Copy int64 // out: bytes copied, or -errno on failure +} + +type uffdioZeropage struct { + Range uffdioRange + Mode uint64 + Zeropage int64 +} + +// uffdMsg matches the packed 32-byte struct uffd_msg in the kernel. +// Arg is the union; interpret based on Event. +type uffdMsg struct { + Event uint8 + Reserved1 uint8 + Reserved2 uint16 + Reserved3 uint32 + Arg [24]byte +} + +type PageFaultArg struct { + Flags uint64 + Address uint64 + Ptid uint32 + _ uint32 +} + +type RemoveArg struct { + Start uint64 + End uint64 + _ uint64 +} + +func (m *uffdMsg) asPageFault() PageFaultArg { + return *(*PageFaultArg)(unsafe.Pointer(&m.Arg[0])) +} + +func (m *uffdMsg) asRemove() RemoveArg { + return *(*RemoveArg)(unsafe.Pointer(&m.Arg[0])) +} + +// ioctlCopy returns bytes_copied on success, or an error wrapping the +// kernel errno. Callers should check errors.Is(err, syscall.EEXIST) and +// errors.Is(err, syscall.EAGAIN) — these are benign races (concurrent +// fault on same page, or REMOVE event pending in queue). +// +// The kernel reports failures via op.Copy = -errno, which we extract +// because the syscall return errno is often less specific. +func ioctlCopy(uffdFd uintptr, dst, src uint64, length uint64, mode uint64) (int64, error) { + op := uffdioCopy{ + Dst: dst, + Src: src, + Len: length, + Mode: mode, + } + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uffdFd, UFFDIO_COPY, uintptr(unsafe.Pointer(&op))) + if op.Copy < 0 { + return op.Copy, syscall.Errno(-op.Copy) + } + if errno != 0 { + return op.Copy, errno + } + return op.Copy, nil +} + +func ioctlZeropage(uffdFd uintptr, start uint64, length uint64) error { + op := uffdioZeropage{ + Range: uffdioRange{Start: start, Len: length}, + } + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uffdFd, UFFDIO_ZEROPAGE, uintptr(unsafe.Pointer(&op))) + if errno != 0 && op.Zeropage <= 0 { + return fmt.Errorf("UFFDIO_ZEROPAGE: %w (zeropage=%d)", errno, op.Zeropage) + } + return nil +} + +// No UFFDIO_UNREGISTER wrapper: closing the fd auto-unregisters all +// regions. Constant is pinned by TestIoctlNumbers for ABI docs. diff --git a/internal/vm/uffd/prefetch.go b/internal/vm/uffd/prefetch.go new file mode 100644 index 0000000..bfefea7 --- /dev/null +++ b/internal/vm/uffd/prefetch.go @@ -0,0 +1,155 @@ +package uffd + +import ( + "bufio" + "context" + "errors" + "fmt" + "os" + "strconv" + "strings" + "syscall" +) + +// runPrefetch warms guest memory pre-emptively. Replays AccessLogPath if +// present (only pages the guest will touch); otherwise walks the +// snapshot sequentially. Pages already faulted by the guest return +// EEXIST from UFFDIO_COPY and are skipped cheaply. +func (h *Handler) runPrefetch(ctx context.Context) { + defer func() { + if r := recover(); r != nil { + h.log.Error().Interface("panic", r).Msg("prefetch panicked") + } + }() + + pageSize := h.pageSize + if pageSize == 0 { + h.log.Warn().Msg("prefetch: no page size, skipping") + return + } + if h.uffdFd.Load() == uffdClosed { + return + } + + offsets, ordered := h.prefetchOffsets() + h.log.Info().Bool("ordered", ordered).Int("offsets", len(offsets)).Msg("prefetch starting") + + for _, off := range offsets { + if ctx.Err() != nil { + return + } + stop, err := h.prefetchOne(off, pageSize) + if err != nil { + // Abort on first non-benign error; on-demand fault serving + // is unaffected. + h.log.Warn().Err(err).Uint64("offset", off).Msg("prefetch aborted") + return + } + if stop { + return + } + } + h.log.Debug().Uint64("served", h.stats.PrefetchedPages.Load()).Msg("prefetch complete") +} + +// prefetchOne returns stop=true when the fd has been closed under us, +// signaling the caller to exit the prefetch loop. +func (h *Handler) prefetchOne(offset, pageSize uint64) (stop bool, err error) { + region := h.regionForOffset(offset) + if region == nil { + return false, nil + } + dst := region.BaseHostVirtAddr + (offset - region.Offset) + srcPtr, perr := h.source.PagePointer(offset, pageSize) + if perr != nil { + return false, perr + } + + h.fdMu.RLock() + fd := h.uffdFd.Load() + if fd == uffdClosed { + h.fdMu.RUnlock() + return true, nil + } + _, copyErr := ioctlCopy(fd, dst, srcPtr, pageSize, 0) + h.fdMu.RUnlock() + if copyErr != nil { + if errors.Is(copyErr, syscall.EEXIST) { + h.stats.PrefetchSkipped.Add(1) + return false, nil + } + if errors.Is(copyErr, syscall.EAGAIN) { + // REMOVE pending in queue; main fault loop will drain it. + // Guest will re-fault this page later if it needs it. + h.stats.PrefetchSkipped.Add(1) + return false, nil + } + return false, fmt.Errorf("prefetch UFFDIO_COPY at %#x: %w", dst, copyErr) + } + h.stats.PrefetchedPages.Add(1) + return false, nil +} + +func (h *Handler) regionForOffset(offset uint64) *GuestRegionMapping { + for i := range h.mappings { + r := &h.mappings[i] + if offset >= r.Offset && offset < r.Offset+r.Size { + return r + } + } + return nil +} + +// prefetchOffsets returns page-aligned offsets to prefetch in order. +// The bool is true when access.log drove the order, false for fallback. +func (h *Handler) prefetchOffsets() ([]uint64, bool) { + if h.cfg.AccessLogPath != "" { + if offsets, err := readAccessLog(h.cfg.AccessLogPath, h.pageSize); err == nil && len(offsets) > 0 { + return offsets, true + } + } + // Sequential fallback: walk every region from start to end. + var out []uint64 + for _, r := range h.mappings { + for off := r.Offset; off < r.Offset+r.Size; off += h.pageSize { + out = append(out, off) + } + } + return out, false +} + +// readAccessLog parses newline-separated decimal offsets. Misaligned +// lines and blank/# lines are silently skipped. +func readAccessLog(path string, pageSize uint64) ([]uint64, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + var offsets []uint64 + seen := make(map[uint64]struct{}) + scanner := bufio.NewScanner(f) + for scanner.Scan() { + s := strings.TrimSpace(scanner.Text()) + if s == "" || strings.HasPrefix(s, "#") { + continue + } + n, err := strconv.ParseUint(s, 10, 64) + if err != nil { + continue + } + if n%pageSize != 0 { + continue + } + if _, dup := seen[n]; dup { + continue + } + seen[n] = struct{}{} + offsets = append(offsets, n) + } + if err := scanner.Err(); err != nil { + return offsets, err + } + return offsets, nil +} diff --git a/internal/vm/uffd/recorder.go b/internal/vm/uffd/recorder.go new file mode 100644 index 0000000..2ce4a07 --- /dev/null +++ b/internal/vm/uffd/recorder.go @@ -0,0 +1,80 @@ +package uffd + +import ( + "bufio" + "fmt" + "os" + "sort" + "strconv" + "sync" +) + +// Recorder collects page-fault offsets in observation order for later +// replay by the prefetcher. Safe for concurrent Record calls; capture +// order approximates guest first-touch order under our worker pool. +type Recorder struct { + mu sync.Mutex + seen map[uint64]struct{} + ordered []uint64 +} + +func NewRecorder() *Recorder { + return &Recorder{seen: make(map[uint64]struct{})} +} + +func (r *Recorder) Record(offset uint64) { + r.mu.Lock() + defer r.mu.Unlock() + if _, ok := r.seen[offset]; ok { + return + } + r.seen[offset] = struct{}{} + r.ordered = append(r.ordered, offset) +} + +func (r *Recorder) Len() int { + r.mu.Lock() + defer r.mu.Unlock() + return len(r.ordered) +} + +// Flush writes recorded offsets to path, one per line. sortAscending=true +// emits numerical order instead of capture order — use when parallel +// workers make capture order unreliable. +func (r *Recorder) Flush(path string, sortAscending bool) error { + r.mu.Lock() + out := make([]uint64, len(r.ordered)) + copy(out, r.ordered) + r.mu.Unlock() + + if sortAscending { + sort.Slice(out, func(i, j int) bool { return out[i] < out[j] }) + } + + tmp := path + ".tmp" + f, err := os.Create(tmp) + if err != nil { + return fmt.Errorf("create %s: %w", tmp, err) + } + w := bufio.NewWriter(f) + for _, off := range out { + if _, err := w.WriteString(strconv.FormatUint(off, 10) + "\n"); err != nil { + _ = f.Close() + _ = os.Remove(tmp) + return fmt.Errorf("write: %w", err) + } + } + if err := w.Flush(); err != nil { + _ = f.Close() + _ = os.Remove(tmp) + return fmt.Errorf("flush: %w", err) + } + if err := f.Close(); err != nil { + _ = os.Remove(tmp) + return fmt.Errorf("close: %w", err) + } + if err := os.Rename(tmp, path); err != nil { + return fmt.Errorf("rename: %w", err) + } + return nil +} diff --git a/internal/vm/uffd/recorder_test.go b/internal/vm/uffd/recorder_test.go new file mode 100644 index 0000000..523c6e1 --- /dev/null +++ b/internal/vm/uffd/recorder_test.go @@ -0,0 +1,92 @@ +package uffd + +import ( + "os" + "path/filepath" + "reflect" + "strings" + "testing" +) + +func TestRecorder_DedupAndOrder(t *testing.T) { + r := NewRecorder() + r.Record(4096) + r.Record(8192) + r.Record(4096) // duplicate — must be ignored + r.Record(12288) + r.Record(8192) // duplicate + + if r.Len() != 3 { + t.Fatalf("Len = %d, want 3", r.Len()) + } + + path := filepath.Join(t.TempDir(), "access.log") + if err := r.Flush(path, false); err != nil { + t.Fatalf("Flush: %v", err) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + lines := strings.Split(strings.TrimSpace(string(data)), "\n") + want := []string{"4096", "8192", "12288"} + if !reflect.DeepEqual(lines, want) { + t.Errorf("flushed contents = %v, want %v (first-touch order, deduped)", lines, want) + } +} + +func TestRecorder_FlushAtomicRename(t *testing.T) { + r := NewRecorder() + r.Record(4096) + + dir := t.TempDir() + path := filepath.Join(dir, "access.log") + + if err := r.Flush(path, false); err != nil { + t.Fatalf("Flush: %v", err) + } + + // .tmp must NOT survive a successful flush — the rename should have + // moved it into place. A leftover .tmp signals a torn flush. + if _, err := os.Stat(path + ".tmp"); !os.IsNotExist(err) { + t.Errorf("tmp file %q still exists after successful flush", path+".tmp") + } + if _, err := os.Stat(path); err != nil { + t.Errorf("output file missing after flush: %v", err) + } +} + +func TestRecorder_FlushEmpty(t *testing.T) { + r := NewRecorder() + path := filepath.Join(t.TempDir(), "access.log") + if err := r.Flush(path, false); err != nil { + t.Fatalf("Flush on empty recorder: %v", err) + } + info, err := os.Stat(path) + if err != nil { + t.Fatalf("output file missing: %v", err) + } + if info.Size() != 0 { + t.Errorf("empty recorder flushed %d bytes, want 0", info.Size()) + } +} + +func TestRecorder_FlushSorted(t *testing.T) { + r := NewRecorder() + r.Record(12288) + r.Record(4096) + r.Record(8192) + + path := filepath.Join(t.TempDir(), "access.log") + if err := r.Flush(path, true); err != nil { + t.Fatalf("Flush: %v", err) + } + + data, _ := os.ReadFile(path) + lines := strings.Split(strings.TrimSpace(string(data)), "\n") + want := []string{"4096", "8192", "12288"} + if !reflect.DeepEqual(lines, want) { + t.Errorf("sorted flush = %v, want %v", lines, want) + } +} diff --git a/internal/vm/uffd/source.go b/internal/vm/uffd/source.go new file mode 100644 index 0000000..b6ca620 --- /dev/null +++ b/internal/vm/uffd/source.go @@ -0,0 +1,71 @@ +package uffd + +import ( + "fmt" + "os" + "unsafe" + + "golang.org/x/sys/unix" +) + +// Source is a memory-mapped read-only view of a snapshot's mem.snap file. +type Source struct { + path string + file *os.File + mapping []byte +} + +// OpenSource opens path and mmaps it PROT_READ | MAP_PRIVATE. +// MAP_POPULATE is intentionally NOT set — that would force synchronous +// memory loading, which is the exact behavior UFFD is meant to avoid. +func OpenSource(path string) (*Source, error) { + f, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("open mem.snap: %w", err) + } + st, err := f.Stat() + if err != nil { + _ = f.Close() + return nil, fmt.Errorf("stat mem.snap: %w", err) + } + size := st.Size() + if size <= 0 { + _ = f.Close() + return nil, fmt.Errorf("mem.snap %s is empty", path) + } + data, err := unix.Mmap(int(f.Fd()), 0, int(size), unix.PROT_READ, unix.MAP_PRIVATE) + if err != nil { + _ = f.Close() + return nil, fmt.Errorf("mmap mem.snap: %w", err) + } + return &Source{path: path, file: f, mapping: data}, nil +} + +func (s *Source) Size() int { return len(s.mapping) } + +// PagePointer returns a kernel-facing address into the mmap'd backing, +// for direct use as the src argument to UFFDIO_COPY. The returned +// address is valid only while Source is open. +func (s *Source) PagePointer(offset, length uint64) (uint64, error) { + if offset+length > uint64(len(s.mapping)) { + return 0, fmt.Errorf("offset %d + length %d exceeds file size %d", offset, length, len(s.mapping)) + } + return uint64(uintptr(unsafe.Pointer(unsafe.SliceData(s.mapping))) + uintptr(offset)), nil +} + +func (s *Source) Close() error { + var firstErr error + if s.mapping != nil { + if err := unix.Munmap(s.mapping); err != nil && firstErr == nil { + firstErr = fmt.Errorf("munmap: %w", err) + } + s.mapping = nil + } + if s.file != nil { + if err := s.file.Close(); err != nil && firstErr == nil { + firstErr = fmt.Errorf("close: %w", err) + } + s.file = nil + } + return firstErr +} diff --git a/internal/vm/uffd/source_test.go b/internal/vm/uffd/source_test.go new file mode 100644 index 0000000..d40bab6 --- /dev/null +++ b/internal/vm/uffd/source_test.go @@ -0,0 +1,89 @@ +package uffd + +import ( + "bytes" + "os" + "path/filepath" + "testing" +) + +func TestSourceOpenAndRead(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "mem.snap") + + // 8 KiB of recognizable bytes — two pages, each with a distinct pattern + // at the page boundary so we can verify PagePointer offsets. + page := 4096 + buf := make([]byte, 2*page) + for i := range buf[:page] { + buf[i] = 0x11 + } + for i := range buf[page:] { + buf[page+i] = 0x22 + } + if err := os.WriteFile(path, buf, 0o644); err != nil { + t.Fatalf("write test file: %v", err) + } + + src, err := OpenSource(path) + if err != nil { + t.Fatalf("OpenSource: %v", err) + } + t.Cleanup(func() { _ = src.Close() }) + + if src.Size() != 2*page { + t.Errorf("Size = %d, want %d", src.Size(), 2*page) + } + + // PagePointer should report distinct offsets for the two pages. + p0, err := src.PagePointer(0, uint64(page)) + if err != nil { + t.Fatalf("PagePointer page 0: %v", err) + } + p1, err := src.PagePointer(uint64(page), uint64(page)) + if err != nil { + t.Fatalf("PagePointer page 1: %v", err) + } + if p1-p0 != uint64(page) { + t.Errorf("page pointer delta = %d, want %d", p1-p0, page) + } + + // Verify mmap contents by reading the internal mapping (test is in + // the same package). PagePointer returns a kernel-facing uint64; we + // don't want to round-trip it through unsafe.Pointer in test code. + if !bytes.Equal(src.mapping[:page], buf[:page]) { + t.Errorf("page 0 contents mismatch") + } + if !bytes.Equal(src.mapping[page:], buf[page:]) { + t.Errorf("page 1 contents mismatch") + } +} + +func TestSourceBoundsCheck(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "mem.snap") + if err := os.WriteFile(path, make([]byte, 4096), 0o644); err != nil { + t.Fatalf("write test file: %v", err) + } + src, err := OpenSource(path) + if err != nil { + t.Fatalf("OpenSource: %v", err) + } + t.Cleanup(func() { _ = src.Close() }) + + // Reading past the end must error rather than reach into nothing. + if _, err := src.PagePointer(4096, 4096); err == nil { + t.Error("PagePointer past EOF returned nil error, want bounds error") + } +} + +func TestSourceRejectsEmpty(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "empty.snap") + if err := os.WriteFile(path, nil, 0o644); err != nil { + t.Fatalf("write: %v", err) + } + if _, err := OpenSource(path); err == nil { + t.Error("OpenSource on empty file returned nil error, want non-nil") + } +}