Is there any Softmax implementation with sections along the dim (blocky Softmax) in PyTorch?

111 Views Asked by At

For example, given logits, dim, and boundary,

boundary = torch.tensor([[0, 3, 4, 8, 0]
                         [1, 3, 5, 7, 9]]

# representing sections look like:
#    [[00012222_]
#     [_00112233]
#  in shape: (2, 9)
# (sections cannot be sliced)

logits = torch.rand(2, 9, 100)
result = blocky_softmax(logits, dim = 1, boundary = boundary)

# result[:, :, 0] may look like:
#   [[0.33, 0.33, 0.33, 1.00, 0.25, 0.25, 0.25, 0.25, 0.0 ]
#    [0.0,  0.50, 0.50, 0.50, 0.50, 0.50, 0.50, 0.50, 0.50]]
# other 99 slices looks similar with each blocks sum to 1.

we hope the Softmax is applied to dim = 1, but sections are also applied to this dim. My current implementation with PyTorch is using for. It is slow and cost too much memory, which looks like:

def blocky_softmax(logits, splits, map_inf_to = None):
    _, batch_len, _ = logits.shape
    exp_logits    = logits.exp() # [2, 9, 100]
    batch_seq_idx = torch.arange(batch_len, device = logits.device)[None, :]
    base          = torch.zeros_like(logits)
    _, n_blocks   = splits.shape
    for nid in range(1, n_blocks):
        start = splits[:, nid - 1, None]
        end   = splits[:, nid,     None]
        area = batch_seq_idx >= start
        area &= batch_seq_idx < end
        area.unsqueeze_(dim = 2)
        blocky_z = area * blocky_z
        base = base + blocky_z
    if map_inf_to is not None:
        good_base = base > 0
        ones = torch.ones_like(base)
        base = torch.where(good_base, base, ones)
        exp_logits = torch.where(good_base, exp_logits, ones * map_inf_to)
    return exp_logits / base

This implementation is slowed and fattened by n_blocks times. But it could be parallel with each section. If there is no off-the-shelf function, should I write a CUDA/C++ library? I hope you could help with my issue.

For further generalization, I hope there are discontinuities in boundary/sections.

sections = torch.tensor([[ 0,  0,  0, -1,  2,  3,  2,  3,  0,  3]
                         [-1,  0,  0,  1,  2,  1,  2,  1, -1,  1]]
# [[000_232303]
#  [_0012121_1]]

Thank you for reading:)

1

There are 1 best solutions below

0
On

I realize that scatter_add and gather perfectly solve the problem.