Source code for torch_struct.linearchain

r"""

A linear-chain dynamic program.

Considers parameterized functions of the form :math:`f: {\cal Y} \rightarrow \mathbb{R}`.

Combinatorial set :math:`{y_{1:N} \in \cal Y}` with each :math:`y_n \in {1, \ldots, C}`

Function factors as :math:`f(y) = \prod_{n=1}^N \phi(n, y_n, y_n{-1})`

Example use cases:

* Part-of-Speech Tagging
* Sequence Labeling
* Hidden Markov Models

"""


import torch
from .helpers import _Struct


[docs]class LinearChain(_Struct): """ Represents structured linear-chain CRFs, generalizing HMMs smoothing, tagging models, and anything with chain-like dynamics. """ def _check_potentials(self, edge, lengths=None): batch, N_1, C, C2 = self._get_dimension(edge) edge = self.semiring.convert(edge) N = N_1 + 1 if lengths is None: lengths = torch.LongTensor([N] * batch).to(edge.device) # pass else: assert max(lengths) <= N, "Length longer than edge scores" assert max(lengths) == N, "One length must be at least N" assert C == C2, "Transition shape doesn't match" return edge, batch, N, C, lengths def logpartition(self, log_potentials, lengths=None, force_grad=False): "Compute forward pass by linear scan" # Setup semiring = self.semiring ssize = semiring.size() log_potentials, batch, N, C, lengths = self._check_potentials( log_potentials, lengths ) log_N, bin_N = self._bin_length(N - 1) chart = self._chart((batch, bin_N, C, C), log_potentials, force_grad) # Init init = torch.zeros_like(chart).bool() init.diagonal(0, 3, 4).fill_(True) chart = semiring.fill(chart, init, semiring.one) # Length mask big = torch.zeros( ssize, batch, bin_N, C, C, dtype=log_potentials.dtype, device=log_potentials.device, ) big[:, :, : N - 1] = log_potentials c = chart[:, :, :].view(ssize, batch * bin_N, C, C) lp = big[:, :, :].view(ssize, batch * bin_N, C, C) mask = torch.arange(bin_N).view(1, bin_N).expand(batch, bin_N).type_as(c) mask = mask >= (lengths - 1).view(batch, 1) mask = mask.view(batch * bin_N, 1, 1).to(lp.device) lp.data[:] = semiring.fill(lp.data, mask, semiring.zero) c.data[:] = semiring.fill(c.data, ~mask, semiring.zero) c[:] = semiring.sum(torch.stack([c.data, lp], dim=-1)) # Scan for n in range(1, log_N + 1): chart = semiring.matmul(chart[:, :, 1::2], chart[:, :, 0::2]) v = semiring.sum(semiring.sum(chart[:, :, 0].contiguous())) return v, [log_potentials] @staticmethod def to_parts(sequence, extra, lengths=None): """ Convert a sequence representation to edges Parameters: sequence : b x N long tensor in [0, C-1] extra : number of states lengths: b long tensor of N values Returns: edge : b x (N-1) x C x C markov indicators (t x z_t x z_{t-1}) """ C = extra batch, N = sequence.shape labels = torch.zeros(batch, N - 1, C, C).long() if lengths is None: lengths = torch.LongTensor([N] * batch) for n in range(1, N): labels[torch.arange(batch), n - 1, sequence[:, n], sequence[:, n - 1]] = 1 for b in range(batch): labels[b, lengths[b] - 1 :, :, :] = 0 return labels @staticmethod def from_parts(edge): """ Convert edges to sequence representation. Parameters: edge : b x (N-1) x C x C markov indicators (t x z_t x z_{t-1}) Returns: sequence : b x N long tensor in [0, C-1] """ batch, N_1, C, _ = edge.shape N = N_1 + 1 labels = torch.zeros(batch, N).long() on = edge.nonzero() for i in range(on.shape[0]): if on[i][1] == 0: labels[on[i][0], on[i][1]] = on[i][3] labels[on[i][0], on[i][1] + 1] = on[i][2] return labels, C # Adapters @staticmethod def hmm(transition, emission, init, observations): """ Convert HMM log-probs to a linear chain. Parameters: transition: C X C emission: V x C init: C observations: b x N between [0, V-1] Returns: edges: b x (N-1) x C x C """ V, C = emission.shape batch, N = observations.shape scores = torch.zeros(batch, N - 1, C, C).type_as(emission) scores[:, :, :, :] += transition.view(1, 1, C, C) scores[:, 0, :, :] += init.view(1, 1, C) obs = emission[observations.view(batch * N), :] scores[:, :, :, :] += obs.view(batch, N, C, 1)[:, 1:] scores[:, 0, :, :] += obs.view(batch, N, 1, C)[:, 0] return scores @staticmethod def _rand(min_n=2): b = torch.randint(2, 4, (1,)) N = torch.randint(min_n, 4, (1,)) C = torch.randint(2, 4, (1,)) return torch.rand(b, N, C, C), (b.item(), (N + 1).item())