Source code for torch_struct.semirings.semirings

import torch

has_genbmm = False
try:
    import genbmm

    has_genbmm = True
except ImportError:
    pass


def matmul(cls, a, b):
    dims = 1
    act_on = -(dims + 1)
    a = a.unsqueeze(-1)
    b = b.unsqueeze(act_on - 1)
    c = cls.times(a, b)
    for d in range(act_on, -1, 1):
        c = cls.sum(c.transpose(-2, -1))
    return c


class Semiring:
    """
    Base semiring class.

    Based on description in:

    * Semiring parsing :cite:`goodman1999semiring`

    """

    @classmethod
    def matmul(cls, a, b):
        "Generalized tensordot. Classes should override."
        return matmul(cls, a, b)

    @classmethod
    def size(cls):
        "Additional *ssize* first dimension needed."
        return 1

    @classmethod
    def dot(cls, a, b):
        "Dot product along last dim."
        a = a.unsqueeze(-2)
        b = b.unsqueeze(-1)
        return cls.matmul(a, b).squeeze(-1).squeeze(-1)

    @staticmethod
    def fill(c, mask, v):
        mask = mask.to(c.device)
        return torch.where(
            mask, v.type_as(c).view((-1,) + (1,) * (len(c.shape) - 1)), c
        )

    @classmethod
    def times(cls, *ls):
        "Multiply a list of tensors together"
        cur = ls[0]
        for l in ls[1:]:
            cur = cls.mul(cur, l)
        return cur

    @classmethod
    def convert(cls, potentials):
        "Convert to semiring by adding an extra first dimension."
        return potentials.unsqueeze(0)

    @classmethod
    def unconvert(cls, potentials):
        "Unconvert from semiring by removing extra first dimension."
        return potentials.squeeze(0)

    @staticmethod
    def sum(xs, dim=-1):
        "Sum over *dim* of tensor."
        raise NotImplementedError()

    @classmethod
    def plus(cls, a, b):
        return cls.sum(torch.stack([a, b], dim=-1))


class _Base(Semiring):
    zero = torch.tensor(0.0)
    one = torch.tensor(1.0)

    @staticmethod
    def mul(a, b):
        return torch.mul(a, b)

    @staticmethod
    def prod(a, dim=-1):
        return torch.prod(a, dim=dim)


class _BaseLog(Semiring):
    zero = torch.tensor(-1e5)
    one = torch.tensor(-0.0)

    @staticmethod
    def sum(xs, dim=-1):
        return torch.logsumexp(xs, dim=dim)

    @staticmethod
    def mul(a, b):
        return a + b

    @staticmethod
    def prod(a, dim=-1):
        return torch.sum(a, dim=dim)

    # @classmethod
    # def matmul(cls, a, b):
    #     return super(cls).matmul(a, b)


