Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions src/fairchem/core/models/uma/escn_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,13 +647,11 @@ def forward(

outputs[energy_key] = {"energy": energy} if self.wrap_property else energy

embeddings = emb["node_embedding"].detach()
if gp_utils.initialized():
embeddings = gp_utils.gather_from_model_parallel_region(embeddings, dim=0)

outputs["embeddings"] = (
{"embeddings": embeddings} if self.wrap_property else embeddings
)
if not gp_utils.initialized():
embeddings = emb["node_embedding"].detach()
outputs["embeddings"] = (
{"embeddings": embeddings} if self.wrap_property else embeddings
)

if self.regress_stress:
grads = torch.autograd.grad(
Expand Down
122 changes: 75 additions & 47 deletions src/fairchem/core/units/mlip_unit/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import math
import os
import random
import sys
from collections import defaultdict
from contextlib import nullcontext
from functools import wraps
Expand Down Expand Up @@ -340,8 +341,7 @@ def move_tensors_to_cpu(data):
return data


@remote
class MLIPWorker:
class MLIPWorkerLocal:
def __init__(
self,
worker_id: int,
Expand All @@ -361,54 +361,61 @@ def __init__(
)
self.master_port = get_free_port() if master_port is None else master_port
self.is_setup = False
self.last_received_atomic_data = None

def get_master_address_and_port(self):
return (self.master_address, self.master_port)

def get_device_for_local_rank(self):
return get_device_for_local_rank()

def _distributed_setup(
self,
worker_id: int,
master_port: int,
world_size: int,
predictor_config: dict,
master_address: str,
):
# initialize distributed environment
# TODO, this wont work for multi-node, need to fix master addr
logging.info(f"Initializing worker {worker_id}...")
setup_env_local_multi_gpu(worker_id, master_port, master_address)
# local_rank = int(os.environ["LOCAL_RANK"])
device = predictor_config.get("device", "cpu")
logging.info(f"Initializing worker {self.worker_id}...")
setup_env_local_multi_gpu(self.worker_id, self.master_port, self.master_address)

device = self.predictor_config.get("device", "cpu")
assign_device_for_local_rank(device == "cpu", 0)
backend = "gloo" if device == "cpu" else "nccl"
dist.init_process_group(
backend=backend,
rank=worker_id,
world_size=world_size,
rank=self.worker_id,
world_size=self.world_size,
)
gp_utils.setup_graph_parallel_groups(world_size, backend)
self.predict_unit = hydra.utils.instantiate(predictor_config)
gp_utils.setup_graph_parallel_groups(self.world_size, backend)
self.predict_unit = hydra.utils.instantiate(self.predictor_config)
self.device = get_device_for_local_rank()
logging.info(
f"Worker {worker_id}, gpu_id: {ray.get_gpu_ids()}, loaded predict unit: {self.predict_unit}, "
f"on port {self.master_port}, with device: {get_device_for_local_rank()}, config: {self.predictor_config}"
f"Worker {self.worker_id}, gpu_id: {ray.get_gpu_ids()}, loaded predict unit: {self.predict_unit}, "
f"on port {self.master_port}, with device: {self.device}, config: {self.predictor_config}"
)
self.is_setup = True

def predict(self, data: AtomicData) -> dict[str, torch.tensor] | None:
def predict(
self, data: AtomicData, use_nccl: bool = False
) -> dict[str, torch.tensor] | None:
if not self.is_setup:
self._distributed_setup(
self.worker_id,
self.master_port,
self.world_size,
self.predictor_config,
self.master_address,
)
self.is_setup = True
self._distributed_setup()

out = self.predict_unit.predict(data)
out = move_tensors_to_cpu(out)
if self.worker_id == 0:
return out
else:
return None
return move_tensors_to_cpu(out)

if self.worker_id != 0 and use_nccl:
self.last_received_atomic_data = data.to(self.device)
while True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure I understand this while loop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed one of the vars (md->use_nccl).

If we are using NCCL than ray workers outside of worker 0 should never leave the predict loop, just wait and listen on NCCL.

That way we dont incur overhead of communicating with them using Ray. The only other worker, worker 0 is in the same process as driver so no overhead there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i see, alternatively you can pass None to the data input in the case that you are broadcasting? i guess if you want to eliminate any ray grpc communication overhead here..

torch.distributed.broadcast(self.last_received_atomic_data.pos, src=0)
self.predict_unit.predict(self.last_received_atomic_data)

return None


@remote
class MLIPWorker(MLIPWorkerLocal):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need a separate class/inheritance structure here? can we just add the functionality to MLIPworker and an additional function that is used for rank0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this, i got an error from ray, something about not being able to initialize a ray actor directly :'(

pass


@requires(ray_installed, message="Requires `ray` to be installed")
Expand All @@ -434,6 +441,7 @@ def __init__(
seed=seed,
atom_refs=atom_refs,
)
self.inference_settings = inference_settings
self._dataset_to_tasks = copy.deepcopy(_mlip_pred_unit.dataset_to_tasks)

predict_unit_config = {
Expand All @@ -446,6 +454,16 @@ def __init__(
"atom_refs": atom_refs,
"assert_on_nans": assert_on_nans,
}

logging.basicConfig(
level=logging.INFO,
force=True,
stream=sys.stdout,
format="%(asctime)s %(levelname)s [%(processName)s] %(name)s: %(message)s",
)
# Optional: keep Ray/uvicorn chatty logs in check
logging.getLogger("ray").setLevel(logging.INFO)
logging.getLogger("uvicorn").setLevel(logging.INFO)
if not ray.is_initialized():
ray.init(
logging_level=logging.INFO,
Expand All @@ -454,6 +472,8 @@ def __init__(
# },
)

self.atomic_data_on_device = None

num_nodes = math.ceil(num_workers / num_workers_per_node)
num_workers_on_node_array = [num_workers_per_node] * num_nodes
if num_workers % num_workers_per_node > 0:
Expand All @@ -473,7 +493,7 @@ def __init__(
placement_groups.append(pg)
ray.get(pg.ready()) # Wait for each placement group to be scheduled

# place rank 0 on placement group 0
# Need to still place worker to occupy space, otherwise ray double books this GPU
rank0_worker = MLIPWorker.options(
num_gpus=num_gpu_per_worker,
scheduling_strategy=PlacementGroupSchedulingStrategy(
Expand All @@ -482,11 +502,18 @@ def __init__(
placement_group_capture_child_tasks=True, # Ensure child tasks also run in this PG
),
).remote(0, num_workers, predict_unit_config)
master_addr, master_port = ray.get(
rank0_worker.get_master_address_and_port.remote()

local_gpu_or_cpu = ray.get(rank0_worker.get_device_for_local_rank.remote())
os.environ[CURRENT_DEVICE_TYPE_STR] = local_gpu_or_cpu

self.workers = []
self.local_rank0 = MLIPWorkerLocal(
worker_id=0,
world_size=num_workers,
predictor_config=predict_unit_config,
)
master_addr, master_port = self.local_rank0.get_master_address_and_port()
logging.info(f"Started rank0 on {master_addr}:{master_port}")
self.workers = [rank0_worker]

# next place all ranks in order and pack them on placement groups
# ie: rank0-7 -> placement group 0, 8->15 -> placement group 1 etc.
Expand Down Expand Up @@ -520,20 +547,21 @@ def __init__(
self.workers.append(actor)
worker_id += 1

def predict(
self, data: AtomicData, undo_element_references: bool = True
) -> dict[str, torch.tensor]:
def predict(self, data: AtomicData) -> dict[str, torch.tensor]:
# put the reference in the object store only once
# this data transfer should be made more efficienct by using a shared memory transfer + nccl broadcast
data_ref = ray.put(data)
futures = [w.predict.remote(data_ref) for w in self.workers]
# just get the first result that is ready since they are identical
# the rest of the futures should go out of scope and memory garbage collected
# ready_ids, _ = ray.wait(futures, num_returns=1)
# result = ray.get(ready_ids[0])
# result = ray.get(futures)
# return result[0]
return ray.get(futures[0])
if not self.inference_settings.merge_mole or self.atomic_data_on_device is None:
data_ref = ray.put(data)
# this will put the ray works into an infinite loop listening for broadcasts
_futures = [
w.predict.remote(data_ref, use_nccl=self.inference_settings.merge_mole)
for w in self.workers
]
self.atomic_data_on_device = data.clone()
else:
self.atomic_data_on_device.pos = data.pos.to(self.local_rank0.device)
torch.distributed.broadcast(self.atomic_data_on_device.pos, src=0)

return self.local_rank0.predict(self.atomic_data_on_device)

@property
def dataset_to_tasks(self) -> dict[str, list]:
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import pytest
import torch

import fairchem.core.common.gp_utils as gp_utils
from fairchem.core.common import distutils

@pytest.fixture()
def command_line_inference_checkpoint(request):
Expand Down Expand Up @@ -134,3 +136,10 @@ def water_xyz_file(tmp_path_factory):
fpath = d / "water.xyz"
fpath.write_text(contents)
return str(fpath)


@pytest.fixture(autouse=True)
def setup_before_each_test():
if gp_utils.initialized():
gp_utils.cleanup_gp()
distutils.cleanup()
10 changes: 10 additions & 0 deletions tests/core/units/mlip_unit/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from fairchem.core.units.mlip_unit.api.inference import InferenceSettings
from fairchem.core.units.mlip_unit.predict import ParallelMLIPPredictUnit
from tests.conftest import seed_everywhere
import fairchem.core.common.gp_utils as gp_utils
from fairchem.core.common import distutils

FORCE_TOL = 1e-4
ATOL = 5e-4
Expand Down Expand Up @@ -154,6 +156,10 @@ def test_parallel_predict_unit(workers, device):
for _ in range(runs):
pp_results = ppunit.predict(atomic_data)

if gp_utils.initialized():
gp_utils.cleanup_gp()
distutils.cleanup()

seed_everywhere(seed)
normal_predict_unit = pretrained_mlip.get_predict_unit(
"uma-s-1p1", device=device, inference_settings=ifsets
Expand Down Expand Up @@ -227,6 +233,10 @@ def test_parallel_predict_unit_batch(workers, device):
for _ in range(runs):
pp_results = ppunit.predict(atomic_data)

if gp_utils.initialized():
gp_utils.cleanup_gp()
distutils.cleanup()

seed_everywhere(seed)
normal_predict_unit = pretrained_mlip.get_predict_unit(
"uma-s-1p1", device=device, inference_settings=ifsets
Expand Down