Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,34 @@
import numpy as np
from numpy.typing import ArrayLike

try:
import ray
RAY_AVAILABLE = True

@ray.remote
def _process_block_remote(in_array, out_array, sl, block_idx, total_blocks):
"""
Ray remote function to process a single block.
Reads block from input, computes if lazy, writes to output.

Returns the block index for progress tracking.
"""
# Read block from input array
block = in_array[sl]

# If it's a dask array, compute it to get numpy array
if hasattr(block, 'compute'):
block = block.compute()

# Write block directly to output array (writes to S3)
out_array[sl] = block

return block_idx

except ImportError:
RAY_AVAILABLE = False
_process_block_remote = None # Not available


def _get_size(shape: Tuple[int, ...], itemsize: int) -> int:
"""
Expand Down Expand Up @@ -188,60 +216,125 @@ def _slice_along_dim(dim: int) -> Generator:

@staticmethod
def store(
in_array: da.Array, out_array: ArrayLike, block_shape: tuple
in_array: ArrayLike, out_array: ArrayLike, block_shape: tuple, use_ray: bool = True, ray_num_cpus: int = None
) -> None:
"""
Partitions the last 3 dimensions of a Dask array
into non-overlapping blocks and writes them sequentially
to a Zarr array. This is meant to reduce the
scheduling burden for massive (terabyte-scale) arrays.
Partitions the last 3 dimensions of an array
into non-overlapping blocks and writes them to a Zarr array.
Can use Ray for parallel processing or sequential processing.

:param in_array: The input Dask array
:param in_array: The input array (can be dask array or numpy array)
:param block_shape: Tuple of (block_depth, block_height, block_width)
:param out_array: The output array
:param use_ray: If True, use Ray for parallel processing. If False, use sequential processing.
:param ray_num_cpus: Number of CPUs to use for Ray. If None, uses all available CPUs.
"""
import sys
logger = logging.getLogger(__name__)

# Calculate total number of blocks for progress tracking
total_blocks = 1
for arr_dim, block_dim in zip(in_array.shape, block_shape):
total_blocks *= (arr_dim + block_dim - 1) // block_dim

logger.info(f" Writing {total_blocks} blocks (block shape: {block_shape})...")
# Check if Ray should be used
use_ray = use_ray and RAY_AVAILABLE

# Iterate through the input array in
# steps equal to the block shape dimensions
block_idx = 0
log_interval = max(1, total_blocks // 10) # Log ~10 times total

for sl in BlockedArrayWriter.gen_slices(in_array.shape, block_shape):
block = in_array[sl]
da.store(
block,
out_array,
regions=sl,
lock=False,
compute=True,
return_stored=False,
)
if use_ray:
# Initialize Ray if not already initialized
if not ray.is_initialized():
logger.info(" Initializing Ray for parallel processing...")

# Configure Ray with memory limits and object spilling
import os

# Set environment variables for Ray memory management BEFORE initialization
# These must be set before ray.init() is called
os.environ.setdefault("RAY_memory_monitor_refresh_ms", "250")
os.environ.setdefault("RAY_memory_usage_threshold", "0.85") # Kill workers at 85% memory usage

ray_config = {
"ignore_reinit_error": True,
"object_store_memory": int(8 * 1024 * 1024 * 1024), # 8 GB for object store
}

# Set number of CPUs if specified
if ray_num_cpus is not None:
ray_config["num_cpus"] = ray_num_cpus
logger.info(f" Limiting Ray to {ray_num_cpus} CPUs to prevent OOM")
else:
ray_config["num_cpus"] = None # Use all available CPUs

ray.init(**ray_config)
actual_cpus = ray.cluster_resources().get('CPU', 0)
logger.info(f" Ray initialized with {actual_cpus} CPUs and 8 GB object store")
logger.info(f" Memory monitor will kill workers at 85% memory usage")

block_idx += 1
if block_idx % log_interval == 0 or block_idx == total_blocks:
progress_pct = (block_idx / total_blocks) * 100
logger.info(f" Progress: {block_idx}/{total_blocks} blocks ({progress_pct:.1f}%)")
logger.info(f" Writing {total_blocks} blocks IN PARALLEL using Ray (block shape: {block_shape})...")
sys.stdout.flush()

# Submit all blocks as Ray tasks
futures = []
block_idx = 0
for sl in BlockedArrayWriter.gen_slices(in_array.shape, block_shape):
future = _process_block_remote.remote(in_array, out_array, sl, block_idx, total_blocks)
futures.append(future)
block_idx += 1

# Wait for all tasks to complete and show progress
completed = 0
while futures:
# Wait for at least one task to complete
done, futures = ray.wait(futures, num_returns=1, timeout=1.0)
completed += len(done)
if done:
progress_pct = (completed / total_blocks) * 100
logger.info(f" Progress: {completed}/{total_blocks} blocks ({progress_pct:.1f}%)")
sys.stdout.flush()

logger.info(f" ✓ All {total_blocks} blocks written successfully using Ray!")

else:
# Sequential processing (original implementation)
if not RAY_AVAILABLE:
logger.warning(" Ray not available, falling back to sequential processing")
else:
logger.info(" Using sequential processing (Ray disabled)")

logger.info(f" Writing {total_blocks} blocks SEQUENTIALLY (block shape: {block_shape})...")
sys.stdout.flush()

block_idx = 0
for sl in BlockedArrayWriter.gen_slices(in_array.shape, block_shape):
logger.info(f" Progress: {block_idx}/{total_blocks} blocks ({(block_idx/total_blocks)*100:.1f}%)")
block_idx += 1

# Read block from input array
block = in_array[sl]

# If it's a dask array, compute it to get numpy array
if hasattr(block, 'compute'):
block = block.compute()

# Write block directly to output array (writes to S3)
out_array[sl] = block

logger.info(f" ✓ All {total_blocks} blocks written successfully!")

sys.stdout.flush()

@staticmethod
def get_block_shape(arr, target_size_mb=409600, mode="cycle", chunks=None):
def get_block_shape(arr, target_block_size_mb=409600, mode="cycle", chunks=None):
"""
Given the shape and chunk size of a pre-chunked
array, determine the optimal block shape closest
to target_size. Expanded block dimensions are
to target block size. Expanded block dimensions are
an integer multiple of the chunk dimension
to ensure optimal access patterns.

