From 7a96412c9772fbc5e779199d5760ff6af3ec6c64 Mon Sep 17 00:00:00 2001 From: Amit Patil Date: Tue, 12 May 2026 11:48:33 -0700 Subject: [PATCH 1/9] vm: lazy memory restore via UFFD; stop systemd unit on restore failure --- internal/vm/firecracker.go | 37 ++++ internal/vm/manager.go | 85 +++++++- internal/vm/reconciler.go | 12 +- internal/vm/uffd/handler.go | 358 +++++++++++++++++++++++++++++++ internal/vm/uffd/handler_test.go | 105 +++++++++ internal/vm/uffd/ioctl.go | 132 ++++++++++++ internal/vm/uffd/source.go | 71 ++++++ internal/vm/uffd/source_test.go | 89 ++++++++ 8 files changed, 885 insertions(+), 4 deletions(-) create mode 100644 internal/vm/uffd/handler.go create mode 100644 internal/vm/uffd/handler_test.go create mode 100644 internal/vm/uffd/ioctl.go create mode 100644 internal/vm/uffd/source.go create mode 100644 internal/vm/uffd/source_test.go diff --git a/internal/vm/firecracker.go b/internal/vm/firecracker.go index c95dc38..7efb5fd 100644 --- a/internal/vm/firecracker.go +++ b/internal/vm/firecracker.go @@ -326,3 +326,40 @@ func RestoreSnapshotWithOverrides(socketPath, snapshotPath, memPath, ifaceID, ta } return nil } + +// RestoreSnapshotUffdWithOverrides is the UFFD-backend variant of +// RestoreSnapshotWithOverrides. Instead of pointing Firecracker at the +// mem.snap file (File backend, which synchronously reads and CRC64-verifies +// the entire snapshot before returning), it points Firecracker at a Unix +// domain socket served by our in-process UFFD handler. +// +// LoadSnapshot returns in milliseconds because pages are not read upfront. +// As the guest touches pages, the kernel forwards faults to our handler, +// which serves them from a memory-mapped mem.snap. +// +// uffdSocketPath must be a bound Unix socket; the caller is responsible for +// starting the handler goroutine before invoking this function. +func RestoreSnapshotUffdWithOverrides(socketPath, snapshotPath, uffdSocketPath, ifaceID, tapDevice, blockDeltaDir string) error { + fc := newFCClient(socketPath) + if _, err := fc.Operations.LoadSnapshot(&operations.LoadSnapshotParams{ + Context: context.Background(), + Body: &models.SnapshotLoadParams{ + SnapshotPath: &snapshotPath, + MemBackend: &models.MemoryBackend{ + BackendType: strPtr(models.MemoryBackendBackendTypeUffd), + BackendPath: &uffdSocketPath, + }, + ResumeVM: true, + NetworkOverrides: []*models.NetworkOverride{ + {IfaceID: &ifaceID, HostDevName: &tapDevice}, + }, + BlockDeltaDir: blockDeltaDir, + }, + }); err != nil { + if isTornSnapshotErr(err) { + return fmt.Errorf("load snapshot (uffd): %w: %v", ErrTornSnapshot, err) + } + return fmt.Errorf("load snapshot (uffd): %w", err) + } + return nil +} diff --git a/internal/vm/manager.go b/internal/vm/manager.go index 7ffe244..6e87cd2 100644 --- a/internal/vm/manager.go +++ b/internal/vm/manager.go @@ -20,6 +20,7 @@ import ( "google.golang.org/grpc/status" "github.com/superserve-ai/sandbox/internal/network" + "github.com/superserve-ai/sandbox/internal/vm/uffd" pb "github.com/superserve-ai/sandbox/proto/boxdpb" ) @@ -88,6 +89,10 @@ type VMInstance struct { CreatedAt time.Time Metadata map[string]string + // uffdCancel cancels the per-VM UFFD handler goroutine. nil for VMs + // not using UFFD (cold boots, template builds, in-place resumes). + uffdCancel context.CancelFunc + mu sync.RWMutex } @@ -441,6 +446,8 @@ func (m *Manager) DestroyVM(ctx context.Context, vmID string, force bool) error log := m.log.With().Str("vm_id", vmID).Logger() log.Info().Bool("force", force).Msg("destroying VM") + m.cancelUffdHandler(inst) + // Stop the systemd unit if one exists — this is the path for sandbox // VMs launched via startFirecrackerViaSystemd. if err := stopUnit(ctx, systemdUnitName(vmID)); err != nil { @@ -843,6 +850,7 @@ func (m *Manager) RestoreVMSnapshot(ctx context.Context, vmID, snapshotPath, mem vmDir := filepath.Join(m.cfg.RunDir, vmID) socketPath := filepath.Join(vmDir, "firecracker.sock") + uffdSocketPath := filepath.Join(vmDir, "uffd.sock") // Publish all the network/disk/socket fields before starting // Firecracker so the in-memory view is consistent for concurrent @@ -856,8 +864,21 @@ func (m *Manager) RestoreVMSnapshot(ctx context.Context, vmID, snapshotPath, mem inst.SocketPath = socketPath inst.mu.Unlock() + // inPlace (pause/resume) keeps the File backend; fresh restores use UFFD. + if !inPlace { + if err := m.startUffdHandler(inst, uffdSocketPath, memPath); err != nil { + if !inPlace { + m.netMgr.CleanupVM(vmID) + } + m.cleanupRunDir(vmID) + m.setStatus(vmID, StatusError) + return nil, fmt.Errorf("start uffd handler: %w", err) + } + } + pid, startErr := m.startFirecrackerViaSystemd(ctx, vmID, socketPath, diskPath, resourceLimits.BasePath, nsName) if startErr != nil { + m.cancelUffdHandler(inst) if !inPlace { m.netMgr.CleanupVM(vmID) } @@ -875,9 +896,13 @@ func (m *Manager) RestoreVMSnapshot(ctx context.Context, vmID, snapshotPath, mem if inPlace { restoreErr = RestoreSnapshot(socketPath, snapshotPath, memPath, plan.deltaDir) } else { - restoreErr = RestoreSnapshotWithOverrides(socketPath, snapshotPath, memPath, "eth0", tapDevice, plan.deltaDir) + restoreErr = RestoreSnapshotUffdWithOverrides(socketPath, snapshotPath, uffdSocketPath, "eth0", tapDevice, plan.deltaDir) } if restoreErr != nil { + // Firecracker is already running; stop the unit before other + // cleanup or it leaks. See stopUnitDuringRestoreError comment. + m.stopUnitDuringRestoreError(vmID) + m.cancelUffdHandler(inst) if !inPlace { m.netMgr.CleanupVM(vmID) } @@ -892,6 +917,8 @@ func (m *Manager) RestoreVMSnapshot(ctx context.Context, vmID, snapshotPath, mem } if err := m.waitForBoxd(ctx, hostIP, 5*time.Second); err != nil { + m.stopUnitDuringRestoreError(vmID) + m.cancelUffdHandler(inst) if !inPlace { m.netMgr.CleanupVM(vmID) } @@ -1336,6 +1363,62 @@ func (m *Manager) cleanupRunDir(dirName string) { } } +// stopUnitDuringRestoreError stops the per-VM systemd unit when a restore +// aborts after Firecracker started. Uses a fresh context because the +// caller's gRPC ctx is often already cancelled (deadline exceeded under +// load). Without this, the firecracker process leaks. +func (m *Manager) stopUnitDuringRestoreError(vmID string) { + cleanupCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := stopUnit(cleanupCtx, systemdUnitName(vmID)); err != nil { + m.log.Warn().Err(err).Str("vm_id", vmID).Msg("systemctl stop failed during restore error cleanup") + } + removeUnitDropIn(vmID) +} + +// startUffdHandler binds the per-VM UFFD socket, opens mem.snap, and +// spawns the fault-serving goroutine. Stores the cancel func on inst so +// destroy/error paths can tear it down. +func (m *Manager) startUffdHandler(inst *VMInstance, uffdSocketPath, memSnapPath string) error { + h := uffd.New(uffd.Config{ + SocketPath: uffdSocketPath, + MemSnapPath: memSnapPath, + Logger: m.log.With().Str("vm_id", inst.ID).Logger(), + }) + if err := h.Start(); err != nil { + return err + } + uffdCtx, uffdCancel := context.WithCancel(context.Background()) + inst.mu.Lock() + inst.uffdCancel = uffdCancel + inst.mu.Unlock() + + vmID := inst.ID + go func() { + defer func() { + if r := recover(); r != nil { + m.log.Error().Interface("panic", r).Str("vm_id", vmID).Msg("uffd handler panicked") + } + }() + defer h.Close() + if err := h.Run(uffdCtx); err != nil && !errors.Is(err, context.Canceled) { + m.log.Error().Err(err).Str("vm_id", vmID).Msg("uffd handler exited with error") + } + }() + return nil +} + +// cancelUffdHandler is idempotent and a no-op for non-UFFD VMs. +func (m *Manager) cancelUffdHandler(inst *VMInstance) { + inst.mu.Lock() + cancel := inst.uffdCancel + inst.uffdCancel = nil + inst.mu.Unlock() + if cancel != nil { + cancel() + } +} + // startFirecrackerColdBoot launches Firecracker inside a network namespace, // configures it, and boots the kernel. Used by the template build pipeline // (coldBootFromRootfs) to boot the throwaway VM that we snapshot into a diff --git a/internal/vm/reconciler.go b/internal/vm/reconciler.go index f35f82d..a91550f 100644 --- a/internal/vm/reconciler.go +++ b/internal/vm/reconciler.go @@ -255,13 +255,16 @@ func (r *Reconciler) runOnce(ctx context.Context) { } } - // Drift 3: systemd has a unit, DB says the sandbox is deleted or has - // no row at all. This is an orphan — stop the unit + clean up. + // Drift 3: systemd unit active, DB says deleted/failed/missing. + // `failed` catches restores whose forward-path cleanup didn't run + // (e.g., gRPC ctx fired mid-LoadSnapshot and our work continued + // after the caller gave up). if dbSandboxes != nil { for id := range active { sb, known := dbSandboxes[id] deleted := known && sb.Sandbox.Status == db.SandboxStatusDeleted - if known && !deleted { + failed := known && sb.Sandbox.Status == db.SandboxStatusFailed + if known && !deleted && !failed { continue } if !r.gracePeriodElapsed("orphan:"+id, now) { @@ -276,6 +279,9 @@ func (r *Reconciler) runOnce(ctx context.Context) { if deleted { reason = "systemd unit for soft-deleted sandbox" kind = "systemd_active_db_deleted" + } else if failed { + reason = "systemd unit for failed sandbox" + kind = "systemd_active_db_failed" } log.Warn().Str("vm_id", id).Str("drift", kind).Msg("orphan systemd unit — stopping") if err := stopUnit(ctx, systemdUnitName(id)); err != nil { diff --git a/internal/vm/uffd/handler.go b/internal/vm/uffd/handler.go new file mode 100644 index 0000000..639ecad --- /dev/null +++ b/internal/vm/uffd/handler.go @@ -0,0 +1,358 @@ +package uffd + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net" + "os" + "path/filepath" + "sync/atomic" + "syscall" + "time" + "unsafe" + + "github.com/rs/zerolog" + "golang.org/x/sync/errgroup" + "golang.org/x/sys/unix" +) + +const defaultFaultWorkers = 4096 + +// GuestRegionMapping is the JSON Firecracker sends over the UDS alongside +// the userfaultfd fd. Field tags must match +// GuestRegionUffdMapping in src/vmm/src/persist.rs of our firecracker fork. +type GuestRegionMapping struct { + BaseHostVirtAddr uint64 `json:"base_host_virt_addr"` + Size uint64 `json:"size"` + Offset uint64 `json:"offset"` + PageSize uint64 `json:"page_size"` + PageSizeKib uint64 `json:"page_size_kib,omitempty"` +} + +func (r *GuestRegionMapping) contains(faultAddr uint64) bool { + return faultAddr >= r.BaseHostVirtAddr && faultAddr < r.BaseHostVirtAddr+r.Size +} + +type Config struct { + SocketPath string + MemSnapPath string + FaultWorkers int // 0 -> defaultFaultWorkers + Logger zerolog.Logger +} + +type Stats struct { + FaultsServed atomic.Uint64 + BytesServed atomic.Uint64 + CopyErrors atomic.Uint64 + UnknownEvents atomic.Uint64 + RemoveEvents atomic.Uint64 + EagainRetries atomic.Uint64 + Eexist atomic.Uint64 +} + +// Handler binds a UDS, receives a userfaultfd fd from Firecracker, then +// serves page faults from mem.snap for the life of the VM. One handler +// per VM. +type Handler struct { + cfg Config + log zerolog.Logger + stats Stats + source *Source + listener *net.UnixListener + uffdFd uintptr // dup'd from SCM_RIGHTS; -1 if not yet received + mappings []GuestRegionMapping + pageSize uint64 + + // closed is signaled when Close has been called. + closed atomic.Bool +} + +func New(cfg Config) *Handler { + if cfg.FaultWorkers <= 0 { + cfg.FaultWorkers = defaultFaultWorkers + } + return &Handler{ + cfg: cfg, + log: cfg.Logger.With().Str("component", "uffd").Logger(), + uffdFd: ^uintptr(0), + } +} + +func (h *Handler) Stats() *Stats { return &h.stats } + +// Start binds the Unix socket and mmaps mem.snap. After it returns, +// Firecracker's LoadSnapshot can be invoked safely. +func (h *Handler) Start() error { + // UFFD handler runs BEFORE startFirecrackerViaSystemd, which is + // where the per-VM rundir is normally created. + if err := os.MkdirAll(filepath.Dir(h.cfg.SocketPath), 0o755); err != nil { + return fmt.Errorf("mkdir socket dir: %w", err) + } + _ = os.Remove(h.cfg.SocketPath) + + src, err := OpenSource(h.cfg.MemSnapPath) + if err != nil { + return fmt.Errorf("open source: %w", err) + } + + listener, err := net.ListenUnix("unix", &net.UnixAddr{Name: h.cfg.SocketPath, Net: "unix"}) + if err != nil { + _ = src.Close() + return fmt.Errorf("listen %s: %w", h.cfg.SocketPath, err) + } + if err := os.Chmod(h.cfg.SocketPath, 0o600); err != nil { + h.log.Warn().Err(err).Msg("chmod socket failed") + } + + h.source = src + h.listener = listener + return nil +} + + +// Run accepts the connection from Firecracker, receives the UFFD fd +// plus JSON mappings, then services page faults until ctx is cancelled +// or the UFFD fd closes. Caller must call Close on return. +func (h *Handler) Run(ctx context.Context) error { + if h.listener == nil || h.source == nil { + return errors.New("Run called before Start (or after Close)") + } + if err := h.acceptAndReceive(ctx); err != nil { + return fmt.Errorf("receive handshake: %w", err) + } + return h.serveLoop(ctx) +} + +func (h *Handler) Close() error { + if !h.closed.CompareAndSwap(false, true) { + return nil + } + var firstErr error + if h.listener != nil { + if err := h.listener.Close(); err != nil && firstErr == nil { + firstErr = fmt.Errorf("close listener: %w", err) + } + } + if h.cfg.SocketPath != "" { + _ = os.Remove(h.cfg.SocketPath) + } + if h.uffdFd != ^uintptr(0) { + _ = unix.Close(int(h.uffdFd)) + h.uffdFd = ^uintptr(0) + } + if h.source != nil { + if err := h.source.Close(); err != nil && firstErr == nil { + firstErr = fmt.Errorf("close source: %w", err) + } + h.source = nil + } + return firstErr +} + +func (h *Handler) acceptAndReceive(ctx context.Context) error { + // ctx with no deadline → 30s default, matching the gRPC deadline on + // the caller side. The listener deadline is the only way to make + // AcceptUnix interruptible. + if deadline, ok := ctx.Deadline(); ok { + _ = h.listener.SetDeadline(deadline) + } else { + _ = h.listener.SetDeadline(time.Now().Add(30 * time.Second)) + } + + conn, err := h.listener.AcceptUnix() + if err != nil { + return fmt.Errorf("accept: %w", err) + } + defer conn.Close() + + body := make([]byte, 4096) + oob := make([]byte, unix.CmsgSpace(4)) + + n, oobn, _, _, err := conn.ReadMsgUnix(body, oob) + if err != nil { + return fmt.Errorf("recvmsg: %w", err) + } + if n == 0 { + return errors.New("empty message from firecracker") + } + + cmsgs, err := unix.ParseSocketControlMessage(oob[:oobn]) + if err != nil { + return fmt.Errorf("parse cmsg: %w", err) + } + uffdFd := -1 + for _, cm := range cmsgs { + fds, err := unix.ParseUnixRights(&cm) + if err != nil { + continue + } + for _, fd := range fds { + if uffdFd == -1 { + uffdFd = fd + } else { + _ = unix.Close(fd) + } + } + } + if uffdFd == -1 { + return errors.New("no UFFD fd received in SCM_RIGHTS") + } + + var mappings []GuestRegionMapping + if err := json.Unmarshal(body[:n], &mappings); err != nil { + _ = unix.Close(uffdFd) + return fmt.Errorf("unmarshal mappings: %w (body=%q)", err, string(body[:n])) + } + if len(mappings) == 0 { + _ = unix.Close(uffdFd) + return errors.New("firecracker sent zero mappings") + } + pageSize := mappings[0].PageSize + if pageSize == 0 { + pageSize = mappings[0].PageSizeKib + } + if pageSize == 0 { + _ = unix.Close(uffdFd) + return fmt.Errorf("invalid page size in mappings: %+v", mappings[0]) + } + + h.uffdFd = uintptr(uffdFd) + h.mappings = mappings + h.pageSize = pageSize + + h.log.Info(). + Int("regions", len(mappings)). + Uint64("page_size", pageSize). + Int("uffd_fd", uffdFd). + Msg("uffd handshake complete") + + return nil +} + +func (h *Handler) serveLoop(ctx context.Context) error { + g, gctx := errgroup.WithContext(ctx) + g.SetLimit(h.cfg.FaultWorkers) + + // Closing the UFFD fd is the only way to wake a blocking read on it, + // so bridge context cancellation to a close here. + go func() { + <-gctx.Done() + if h.uffdFd != ^uintptr(0) { + _ = unix.Close(int(h.uffdFd)) + h.uffdFd = ^uintptr(0) + } + }() + + var msgBuf [32]byte + for { + if err := gctx.Err(); err != nil { + return g.Wait() + } + fd := h.uffdFd + if fd == ^uintptr(0) { + return g.Wait() + } + n, err := unix.Read(int(fd), msgBuf[:]) + if err != nil { + if errors.Is(err, syscall.EINTR) { + continue + } + // EBADF: closed by our cancellation goroutine. + // EAGAIN: transient. + if errors.Is(err, syscall.EBADF) || errors.Is(err, syscall.EAGAIN) { + return g.Wait() + } + return fmt.Errorf("read uffd_msg: %w", err) + } + if n == 0 { + return g.Wait() // Firecracker exited. + } + if n < int(unsafe.Sizeof(uffdMsg{})) { + h.log.Warn().Int("bytes", n).Msg("short uffd_msg read; skipping") + continue + } + msg := *(*uffdMsg)(unsafe.Pointer(&msgBuf[0])) + switch msg.Event { + case UFFD_EVENT_PAGEFAULT: + pf := msg.asPageFault() + fdCopy := fd + pageSize := h.pageSize + h.servePagefault(g, fdCopy, pf.Address, pageSize) + case UFFD_EVENT_REMOVE: + h.stats.RemoveEvents.Add(1) + rm := msg.asRemove() + h.handleRemove(fd, rm.Start, rm.End) + default: + h.stats.UnknownEvents.Add(1) + h.log.Debug().Uint8("event", msg.Event).Msg("ignoring unknown uffd event") + } + } +} + +// servePagefault dispatches a single fault to the worker pool. The pool +// blocks if all FaultWorkers slots are busy — natural backpressure. +func (h *Handler) servePagefault(g *errgroup.Group, fd uintptr, faultAddr, pageSize uint64) { + g.Go(func() error { + pageAddr := faultAddr &^ (pageSize - 1) + + region := h.findRegion(pageAddr) + if region == nil { + h.log.Error().Uint64("addr", faultAddr).Msg("page fault address outside all regions; killing handler") + return fmt.Errorf("address %#x outside guest regions", faultAddr) + } + + offset := region.Offset + (pageAddr - region.BaseHostVirtAddr) + srcPtr, err := h.source.PagePointer(offset, pageSize) + if err != nil { + h.stats.CopyErrors.Add(1) + return fmt.Errorf("source page lookup: %w", err) + } + + _, copyErr := ioctlCopy(fd, pageAddr, srcPtr, pageSize, 0) + if copyErr != nil { + // EEXIST: another fault on the same page raced ahead of us; + // it's already mapped. Benign. + if errors.Is(copyErr, syscall.EEXIST) { + h.stats.Eexist.Add(1) + return nil + } + // EAGAIN: a REMOVE event is pending in the kernel queue; + // all ioctls return EAGAIN until it's drained. The kernel + // re-fires the fault after we handle the REMOVE. + if errors.Is(copyErr, syscall.EAGAIN) { + h.stats.EagainRetries.Add(1) + return nil + } + h.stats.CopyErrors.Add(1) + return fmt.Errorf("UFFDIO_COPY at %#x: %w", pageAddr, copyErr) + } + h.stats.FaultsServed.Add(1) + h.stats.BytesServed.Add(pageSize) + return nil + }) +} + +// handleRemove responds to balloon-device removals by zeroing the range. +// Without this, subsequent faults on these pages would resolve to stale +// mem.snap contents instead of zero pages. +func (h *Handler) handleRemove(fd uintptr, start, end uint64) { + if end <= start { + return + } + length := end - start + if err := ioctlZeropage(fd, start, length); err != nil { + h.log.Warn().Err(err).Uint64("start", start).Uint64("end", end).Msg("UFFDIO_ZEROPAGE failed") + } +} + +func (h *Handler) findRegion(faultAddr uint64) *GuestRegionMapping { + for i := range h.mappings { + if h.mappings[i].contains(faultAddr) { + return &h.mappings[i] + } + } + return nil +} + diff --git a/internal/vm/uffd/handler_test.go b/internal/vm/uffd/handler_test.go new file mode 100644 index 0000000..e4f7ca4 --- /dev/null +++ b/internal/vm/uffd/handler_test.go @@ -0,0 +1,105 @@ +package uffd + +import ( + "encoding/json" + "testing" + "unsafe" +) + +// firecracker-produced JSON from src/vmm/src/persist.rs:725-768. Verified +// against the upstream struct definition (GuestRegionUffdMapping) and the +// reference example handler in src/firecracker/examples/uffd/uffd_utils.rs. +const sampleMappingsJSON = `[ + {"base_host_virt_addr": 140737488355328, "size": 268435456, "offset": 0, "page_size": 4096, "page_size_kib": 4096} +]` + +func TestParseGuestRegionMappings(t *testing.T) { + var mappings []GuestRegionMapping + if err := json.Unmarshal([]byte(sampleMappingsJSON), &mappings); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if len(mappings) != 1 { + t.Fatalf("got %d mappings, want 1", len(mappings)) + } + m := mappings[0] + if m.BaseHostVirtAddr != 140737488355328 { + t.Errorf("BaseHostVirtAddr = %d, want 140737488355328", m.BaseHostVirtAddr) + } + if m.Size != 268435456 { + t.Errorf("Size = %d, want 268435456 (256 MiB)", m.Size) + } + if m.Offset != 0 { + t.Errorf("Offset = %d, want 0", m.Offset) + } + if m.PageSize != 4096 { + t.Errorf("PageSize = %d, want 4096", m.PageSize) + } +} + +func TestGuestRegionMapping_Contains(t *testing.T) { + r := &GuestRegionMapping{ + BaseHostVirtAddr: 0x1000_0000, + Size: 0x1000_0000, + } + cases := []struct { + name string + addr uint64 + want bool + }{ + {"start of region", 0x1000_0000, true}, + {"middle of region", 0x1800_0000, true}, + {"last byte", 0x1FFF_FFFF, true}, + {"first byte after region", 0x2000_0000, false}, + {"before region", 0x0FFF_FFFF, false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := r.contains(tc.addr); got != tc.want { + t.Errorf("contains(%#x) = %v, want %v", tc.addr, got, tc.want) + } + }) + } +} + +func TestHandler_FindRegion(t *testing.T) { + h := &Handler{ + mappings: []GuestRegionMapping{ + {BaseHostVirtAddr: 0x1000_0000, Size: 0x1000_0000}, + {BaseHostVirtAddr: 0x3000_0000, Size: 0x2000_0000}, + }, + } + if r := h.findRegion(0x1500_0000); r == nil || r.BaseHostVirtAddr != 0x1000_0000 { + t.Errorf("findRegion(0x1500_0000): got %+v, want region 1", r) + } + if r := h.findRegion(0x4000_0000); r == nil || r.BaseHostVirtAddr != 0x3000_0000 { + t.Errorf("findRegion(0x4000_0000): got %+v, want region 2", r) + } + if r := h.findRegion(0x2500_0000); r != nil { + t.Errorf("findRegion(0x2500_0000) = %+v, want nil (gap between regions)", r) + } +} + +func TestIoctlNumbers(t *testing.T) { + // These constants are computed from 's _IOWR/_IOR + // macros and should never change. Pinning them here so a code edit + // can't silently break the kernel ABI. + if UFFDIO_COPY != 0xC028AA03 { + t.Errorf("UFFDIO_COPY = %#x, want 0xC028AA03", UFFDIO_COPY) + } + if UFFDIO_ZEROPAGE != 0xC020AA04 { + t.Errorf("UFFDIO_ZEROPAGE = %#x, want 0xC020AA04", UFFDIO_ZEROPAGE) + } + if UFFDIO_UNREGISTER != 0x8010AA01 { + t.Errorf("UFFDIO_UNREGISTER = %#x, want 0x8010AA01", UFFDIO_UNREGISTER) + } +} + +func TestUffdMsgSize(t *testing.T) { + // uffd_msg is 32 bytes packed in the kernel ABI. Reading 32 bytes + // from the UFFD fd is correct iff the Go struct also weighs 32 bytes. + var msg uffdMsg + want := uintptr(32) + if got := unsafe.Sizeof(msg); got != want { + t.Errorf("sizeof(uffdMsg) = %d, want %d", got, want) + } +} diff --git a/internal/vm/uffd/ioctl.go b/internal/vm/uffd/ioctl.go new file mode 100644 index 0000000..d785f87 --- /dev/null +++ b/internal/vm/uffd/ioctl.go @@ -0,0 +1,132 @@ +// Package uffd implements a userfaultfd page-fault handler for Firecracker +// snapshot restore. The handler receives a userfaultfd file descriptor from +// Firecracker over a Unix domain socket, then serves page faults by copying +// pages from a memory snapshot file into the guest's address space. +// +// One handler per VM. Lifecycle is tied to the VM's gRPC restore call — +// the handler goroutine is spawned before LoadSnapshot is issued and +// terminates when the VM is destroyed or the UFFD fd closes. +package uffd + +import ( + "fmt" + "syscall" + "unsafe" +) + +// Hardcoded from ; verified against Linux 6.8 (our +// prod target). _IO[RW] encoding: (dir << 30) | (sizeof << 16) | (type +// << 8) | nr, with type = 0xAA. Pinned by TestIoctlNumbers so a struct +// resize can't silently break the kernel ABI. +const ( + UFFDIO_COPY uintptr = 0xC028AA03 + UFFDIO_ZEROPAGE uintptr = 0xC020AA04 + UFFDIO_UNREGISTER uintptr = 0x8010AA01 +) + +const ( + UFFD_EVENT_PAGEFAULT uint8 = 0x12 + UFFD_EVENT_FORK uint8 = 0x13 + UFFD_EVENT_REMAP uint8 = 0x14 + UFFD_EVENT_REMOVE uint8 = 0x15 + UFFD_EVENT_UNMAP uint8 = 0x16 +) + +const ( + UFFDIO_COPY_MODE_DONTWAKE uint64 = 1 << 0 + UFFDIO_COPY_MODE_WP uint64 = 1 << 1 +) + +type uffdioRange struct { + Start uint64 + Len uint64 +} + +type uffdioCopy struct { + Dst uint64 + Src uint64 + Len uint64 + Mode uint64 + Copy int64 // out: bytes copied, or -errno on failure +} + +type uffdioZeropage struct { + Range uffdioRange + Mode uint64 + Zeropage int64 +} + +// uffdMsg matches the packed 32-byte struct uffd_msg in the kernel. +// Arg is the union; interpret based on Event. +type uffdMsg struct { + Event uint8 + Reserved1 uint8 + Reserved2 uint16 + Reserved3 uint32 + Arg [24]byte +} + +type PageFaultArg struct { + Flags uint64 + Address uint64 + Ptid uint32 + _ uint32 +} + +type RemoveArg struct { + Start uint64 + End uint64 + _ uint64 +} + +func (m *uffdMsg) asPageFault() PageFaultArg { + return *(*PageFaultArg)(unsafe.Pointer(&m.Arg[0])) +} + +func (m *uffdMsg) asRemove() RemoveArg { + return *(*RemoveArg)(unsafe.Pointer(&m.Arg[0])) +} + +// ioctlCopy returns bytes_copied on success, or an error wrapping the +// kernel errno. Callers should check errors.Is(err, syscall.EEXIST) and +// errors.Is(err, syscall.EAGAIN) — these are benign races (concurrent +// fault on same page, or REMOVE event pending in queue). +// +// The kernel reports failures via op.Copy = -errno, which we extract +// because the syscall return errno is often less specific. +func ioctlCopy(uffdFd uintptr, dst, src uint64, length uint64, mode uint64) (int64, error) { + op := uffdioCopy{ + Dst: dst, + Src: src, + Len: length, + Mode: mode, + } + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uffdFd, UFFDIO_COPY, uintptr(unsafe.Pointer(&op))) + if op.Copy < 0 { + return op.Copy, syscall.Errno(-op.Copy) + } + if errno != 0 { + return op.Copy, errno + } + return op.Copy, nil +} + +func ioctlZeropage(uffdFd uintptr, start uint64, length uint64) error { + op := uffdioZeropage{ + Range: uffdioRange{Start: start, Len: length}, + } + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uffdFd, UFFDIO_ZEROPAGE, uintptr(unsafe.Pointer(&op))) + if errno != 0 && op.Zeropage <= 0 { + return fmt.Errorf("UFFDIO_ZEROPAGE: %w (zeropage=%d)", errno, op.Zeropage) + } + return nil +} + +func ioctlUnregister(uffdFd uintptr, start uint64, length uint64) error { + r := uffdioRange{Start: start, Len: length} + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uffdFd, UFFDIO_UNREGISTER, uintptr(unsafe.Pointer(&r))) + if errno != 0 { + return fmt.Errorf("UFFDIO_UNREGISTER: %w", errno) + } + return nil +} diff --git a/internal/vm/uffd/source.go b/internal/vm/uffd/source.go new file mode 100644 index 0000000..b6ca620 --- /dev/null +++ b/internal/vm/uffd/source.go @@ -0,0 +1,71 @@ +package uffd + +import ( + "fmt" + "os" + "unsafe" + + "golang.org/x/sys/unix" +) + +// Source is a memory-mapped read-only view of a snapshot's mem.snap file. +type Source struct { + path string + file *os.File + mapping []byte +} + +// OpenSource opens path and mmaps it PROT_READ | MAP_PRIVATE. +// MAP_POPULATE is intentionally NOT set — that would force synchronous +// memory loading, which is the exact behavior UFFD is meant to avoid. +func OpenSource(path string) (*Source, error) { + f, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("open mem.snap: %w", err) + } + st, err := f.Stat() + if err != nil { + _ = f.Close() + return nil, fmt.Errorf("stat mem.snap: %w", err) + } + size := st.Size() + if size <= 0 { + _ = f.Close() + return nil, fmt.Errorf("mem.snap %s is empty", path) + } + data, err := unix.Mmap(int(f.Fd()), 0, int(size), unix.PROT_READ, unix.MAP_PRIVATE) + if err != nil { + _ = f.Close() + return nil, fmt.Errorf("mmap mem.snap: %w", err) + } + return &Source{path: path, file: f, mapping: data}, nil +} + +func (s *Source) Size() int { return len(s.mapping) } + +// PagePointer returns a kernel-facing address into the mmap'd backing, +// for direct use as the src argument to UFFDIO_COPY. The returned +// address is valid only while Source is open. +func (s *Source) PagePointer(offset, length uint64) (uint64, error) { + if offset+length > uint64(len(s.mapping)) { + return 0, fmt.Errorf("offset %d + length %d exceeds file size %d", offset, length, len(s.mapping)) + } + return uint64(uintptr(unsafe.Pointer(unsafe.SliceData(s.mapping))) + uintptr(offset)), nil +} + +func (s *Source) Close() error { + var firstErr error + if s.mapping != nil { + if err := unix.Munmap(s.mapping); err != nil && firstErr == nil { + firstErr = fmt.Errorf("munmap: %w", err) + } + s.mapping = nil + } + if s.file != nil { + if err := s.file.Close(); err != nil && firstErr == nil { + firstErr = fmt.Errorf("close: %w", err) + } + s.file = nil + } + return firstErr +} diff --git a/internal/vm/uffd/source_test.go b/internal/vm/uffd/source_test.go new file mode 100644 index 0000000..d40bab6 --- /dev/null +++ b/internal/vm/uffd/source_test.go @@ -0,0 +1,89 @@ +package uffd + +import ( + "bytes" + "os" + "path/filepath" + "testing" +) + +func TestSourceOpenAndRead(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "mem.snap") + + // 8 KiB of recognizable bytes — two pages, each with a distinct pattern + // at the page boundary so we can verify PagePointer offsets. + page := 4096 + buf := make([]byte, 2*page) + for i := range buf[:page] { + buf[i] = 0x11 + } + for i := range buf[page:] { + buf[page+i] = 0x22 + } + if err := os.WriteFile(path, buf, 0o644); err != nil { + t.Fatalf("write test file: %v", err) + } + + src, err := OpenSource(path) + if err != nil { + t.Fatalf("OpenSource: %v", err) + } + t.Cleanup(func() { _ = src.Close() }) + + if src.Size() != 2*page { + t.Errorf("Size = %d, want %d", src.Size(), 2*page) + } + + // PagePointer should report distinct offsets for the two pages. + p0, err := src.PagePointer(0, uint64(page)) + if err != nil { + t.Fatalf("PagePointer page 0: %v", err) + } + p1, err := src.PagePointer(uint64(page), uint64(page)) + if err != nil { + t.Fatalf("PagePointer page 1: %v", err) + } + if p1-p0 != uint64(page) { + t.Errorf("page pointer delta = %d, want %d", p1-p0, page) + } + + // Verify mmap contents by reading the internal mapping (test is in + // the same package). PagePointer returns a kernel-facing uint64; we + // don't want to round-trip it through unsafe.Pointer in test code. + if !bytes.Equal(src.mapping[:page], buf[:page]) { + t.Errorf("page 0 contents mismatch") + } + if !bytes.Equal(src.mapping[page:], buf[page:]) { + t.Errorf("page 1 contents mismatch") + } +} + +func TestSourceBoundsCheck(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "mem.snap") + if err := os.WriteFile(path, make([]byte, 4096), 0o644); err != nil { + t.Fatalf("write test file: %v", err) + } + src, err := OpenSource(path) + if err != nil { + t.Fatalf("OpenSource: %v", err) + } + t.Cleanup(func() { _ = src.Close() }) + + // Reading past the end must error rather than reach into nothing. + if _, err := src.PagePointer(4096, 4096); err == nil { + t.Error("PagePointer past EOF returned nil error, want bounds error") + } +} + +func TestSourceRejectsEmpty(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "empty.snap") + if err := os.WriteFile(path, nil, 0o644); err != nil { + t.Fatalf("write: %v", err) + } + if _, err := OpenSource(path); err == nil { + t.Error("OpenSource on empty file returned nil error, want non-nil") + } +} From f4e54dea0d5b87a566a3a82c1285da5616d4d2d7 Mon Sep 17 00:00:00 2001 From: Amit Patil Date: Tue, 12 May 2026 12:18:16 -0700 Subject: [PATCH 2/9] vm,uffd: prefetch on restore; record access pattern during template build --- cmd/vmd/main.go | 2 + internal/vm/build.go | 16 ++++ internal/vm/manager.go | 77 ++++++++++++++++-- internal/vm/uffd/handler.go | 39 ++++++++-- internal/vm/uffd/prefetch.go | 146 +++++++++++++++++++++++++++++++++++ internal/vm/uffd/recorder.go | 80 +++++++++++++++++++ 6 files changed, 345 insertions(+), 15 deletions(-) create mode 100644 internal/vm/uffd/prefetch.go create mode 100644 internal/vm/uffd/recorder.go diff --git a/cmd/vmd/main.go b/cmd/vmd/main.go index e3e51a8..93dc31a 100644 --- a/cmd/vmd/main.go +++ b/cmd/vmd/main.go @@ -248,6 +248,7 @@ func main() { // ---- VM manager ---- maxRestores, _ := strconv.Atoi(envOrDefault("VMD_MAX_CONCURRENT_RESTORES", "100")) + uffdPrefetchEnabled := envOrDefault("VMD_UFFD_PREFETCH_ENABLED", "true") != "false" mgr, err := vm.NewManager(vm.ManagerConfig{ FirecrackerBin: cfg.FirecrackerBin, @@ -260,6 +261,7 @@ func main() { BoxdBinaryPath: cfg.BoxdBinaryPath, HostInterface: cfg.HostInterface, MaxConcurrentRestores: maxRestores, + UffdPrefetchEnabled: uffdPrefetchEnabled, }, netMgr, log) if err != nil { log.Fatal().Err(err).Msg("failed to initialize VM manager") diff --git a/internal/vm/build.go b/internal/vm/build.go index 38a1075..44ee7e8 100644 --- a/internal/vm/build.go +++ b/internal/vm/build.go @@ -167,6 +167,22 @@ func (m *Manager) buildTemplateSync(ctx context.Context, buildVMID string, req B return nil, fmt.Errorf("read build meta: %w", err) } + // Best-effort: a missing access.log just means sandboxes fall back + // to sequential prefetch. + if m.cfg.UffdPrefetchEnabled { + recordingVMID := "record-" + buildVMID + accessLogPath := filepath.Join(snapshotDir, "access.log") + recCfg := VMConfig{ + VCPU: req.VCPU, + MemoryMiB: req.MemoryMiB, + BasePath: result.BasePath, + DeltaDir: snapshotDir, + } + if recErr := m.RecordAccessPattern(ctx, recordingVMID, result.SnapshotPath, result.MemFilePath, accessLogPath, recCfg, nil, 2*time.Second); recErr != nil { + log.Warn().Err(recErr).Msg("access-pattern recording failed (sandbox will fall back to sequential prefetch)") + } + } + log.Info().Dur("total", time.Since(buildStart)).Msg("template build complete") return result, nil } diff --git a/internal/vm/manager.go b/internal/vm/manager.go index 6e87cd2..652e030 100644 --- a/internal/vm/manager.go +++ b/internal/vm/manager.go @@ -127,6 +127,12 @@ type ManagerConfig struct { // prevent a spike of concurrent sandbox creates from saturating host // file I/O, netns setup, and Firecracker boots. 0 → default 100. MaxConcurrentRestores int + + // UffdPrefetchEnabled turns on background prefetch in the UFFD + // handler. When true, the handler walks mem.snap after the handshake + // and pre-copies pages into guest memory so the first exec doesn't + // stall on cold pages. + UffdPrefetchEnabled bool } // --------------------------------------------------------------------------- @@ -753,6 +759,14 @@ func (m *Manager) assertUnderVMSnapshotDir(vmID, p string) error { // RestoreVMSnapshot boots a VM from a previously captured snapshot. func (m *Manager) RestoreVMSnapshot(ctx context.Context, vmID, snapshotPath, memPath string, resourceLimits VMConfig, netCfg *network.Config, +) (*VMInstance, error) { + return m.restoreVMSnapshot(ctx, vmID, snapshotPath, memPath, resourceLimits, netCfg, nil) +} + +// restoreVMSnapshot is the implementation. The recorder argument is +// non-nil only for template-build access-pattern recording. +func (m *Manager) restoreVMSnapshot(ctx context.Context, vmID, snapshotPath, memPath string, + resourceLimits VMConfig, netCfg *network.Config, recorder *uffd.Recorder, ) (*VMInstance, error) { log := m.log.With().Str("vm_id", vmID).Logger() @@ -866,7 +880,7 @@ func (m *Manager) RestoreVMSnapshot(ctx context.Context, vmID, snapshotPath, mem // inPlace (pause/resume) keeps the File backend; fresh restores use UFFD. if !inPlace { - if err := m.startUffdHandler(inst, uffdSocketPath, memPath); err != nil { + if err := m.startUffdHandler(inst, uffdSocketPath, memPath, recorder); err != nil { if !inPlace { m.netMgr.CleanupVM(vmID) } @@ -1379,12 +1393,25 @@ func (m *Manager) stopUnitDuringRestoreError(vmID string) { // startUffdHandler binds the per-VM UFFD socket, opens mem.snap, and // spawns the fault-serving goroutine. Stores the cancel func on inst so // destroy/error paths can tear it down. -func (m *Manager) startUffdHandler(inst *VMInstance, uffdSocketPath, memSnapPath string) error { - h := uffd.New(uffd.Config{ - SocketPath: uffdSocketPath, - MemSnapPath: memSnapPath, - Logger: m.log.With().Str("vm_id", inst.ID).Logger(), - }) +// +// recorder, when non-nil, hooks into OnPageFault for template-build +// recording. Prefetch is disabled in recording mode — recording captures +// guest demand order, which prefetch would scramble. +func (m *Manager) startUffdHandler(inst *VMInstance, uffdSocketPath, memSnapPath string, recorder *uffd.Recorder) error { + accessLogPath := filepath.Join(filepath.Dir(memSnapPath), "access.log") + cfg := uffd.Config{ + SocketPath: uffdSocketPath, + MemSnapPath: memSnapPath, + AccessLogPath: accessLogPath, + Logger: m.log.With().Str("vm_id", inst.ID).Logger(), + } + if recorder != nil { + cfg.OnPageFault = recorder.Record + cfg.PrefetchEnabled = false + } else { + cfg.PrefetchEnabled = m.cfg.UffdPrefetchEnabled + } + h := uffd.New(cfg) if err := h.Start(); err != nil { return err } @@ -1408,6 +1435,42 @@ func (m *Manager) startUffdHandler(inst *VMInstance, uffdSocketPath, memSnapPath return nil } +// RecordAccessPattern restores the snapshot under UFFD-recording mode, +// lets the guest run for settleDuration, then destroys and writes the +// observed page-access order to outputPath. Called by the template +// build pipeline; produces the access.log the prefetcher consumes. +func (m *Manager) RecordAccessPattern(ctx context.Context, vmID, snapshotPath, memPath, outputPath string, + resourceLimits VMConfig, netCfg *network.Config, settleDuration time.Duration, +) error { + if _, err := os.Stat(outputPath); err == nil { + m.log.Info().Str("path", outputPath).Msg("access log already exists, skipping recording") + return nil + } + + recorder := uffd.NewRecorder() + inst, err := m.restoreVMSnapshot(ctx, vmID, snapshotPath, memPath, resourceLimits, netCfg, recorder) + if err != nil { + return fmt.Errorf("restore for recording: %w", err) + } + defer func() { + _ = m.DestroyVM(context.Background(), inst.ID, true) + }() + + timer := time.NewTimer(settleDuration) + defer timer.Stop() + select { + case <-timer.C: + case <-ctx.Done(): + return ctx.Err() + } + + if err := recorder.Flush(outputPath, false); err != nil { + return fmt.Errorf("flush access log: %w", err) + } + m.log.Info().Str("path", outputPath).Int("pages", recorder.Len()).Msg("access pattern recorded") + return nil +} + // cancelUffdHandler is idempotent and a no-op for non-UFFD VMs. func (m *Manager) cancelUffdHandler(inst *VMInstance) { inst.mu.Lock() diff --git a/internal/vm/uffd/handler.go b/internal/vm/uffd/handler.go index 639ecad..94c0f27 100644 --- a/internal/vm/uffd/handler.go +++ b/internal/vm/uffd/handler.go @@ -39,17 +39,34 @@ type Config struct { SocketPath string MemSnapPath string FaultWorkers int // 0 -> defaultFaultWorkers - Logger zerolog.Logger + + // AccessLogPath, when set and present on disk, supplies a recorded + // page-access order from template build time. The prefetcher replays + // it. Missing file → sequential fallback. + AccessLogPath string + + // PrefetchEnabled controls whether the prefetcher runs at all. + PrefetchEnabled bool + + // OnPageFault, if non-nil, is invoked for each PageFault event with + // the backing-file offset of the page about to be served. Used by + // template-builder to record access patterns. Must be fast (called + // on the fault hot path). + OnPageFault func(offset uint64) + + Logger zerolog.Logger } type Stats struct { - FaultsServed atomic.Uint64 - BytesServed atomic.Uint64 - CopyErrors atomic.Uint64 - UnknownEvents atomic.Uint64 - RemoveEvents atomic.Uint64 - EagainRetries atomic.Uint64 - Eexist atomic.Uint64 + FaultsServed atomic.Uint64 + BytesServed atomic.Uint64 + CopyErrors atomic.Uint64 + UnknownEvents atomic.Uint64 + RemoveEvents atomic.Uint64 + EagainRetries atomic.Uint64 + Eexist atomic.Uint64 + PrefetchedPages atomic.Uint64 + PrefetchSkipped atomic.Uint64 // EEXIST during prefetch (page already faulted by guest) } // Handler binds a UDS, receives a userfaultfd fd from Firecracker, then @@ -122,6 +139,9 @@ func (h *Handler) Run(ctx context.Context) error { if err := h.acceptAndReceive(ctx); err != nil { return fmt.Errorf("receive handshake: %w", err) } + if h.cfg.PrefetchEnabled { + go h.runPrefetch(ctx) + } return h.serveLoop(ctx) } @@ -304,6 +324,9 @@ func (h *Handler) servePagefault(g *errgroup.Group, fd uintptr, faultAddr, pageS } offset := region.Offset + (pageAddr - region.BaseHostVirtAddr) + if h.cfg.OnPageFault != nil { + h.cfg.OnPageFault(offset) + } srcPtr, err := h.source.PagePointer(offset, pageSize) if err != nil { h.stats.CopyErrors.Add(1) diff --git a/internal/vm/uffd/prefetch.go b/internal/vm/uffd/prefetch.go new file mode 100644 index 0000000..7690f55 --- /dev/null +++ b/internal/vm/uffd/prefetch.go @@ -0,0 +1,146 @@ +package uffd + +import ( + "bufio" + "context" + "errors" + "fmt" + "os" + "strconv" + "strings" + "syscall" +) + +// runPrefetch warms guest memory pre-emptively. Replays AccessLogPath if +// present (only pages the guest will touch); otherwise walks the +// snapshot sequentially. Pages already faulted by the guest return +// EEXIST from UFFDIO_COPY and are skipped cheaply. +func (h *Handler) runPrefetch(ctx context.Context) { + defer func() { + if r := recover(); r != nil { + h.log.Error().Interface("panic", r).Msg("prefetch panicked") + } + }() + + pageSize := h.pageSize + if pageSize == 0 { + h.log.Warn().Msg("prefetch: no page size, skipping") + return + } + fd := h.uffdFd + if fd == ^uintptr(0) { + return + } + + offsets, ordered := h.prefetchOffsets() + h.log.Info().Bool("ordered", ordered).Int("offsets", len(offsets)).Msg("prefetch starting") + + for _, off := range offsets { + if ctx.Err() != nil { + return + } + fd := h.uffdFd + if fd == ^uintptr(0) { + return + } + if err := h.prefetchOne(fd, off, pageSize); err != nil { + // Abort on first non-benign error; on-demand fault serving + // is unaffected. + h.log.Warn().Err(err).Uint64("offset", off).Msg("prefetch aborted") + return + } + } + h.log.Debug().Uint64("served", h.stats.PrefetchedPages.Load()).Msg("prefetch complete") +} + +func (h *Handler) prefetchOne(fd uintptr, offset, pageSize uint64) error { + region := h.regionForOffset(offset) + if region == nil { + return nil + } + dst := region.BaseHostVirtAddr + (offset - region.Offset) + srcPtr, err := h.source.PagePointer(offset, pageSize) + if err != nil { + return err + } + _, copyErr := ioctlCopy(fd, dst, srcPtr, pageSize, 0) + if copyErr != nil { + if errors.Is(copyErr, syscall.EEXIST) { + h.stats.PrefetchSkipped.Add(1) + return nil + } + if errors.Is(copyErr, syscall.EAGAIN) { + // REMOVE pending in queue; main fault loop will drain it. + // Guest will re-fault this page later if it needs it. + h.stats.PrefetchSkipped.Add(1) + return nil + } + return fmt.Errorf("prefetch UFFDIO_COPY at %#x: %w", dst, copyErr) + } + h.stats.PrefetchedPages.Add(1) + return nil +} + +func (h *Handler) regionForOffset(offset uint64) *GuestRegionMapping { + for i := range h.mappings { + r := &h.mappings[i] + if offset >= r.Offset && offset < r.Offset+r.Size { + return r + } + } + return nil +} + +// prefetchOffsets returns page-aligned offsets to prefetch in order. +// The bool is true when access.log drove the order, false for fallback. +func (h *Handler) prefetchOffsets() ([]uint64, bool) { + if h.cfg.AccessLogPath != "" { + if offsets, err := readAccessLog(h.cfg.AccessLogPath, h.pageSize); err == nil && len(offsets) > 0 { + return offsets, true + } + } + // Sequential fallback: walk every region from start to end. + var out []uint64 + for _, r := range h.mappings { + for off := r.Offset; off < r.Offset+r.Size; off += h.pageSize { + out = append(out, off) + } + } + return out, false +} + +// readAccessLog parses newline-separated decimal offsets. Misaligned +// lines and blank/# lines are silently skipped. +func readAccessLog(path string, pageSize uint64) ([]uint64, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + var offsets []uint64 + seen := make(map[uint64]struct{}) + scanner := bufio.NewScanner(f) + for scanner.Scan() { + s := strings.TrimSpace(scanner.Text()) + if s == "" || strings.HasPrefix(s, "#") { + continue + } + n, err := strconv.ParseUint(s, 10, 64) + if err != nil { + continue + } + if n%pageSize != 0 { + continue + } + if _, dup := seen[n]; dup { + continue + } + seen[n] = struct{}{} + offsets = append(offsets, n) + } + if err := scanner.Err(); err != nil { + return offsets, err + } + return offsets, nil +} diff --git a/internal/vm/uffd/recorder.go b/internal/vm/uffd/recorder.go new file mode 100644 index 0000000..2ce4a07 --- /dev/null +++ b/internal/vm/uffd/recorder.go @@ -0,0 +1,80 @@ +package uffd + +import ( + "bufio" + "fmt" + "os" + "sort" + "strconv" + "sync" +) + +// Recorder collects page-fault offsets in observation order for later +// replay by the prefetcher. Safe for concurrent Record calls; capture +// order approximates guest first-touch order under our worker pool. +type Recorder struct { + mu sync.Mutex + seen map[uint64]struct{} + ordered []uint64 +} + +func NewRecorder() *Recorder { + return &Recorder{seen: make(map[uint64]struct{})} +} + +func (r *Recorder) Record(offset uint64) { + r.mu.Lock() + defer r.mu.Unlock() + if _, ok := r.seen[offset]; ok { + return + } + r.seen[offset] = struct{}{} + r.ordered = append(r.ordered, offset) +} + +func (r *Recorder) Len() int { + r.mu.Lock() + defer r.mu.Unlock() + return len(r.ordered) +} + +// Flush writes recorded offsets to path, one per line. sortAscending=true +// emits numerical order instead of capture order — use when parallel +// workers make capture order unreliable. +func (r *Recorder) Flush(path string, sortAscending bool) error { + r.mu.Lock() + out := make([]uint64, len(r.ordered)) + copy(out, r.ordered) + r.mu.Unlock() + + if sortAscending { + sort.Slice(out, func(i, j int) bool { return out[i] < out[j] }) + } + + tmp := path + ".tmp" + f, err := os.Create(tmp) + if err != nil { + return fmt.Errorf("create %s: %w", tmp, err) + } + w := bufio.NewWriter(f) + for _, off := range out { + if _, err := w.WriteString(strconv.FormatUint(off, 10) + "\n"); err != nil { + _ = f.Close() + _ = os.Remove(tmp) + return fmt.Errorf("write: %w", err) + } + } + if err := w.Flush(); err != nil { + _ = f.Close() + _ = os.Remove(tmp) + return fmt.Errorf("flush: %w", err) + } + if err := f.Close(); err != nil { + _ = os.Remove(tmp) + return fmt.Errorf("close: %w", err) + } + if err := os.Rename(tmp, path); err != nil { + return fmt.Errorf("rename: %w", err) + } + return nil +} From b4c6d1d50468bc87f42265eb87ffac205113d0d6 Mon Sep 17 00:00:00 2001 From: Amit Patil Date: Tue, 12 May 2026 13:59:43 -0700 Subject: [PATCH 3/9] uffd: clear O_NONBLOCK on UFFD fd so serveLoop blocks instead of exiting on EAGAIN --- internal/vm/uffd/handler.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/internal/vm/uffd/handler.go b/internal/vm/uffd/handler.go index 94c0f27..2c6eaf9 100644 --- a/internal/vm/uffd/handler.go +++ b/internal/vm/uffd/handler.go @@ -220,6 +220,12 @@ func (h *Handler) acceptAndReceive(ctx context.Context) error { return errors.New("no UFFD fd received in SCM_RIGHTS") } + // Firecracker creates UFFD with O_NONBLOCK; clear it so unix.Read + // blocks (cancellation closes the fd → EBADF wakes the read). + if flags, err := unix.FcntlInt(uintptr(uffdFd), unix.F_GETFL, 0); err == nil { + _, _ = unix.FcntlInt(uintptr(uffdFd), unix.F_SETFL, flags&^unix.O_NONBLOCK) + } + var mappings []GuestRegionMapping if err := json.Unmarshal(body[:n], &mappings); err != nil { _ = unix.Close(uffdFd) @@ -279,9 +285,8 @@ func (h *Handler) serveLoop(ctx context.Context) error { if errors.Is(err, syscall.EINTR) { continue } - // EBADF: closed by our cancellation goroutine. - // EAGAIN: transient. - if errors.Is(err, syscall.EBADF) || errors.Is(err, syscall.EAGAIN) { + // EBADF: cancellation goroutine closed the fd. Exit cleanly. + if errors.Is(err, syscall.EBADF) { return g.Wait() } return fmt.Errorf("read uffd_msg: %w", err) From a7ffc1110093a26f7e9988ebd14ff173165bc814 Mon Sep 17 00:00:00 2001 From: Amit Patil Date: Tue, 12 May 2026 14:14:10 -0700 Subject: [PATCH 4/9] uffd: serialize ioctls vs close with RWMutex + atomic fd to close fd-reuse race --- internal/vm/uffd/handler.go | 81 +++++++++++++++++++++++++----------- internal/vm/uffd/prefetch.go | 41 +++++++++++------- 2 files changed, 82 insertions(+), 40 deletions(-) diff --git a/internal/vm/uffd/handler.go b/internal/vm/uffd/handler.go index 2c6eaf9..888a571 100644 --- a/internal/vm/uffd/handler.go +++ b/internal/vm/uffd/handler.go @@ -8,6 +8,7 @@ import ( "net" "os" "path/filepath" + "sync" "sync/atomic" "syscall" "time" @@ -20,6 +21,10 @@ import ( const defaultFaultWorkers = 4096 +// uffdClosed is the sentinel value of Handler.uffdFd before the fd is +// received and after it has been closed. +const uffdClosed = ^uintptr(0) + // GuestRegionMapping is the JSON Firecracker sends over the UDS alongside // the userfaultfd fd. Field tags must match // GuestRegionUffdMapping in src/vmm/src/persist.rs of our firecracker fork. @@ -78,11 +83,16 @@ type Handler struct { stats Stats source *Source listener *net.UnixListener - uffdFd uintptr // dup'd from SCM_RIGHTS; -1 if not yet received + + // fdMu (RLock per ioctl, Lock around close) prevents an ioctl from + // racing with close + kernel fd-reuse. uffdFd is atomic so the + // lock-free serveLoop read is sound under Go's memory model. + fdMu sync.RWMutex + uffdFd atomic.Uintptr // uffdClosed when not held + mappings []GuestRegionMapping pageSize uint64 - // closed is signaled when Close has been called. closed atomic.Bool } @@ -90,11 +100,12 @@ func New(cfg Config) *Handler { if cfg.FaultWorkers <= 0 { cfg.FaultWorkers = defaultFaultWorkers } - return &Handler{ - cfg: cfg, - log: cfg.Logger.With().Str("component", "uffd").Logger(), - uffdFd: ^uintptr(0), + h := &Handler{ + cfg: cfg, + log: cfg.Logger.With().Str("component", "uffd").Logger(), } + h.uffdFd.Store(uffdClosed) + return h } func (h *Handler) Stats() *Stats { return &h.stats } @@ -158,10 +169,7 @@ func (h *Handler) Close() error { if h.cfg.SocketPath != "" { _ = os.Remove(h.cfg.SocketPath) } - if h.uffdFd != ^uintptr(0) { - _ = unix.Close(int(h.uffdFd)) - h.uffdFd = ^uintptr(0) - } + h.closeFd() if h.source != nil { if err := h.source.Close(); err != nil && firstErr == nil { firstErr = fmt.Errorf("close source: %w", err) @@ -171,6 +179,21 @@ func (h *Handler) Close() error { return firstErr } +// closeFd closes the userfaultfd under the write lock so no in-flight +// ioctl can be operating on the fd when the kernel frees the slot +// (preventing fd-number reuse from redirecting a worker's ioctl to an +// unrelated open). Idempotent. +func (h *Handler) closeFd() { + h.fdMu.Lock() + defer h.fdMu.Unlock() + fd := h.uffdFd.Load() + if fd == uffdClosed { + return + } + _ = unix.Close(int(fd)) + h.uffdFd.Store(uffdClosed) +} + func (h *Handler) acceptAndReceive(ctx context.Context) error { // ctx with no deadline → 30s default, matching the gRPC deadline on // the caller side. The listener deadline is the only way to make @@ -244,7 +267,7 @@ func (h *Handler) acceptAndReceive(ctx context.Context) error { return fmt.Errorf("invalid page size in mappings: %+v", mappings[0]) } - h.uffdFd = uintptr(uffdFd) + h.uffdFd.Store(uintptr(uffdFd)) h.mappings = mappings h.pageSize = pageSize @@ -265,10 +288,7 @@ func (h *Handler) serveLoop(ctx context.Context) error { // so bridge context cancellation to a close here. go func() { <-gctx.Done() - if h.uffdFd != ^uintptr(0) { - _ = unix.Close(int(h.uffdFd)) - h.uffdFd = ^uintptr(0) - } + h.closeFd() }() var msgBuf [32]byte @@ -276,8 +296,8 @@ func (h *Handler) serveLoop(ctx context.Context) error { if err := gctx.Err(); err != nil { return g.Wait() } - fd := h.uffdFd - if fd == ^uintptr(0) { + fd := h.uffdFd.Load() + if fd == uffdClosed { return g.Wait() } n, err := unix.Read(int(fd), msgBuf[:]) @@ -302,13 +322,11 @@ func (h *Handler) serveLoop(ctx context.Context) error { switch msg.Event { case UFFD_EVENT_PAGEFAULT: pf := msg.asPageFault() - fdCopy := fd - pageSize := h.pageSize - h.servePagefault(g, fdCopy, pf.Address, pageSize) + h.servePagefault(g, pf.Address, h.pageSize) case UFFD_EVENT_REMOVE: h.stats.RemoveEvents.Add(1) rm := msg.asRemove() - h.handleRemove(fd, rm.Start, rm.End) + h.handleRemove(rm.Start, rm.End) default: h.stats.UnknownEvents.Add(1) h.log.Debug().Uint8("event", msg.Event).Msg("ignoring unknown uffd event") @@ -318,7 +336,7 @@ func (h *Handler) serveLoop(ctx context.Context) error { // servePagefault dispatches a single fault to the worker pool. The pool // blocks if all FaultWorkers slots are busy — natural backpressure. -func (h *Handler) servePagefault(g *errgroup.Group, fd uintptr, faultAddr, pageSize uint64) { +func (h *Handler) servePagefault(g *errgroup.Group, faultAddr, pageSize uint64) { g.Go(func() error { pageAddr := faultAddr &^ (pageSize - 1) @@ -338,7 +356,14 @@ func (h *Handler) servePagefault(g *errgroup.Group, fd uintptr, faultAddr, pageS return fmt.Errorf("source page lookup: %w", err) } + h.fdMu.RLock() + fd := h.uffdFd.Load() + if fd == uffdClosed { + h.fdMu.RUnlock() + return nil + } _, copyErr := ioctlCopy(fd, pageAddr, srcPtr, pageSize, 0) + h.fdMu.RUnlock() if copyErr != nil { // EEXIST: another fault on the same page raced ahead of us; // it's already mapped. Benign. @@ -365,12 +390,20 @@ func (h *Handler) servePagefault(g *errgroup.Group, fd uintptr, faultAddr, pageS // handleRemove responds to balloon-device removals by zeroing the range. // Without this, subsequent faults on these pages would resolve to stale // mem.snap contents instead of zero pages. -func (h *Handler) handleRemove(fd uintptr, start, end uint64) { +func (h *Handler) handleRemove(start, end uint64) { if end <= start { return } length := end - start - if err := ioctlZeropage(fd, start, length); err != nil { + h.fdMu.RLock() + fd := h.uffdFd.Load() + if fd == uffdClosed { + h.fdMu.RUnlock() + return + } + err := ioctlZeropage(fd, start, length) + h.fdMu.RUnlock() + if err != nil { h.log.Warn().Err(err).Uint64("start", start).Uint64("end", end).Msg("UFFDIO_ZEROPAGE failed") } } diff --git a/internal/vm/uffd/prefetch.go b/internal/vm/uffd/prefetch.go index 7690f55..bfefea7 100644 --- a/internal/vm/uffd/prefetch.go +++ b/internal/vm/uffd/prefetch.go @@ -27,8 +27,7 @@ func (h *Handler) runPrefetch(ctx context.Context) { h.log.Warn().Msg("prefetch: no page size, skipping") return } - fd := h.uffdFd - if fd == ^uintptr(0) { + if h.uffdFd.Load() == uffdClosed { return } @@ -39,46 +38,56 @@ func (h *Handler) runPrefetch(ctx context.Context) { if ctx.Err() != nil { return } - fd := h.uffdFd - if fd == ^uintptr(0) { - return - } - if err := h.prefetchOne(fd, off, pageSize); err != nil { + stop, err := h.prefetchOne(off, pageSize) + if err != nil { // Abort on first non-benign error; on-demand fault serving // is unaffected. h.log.Warn().Err(err).Uint64("offset", off).Msg("prefetch aborted") return } + if stop { + return + } } h.log.Debug().Uint64("served", h.stats.PrefetchedPages.Load()).Msg("prefetch complete") } -func (h *Handler) prefetchOne(fd uintptr, offset, pageSize uint64) error { +// prefetchOne returns stop=true when the fd has been closed under us, +// signaling the caller to exit the prefetch loop. +func (h *Handler) prefetchOne(offset, pageSize uint64) (stop bool, err error) { region := h.regionForOffset(offset) if region == nil { - return nil + return false, nil } dst := region.BaseHostVirtAddr + (offset - region.Offset) - srcPtr, err := h.source.PagePointer(offset, pageSize) - if err != nil { - return err + srcPtr, perr := h.source.PagePointer(offset, pageSize) + if perr != nil { + return false, perr + } + + h.fdMu.RLock() + fd := h.uffdFd.Load() + if fd == uffdClosed { + h.fdMu.RUnlock() + return true, nil } _, copyErr := ioctlCopy(fd, dst, srcPtr, pageSize, 0) + h.fdMu.RUnlock() if copyErr != nil { if errors.Is(copyErr, syscall.EEXIST) { h.stats.PrefetchSkipped.Add(1) - return nil + return false, nil } if errors.Is(copyErr, syscall.EAGAIN) { // REMOVE pending in queue; main fault loop will drain it. // Guest will re-fault this page later if it needs it. h.stats.PrefetchSkipped.Add(1) - return nil + return false, nil } - return fmt.Errorf("prefetch UFFDIO_COPY at %#x: %w", dst, copyErr) + return false, fmt.Errorf("prefetch UFFDIO_COPY at %#x: %w", dst, copyErr) } h.stats.PrefetchedPages.Add(1) - return nil + return false, nil } func (h *Handler) regionForOffset(offset uint64) *GuestRegionMapping { From 32f3c33dee36aa3fd0451393b5c75c446e148d66 Mon Sep 17 00:00:00 2001 From: Amit Patil Date: Tue, 12 May 2026 14:34:46 -0700 Subject: [PATCH 5/9] uffd: replace fixed 2s recording timeout with fault-rate convergence detection --- cmd/vmd/main.go | 2 + internal/vm/build.go | 2 +- internal/vm/manager.go | 102 ++++++++++++++++++++++++++++++++---- internal/vm/manager_test.go | 99 ++++++++++++++++++++++++++++++++++ 4 files changed, 194 insertions(+), 11 deletions(-) diff --git a/cmd/vmd/main.go b/cmd/vmd/main.go index 93dc31a..b7f94d3 100644 --- a/cmd/vmd/main.go +++ b/cmd/vmd/main.go @@ -249,6 +249,7 @@ func main() { // ---- VM manager ---- maxRestores, _ := strconv.Atoi(envOrDefault("VMD_MAX_CONCURRENT_RESTORES", "100")) uffdPrefetchEnabled := envOrDefault("VMD_UFFD_PREFETCH_ENABLED", "true") != "false" + uffdRecordMaxSeconds, _ := strconv.Atoi(envOrDefault("VMD_UFFD_RECORD_MAX_SECONDS", "10")) mgr, err := vm.NewManager(vm.ManagerConfig{ FirecrackerBin: cfg.FirecrackerBin, @@ -262,6 +263,7 @@ func main() { HostInterface: cfg.HostInterface, MaxConcurrentRestores: maxRestores, UffdPrefetchEnabled: uffdPrefetchEnabled, + UffdRecordMaxSeconds: uffdRecordMaxSeconds, }, netMgr, log) if err != nil { log.Fatal().Err(err).Msg("failed to initialize VM manager") diff --git a/internal/vm/build.go b/internal/vm/build.go index 44ee7e8..9efb2df 100644 --- a/internal/vm/build.go +++ b/internal/vm/build.go @@ -178,7 +178,7 @@ func (m *Manager) buildTemplateSync(ctx context.Context, buildVMID string, req B BasePath: result.BasePath, DeltaDir: snapshotDir, } - if recErr := m.RecordAccessPattern(ctx, recordingVMID, result.SnapshotPath, result.MemFilePath, accessLogPath, recCfg, nil, 2*time.Second); recErr != nil { + if recErr := m.RecordAccessPattern(ctx, recordingVMID, result.SnapshotPath, result.MemFilePath, accessLogPath, recCfg, nil); recErr != nil { log.Warn().Err(recErr).Msg("access-pattern recording failed (sandbox will fall back to sequential prefetch)") } } diff --git a/internal/vm/manager.go b/internal/vm/manager.go index 652e030..285e130 100644 --- a/internal/vm/manager.go +++ b/internal/vm/manager.go @@ -92,6 +92,8 @@ type VMInstance struct { // uffdCancel cancels the per-VM UFFD handler goroutine. nil for VMs // not using UFFD (cold boots, template builds, in-place resumes). uffdCancel context.CancelFunc + // uffdHandler is the per-VM UFFD handler; nil for non-UFFD VMs. + uffdHandler *uffd.Handler mu sync.RWMutex } @@ -133,6 +135,10 @@ type ManagerConfig struct { // and pre-copies pages into guest memory so the first exec doesn't // stall on cold pages. UffdPrefetchEnabled bool + + // UffdRecordMaxSeconds caps how long RecordAccessPattern waits for the + // guest's fault rate to settle before giving up. 0 → 10s default. + UffdRecordMaxSeconds int } // --------------------------------------------------------------------------- @@ -1418,6 +1424,7 @@ func (m *Manager) startUffdHandler(inst *VMInstance, uffdSocketPath, memSnapPath uffdCtx, uffdCancel := context.WithCancel(context.Background()) inst.mu.Lock() inst.uffdCancel = uffdCancel + inst.uffdHandler = h inst.mu.Unlock() vmID := inst.ID @@ -1435,12 +1442,18 @@ func (m *Manager) startUffdHandler(inst *VMInstance, uffdSocketPath, memSnapPath return nil } +const ( + recordTickInterval = 100 * time.Millisecond + recordMinDuration = 500 * time.Millisecond + recordQuietTicks = 3 + recordDefaultMax = 10 * time.Second +) + // RecordAccessPattern restores the snapshot under UFFD-recording mode, -// lets the guest run for settleDuration, then destroys and writes the -// observed page-access order to outputPath. Called by the template -// build pipeline; produces the access.log the prefetcher consumes. +// waits for the guest's page-fault rate to settle, and writes the +// observed page-access order to outputPath. func (m *Manager) RecordAccessPattern(ctx context.Context, vmID, snapshotPath, memPath, outputPath string, - resourceLimits VMConfig, netCfg *network.Config, settleDuration time.Duration, + resourceLimits VMConfig, netCfg *network.Config, ) error { if _, err := os.Stat(outputPath); err == nil { m.log.Info().Str("path", outputPath).Msg("access log already exists, skipping recording") @@ -1456,21 +1469,90 @@ func (m *Manager) RecordAccessPattern(ctx context.Context, vmID, snapshotPath, m _ = m.DestroyVM(context.Background(), inst.ID, true) }() - timer := time.NewTimer(settleDuration) - defer timer.Stop() - select { - case <-timer.C: - case <-ctx.Done(): + inst.mu.RLock() + handler := inst.uffdHandler + inst.mu.RUnlock() + if handler == nil { + return errors.New("UFFD handler missing after restore-for-recording") + } + + maxDuration := recordDefaultMax + if m.cfg.UffdRecordMaxSeconds > 0 { + maxDuration = time.Duration(m.cfg.UffdRecordMaxSeconds) * time.Second + } + + settled, elapsed, totalFaults := m.waitFaultsSettle(ctx, handler, maxDuration) + if errors.Is(ctx.Err(), context.Canceled) || errors.Is(ctx.Err(), context.DeadlineExceeded) { return ctx.Err() } + pages := recorder.Len() + mkEvent := func() *zerolog.Event { + ev := m.log.Info() + if !settled { + ev = m.log.Warn() + } + return ev.Str("template_vm", vmID).Bool("settled", settled). + Dur("elapsed", elapsed).Uint64("total_faults", totalFaults).Int("pages", pages) + } + + if pages == 0 { + // Missing access.log → sequential prefetch fallback, which is a + // cleaner signal than a zero-byte log. + mkEvent().Msg("access pattern recording produced no pages; access.log not written") + return nil + } + if err := recorder.Flush(outputPath, false); err != nil { return fmt.Errorf("flush access log: %w", err) } - m.log.Info().Str("path", outputPath).Int("pages", recorder.Len()).Msg("access pattern recorded") + mkEvent().Str("path", outputPath).Msg("access pattern recorded") return nil } +// waitFaultsSettle returns settled=true once the guest has produced +// activity then stayed quiet for recordQuietTicks past +// recordMinDuration. settled=false means maxDuration or ctx tripped. +func (m *Manager) waitFaultsSettle(ctx context.Context, h *uffd.Handler, maxDuration time.Duration) (settled bool, elapsed time.Duration, totalFaults uint64) { + start := time.Now() + ticker := time.NewTicker(recordTickInterval) + defer ticker.Stop() + deadline := time.NewTimer(maxDuration) + defer deadline.Stop() + + var lastFaults uint64 + seenActivity := false + quietCount := 0 + + for { + select { + case <-ctx.Done(): + return false, time.Since(start), h.Stats().FaultsServed.Load() + case <-deadline.C: + return false, time.Since(start), h.Stats().FaultsServed.Load() + case <-ticker.C: + cur := h.Stats().FaultsServed.Load() + delta := cur - lastFaults + lastFaults = cur + if delta > 0 { + seenActivity = true + quietCount = 0 + continue + } + if !seenActivity { + continue + } + if time.Since(start) < recordMinDuration { + continue + } + quietCount++ + if quietCount >= recordQuietTicks { + return true, time.Since(start), cur + } + } + } +} + // cancelUffdHandler is idempotent and a no-op for non-UFFD VMs. func (m *Manager) cancelUffdHandler(inst *VMInstance) { inst.mu.Lock() diff --git a/internal/vm/manager_test.go b/internal/vm/manager_test.go index 694eb78..fa27efd 100644 --- a/internal/vm/manager_test.go +++ b/internal/vm/manager_test.go @@ -5,6 +5,11 @@ import ( "os" "path/filepath" "testing" + "time" + + "github.com/rs/zerolog" + + "github.com/superserve-ai/sandbox/internal/vm/uffd" ) // TestPlanRestore pins the restore-decision behavior across the four input @@ -374,3 +379,97 @@ func TestDeleteSnapshotFiles_NoSnapshotDirConfigured_Rejected(t *testing.T) { t.Error("expected rejection when SnapshotDir is unconfigured") } } + +func newTestManager() *Manager { + return &Manager{log: zerolog.Nop()} +} + +func TestWaitFaultsSettle_Converges(t *testing.T) { + h := uffd.New(uffd.Config{Logger: zerolog.Nop()}) + stop := make(chan struct{}) + go func() { + ticker := time.NewTicker(20 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-stop: + return + case <-ticker.C: + h.Stats().FaultsServed.Add(100) + } + } + }() + time.AfterFunc(150*time.Millisecond, func() { close(stop) }) + + mgr := newTestManager() + settled, elapsed, total := mgr.waitFaultsSettle(context.Background(), h, 5*time.Second) + if !settled { + t.Errorf("expected settled=true, got false (elapsed=%v total=%d)", elapsed, total) + } + if elapsed > 2*time.Second { + t.Errorf("settle took too long: %v", elapsed) + } + if total == 0 { + t.Error("expected non-zero total faults") + } +} + +func TestWaitFaultsSettle_NoActivityHitsCeiling(t *testing.T) { + h := uffd.New(uffd.Config{Logger: zerolog.Nop()}) + mgr := newTestManager() + settled, elapsed, total := mgr.waitFaultsSettle(context.Background(), h, 300*time.Millisecond) + if settled { + t.Errorf("expected settled=false (no activity), got true") + } + if total != 0 { + t.Errorf("expected zero total faults, got %d", total) + } + if elapsed < 250*time.Millisecond { + t.Errorf("ceiling hit too early: %v", elapsed) + } +} + +func TestWaitFaultsSettle_NeverQuietHitsCeiling(t *testing.T) { + h := uffd.New(uffd.Config{Logger: zerolog.Nop()}) + stop := make(chan struct{}) + defer close(stop) + go func() { + ticker := time.NewTicker(30 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-stop: + return + case <-ticker.C: + h.Stats().FaultsServed.Add(1) + } + } + }() + + mgr := newTestManager() + settled, elapsed, total := mgr.waitFaultsSettle(context.Background(), h, 400*time.Millisecond) + if settled { + t.Errorf("expected settled=false (continuous activity), got true") + } + if total == 0 { + t.Error("expected non-zero total faults") + } + if elapsed < 350*time.Millisecond { + t.Errorf("ceiling tripped too early: %v", elapsed) + } +} + +func TestWaitFaultsSettle_CtxCancel(t *testing.T) { + h := uffd.New(uffd.Config{Logger: zerolog.Nop()}) + ctx, cancel := context.WithCancel(context.Background()) + time.AfterFunc(100*time.Millisecond, cancel) + + mgr := newTestManager() + settled, elapsed, _ := mgr.waitFaultsSettle(ctx, h, 30*time.Second) + if settled { + t.Errorf("expected settled=false on cancel") + } + if elapsed > 500*time.Millisecond { + t.Errorf("cancel did not propagate promptly: %v", elapsed) + } +} From e803de915f368be874b237efbe5e209430938556 Mon Sep 17 00:00:00 2001 From: Amit Patil Date: Tue, 12 May 2026 14:41:10 -0700 Subject: [PATCH 6/9] uffd: split handshake/serve ctxs and add VMD_UFFD_ENABLED kill switch --- cmd/vmd/main.go | 2 ++ internal/vm/build.go | 2 +- internal/vm/manager.go | 54 +++++++++++++++++++++++-------------- internal/vm/uffd/handler.go | 18 +++++++++---- 4 files changed, 50 insertions(+), 26 deletions(-) diff --git a/cmd/vmd/main.go b/cmd/vmd/main.go index b7f94d3..b02f621 100644 --- a/cmd/vmd/main.go +++ b/cmd/vmd/main.go @@ -248,6 +248,7 @@ func main() { // ---- VM manager ---- maxRestores, _ := strconv.Atoi(envOrDefault("VMD_MAX_CONCURRENT_RESTORES", "100")) + uffdEnabled := envOrDefault("VMD_UFFD_ENABLED", "true") != "false" uffdPrefetchEnabled := envOrDefault("VMD_UFFD_PREFETCH_ENABLED", "true") != "false" uffdRecordMaxSeconds, _ := strconv.Atoi(envOrDefault("VMD_UFFD_RECORD_MAX_SECONDS", "10")) @@ -262,6 +263,7 @@ func main() { BoxdBinaryPath: cfg.BoxdBinaryPath, HostInterface: cfg.HostInterface, MaxConcurrentRestores: maxRestores, + UffdEnabled: uffdEnabled, UffdPrefetchEnabled: uffdPrefetchEnabled, UffdRecordMaxSeconds: uffdRecordMaxSeconds, }, netMgr, log) diff --git a/internal/vm/build.go b/internal/vm/build.go index 9efb2df..1441247 100644 --- a/internal/vm/build.go +++ b/internal/vm/build.go @@ -169,7 +169,7 @@ func (m *Manager) buildTemplateSync(ctx context.Context, buildVMID string, req B // Best-effort: a missing access.log just means sandboxes fall back // to sequential prefetch. - if m.cfg.UffdPrefetchEnabled { + if m.cfg.UffdEnabled && m.cfg.UffdPrefetchEnabled { recordingVMID := "record-" + buildVMID accessLogPath := filepath.Join(snapshotDir, "access.log") recCfg := VMConfig{ diff --git a/internal/vm/manager.go b/internal/vm/manager.go index 285e130..86d1b31 100644 --- a/internal/vm/manager.go +++ b/internal/vm/manager.go @@ -130,10 +130,16 @@ type ManagerConfig struct { // file I/O, netns setup, and Firecracker boots. 0 → default 100. MaxConcurrentRestores int + // UffdEnabled gates the UFFD lazy-restore path. false → fresh + // restores fall back to the File memory backend (synchronous CRC64), + // same as in-place resume. Default true; flip to false as an ops + // circuit breaker without redeploying. + UffdEnabled bool + // UffdPrefetchEnabled turns on background prefetch in the UFFD // handler. When true, the handler walks mem.snap after the handshake // and pre-copies pages into guest memory so the first exec doesn't - // stall on cold pages. + // stall on cold pages. Ignored when UffdEnabled is false. UffdPrefetchEnabled bool // UffdRecordMaxSeconds caps how long RecordAccessPattern waits for the @@ -884,12 +890,12 @@ func (m *Manager) restoreVMSnapshot(ctx context.Context, vmID, snapshotPath, mem inst.SocketPath = socketPath inst.mu.Unlock() - // inPlace (pause/resume) keeps the File backend; fresh restores use UFFD. - if !inPlace { - if err := m.startUffdHandler(inst, uffdSocketPath, memPath, recorder); err != nil { - if !inPlace { - m.netMgr.CleanupVM(vmID) - } + // inPlace resume always uses File backend; UffdEnabled=false is the + // ops circuit breaker that forces fresh restores onto File too. + useUffd := !inPlace && m.cfg.UffdEnabled + if useUffd { + if err := m.startUffdHandler(ctx, inst, uffdSocketPath, memPath, recorder); err != nil { + m.netMgr.CleanupVM(vmID) m.cleanupRunDir(vmID) m.setStatus(vmID, StatusError) return nil, fmt.Errorf("start uffd handler: %w", err) @@ -913,10 +919,14 @@ func (m *Manager) restoreVMSnapshot(ctx context.Context, vmID, snapshotPath, mem log.Info().Msg("restoring snapshot") var restoreErr error - if inPlace { - restoreErr = RestoreSnapshot(socketPath, snapshotPath, memPath, plan.deltaDir) - } else { + switch { + case useUffd: restoreErr = RestoreSnapshotUffdWithOverrides(socketPath, snapshotPath, uffdSocketPath, "eth0", tapDevice, plan.deltaDir) + case inPlace: + restoreErr = RestoreSnapshot(socketPath, snapshotPath, memPath, plan.deltaDir) + default: + // UFFD disabled but fresh restore — File backend with network overrides. + restoreErr = RestoreSnapshotWithOverrides(socketPath, snapshotPath, memPath, "eth0", tapDevice, plan.deltaDir) } if restoreErr != nil { // Firecracker is already running; stop the unit before other @@ -1397,13 +1407,11 @@ func (m *Manager) stopUnitDuringRestoreError(vmID string) { } // startUffdHandler binds the per-VM UFFD socket, opens mem.snap, and -// spawns the fault-serving goroutine. Stores the cancel func on inst so -// destroy/error paths can tear it down. -// -// recorder, when non-nil, hooks into OnPageFault for template-build -// recording. Prefetch is disabled in recording mode — recording captures -// guest demand order, which prefetch would scramble. -func (m *Manager) startUffdHandler(inst *VMInstance, uffdSocketPath, memSnapPath string, recorder *uffd.Recorder) error { +// launches the handler goroutine. reqCtx is the caller's request ctx; +// its deadline applies to the handshake. The serve loop uses an +// independent ctx so it outlives the request. recorder, when non-nil, +// hooks OnPageFault for template-build recording. +func (m *Manager) startUffdHandler(reqCtx context.Context, inst *VMInstance, uffdSocketPath, memSnapPath string, recorder *uffd.Recorder) error { accessLogPath := filepath.Join(filepath.Dir(memSnapPath), "access.log") cfg := uffd.Config{ SocketPath: uffdSocketPath, @@ -1421,9 +1429,9 @@ func (m *Manager) startUffdHandler(inst *VMInstance, uffdSocketPath, memSnapPath if err := h.Start(); err != nil { return err } - uffdCtx, uffdCancel := context.WithCancel(context.Background()) + lifetimeCtx, lifetimeCancel := context.WithCancel(context.Background()) inst.mu.Lock() - inst.uffdCancel = uffdCancel + inst.uffdCancel = lifetimeCancel inst.uffdHandler = h inst.mu.Unlock() @@ -1435,7 +1443,13 @@ func (m *Manager) startUffdHandler(inst *VMInstance, uffdSocketPath, memSnapPath } }() defer h.Close() - if err := h.Run(uffdCtx); err != nil && !errors.Is(err, context.Canceled) { + if err := h.AcceptHandshake(reqCtx); err != nil { + if !errors.Is(err, context.Canceled) { + m.log.Error().Err(err).Str("vm_id", vmID).Msg("uffd handshake failed") + } + return + } + if err := h.Serve(lifetimeCtx); err != nil && !errors.Is(err, context.Canceled) { m.log.Error().Err(err).Str("vm_id", vmID).Msg("uffd handler exited with error") } }() diff --git a/internal/vm/uffd/handler.go b/internal/vm/uffd/handler.go index 888a571..adf0a2d 100644 --- a/internal/vm/uffd/handler.go +++ b/internal/vm/uffd/handler.go @@ -140,16 +140,24 @@ func (h *Handler) Start() error { } -// Run accepts the connection from Firecracker, receives the UFFD fd -// plus JSON mappings, then services page faults until ctx is cancelled -// or the UFFD fd closes. Caller must call Close on return. -func (h *Handler) Run(ctx context.Context) error { +// AcceptHandshake blocks on the UDS listener until Firecracker connects +// and sends the UFFD fd + region mappings. ctx supplies the handshake +// deadline — kept separate from Serve's lifetime ctx so a slow handshake +// under burst doesn't share fate with the page-fault loop. +func (h *Handler) AcceptHandshake(ctx context.Context) error { if h.listener == nil || h.source == nil { - return errors.New("Run called before Start (or after Close)") + return errors.New("AcceptHandshake called before Start (or after Close)") } if err := h.acceptAndReceive(ctx); err != nil { return fmt.Errorf("receive handshake: %w", err) } + return nil +} + +// Serve runs the page-fault loop (and prefetcher, if enabled) until ctx +// is cancelled or the UFFD fd is closed. AcceptHandshake must have +// returned nil first. +func (h *Handler) Serve(ctx context.Context) error { if h.cfg.PrefetchEnabled { go h.runPrefetch(ctx) } From dd64f9c772954c83339514ae276abc3b56dfd7b6 Mon Sep 17 00:00:00 2001 From: Amit Patil Date: Tue, 12 May 2026 16:05:10 -0700 Subject: [PATCH 7/9] =?UTF-8?q?uffd:=20prod-readiness=20fixes=20=E2=80=94?= =?UTF-8?q?=20ID=20prefix,=20handshake=20ack,=20accept=20watchdog,=20bound?= =?UTF-8?q?ed=20ioctls,=20LoadSnapshot=20timeout?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/vm/build.go | 5 ++- internal/vm/firecracker.go | 6 ++- internal/vm/manager.go | 19 ++++++++++ internal/vm/uffd/handler.go | 74 ++++++++++++++++++++++++++++++------- internal/vm/uffd/ioctl.go | 8 ++-- 5 files changed, 92 insertions(+), 20 deletions(-) diff --git a/internal/vm/build.go b/internal/vm/build.go index 1441247..8dd450f 100644 --- a/internal/vm/build.go +++ b/internal/vm/build.go @@ -168,9 +168,10 @@ func (m *Manager) buildTemplateSync(ctx context.Context, buildVMID string, req B } // Best-effort: a missing access.log just means sandboxes fall back - // to sequential prefetch. + // to sequential prefetch. The "build-" prefix must remain so isBuildVM + // skips persistence + reconciler for this throwaway VM. if m.cfg.UffdEnabled && m.cfg.UffdPrefetchEnabled { - recordingVMID := "record-" + buildVMID + recordingVMID := "build-record-" + req.TemplateID accessLogPath := filepath.Join(snapshotDir, "access.log") recCfg := VMConfig{ VCPU: req.VCPU, diff --git a/internal/vm/firecracker.go b/internal/vm/firecracker.go index 7efb5fd..6b1223b 100644 --- a/internal/vm/firecracker.go +++ b/internal/vm/firecracker.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "strings" + "time" httptransport "github.com/go-openapi/runtime/client" "github.com/go-openapi/strfmt" @@ -340,9 +341,12 @@ func RestoreSnapshotWithOverrides(socketPath, snapshotPath, memPath, ifaceID, ta // uffdSocketPath must be a bound Unix socket; the caller is responsible for // starting the handler goroutine before invoking this function. func RestoreSnapshotUffdWithOverrides(socketPath, snapshotPath, uffdSocketPath, ifaceID, tapDevice, blockDeltaDir string) error { + // Bound LoadSnapshot so a hung Firecracker doesn't wedge vmd. + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() fc := newFCClient(socketPath) if _, err := fc.Operations.LoadSnapshot(&operations.LoadSnapshotParams{ - Context: context.Background(), + Context: ctx, Body: &models.SnapshotLoadParams{ SnapshotPath: &snapshotPath, MemBackend: &models.MemoryBackend{ diff --git a/internal/vm/manager.go b/internal/vm/manager.go index 86d1b31..5af5f07 100644 --- a/internal/vm/manager.go +++ b/internal/vm/manager.go @@ -946,6 +946,25 @@ func (m *Manager) restoreVMSnapshot(ctx context.Context, vmID, snapshotPath, mem return nil, fmt.Errorf("restore snapshot: %w", restoreErr) } + // LoadSnapshot doesn't ack the UFFD handshake — Firecracker can + // return success while our handler silently failed, which would + // leave the guest hanging on first page fault. + if useUffd { + inst.mu.RLock() + uffdH := inst.uffdHandler + inst.mu.RUnlock() + if uffdH != nil { + if hsErr := uffdH.WaitHandshake(ctx); hsErr != nil { + m.stopUnitDuringRestoreError(vmID) + m.cancelUffdHandler(inst) + m.netMgr.CleanupVM(vmID) + m.cleanupRunDir(vmID) + m.setStatus(vmID, StatusError) + return nil, fmt.Errorf("uffd handshake: %w", hsErr) + } + } + } + if err := m.waitForBoxd(ctx, hostIP, 5*time.Second); err != nil { m.stopUnitDuringRestoreError(vmID) m.cancelUffdHandler(inst) diff --git a/internal/vm/uffd/handler.go b/internal/vm/uffd/handler.go index adf0a2d..757aad0 100644 --- a/internal/vm/uffd/handler.go +++ b/internal/vm/uffd/handler.go @@ -19,7 +19,10 @@ import ( "golang.org/x/sys/unix" ) -const defaultFaultWorkers = 4096 +// defaultFaultWorkers caps concurrent in-flight UFFDIO_COPY ioctls per +// VM. Bounded so handler teardown via g.Wait doesn't stall on a wedged +// worker. +const defaultFaultWorkers = 256 // uffdClosed is the sentinel value of Handler.uffdFd before the fd is // received and after it has been closed. @@ -93,6 +96,9 @@ type Handler struct { mappings []GuestRegionMapping pageSize uint64 + // handshakeDone carries the AcceptHandshake outcome to WaitHandshake. + handshakeDone chan error + closed atomic.Bool } @@ -101,8 +107,9 @@ func New(cfg Config) *Handler { cfg.FaultWorkers = defaultFaultWorkers } h := &Handler{ - cfg: cfg, - log: cfg.Logger.With().Str("component", "uffd").Logger(), + cfg: cfg, + log: cfg.Logger.With().Str("component", "uffd").Logger(), + handshakeDone: make(chan error, 1), } h.uffdFd.Store(uffdClosed) return h @@ -143,17 +150,42 @@ func (h *Handler) Start() error { // AcceptHandshake blocks on the UDS listener until Firecracker connects // and sends the UFFD fd + region mappings. ctx supplies the handshake // deadline — kept separate from Serve's lifetime ctx so a slow handshake -// under burst doesn't share fate with the page-fault loop. +// under burst doesn't share fate with the page-fault loop. The outcome +// is published to WaitHandshake. func (h *Handler) AcceptHandshake(ctx context.Context) error { if h.listener == nil || h.source == nil { - return errors.New("AcceptHandshake called before Start (or after Close)") + err := errors.New("AcceptHandshake called before Start (or after Close)") + h.publishHandshake(err) + return err } if err := h.acceptAndReceive(ctx); err != nil { - return fmt.Errorf("receive handshake: %w", err) + err = fmt.Errorf("receive handshake: %w", err) + h.publishHandshake(err) + return err } + h.publishHandshake(nil) return nil } +func (h *Handler) publishHandshake(err error) { + select { + case h.handshakeDone <- err: + default: + } +} + +// WaitHandshake returns the AcceptHandshake outcome, blocking until it +// completes or ctx is cancelled. Idempotent. +func (h *Handler) WaitHandshake(ctx context.Context) error { + select { + case err := <-h.handshakeDone: + h.publishHandshake(err) + return err + case <-ctx.Done(): + return ctx.Err() + } +} + // Serve runs the page-fault loop (and prefetcher, if enabled) until ctx // is cancelled or the UFFD fd is closed. AcceptHandshake must have // returned nil first. @@ -203,15 +235,24 @@ func (h *Handler) closeFd() { } func (h *Handler) acceptAndReceive(ctx context.Context) error { - // ctx with no deadline → 30s default, matching the gRPC deadline on - // the caller side. The listener deadline is the only way to make - // AcceptUnix interruptible. + // SetDeadline is the only way to make AcceptUnix interruptible; the + // watchdog below sets it to the past on ctx cancel. if deadline, ok := ctx.Deadline(); ok { _ = h.listener.SetDeadline(deadline) } else { _ = h.listener.SetDeadline(time.Now().Add(30 * time.Second)) } + stop := make(chan struct{}) + defer close(stop) + go func() { + select { + case <-ctx.Done(): + _ = h.listener.SetDeadline(time.Now()) + case <-stop: + } + }() + conn, err := h.listener.AcceptUnix() if err != nil { return fmt.Errorf("accept: %w", err) @@ -233,18 +274,25 @@ func (h *Handler) acceptAndReceive(ctx context.Context) error { if err != nil { return fmt.Errorf("parse cmsg: %w", err) } + // Firecracker sends exactly one fd; cap defensively. uffdFd := -1 for _, cm := range cmsgs { fds, err := unix.ParseUnixRights(&cm) if err != nil { continue } - for _, fd := range fds { - if uffdFd == -1 { - uffdFd = fd - } else { + if len(fds) > 1 { + for _, fd := range fds { _ = unix.Close(fd) } + return fmt.Errorf("expected 1 fd in SCM_RIGHTS, got %d", len(fds)) + } + if len(fds) == 1 { + if uffdFd != -1 { + _ = unix.Close(fds[0]) + return errors.New("multiple SCM_RIGHTS messages; expected one") + } + uffdFd = fds[0] } } if uffdFd == -1 { diff --git a/internal/vm/uffd/ioctl.go b/internal/vm/uffd/ioctl.go index d785f87..8f12445 100644 --- a/internal/vm/uffd/ioctl.go +++ b/internal/vm/uffd/ioctl.go @@ -14,10 +14,10 @@ import ( "unsafe" ) -// Hardcoded from ; verified against Linux 6.8 (our -// prod target). _IO[RW] encoding: (dir << 30) | (sizeof << 16) | (type -// << 8) | nr, with type = 0xAA. Pinned by TestIoctlNumbers so a struct -// resize can't silently break the kernel ABI. +// From , verified against Linux 6.8. _IO[RW] +// encoding: (dir << 30) | (sizeof << 16) | (type << 8) | nr, with +// type = 0xAA. Pinned by TestIoctlNumbers so a struct resize can't +// silently break the kernel ABI. const ( UFFDIO_COPY uintptr = 0xC028AA03 UFFDIO_ZEROPAGE uintptr = 0xC020AA04 From 37edee6d9c2fa8bf180689cb847fd5fe568f2366 Mon Sep 17 00:00:00 2001 From: Amit Patil Date: Tue, 12 May 2026 19:07:33 -0700 Subject: [PATCH 8/9] uffd: join prefetch goroutine before munmap; treat ZEROPAGE EAGAIN as benign; drop dead UNREGISTER wrapper --- internal/vm/uffd/handler.go | 18 +++++++++++++++++- internal/vm/uffd/ioctl.go | 10 ++-------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/internal/vm/uffd/handler.go b/internal/vm/uffd/handler.go index 757aad0..68c5a41 100644 --- a/internal/vm/uffd/handler.go +++ b/internal/vm/uffd/handler.go @@ -99,6 +99,9 @@ type Handler struct { // handshakeDone carries the AcceptHandshake outcome to WaitHandshake. handshakeDone chan error + // prefetchWg lets Close wait for the prefetch goroutine before munmap. + prefetchWg sync.WaitGroup + closed atomic.Bool } @@ -191,7 +194,11 @@ func (h *Handler) WaitHandshake(ctx context.Context) error { // returned nil first. func (h *Handler) Serve(ctx context.Context) error { if h.cfg.PrefetchEnabled { - go h.runPrefetch(ctx) + h.prefetchWg.Add(1) + go func() { + defer h.prefetchWg.Done() + h.runPrefetch(ctx) + }() } return h.serveLoop(ctx) } @@ -210,6 +217,9 @@ func (h *Handler) Close() error { _ = os.Remove(h.cfg.SocketPath) } h.closeFd() + // Drain prefetch before munmap. Fault workers are already joined by + // serveLoop's g.Wait; only the bare prefetch goroutine needs this. + h.prefetchWg.Wait() if h.source != nil { if err := h.source.Close(); err != nil && firstErr == nil { firstErr = fmt.Errorf("close source: %w", err) @@ -460,6 +470,12 @@ func (h *Handler) handleRemove(start, end uint64) { err := ioctlZeropage(fd, start, length) h.fdMu.RUnlock() if err != nil { + // EAGAIN: another REMOVE event is queued ahead of us — the next + // REMOVE will cover this range, so skipping the zero is correct. + if errors.Is(err, syscall.EAGAIN) { + h.stats.EagainRetries.Add(1) + return + } h.log.Warn().Err(err).Uint64("start", start).Uint64("end", end).Msg("UFFDIO_ZEROPAGE failed") } } diff --git a/internal/vm/uffd/ioctl.go b/internal/vm/uffd/ioctl.go index 8f12445..a2dcba7 100644 --- a/internal/vm/uffd/ioctl.go +++ b/internal/vm/uffd/ioctl.go @@ -122,11 +122,5 @@ func ioctlZeropage(uffdFd uintptr, start uint64, length uint64) error { return nil } -func ioctlUnregister(uffdFd uintptr, start uint64, length uint64) error { - r := uffdioRange{Start: start, Len: length} - _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uffdFd, UFFDIO_UNREGISTER, uintptr(unsafe.Pointer(&r))) - if errno != 0 { - return fmt.Errorf("UFFDIO_UNREGISTER: %w", errno) - } - return nil -} +// No UFFDIO_UNREGISTER wrapper: closing the fd auto-unregisters all +// regions. Constant is pinned by TestIoctlNumbers for ABI docs. From c919a35ef4738ef224253fcc77c057a461639095 Mon Sep 17 00:00:00 2001 From: Amit Patil Date: Tue, 12 May 2026 19:51:35 -0700 Subject: [PATCH 9/9] uffd: tests for Close idempotency, WaitHandshake, and Recorder dedup/flush --- internal/vm/uffd/handler_test.go | 53 ++++++++++++++++++ internal/vm/uffd/recorder_test.go | 92 +++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 internal/vm/uffd/recorder_test.go diff --git a/internal/vm/uffd/handler_test.go b/internal/vm/uffd/handler_test.go index e4f7ca4..1e2f39e 100644 --- a/internal/vm/uffd/handler_test.go +++ b/internal/vm/uffd/handler_test.go @@ -1,9 +1,14 @@ package uffd import ( + "context" "encoding/json" + "errors" "testing" + "time" "unsafe" + + "github.com/rs/zerolog" ) // firecracker-produced JSON from src/vmm/src/persist.rs:725-768. Verified @@ -103,3 +108,51 @@ func TestUffdMsgSize(t *testing.T) { t.Errorf("sizeof(uffdMsg) = %d, want %d", got, want) } } + +func TestHandler_CloseIdempotent(t *testing.T) { + h := New(Config{Logger: zerolog.Nop()}) + if err := h.Close(); err != nil { + t.Errorf("first Close: %v", err) + } + // Second call must be a no-op — closed CAS guards against double-Munmap + // and double-close of listener/fd. + if err := h.Close(); err != nil { + t.Errorf("second Close: %v", err) + } +} + +func TestHandler_WaitHandshake_Success(t *testing.T) { + h := New(Config{Logger: zerolog.Nop()}) + h.publishHandshake(nil) + if err := h.WaitHandshake(context.Background()); err != nil { + t.Errorf("WaitHandshake: %v", err) + } + // Idempotent: re-publish lets a second caller see the same outcome. + if err := h.WaitHandshake(context.Background()); err != nil { + t.Errorf("second WaitHandshake: %v", err) + } +} + +func TestHandler_WaitHandshake_Error(t *testing.T) { + h := New(Config{Logger: zerolog.Nop()}) + want := errors.New("boom") + h.publishHandshake(want) + got := h.WaitHandshake(context.Background()) + if !errors.Is(got, want) { + t.Errorf("WaitHandshake = %v, want %v", got, want) + } +} + +func TestHandler_WaitHandshake_CtxCancel(t *testing.T) { + h := New(Config{Logger: zerolog.Nop()}) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + start := time.Now() + err := h.WaitHandshake(ctx) + if !errors.Is(err, context.Canceled) { + t.Errorf("WaitHandshake = %v, want context.Canceled", err) + } + if time.Since(start) > 100*time.Millisecond { + t.Errorf("cancel didn't propagate promptly: %v", time.Since(start)) + } +} diff --git a/internal/vm/uffd/recorder_test.go b/internal/vm/uffd/recorder_test.go new file mode 100644 index 0000000..523c6e1 --- /dev/null +++ b/internal/vm/uffd/recorder_test.go @@ -0,0 +1,92 @@ +package uffd + +import ( + "os" + "path/filepath" + "reflect" + "strings" + "testing" +) + +func TestRecorder_DedupAndOrder(t *testing.T) { + r := NewRecorder() + r.Record(4096) + r.Record(8192) + r.Record(4096) // duplicate — must be ignored + r.Record(12288) + r.Record(8192) // duplicate + + if r.Len() != 3 { + t.Fatalf("Len = %d, want 3", r.Len()) + } + + path := filepath.Join(t.TempDir(), "access.log") + if err := r.Flush(path, false); err != nil { + t.Fatalf("Flush: %v", err) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + lines := strings.Split(strings.TrimSpace(string(data)), "\n") + want := []string{"4096", "8192", "12288"} + if !reflect.DeepEqual(lines, want) { + t.Errorf("flushed contents = %v, want %v (first-touch order, deduped)", lines, want) + } +} + +func TestRecorder_FlushAtomicRename(t *testing.T) { + r := NewRecorder() + r.Record(4096) + + dir := t.TempDir() + path := filepath.Join(dir, "access.log") + + if err := r.Flush(path, false); err != nil { + t.Fatalf("Flush: %v", err) + } + + // .tmp must NOT survive a successful flush — the rename should have + // moved it into place. A leftover .tmp signals a torn flush. + if _, err := os.Stat(path + ".tmp"); !os.IsNotExist(err) { + t.Errorf("tmp file %q still exists after successful flush", path+".tmp") + } + if _, err := os.Stat(path); err != nil { + t.Errorf("output file missing after flush: %v", err) + } +} + +func TestRecorder_FlushEmpty(t *testing.T) { + r := NewRecorder() + path := filepath.Join(t.TempDir(), "access.log") + if err := r.Flush(path, false); err != nil { + t.Fatalf("Flush on empty recorder: %v", err) + } + info, err := os.Stat(path) + if err != nil { + t.Fatalf("output file missing: %v", err) + } + if info.Size() != 0 { + t.Errorf("empty recorder flushed %d bytes, want 0", info.Size()) + } +} + +func TestRecorder_FlushSorted(t *testing.T) { + r := NewRecorder() + r.Record(12288) + r.Record(4096) + r.Record(8192) + + path := filepath.Join(t.TempDir(), "access.log") + if err := r.Flush(path, true); err != nil { + t.Fatalf("Flush: %v", err) + } + + data, _ := os.ReadFile(path) + lines := strings.Split(strings.TrimSpace(string(data)), "\n") + want := []string{"4096", "8192", "12288"} + if !reflect.DeepEqual(lines, want) { + t.Errorf("sorted flush = %v, want %v", lines, want) + } +}