Closed
Description
❓ Questions and Help
Let's say my dataloader yields a dict when iterating over and the members of this dict has different dimensions
{
"input_ids": shape = (batch, seq),
"masks": shape = (batch, seq, seq),
}
pl.MpDeviceLoader
appears to only able to provide one sharding annotation. I'm currently using it like this:
data_loader = pl.MpDeviceLoader(
data_loader,
dev,
input_sharding=xs.ShardingSpec(mesh, ('data', None, None)))
Obviously, ('data', None, None) is not valid for input_ids
which has only 2 dimensions. But this seems to work. I wonder what's the proper way of using MpDeviceLoader
in this case.