Args:
arr: the input array
target_size_mb: target block size in megabytes,
target_block_size_mb: target block size in megabytes,
default is 409600 mode: strategy.
Must be one of "cycle", or "iso"

Expand All @@ -259,7 +352,7 @@ def get_block_shape(arr, target_size_mb=409600, mode="cycle", chunks=None):
return expand_chunks(
chunks,
arr.shape[-3:],
target_size_mb * 1024**2,
target_block_size_mb * 1024**2,
arr.itemsize,
mode,
)
87 changes: 63 additions & 24 deletions Rhapso/fusion/multiscale/aind_z1_radial_correction/array_to_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ def convert_array_to_zarr(
"clevel": 3,
"shuffle": Blosc.SHUFFLE,
},
target_size_mb: Optional[int] = 24000,
target_block_size_mb: Optional[int] = 24000,
use_ray: Optional[bool] = True,
ray_num_cpus: Optional[int] = None,
dont_copy_fullscale: Optional[bool] = False,
):
"""
Converts an array to zarr format
Expand Down Expand Up @@ -153,11 +156,24 @@ def convert_array_to_zarr(
scale_factor = [int(s) for s in scale_factor]
voxel_size = [float(v) for v in voxel_size]

new_channel_group = root_group.create_group(
name=stack_name, overwrite=True
)
# If dont_copy_fullscale is True, try to open existing group instead of overwriting
if dont_copy_fullscale:
try:
# Try to open existing group in read-write mode
new_channel_group = root_group[stack_name]
logger.info(f"Opened existing channel group: {stack_name}")
except KeyError:
# Group doesn't exist, create it
new_channel_group = root_group.create_group(
name=stack_name, overwrite=True
)
logger.info(f"Created new channel group: {stack_name}")
else:
new_channel_group = root_group.create_group(
name=stack_name, overwrite=True
)

# Writing OME-NGFF metadata
# Writing OME-NGFF metadata (always write to ensure metadata is up to date)
write_ome_ngff_metadata(
group=new_channel_group,
arr_shape=dataset_shape,
Expand All @@ -174,43 +190,59 @@ def convert_array_to_zarr(
origin = [0,0,0]
)

# Writing first multiscale by default
pyramid_group = new_channel_group.create_dataset(
name="0",
shape=dataset_shape,
chunks=chunk_size,
dtype=array.dtype,
compressor=compressor,
dimension_separator="/",
overwrite=True,
)

# Writing multiscales
# Handle both numpy arrays and dask arrays
if isinstance(array, da.Array):
# Already a dask array, rechunk if needed
previous_scale = da.rechunk(array, chunks=pyramid_group.chunks)
previous_scale = da.rechunk(array, chunks=chunk_size)
else:
# Convert numpy array to dask array
previous_scale = da.from_array(array, pyramid_group.chunks)
previous_scale = da.from_array(array, chunk_size)

block_shape = list(
BlockedArrayWriter.get_block_shape(
arr=previous_scale,
target_size_mb=target_size_mb,
target_block_size_mb=target_block_size_mb,
chunks=chunk_size,
)
)
block_shape = extra_axes + tuple(block_shape)

logger.info(f"Writing {n_lvls} pyramid levels...")
# Determine start level based on dont_copy_fullscale flag
start_level = 1 if dont_copy_fullscale else 0
levels_to_write = n_lvls - start_level

if dont_copy_fullscale:
logger.info(f"Skipping level 0 (fullscale) - will only write levels 1-{n_lvls-1}")
# Open existing level 0 dataset to read from for computing level 1
try:
existing_level0 = new_channel_group["0"]
pyramid_group = existing_level0
logger.info(f"Using existing level 0 from {output_path} for downsampling")
except KeyError:
logger.error(f"Level 0 not found at {output_path}/0. Cannot skip fullscale copy if level 0 doesn't exist.")
raise ValueError(f"Level 0 not found at {output_path}/0. Set dont_copy_fullscale=False or ensure level 0 exists.")
else:
logger.info(f"Writing {n_lvls} pyramid levels...")
# Writing first multiscale by default
pyramid_group = new_channel_group.create_dataset(
name="0",
shape=dataset_shape,
chunks=chunk_size,
dtype=array.dtype,
compressor=compressor,
dimension_separator="/",
overwrite=True,
)

for level in range(0, n_lvls):
if not level:
for level in range(start_level, n_lvls):
if level == 0:
# Only reached if dont_copy_fullscale is False
array_to_write = previous_scale
logger.info(f"Level {level}/{n_lvls-1}: Writing full resolution - shape {array_to_write.shape}")

else:
# Read from the previous level (either written level 0 or existing level 0)
previous_scale = da.from_zarr(pyramid_group, pyramid_group.chunks)
new_scale_factor = (
[1] * (len(previous_scale.shape) - len(scale_factor))
Expand Down Expand Up @@ -238,8 +270,15 @@ def convert_array_to_zarr(
)

logger.info(f"Level {level}/{n_lvls-1}: Writing to storage...")
BlockedArrayWriter.store(array_to_write, pyramid_group, block_shape)
logger.info(f"Level {level}/{n_lvls-1}: ✓ Complete ({level+1}/{n_lvls} levels done)")
import sys
sys.stdout.flush()
try:
BlockedArrayWriter.store(array_to_write, pyramid_group, block_shape, use_ray=use_ray, ray_num_cpus=ray_num_cpus)
logger.info(f"Level {level}/{n_lvls-1}: ✓ Complete ({level-start_level+1}/{levels_to_write} levels done)")
except Exception as e:
logger.error(f"Level {level}/{n_lvls-1}: FAILED with error: {e}")
logger.exception("Full traceback:")
raise

if __name__ == "__main__":
BASE_PATH = "/data"
Expand Down
Loading