Replies: 1 comment 2 replies
-
Well inside, shard_map, nothing is sharded anymore so all you can do is track which mesh axes you are varying across. Can you explain more about your use case? Maybe there is a different way to do what you want. |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
While looking through the new shard_map examples I had a thought about a potential feature that I think is missing. I am not sure how often it would be useful, but I think it could make debugging more clear.
Let's assume you have a multidimensional array that is sharded along one axis and you pass it to a function using shard_map to interact with the data along the sharded axis.
Then the array
x
inside shard_map will have a smaller shape along the axis that is sharded. But when writing more complex code inside shard_map I think it might be harder to keep track of which arrays are "normal" arrays and which ones are "shard_map views into sharded arrays".x.vma
does tell you thatx
isshard_map
ped, but it does not give us the information which of the dimensions ofx
is. And I guess this will get worse if there are multiple sharded dimensions.Or in other words: If this is easily possible and others find this useful to, would it be possible to also provide the information that
gpus
applies along the second axis ofx
tox
?Beta Was this translation helpful? Give feedback.
All reactions