-
Notifications
You must be signed in to change notification settings - Fork 400
Use NCCL with ParallelMLIPPredictUnitRay when possible #1587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8ea8417
1082d63
163f0ae
593402a
79263cd
fe61759
11c36c6
ca60f49
3a3753d
5d3c6a0
88ea2cc
3f711c5
1a824f9
eaf99ce
1f1f251
314f9e5
9b41c50
59f8547
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| import math | ||
| import os | ||
| import random | ||
| import sys | ||
| from collections import defaultdict | ||
| from contextlib import nullcontext | ||
| from functools import wraps | ||
|
|
@@ -340,8 +341,7 @@ def move_tensors_to_cpu(data): | |
| return data | ||
|
|
||
|
|
||
| @remote | ||
| class MLIPWorker: | ||
| class MLIPWorkerLocal: | ||
| def __init__( | ||
| self, | ||
| worker_id: int, | ||
|
|
@@ -361,54 +361,61 @@ def __init__( | |
| ) | ||
| self.master_port = get_free_port() if master_port is None else master_port | ||
| self.is_setup = False | ||
| self.last_received_atomic_data = None | ||
|
|
||
| def get_master_address_and_port(self): | ||
| return (self.master_address, self.master_port) | ||
|
|
||
| def get_device_for_local_rank(self): | ||
| return get_device_for_local_rank() | ||
|
|
||
| def _distributed_setup( | ||
| self, | ||
| worker_id: int, | ||
| master_port: int, | ||
| world_size: int, | ||
| predictor_config: dict, | ||
| master_address: str, | ||
| ): | ||
| # initialize distributed environment | ||
| # TODO, this wont work for multi-node, need to fix master addr | ||
| logging.info(f"Initializing worker {worker_id}...") | ||
| setup_env_local_multi_gpu(worker_id, master_port, master_address) | ||
| # local_rank = int(os.environ["LOCAL_RANK"]) | ||
| device = predictor_config.get("device", "cpu") | ||
| logging.info(f"Initializing worker {self.worker_id}...") | ||
| setup_env_local_multi_gpu(self.worker_id, self.master_port, self.master_address) | ||
|
|
||
| device = self.predictor_config.get("device", "cpu") | ||
| assign_device_for_local_rank(device == "cpu", 0) | ||
| backend = "gloo" if device == "cpu" else "nccl" | ||
| dist.init_process_group( | ||
| backend=backend, | ||
| rank=worker_id, | ||
| world_size=world_size, | ||
| rank=self.worker_id, | ||
| world_size=self.world_size, | ||
| ) | ||
| gp_utils.setup_graph_parallel_groups(world_size, backend) | ||
| self.predict_unit = hydra.utils.instantiate(predictor_config) | ||
| gp_utils.setup_graph_parallel_groups(self.world_size, backend) | ||
| self.predict_unit = hydra.utils.instantiate(self.predictor_config) | ||
| self.device = get_device_for_local_rank() | ||
| logging.info( | ||
| f"Worker {worker_id}, gpu_id: {ray.get_gpu_ids()}, loaded predict unit: {self.predict_unit}, " | ||
| f"on port {self.master_port}, with device: {get_device_for_local_rank()}, config: {self.predictor_config}" | ||
| f"Worker {self.worker_id}, gpu_id: {ray.get_gpu_ids()}, loaded predict unit: {self.predict_unit}, " | ||
| f"on port {self.master_port}, with device: {self.device}, config: {self.predictor_config}" | ||
| ) | ||
| self.is_setup = True | ||
|
|
||
| def predict(self, data: AtomicData) -> dict[str, torch.tensor] | None: | ||
| def predict( | ||
| self, data: AtomicData, use_nccl: bool = False | ||
| ) -> dict[str, torch.tensor] | None: | ||
| if not self.is_setup: | ||
| self._distributed_setup( | ||
| self.worker_id, | ||
| self.master_port, | ||
| self.world_size, | ||
| self.predictor_config, | ||
| self.master_address, | ||
| ) | ||
| self.is_setup = True | ||
| self._distributed_setup() | ||
|
|
||
| out = self.predict_unit.predict(data) | ||
| out = move_tensors_to_cpu(out) | ||
| if self.worker_id == 0: | ||
| return out | ||
| else: | ||
| return None | ||
| return move_tensors_to_cpu(out) | ||
|
|
||
| if self.worker_id != 0 and use_nccl: | ||
| self.last_received_atomic_data = data.to(self.device) | ||
| while True: | ||
| torch.distributed.broadcast(self.last_received_atomic_data.pos, src=0) | ||
| self.predict_unit.predict(self.last_received_atomic_data) | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| @remote | ||
| class MLIPWorker(MLIPWorkerLocal): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need a separate class/inheritance structure here? can we just add the functionality to MLIPworker and an additional function that is used for rank0? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried this, i got an error from ray, something about not being able to initialize a ray actor directly :'( |
||
| pass | ||
|
|
||
|
|
||
| @requires(ray_installed, message="Requires `ray` to be installed") | ||
|
|
@@ -434,6 +441,7 @@ def __init__( | |
| seed=seed, | ||
| atom_refs=atom_refs, | ||
| ) | ||
| self.inference_settings = inference_settings | ||
| self._dataset_to_tasks = copy.deepcopy(_mlip_pred_unit.dataset_to_tasks) | ||
|
|
||
| predict_unit_config = { | ||
|
|
@@ -446,6 +454,16 @@ def __init__( | |
| "atom_refs": atom_refs, | ||
| "assert_on_nans": assert_on_nans, | ||
| } | ||
|
|
||
| logging.basicConfig( | ||
| level=logging.INFO, | ||
| force=True, | ||
| stream=sys.stdout, | ||
| format="%(asctime)s %(levelname)s [%(processName)s] %(name)s: %(message)s", | ||
| ) | ||
| # Optional: keep Ray/uvicorn chatty logs in check | ||
| logging.getLogger("ray").setLevel(logging.INFO) | ||
| logging.getLogger("uvicorn").setLevel(logging.INFO) | ||
| if not ray.is_initialized(): | ||
| ray.init( | ||
| logging_level=logging.INFO, | ||
|
|
@@ -454,6 +472,8 @@ def __init__( | |
| # }, | ||
| ) | ||
|
|
||
| self.atomic_data_on_device = None | ||
|
|
||
| num_nodes = math.ceil(num_workers / num_workers_per_node) | ||
| num_workers_on_node_array = [num_workers_per_node] * num_nodes | ||
| if num_workers % num_workers_per_node > 0: | ||
|
|
@@ -473,7 +493,7 @@ def __init__( | |
| placement_groups.append(pg) | ||
| ray.get(pg.ready()) # Wait for each placement group to be scheduled | ||
|
|
||
| # place rank 0 on placement group 0 | ||
| # Need to still place worker to occupy space, otherwise ray double books this GPU | ||
| rank0_worker = MLIPWorker.options( | ||
| num_gpus=num_gpu_per_worker, | ||
| scheduling_strategy=PlacementGroupSchedulingStrategy( | ||
|
|
@@ -482,11 +502,18 @@ def __init__( | |
| placement_group_capture_child_tasks=True, # Ensure child tasks also run in this PG | ||
| ), | ||
| ).remote(0, num_workers, predict_unit_config) | ||
| master_addr, master_port = ray.get( | ||
| rank0_worker.get_master_address_and_port.remote() | ||
|
|
||
| local_gpu_or_cpu = ray.get(rank0_worker.get_device_for_local_rank.remote()) | ||
| os.environ[CURRENT_DEVICE_TYPE_STR] = local_gpu_or_cpu | ||
|
|
||
| self.workers = [] | ||
| self.local_rank0 = MLIPWorkerLocal( | ||
| worker_id=0, | ||
| world_size=num_workers, | ||
| predictor_config=predict_unit_config, | ||
| ) | ||
| master_addr, master_port = self.local_rank0.get_master_address_and_port() | ||
| logging.info(f"Started rank0 on {master_addr}:{master_port}") | ||
| self.workers = [rank0_worker] | ||
|
|
||
| # next place all ranks in order and pack them on placement groups | ||
| # ie: rank0-7 -> placement group 0, 8->15 -> placement group 1 etc. | ||
|
|
@@ -520,20 +547,21 @@ def __init__( | |
| self.workers.append(actor) | ||
| worker_id += 1 | ||
|
|
||
| def predict( | ||
| self, data: AtomicData, undo_element_references: bool = True | ||
| ) -> dict[str, torch.tensor]: | ||
| def predict(self, data: AtomicData) -> dict[str, torch.tensor]: | ||
| # put the reference in the object store only once | ||
| # this data transfer should be made more efficienct by using a shared memory transfer + nccl broadcast | ||
| data_ref = ray.put(data) | ||
| futures = [w.predict.remote(data_ref) for w in self.workers] | ||
| # just get the first result that is ready since they are identical | ||
| # the rest of the futures should go out of scope and memory garbage collected | ||
| # ready_ids, _ = ray.wait(futures, num_returns=1) | ||
| # result = ray.get(ready_ids[0]) | ||
| # result = ray.get(futures) | ||
| # return result[0] | ||
| return ray.get(futures[0]) | ||
| if not self.inference_settings.merge_mole or self.atomic_data_on_device is None: | ||
| data_ref = ray.put(data) | ||
| # this will put the ray works into an infinite loop listening for broadcasts | ||
| _futures = [ | ||
| w.predict.remote(data_ref, use_nccl=self.inference_settings.merge_mole) | ||
| for w in self.workers | ||
| ] | ||
| self.atomic_data_on_device = data.clone() | ||
| else: | ||
| self.atomic_data_on_device.pos = data.pos.to(self.local_rank0.device) | ||
| torch.distributed.broadcast(self.atomic_data_on_device.pos, src=0) | ||
|
|
||
| return self.local_rank0.predict(self.atomic_data_on_device) | ||
|
|
||
| @property | ||
| def dataset_to_tasks(self) -> dict[str, list]: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure I understand this while loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed one of the vars (md->use_nccl).
If we are using NCCL than ray workers outside of worker 0 should never leave the predict loop, just wait and listen on NCCL.
That way we dont incur overhead of communicating with them using Ray. The only other worker, worker 0 is in the same process as driver so no overhead there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh i see, alternatively you can pass None to the data input in the case that you are broadcasting? i guess if you want to eliminate any ray grpc communication overhead here..