Source code for torch_struct.semirings.sparse_max

import torch
from .semirings import _BaseLog


[docs]class SparseMaxSemiring(_BaseLog): """ Implements differentiable dynamic programming with a sparsemax semiring (sparsemax, +, -inf, 0). Sparse-max gradients give a more sparse set of marginal like terms. * From softmax to sparsemax- A sparse model of attention and multi-label classification :cite:`martins2016softmax` * Differentiable dynamic programming for structured prediction and attention :cite:`mensch2018differentiable` """ @staticmethod def sum(xs, dim=-1): return _SimplexProject.apply(xs, dim)
class _SimplexProject(torch.autograd.Function): @staticmethod def forward(ctx, input, dim, z=1): w_star = project_simplex(input, dim) ctx.save_for_backward(input, w_star.clone(), torch.tensor(dim)) x = input.mul(w_star).sum(dim) - w_star.norm(p=2, dim=dim) return x @staticmethod def backward(ctx, grad_output): input, w_star, dim = ctx.saved_tensors w_star.requires_grad_(True) grad_input = None if ctx.needs_input_grad[0]: wstar = _SparseMaxGrad.apply(w_star, dim) grad_input = grad_output.unsqueeze(dim).mul(wstar) return grad_input, None, None class _SparseMaxGrad(torch.autograd.Function): @staticmethod def forward(ctx, w_star, dim): ctx.save_for_backward(w_star, dim) return w_star @staticmethod def backward(ctx, grad_output): w_star, dim = ctx.saved_tensors return sparsemax_grad(grad_output, w_star, dim.item()), None def project_simplex(v, dim, z=1): v_sorted, _ = torch.sort(v, dim=dim, descending=True) cssv = torch.cumsum(v_sorted, dim=dim) - z ind = torch.arange(1, 1 + v.shape[dim]).to(dtype=v.dtype).to(v.device) cond = v_sorted - cssv / ind >= 0 k = cond.sum(dim=dim, keepdim=True) tau = cssv.gather(dim, k - 1) / k.to(dtype=v.dtype) w = torch.clamp(v - tau, min=0) return w def sparsemax_grad(dout, w_star, dim): out = dout.clone() supp = w_star > 0 out[w_star <= 0] = 0 nnz = supp.to(dtype=dout.dtype).sum(dim=dim, keepdim=True) out = out - (out.sum(dim=dim, keepdim=True) / nnz) out[w_star <= 0] = 0 return out