diff --git a/torchdrug/layers/functional/functional.py b/torchdrug/layers/functional/functional.py index bbb989c..0e48997 100644 --- a/torchdrug/layers/functional/functional.py +++ b/torchdrug/layers/functional/functional.py @@ -163,11 +163,22 @@ def variadic_sum(input, size): """ Compute sum over sets with variadic sizes. - Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`. + Suppose there are :math:`N` sets, and the sizes of all sets :math:`\sum_{i=0}^{N-1} n_i` are summed to :math:`B`. - Parameters: + Input: input (Tensor): input of shape :math:`(B, ...)` - size (LongTensor): size of sets of shape :math:`(N,)` + size (LongTensor): size of sets of shape :math:`(n_0, n_1, ..., n_{N-1})` + + Output: + value (Tensor): output of shape :math:`(N, ...)` + + Example: + >>> input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9,], [4, 5, 6], [7, 8, 9,]]) + >>> size = torch.tensor([1,1,3]) + >>> print(variadic_sum(input, size)) + tensor([[ 1, 2, 3], + [ 4, 5, 6], + [18, 21, 24]]) """ index2sample = torch.repeat_interleave(size) index2sample = index2sample.view([-1] + [1] * (input.ndim - 1)) @@ -358,7 +369,7 @@ def variadic_sort(input, size, descending=False): input (Tensor): input of shape :math:`(B, ...)` size (LongTensor): size of sets of shape :math:`(N,)` descending (bool, optional): return ascending or descending order - + Returns (Tensor, LongTensor): sorted values and indexes """