Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
b1d758f
Add POSIX shared memory transport for multi-worker dataloaders
dfalbel Apr 23, 2026
0afc8b9
Rename mori references to shm
dfalbel Apr 23, 2026
996e198
Add test that re-runs all dataloader tests with SHM transport
dfalbel Apr 23, 2026
af05e5e
Fix Windows build: move #ifdef inside function bodies
dfalbel Apr 23, 2026
daf0665
Enable SHM transport by default on Unix, test both paths
dfalbel Apr 23, 2026
8308044
Fix writable SHM mappings, enable SHM by default on Unix
dfalbel Apr 23, 2026
a313d26
Clean up unread SHM segments when iterator is finalized
dfalbel Apr 23, 2026
92557a1
Add regression test for SHM cleanup on early iteration stop
dfalbel Apr 23, 2026
e402021
Handle zero-length tensors in SHM transfer path
dfalbel Apr 23, 2026
d072635
Handle non-tensor batch objects in SHM transport path
dfalbel Apr 23, 2026
53579ae
Preserve list attributes in SHM roundtrip, drop torch_shared_batch class
dfalbel Apr 23, 2026
f26e74f
Keep SHM mappings alive for derived tensors
dfalbel Apr 23, 2026
c7d2548
Preserve requires_grad when round-tripping tensors through SHM
dfalbel Apr 23, 2026
54baf47
Preserve aliasing when the same tensor appears multiple times in a batch
dfalbel Apr 23, 2026
9abf8d1
Fall back to serialization when SHM allocation fails
dfalbel Apr 24, 2026
a5982d8
Respect explicit socket-transport override over SHM
dfalbel Apr 24, 2026
14f5753
Preserve aliasing for repeated non-contiguous tensors
dfalbel Apr 24, 2026
d380231
Unlink already-created SHM segments when conversion fails partway
dfalbel Apr 24, 2026
22f947b
Preserve canonical dtype names in SHM descriptors
dfalbel Apr 24, 2026
d09dfb4
Guard 0-dim tensors before taking &shape[0] in SHM path
dfalbel Apr 24, 2026
d9fb783
Wait indefinitely for prefetched tasks during finalization
dfalbel Apr 24, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ jobs:
bash bench/benchmark_cpu_cache/run.sh 2>&1 | tee -a $GITHUB_STEP_SUMMARY
echo '```' >> $GITHUB_STEP_SUMMARY

- name: Run SHM dataloader benchmark
run: |
echo '## SHM Dataloader Benchmark' >> $GITHUB_STEP_SUMMARY
Rscript bench/bench-shm.R 2>&1 | tee -a $GITHUB_STEP_SUMMARY

build-gpu-image:
needs: lantern
if: ${{ always() && needs.lantern.result != 'failed' }}
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# torch (development version)

- Multi-worker dataloaders now use POSIX shared memory for tensor transfer on
Unix systems, resulting in up to 2x faster data loading. To revert to the
previous behavior, set `options(torch.dataloader_use_shm = FALSE)`. (#1456)

# torch 0.17.0

## Breaking changes
Expand Down
16 changes: 16 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -16941,6 +16941,22 @@ cpp_buffer_from_tensor <- function(data) {
.Call(`_torch_cpp_buffer_from_tensor`, data)
}

cpp_tensor_to_shm <- function(tensor) {
.Call(`_torch_cpp_tensor_to_shm`, tensor)
}

cpp_tensor_from_shm <- function(name, nbytes_dbl, shape, options) {
.Call(`_torch_cpp_tensor_from_shm`, name, nbytes_dbl, shape, options)
}

cpp_shm_exists <- function(name) {
.Call(`_torch_cpp_shm_exists`, name)
}

cpp_shm_unlink <- function(name) {
invisible(.Call(`_torch_cpp_shm_unlink`, name))
}

cpp_torch_tensor_dtype <- function(x) {
.Call(`_torch_cpp_torch_tensor_dtype`, x)
}
Expand Down
188 changes: 168 additions & 20 deletions R/utils-data-dataloader.R
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ MultiProcessingDataLoaderIter <- R6::R6Class(
}

worker_config <- function(id, num_workers, seed, init_fn, globals,
packages, socket_port = NULL) {
packages, socket_port = NULL, use_shm = FALSE) {
library(torch)
.worker_info <<- list(
id = id,
Expand All @@ -386,27 +386,28 @@ MultiProcessingDataLoaderIter <- R6::R6Class(
if (!is.null(init_fn)) {
init_fn(id)
}

.socket_con <<- NULL
if (!is.null(socket_port)) {
# We need to wait for the main process to start the server, so here we
# retry a few times until the conection works.
for(i in 1:20) {
tr <- try({.socket_con <<- socketConnection(
port = socket_port,
blocking = TRUE,
port = socket_port,
blocking = TRUE,
open = "a+b"
)}, silent = TRUE)
)}, silent = TRUE)

if (!inherits(tr, "try-error")) break
Sys.sleep(0.5)

if (i == 20) {
runtime_error("Could not create a connection with the main process.")
}
}
}


.use_shm <<- use_shm
}

