Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
abhilash1910 committed Dec 5, 2023
1 parent 618de50 commit 685ae22
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,10 @@ def wrapper(*args, **kwargs):
tensor = kwargs["tensor"]
else:
tensor = args[0]
state = PartialState()
if state.device.type != tensor.device.type:
raise RuntimeError(
f"One or more of the tensors passed to {operation} were not on the {tensor.device.type} while the `Accelerator` is configured for {state.device.type}. "
f"Please move it to the {state.device.type} before calling {operation}."
if PartialState().device.type != tensor.device.type:
raise DistributedOperationException(
f"One or more of the tensors passed to {operation} were not on the {tensor.device.type} while the `Accelerator` is configured for {PartialState().device.type}. "
f"Please move it to the {PartialState().device.type} before calling {operation}."
)
shapes = get_shape(tensor)
output = gather_object([shapes])
Expand Down

0 comments on commit 685ae22

Please sign in to comment.