Source code for torch_struct.autoregressive

import torch
from .semirings import MaxSemiring, KMaxSemiring, TempMax
from torch.distributions.distribution import Distribution


[docs]class AutoregressiveModel(torch.nn.Module): """ User should implement as their favorite RNN / Transformer / etc. """ def forward(self, inputs, state=None): r""" Compute the logits for all tokens in a batched sequence :math:`p(y_{t+1}, ... y_{T}| y_1 \ldots t)` Parameters: inputs (batch_size x N x C ): next tokens to update representation state (tuple of batch_size x ...): everything needed for conditioning. Retuns: logits (*batch_size x C*): next set of logits. state (*tuple of batch_size x ...*): next set of logits. """ pass
def wrap(state, ssize): return state.contiguous().view(ssize, -1, *state.shape[1:]) def unwrap(state): return state.contiguous().view(-1, *state.shape[2:])
[docs]class Autoregressive(Distribution): """ Autoregressive sequence model utilizing beam search. * batch_shape -> Given by initializer * event_shape -> N x T sequence of choices Parameters: model (AutoregressiveModel): A lazily computed autoregressive model. init (tuple of tensors, batch_shape x ...): initial state of autoregressive model. n_classes (int): number of classes in each time step n_length (int): max length of sequence """ def __init__( self, model, initial_state, n_classes, n_length, normalize=True, start_class=0, end_class=None, ): self.model = model self.init = initial_state self.n_length = n_length self.n_classes = n_classes self.start_class = start_class self.normalize = normalize event_shape = (n_length, n_classes) batch_shape = initial_state[0].shape[:1] self.device = initial_state[0].device super().__init__(batch_shape=batch_shape, event_shape=event_shape) def log_prob(self, value, sparse=False): """ Compute log probability over values :math:`p(z)`. Parameters: value (tensor): One-hot events (*sample_shape x batch_shape x N*) sparse (bool): implement sparse Returns: log_probs (*sample_shape x batch_shape*) """ value = value.long() if not sparse: sample, batch_shape, n_length, n_classes = value.shape value = ( (value * torch.arange(n_classes).view(1, 1, n_classes)).sum(-1).long() ) else: sample, batch_shape, n_length = value.shape value = torch.cat( [ torch.zeros(sample, batch_shape, 1, device=value.device) .fill_(self.start_class) .long(), value, ], dim=2, ) value = unwrap(value) state = tuple( (unwrap(i.unsqueeze(0).expand((sample,) + i.shape)) for i in self.init) ) logits, _ = self.model(value, state) b2, n2, c2 = logits.shape assert ( (b2 == sample * batch_shape) and (n2 == n_length + 1) and (c2 == self.n_classes) ), "Model should return logits of shape `batch x N x C` " if self.normalize: log_probs = logits.log_softmax(-1) else: log_probs = logits scores = log_probs[:, :-1].gather(2, value[:, 1:].unsqueeze(-1)).sum(-1).sum(-1) return wrap(scores, sample) def _beam_search(self, semiring, gumbel=False): beam = semiring.fill( torch.zeros((semiring.size(),) + self.batch_shape, device=self.device), torch.tensor(True), semiring.one, ) ssize = semiring.size() def take(state, indices): return tuple( ( s.contiguous()[ ( indices * self.batch_shape[0] + torch.arange( self.batch_shape[0], device=self.device ).unsqueeze(0) ) .contiguous() .view(-1) ] for s in state ) ) tokens = ( torch.zeros((ssize * self.batch_shape[0])).long().fill_(self.start_class) ) state = tuple( (unwrap(i.unsqueeze(0).expand((ssize,) + i.shape)) for i in self.init) ) # Beam Search all_beams = [] all_logits = [] for t in range(0, self.n_length): logits, state = self.model(unwrap(tokens).unsqueeze(1), state) b2, n2, c2 = logits.shape assert ( (b2 == ssize * self.batch_shape[0]) and (n2 == 1) and (c2 == self.n_classes) ), "Model should return logits of shape `batch x N x C` " for s in state: assert ( s.shape[0] == ssize * self.batch_shape[0] ), "Model should return state tuple with shapes `batch x ...` " logits = wrap(logits.squeeze(1), ssize) if gumbel: logits = logits + torch.distributions.Gumbel(0.0, 1.0).sample( logits.shape ) if self.normalize: logits = logits.log_softmax(-1) all_logits.append(logits) ex_beam = beam.unsqueeze(-1) + logits ex_beam.requires_grad_(True) all_beams.append(ex_beam) beam, (positions, tokens) = semiring.sparse_sum(ex_beam) state = take(state, positions) # Back pointers v = beam all_m = [] for k in range(v.shape[0]): obj = v[k].sum(dim=0) marg = torch.autograd.grad( obj, all_beams, create_graph=True, only_inputs=True, allow_unused=False ) marg = torch.stack(marg, dim=2) all_m.append(marg.sum(0)) return torch.stack(all_m, dim=0), v, torch.stack(all_logits, dim=2)
[docs] def greedy_max(self): """ Compute "argmax" using greedy search. Returns: greedy_path (*batch x N x C*) greedy_max (*batch*) logits (*batch x N x C*) """ a, b, c = self._beam_search(MaxSemiring) return a.squeeze(0), b.squeeze(0), c.squeeze(0)
[docs] def greedy_tempmax(self, alpha): """ Compute differentiable scheduled sampling using greedy search. Based on: * Differentiable Scheduled Sampling for Credit Assignment :cite:`goyal2017differentiable` Parameters: alpha : alpha param Returns: greedy_path (*batch x N x C*) greedy_max (*batch*) logits (*batch x N x C*) """ a, b, c = self._beam_search(TempMax(alpha), alpha) return a.squeeze(0), b.squeeze(0), c.squeeze(0)
[docs] def beam_topk(self, K): """ Compute "top-k" using beam search Parameters: K : top-k Returns: paths (*K x batch x N x C*) """ return self._beam_search(KMaxSemiring(K))[0]
def _beam_max(self, K): return self._beam_search(KMaxSemiring(K))[1]
[docs] def sample_without_replacement(self, sample_shape=torch.Size()): """ Compute sampling without replacement using Gumbel trick. Based on: * Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement :cite:`DBLP:journals/corr/abs-1903-06059` Parameters: sample_shape (torch.Size): batch_size Returns: paths (*K x batch x N x C*) """ K = sample_shape[0] return self._beam_search(KMaxSemiring(K), gumbel=True)[0]
[docs] def sample(self, sample_shape=torch.Size()): r""" Compute structured samples from the distribution :math:`z \sim p(z)`. Parameters: sample_shape (torch.Size): number of samples Returns: samples (*sample_shape x batch_shape x event_shape*) """ sample_shape = sample_shape[0] state = tuple( ( unwrap(i.unsqueeze(0).expand((sample_shape,) + i.shape)) for i in self.init ) ) all_tokens = [] tokens = ( torch.zeros((sample_shape * self.batch_shape[0])) .long() .fill_(self.start_class) ) for t in range(0, self.n_length): logits, state = self.model(tokens.unsqueeze(-1), state) logits = logits.squeeze(1) tokens = torch.distributions.Categorical(logits=logits).sample((1,))[0] all_tokens.append(tokens) v = wrap(torch.stack(all_tokens, dim=1), sample_shape) return torch.nn.functional.one_hot(v, self.n_classes)