This is to unblock @cowanmeg and @samnordmann 's distributed matmul experiments.
I'll start with the tensor parallelism proposed by the original Megatron-LM paper.
- Only MHA and MLP are sharded.
- Activations are sharded in 2D, batch and hidden. However, the batch dimension sharding is just for data parallelism and the dimension is never resharded.
- Weights are sharded in 1D, the hidden dimension.