[docs]class StdSemiring(_Base): """ Implements the counting semiring (+, *, 0, 1). """ @staticmethod def sum(xs, dim=-1): return torch.sum(xs, dim=dim) @classmethod def matmul(cls, a, b): "Dot product along last dim" if has_genbmm and isinstance(a, genbmm.BandedMatrix): return b.multiply(a.transpose()) else: return torch.matmul(a, b)
[docs]class LogSemiring(_BaseLog): """ Implements the log-space semiring (logsumexp, +, -inf, 0). Gradients give marginals. """ @classmethod def matmul(cls, a, b): if has_genbmm and isinstance(a, genbmm.BandedMatrix): return b.multiply_log(a.transpose()) else: return _BaseLog.matmul(a, b)
[docs]class MaxSemiring(_BaseLog): """ Implements the max semiring (max, +, -inf, 0). Gradients give argmax. """ @classmethod def matmul(cls, a, b): if has_genbmm and isinstance(a, genbmm.BandedMatrix): return b.multiply_max(a.transpose()) else: return matmul(cls, a, b) @staticmethod def sum(xs, dim=-1): return torch.max(xs, dim=dim)[0] @staticmethod def sparse_sum(xs, dim=-1): m, a = torch.max(xs, dim=dim) return m, (torch.zeros(a.shape).long(), a)
def KMaxSemiring(k): "Implements the k-max semiring (kmax, +, [-inf, -inf..], [0, -inf, ...])." class KMaxSemiring(_BaseLog): zero = torch.tensor([-1e5 for i in range(k)]) one = torch.tensor([0 if i == 0 else -1e5 for i in range(k)]) @staticmethod def size(): return k @classmethod def convert(cls, orig_potentials): potentials = torch.zeros( (k,) + orig_potentials.shape, dtype=orig_potentials.dtype, device=orig_potentials.device, ) potentials = cls.fill(potentials, torch.tensor(True), cls.zero) potentials[0] = orig_potentials return potentials @staticmethod def unconvert(potentials): return potentials[0] @staticmethod def sum(xs, dim=-1): if dim == -1: xs = xs.permute(tuple(range(1, xs.dim())) + (0,)) xs = xs.contiguous().view(xs.shape[:-2] + (-1,)) xs = torch.topk(xs, k, dim=-1)[0] xs = xs.permute((xs.dim() - 1,) + tuple(range(0, xs.dim() - 1))) assert xs.shape[0] == k return xs assert False @staticmethod def sparse_sum(xs, dim=-1): if dim == -1: xs = xs.permute(tuple(range(1, xs.dim())) + (0,)) xs = xs.contiguous().view(xs.shape[:-2] + (-1,)) xs, xs2 = torch.topk(xs, k, dim=-1) xs = xs.permute((xs.dim() - 1,) + tuple(range(0, xs.dim() - 1))) xs2 = xs2.permute((xs.dim() - 1,) + tuple(range(0, xs.dim() - 1))) assert xs.shape[0] == k return xs, (xs2 % k, xs2 // k) assert False @staticmethod def mul(a, b): a = a.view((k, 1) + a.shape[1:]) b = b.view((1, k) + b.shape[1:]) c = a + b c = c.contiguous().view((k * k,) + c.shape[2:]) ret = torch.topk(c, k, 0)[0] assert ret.shape[0] == k return ret return KMaxSemiring class KLDivergenceSemiring(Semiring): """ Implements an KL-divergence semiring. Computes both the log-values of two distributions and the running KL divergence between two distributions. Based on descriptions in: * Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter` * First-and second-order expectation semirings with applications to minimumrisk training on translation forests :cite:`li2009first` * Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf` """ zero = torch.tensor([-1e5, -1e5, 0.0]) one = torch.tensor([0.0, 0.0, 0.0]) @staticmethod def size(): return 3 @staticmethod def convert(xs): values = torch.zeros((3,) + xs[0].shape).type_as(xs[0]) values[0] = xs[0] values[1] = xs[1] values[2] = 0 return values @staticmethod def unconvert(xs): return xs[-1] @staticmethod def sum(xs, dim=-1): assert dim != 0 d = dim - 1 if dim > 0 else dim part_p = torch.logsumexp(xs[0], dim=d) part_q = torch.logsumexp(xs[1], dim=d) log_sm_p = xs[0] - part_p.unsqueeze(d) log_sm_q = xs[1] - part_q.unsqueeze(d) sm_p = log_sm_p.exp() return torch.stack( ( part_p, part_q, torch.sum( xs[2].mul(sm_p) - log_sm_q.mul(sm_p) + log_sm_p.mul(sm_p), dim=d ), ) ) @staticmethod def mul(a, b): return torch.stack((a[0] + b[0], a[1] + b[1], a[2] + b[2])) @classmethod def prod(cls, xs, dim=-1): return xs.sum(dim) class CrossEntropySemiring(Semiring): """ Implements an cross-entropy expectation semiring. Computes both the log-values of two distributions and the running cross entropy between two distributions. Based on descriptions in: * Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter` * First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first` * Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf` """ zero = torch.tensor([-1e5, -1e5, 0.0]) one = torch.tensor([0.0, 0.0, 0.0]) @staticmethod def size(): return 3 @staticmethod def convert(xs): values = torch.zeros((3,) + xs[0].shape).type_as(xs[0]) values[0] = xs[0] values[1] = xs[1] values[2] = 0 return values @staticmethod def unconvert(xs): return xs[-1] @staticmethod def sum(xs, dim=-1): assert dim != 0 d = dim - 1 if dim > 0 else dim part_p = torch.logsumexp(xs[0], dim=d) part_q = torch.logsumexp(xs[1], dim=d) log_sm_p = xs[0] - part_p.unsqueeze(d) log_sm_q = xs[1] - part_q.unsqueeze(d) sm_p = log_sm_p.exp() return torch.stack( (part_p, part_q, torch.sum(xs[2].mul(sm_p) - log_sm_q.mul(sm_p), dim=d)) ) @staticmethod def mul(a, b): return torch.stack((a[0] + b[0], a[1] + b[1], a[2] + b[2])) @classmethod def prod(cls, xs, dim=-1): return xs.sum(dim)
[docs]class EntropySemiring(Semiring): """ Implements an entropy expectation semiring. Computes both the log-values and the running distributional entropy. Based on descriptions in: * Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter` * First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first` * Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf` """ zero = torch.tensor([-1e5, 0.0]) one = torch.tensor([0.0, 0.0]) @staticmethod def size(): return 2 @staticmethod def convert(xs): values = torch.zeros((2,) + xs.shape).type_as(xs) values[0] = xs values[1] = 0 return values @staticmethod def unconvert(xs): return xs[1] @staticmethod def sum(xs, dim=-1): assert dim != 0 d = dim - 1 if dim > 0 else dim part = torch.logsumexp(xs[0], dim=d) log_sm = xs[0] - part.unsqueeze(d) sm = log_sm.exp() return torch.stack((part, torch.sum(xs[1].mul(sm) - log_sm.mul(sm), dim=d))) @staticmethod def mul(a, b): return torch.stack((a[0] + b[0], a[1] + b[1])) @classmethod def prod(cls, xs, dim=-1): return xs.sum(dim)
def TempMax(alpha): class _TempMax(_BaseLog): """ Implements a max forward, hot softmax backward. """ @staticmethod def sum(xs, dim=-1): pass @staticmethod def sparse_sum(xs, dim=-1): m, _ = torch.max(xs, dim=dim) a = torch.softmax(alpha * xs, dim) return m, (torch.zeros(a.shape[:-1]).long(), a) return _TempMax