For example, given logits, dim, and boundary,
boundary = torch.tensor([[0, 3, 4, 8, 0]
[1, 3, 5, 7, 9]]
# representing sections look like:
# [[00012222_]
# [_00112233]
# in shape: (2, 9)
# (sections cannot be sliced)
logits = torch.rand(2, 9, 100)
result = blocky_softmax(logits, dim = 1, boundary = boundary)
# result[:, :, 0] may look like:
# [[0.33, 0.33, 0.33, 1.00, 0.25, 0.25, 0.25, 0.25, 0.0 ]
# [0.0, 0.50, 0.50, 0.50, 0.50, 0.50, 0.50, 0.50, 0.50]]
# other 99 slices looks similar with each blocks sum to 1.
we hope the Softmax is applied to dim = 1, but sections are also applied to this dim.
My current implementation with PyTorch is using for
. It is slow and cost too much memory,
which looks like:
def blocky_softmax(logits, splits, map_inf_to = None):
_, batch_len, _ = logits.shape
exp_logits = logits.exp() # [2, 9, 100]
batch_seq_idx = torch.arange(batch_len, device = logits.device)[None, :]
base = torch.zeros_like(logits)
_, n_blocks = splits.shape
for nid in range(1, n_blocks):
start = splits[:, nid - 1, None]
end = splits[:, nid, None]
area = batch_seq_idx >= start
area &= batch_seq_idx < end
area.unsqueeze_(dim = 2)
blocky_z = area * blocky_z
base = base + blocky_z
if map_inf_to is not None:
good_base = base > 0
ones = torch.ones_like(base)
base = torch.where(good_base, base, ones)
exp_logits = torch.where(good_base, exp_logits, ones * map_inf_to)
return exp_logits / base
This implementation is slowed and fattened by n_blocks
times. But it could be parallel with each section.
If there is no off-the-shelf function, should I write a CUDA/C++ library? I hope you could help with my issue.
For further generalization, I hope there are discontinuities in boundary
/sections
.
sections = torch.tensor([[ 0, 0, 0, -1, 2, 3, 2, 3, 0, 3]
[-1, 0, 0, 1, 2, 1, 2, 1, -1, 1]]
# [[000_232303]
# [_0012121_1]]
Thank you for reading:)
I realize that
scatter_add
andgather
perfectly solve the problem.