fetcher <- self$.dataset_fetcher$fetch
Expand All @@ -424,7 +425,8 @@ MultiProcessingDataLoaderIter <- R6::R6Class(
init_fn = self$.worker_init_fn,
globals = self$.worker_globals,
packages = self$.worker_packages,
socket_port = worker$port
socket_port = worker$port,
use_shm = worker$using_shm
)
)

Expand Down Expand Up @@ -464,11 +466,11 @@ MultiProcessingDataLoaderIter <- R6::R6Class(
# send task to the worker
if (coro::is_exhausted(index)) {
worker$session$call(function() {
torch:::to_exportable_tensor(coro::exhausted(), .socket_con)
torch:::to_exportable_tensor(coro::exhausted(), .socket_con, .use_shm)
})
} else {
worker$session$call(function(index) {
torch:::to_exportable_tensor(fetcher(index), .socket_con)
torch:::to_exportable_tensor(fetcher(index), .socket_con, .use_shm)
}, list(index = index))
}

Expand All @@ -481,7 +483,27 @@ MultiProcessingDataLoaderIter <- R6::R6Class(
task <- private$tasks[[1]]
private$tasks <- private$tasks[-1]

if (!task$using_socket_con) {
if (task$using_shm) {
# SHM path: tensor data is in shared memory, only a small
# reference comes through callr's pipe.
p <- task$session$poll_process(timeout = self$.timeout)
if (p == "timeout") {
runtime_error("dataloader worker timed out.")
}
result <- task$session$read()
if (!is.null(result$error)) {
if (packageVersion("callr") >= "3.7.1") {
rlang::abort(
"Error when getting dataset item.",
parent = result$error,
class = "runtime_error"
)
} else {
runtime_error(result$error$message)
}
}
from_exportable_tensor(result$result)
} else if (!task$using_socket_con) {
# wait for the process to be ready
p <- task$session$poll_process(timeout = self$.timeout)
if (p == "timeout") {
Expand Down Expand Up @@ -563,6 +585,19 @@ MultiProcessingDataLoaderIter <- R6::R6Class(
private = list(
tasks = list(),
finalize = function() {
# Drain any prefetched tasks so their SHM segments are cleaned up.
# Wait indefinitely (-1) — the worker will finish or die when the
# session is closed. A timeout would leak SHM segments.
for (task in private$tasks) {
tryCatch({
task$session$poll_process(timeout = -1)
result <- task$session$read()
if (!is.null(result$result)) {
shm_unlink_recursive(result$result)
}
}, error = function(e) NULL)
}
private$tasks <- list()
lapply(private$workers, function(x) {
x$close_socket_con()
})
Expand Down Expand Up @@ -597,7 +632,12 @@ as_iterator.dataloader <- function(x) {

# takes a tensor and saves it's state in a field so we can
# reconstruct it after transfering via futures
to_exportable_tensor <- function(x, con) {
to_exportable_tensor <- function(x, con, use_shm = FALSE) {
if (use_shm) {
result <- tryCatch(tensors_to_shared(x), error = function(e) NULL)
if (!is.null(result)) return(result)
# SHM failed (e.g. /dev/shm full in Docker) — fall back to serialization
}
if (is.null(con)) {
return(tensor_to_raw_vector(x))
}
Expand All @@ -606,13 +646,114 @@ to_exportable_tensor <- function(x, con) {
}

from_exportable_tensor <- function(x) {
if (!inherits(x, "connection")) {
con <- rawConnection(x)
on.exit({close(con)})
} else {
con <- x
if (coro::is_exhausted(x)) return(x)
# tensors_from_shared is a no-op for non-shared objects
x <- tensors_from_shared(x)
if (is.raw(x) || inherits(x, "connection")) {
if (!inherits(x, "connection")) {
con <- rawConnection(x)
on.exit({close(con)})
} else {
con <- x
}
return(torch_load(con))
}
torch_load(con)
# Non-tensor, non-serialized payload (e.g. from a custom collate_fn
# that returns scalars, character vectors, etc.) — pass through as-is.
x
}

# Map C++ dtype names to names accepted by torch_tensor_from_buffer / dtype_from_string.
# Most match with tolower(), but a few have different canonical names.
dtype_to_shm_string <- function(dtype) {
s <- tolower(as.character(dtype))
switch(s,
"char" = "int8",
"byte" = "uint8",
"complexfloat" = "cfloat",
"complexdouble" = "cdouble",
s
)
}

# Convert batch tensors to POSIX shared memory for IPC.
# Single memcpy: tensor data -> SHM. Called in the worker process.
# Memoizes by data pointer so tensors sharing storage produce the same
# SHM descriptor, preserving aliasing through the roundtrip.
tensors_to_shared <- function(x) {
memo <- new.env(parent = emptyenv())
to_shared <- function(x) {
if (is_torch_tensor(x)) {
key <- xptr_address(x)
if (!is.null(memo[[key]])) return(memo[[key]])
t <- x$cpu()$contiguous()
shm <- cpp_tensor_to_shm(t)
result <- structure(
list(name = shm$name, nbytes = shm$nbytes,
shape = t$shape, dtype = dtype_to_shm_string(t$dtype),
requires_grad = t$requires_grad),
class = "torch_shared_tensor"
)
memo[[key]] <- result
return(result)
}
if (is.list(x)) {
out <- lapply(x, to_shared)
attributes(out) <- attributes(x)
return(out)
}
x
}
if (coro::is_exhausted(x)) return(x)
tryCatch(to_shared(x), error = function(e) {
# Clean up any SHM segments created before the failure
for (key in ls(memo)) {
nm <- memo[[key]]$name
if (nzchar(nm)) tryCatch(cpp_shm_unlink(nm), error = function(e2) NULL)
}
stop(e)
})
}

# Reconstruct tensors from POSIX shared memory.
# Memoizes by SHM name so duplicate references reconstruct to the same
# tensor, preserving storage sharing from the original batch.
tensors_from_shared <- function(x) {
memo <- new.env(parent = emptyenv())
from_shared <- function(x) {
if (inherits(x, "torch_shared_tensor")) {
if (x$nbytes == 0) {
t <- torch_tensor(numeric(0), dtype = x$dtype)$reshape(x$shape)
if (isTRUE(x$requires_grad)) t <- t$requires_grad_(TRUE)
return(t)
}
key <- x$name
if (!is.null(memo[[key]])) return(memo[[key]])
t <- cpp_tensor_from_shm(x$name, x$nbytes, x$shape, list(dtype = x$dtype))
if (isTRUE(x$requires_grad)) t <- t$requires_grad_(TRUE)
memo[[key]] <- t
return(t)
}
if (is.list(x)) {
out <- lapply(x, from_shared)
attributes(out) <- attributes(x)
return(out)
}
x
}
if (coro::is_exhausted(x)) return(x)
from_shared(x)
}

# Unlink SHM segments from a shared result without mapping them.
# Used during cleanup of prefetched but unconsumed tasks.
shm_unlink_recursive <- function(x) {
if (inherits(x, "torch_shared_tensor")) {
if (nzchar(x$name)) cpp_shm_unlink(x$name)
} else if (is.list(x)) {
lapply(x, shm_unlink_recursive)
}
invisible(NULL)
}

walk_fields <- function(env, nms, func) {
Expand Down Expand Up @@ -657,10 +798,13 @@ r_session <- R6::R6Class(
con = NULL,
session = NULL,
using_socket_con = FALSE,
using_shm = FALSE,
initialize = function() {
if (use_socket_con()) {
self$port <- parallelly::freePort()
self$port <- parallelly::freePort()
self$using_socket_con <- TRUE
} else if (use_shm()) {
self$using_shm <- TRUE
}
self$session <- callr::r_session$new()
},
Expand All @@ -681,3 +825,7 @@ r_session <- R6::R6Class(
use_socket_con <- function() {
getOption("torch.dataloader_use_socket_con", FALSE)
}

use_shm <- function() {
getOption("torch.dataloader_use_shm", .Platform$OS.type != "windows")
}
55 changes: 55 additions & 0 deletions bench/bench-shm.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/usr/bin/env Rscript
# Benchmark: POSIX shared memory IPC vs default callr pipe for dataloaders.
# Measures only data transfer time (excludes worker startup).

library(torch)

make_ds <- function(n, p) {
dataset(
initialize = function() {
self$x <- matrix(rnorm(n * p), nrow = n, ncol = p)
},
.getitem = function(i) {
torch_tensor(self$x[i, ])
},
.length = function() { nrow(self$x) }
)
}

bench_transfer <- function(n, p, bs, nw, n_reps = 3) {
times <- numeric(n_reps)
for (r in seq_len(n_reps)) {
dl <- dataloader(make_ds(n, p)(), batch_size = bs, num_workers = nw)
iter <- dataloader_make_iter(dl)
# first batch warms up workers, discard it
dataloader_next(iter)
start <- proc.time()["elapsed"]
while (!is.null(dataloader_next(iter, completed = NULL))) { }
times[r] <- proc.time()["elapsed"] - start
}
median(times)
}

configs <- list(
list(n = 500, p = 1000, bs = 32, label = "500x1K, bs=32"),
list(n = 200, p = 50000, bs = 64, label = "200x50K, bs=64"),
list(n = 200, p = 100000, bs = 64, label = "200x100K, bs=64"),
list(n = 100, p = 500000, bs = 64, label = "100x500K, bs=64"),
list(n = 100, p = 1000000, bs = 32, label = "100x1M, bs=32")
)

cat("| Config | Default | SHM | Speedup | MB/batch |\n")
cat("|---|---|---|---|---|\n")

for (cfg in configs) {
batch_mb <- cfg$bs * cfg$p * 4 / 1024^2

options(torch.dataloader_use_shm = FALSE)
t_default <- bench_transfer(cfg$n, cfg$p, cfg$bs, 2)

options(torch.dataloader_use_shm = TRUE)
t_shm <- bench_transfer(cfg$n, cfg$p, cfg$bs, 2)

cat(sprintf("| %s | %.3fs | %.3fs | %.2fx | %.1f |\n",
cfg$label, t_default, t_shm, t_default / t_shm, batch_mb))
}
Loading