import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import lazy_property
from .linearchain import LinearChain
from .cky import CKY
from .semimarkov import SemiMarkov
from .alignment import Alignment
from .deptree import DepTree, deptree_nonproj, deptree_part
from .cky_crf import CKY_CRF
from .semirings import (
LogSemiring,
MaxSemiring,
EntropySemiring,
CrossEntropySemiring,
KLDivergenceSemiring,
MultiSampledSemiring,
KMaxSemiring,
StdSemiring,
GumbelCRFSemiring,
)
[docs]class StructDistribution(Distribution):
r"""
Base structured distribution class.
Dynamic distribution for length N of structures :math:`p(z)`.
Implemented based on gradient identities from:
* Inside-outside and forward-backward algorithms are just backprop :cite:`eisner2016inside`
* Semiring Parsing :cite:`goodman1999semiring`
* First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
Parameters:
log_potentials (tensor, batch_shape x event_shape) : log-potentials :math:`\phi`
lengths (long tensor, batch_shape) : integers for length masking
"""
arg_constraints = {
"log_potentials": constraints.real,
"lengths": constraints.nonnegative_integer
}
def __init__(self, log_potentials, lengths=None, args={}, validate_args=False):
batch_shape = log_potentials.shape[:1]
event_shape = log_potentials.shape[1:]
self.log_potentials = log_potentials
self.lengths = lengths
self.args = args
super().__init__(batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args)
def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)
[docs] def log_prob(self, value):
"""
Compute log probability over values :math:`p(z)`.
Parameters:
value (tensor): One-hot events (*sample_shape x batch_shape x event_shape*)
Returns:
log_probs (*sample_shape x batch_shape*)
"""
d = value.dim()
batch_dims = range(d - len(self.event_shape))
v = self._struct().score(
self.log_potentials,
value.type_as(self.log_potentials),
batch_dims=batch_dims,
)
return v - self.partition
@lazy_property
def entropy(self):
"""
Compute entropy for distribution :math:`H[z]`.
Returns:
entropy (*batch_shape*)
"""
return self._struct(EntropySemiring).sum(self.log_potentials, self.lengths)
[docs] def cross_entropy(self, other):
"""
Compute cross-entropy for distribution p(self) and q(other) :math:`H[p, q]`.
Parameters:
other : Comparison distribution
Returns:
cross entropy (*batch_shape*)
"""
return self._struct(CrossEntropySemiring).sum(
[self.log_potentials, other.log_potentials], self.lengths
)
[docs] def kl(self, other):
"""
Compute KL-divergence for distribution p(self) and q(other) :math:`KL[p || q] = H[p, q] - H[p]`.
Parameters:
other : Comparison distribution
Returns:
cross entropy (*batch_shape*)
"""
return self._struct(KLDivergenceSemiring).sum(
[self.log_potentials, other.log_potentials], self.lengths
)
@lazy_property
def max(self):
r"""
Compute an max for distribution :math:`\max p(z)`.
Returns:
max (*batch_shape*)
"""
return self._struct(MaxSemiring).sum(self.log_potentials, self.lengths)
@lazy_property
def argmax(self):
r"""
Compute an argmax for distribution :math:`\arg\max p(z)`.
Returns:
argmax (*batch_shape x event_shape*)
"""
return self._struct(MaxSemiring).marginals(self.log_potentials, self.lengths)
[docs] def kmax(self, k):
r"""
Compute the k-max for distribution :math:`k\max p(z)`.
Parameters :
k : Number of solutions to return
Returns:
kmax (*k x batch_shape*)
"""
with torch.enable_grad():
return self._struct(KMaxSemiring(k)).sum(
self.log_potentials, self.lengths, _raw=True
)
[docs] def topk(self, k):
r"""
Compute the k-argmax for distribution :math:`k\max p(z)`.
Parameters :
k : Number of solutions to return
Returns:
kmax (*k x batch_shape x event_shape*)
"""
with torch.enable_grad():
return self._struct(KMaxSemiring(k)).marginals(
self.log_potentials, self.lengths, _raw=True
)
@lazy_property
def mode(self):
return self.argmax
@lazy_property
def marginals(self):
"""
Compute marginals for distribution :math:`p(z_t)`.
Can be used in higher-order calculations, i.e.
*
Returns:
marginals (*batch_shape x event_shape*)
"""
return self._struct(LogSemiring).marginals(self.log_potentials, self.lengths)
@lazy_property
def count(self):
"Compute the log-partition function."
ones = torch.ones_like(self.log_potentials)
ones[self.log_potentials.eq(-float("inf"))] = 0
return self._struct(StdSemiring).sum(ones, self.lengths)
def gumbel_crf(self, temperature=1.0):
with torch.enable_grad():
st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(
self.log_potentials, self.lengths
)
return st_gumbel
# @constraints.dependent_property
# def support(self):
# pass
# @property
# def param_shape(self):
# return self._param.size()
@lazy_property
def partition(self):
"Compute the log-partition function."
return self._struct(LogSemiring).sum(self.log_potentials, self.lengths)
[docs] def sample(self, sample_shape=torch.Size()):
r"""
Compute structured samples from the distribution :math:`z \sim p(z)`.
Parameters:
sample_shape (int): number of samples
Returns:
samples (*sample_shape x batch_shape x event_shape*)
"""
assert len(sample_shape) == 1
nsamples = sample_shape[0]
samples = []
for k in range(nsamples):
if k % 10 == 0:
sample = self._struct(MultiSampledSemiring).marginals(
self.log_potentials, lengths=self.lengths
)
sample = sample.detach()
tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % 10) + 1)
samples.append(tmp_sample)
return torch.stack(samples)
[docs] def to_event(self, sequence, extra, lengths=None):
"Convert simple representation to event."
return self.struct.to_parts(sequence, extra, lengths=lengths)
[docs] def from_event(self, event):
"Convert event to simple representation."
return self.struct.from_parts(event)
def _struct(self, sr=None):
return self.struct(sr if sr is not None else LogSemiring)
[docs]class LinearChainCRF(StructDistribution):
r"""
Represents structured linear-chain CRFs with C classes.
For reference see:
* An introduction to conditional random fields :cite:`sutton2012introduction`
Example application:
* Bidirectional LSTM-CRF Models for Sequence Tagging :cite:`huang2015bidirectional`
Event shape is of the form:
Parameters:
log_potentials (tensor) : event shape (*(N-1) x C x C*) e.g.
:math:`\phi(n, z_{n+1}, z_{n})`
lengths (long tensor) : batch_shape integers for length masking.
Compact representation: N long tensor in [0, ..., C-1]
Implementation uses linear-scan, forward-pass only.
* Parallel Time: :math:`O(\log(N))` parallel merges.
* Forward Memory: :math:`O(N \log(N) C^2)`
"""
struct = LinearChain
[docs]class AlignmentCRF(StructDistribution):
r"""
Represents basic alignment algorithm, i.e. dynamic-time warping, Needleman-Wunsch, and Smith-Waterman.
Event shape is of the form:
Parameters:
log_potentials (tensor) : event_shape (*N x M x 3*), e.g.
:math:`\phi(i, j, op)`
Ops are 0 -> j-1, 1->i-1,j-1, and 2->i-1
local (bool): if true computes local alignment (Smith-Waterman), else Needleman-Wunsch
max_gap (int or None): the maximum gap to allow in the dynamic program
lengths (long tensor) : batch shape integers for length masking.
Implementation uses convolution and linear-scan. Use max_gap for long sequences.
* Parallel Time: :math:`O(\log (M + N))` parallel merges.
* Forward Memory: :math:`O((M+N)^2)`
"""
struct = Alignment
arg_constraints = {
"log_potentials": constraints.real,
"local": constraints.boolean,
"max_gap": constraints.nonnegative_integer,
"lengths": constraints.nonnegative_integer
}
def __init__(self, log_potentials, local=False, lengths=None, max_gap=None, validate_args=False):
self.local = local
self.max_gap = max_gap
super().__init__(log_potentials, lengths, validate_args=validate_args)
def _struct(self, sr=None):
return self.struct(
sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap
)
[docs]class HMM(StructDistribution):
r"""
Represents hidden-markov smoothing with C hidden states.
Event shape is of the form:
Parameters:
transition (tensor): log-probabilities (*C X C*) :math:`p(z_n| z_n-1)`
emission (tensor): log-probabilities (*V x C*) :math:`p(x_n| z_n)`
init (tensor): log-probabilities (*C*) :math:`p(z_1)`
observations (long tensor): indices (*batch x N*) between [0, V-1]
Compact representation: N long tensor in [0, ..., C-1]
Implemented as a special case of linear chain CRF.
"""
def __init__(self, transition, emission, init, observations, lengths=None, validate_args=False):
log_potentials = HMM.struct.hmm(transition, emission, init, observations)
super().__init__(log_potentials, lengths, validate_args=validate_args)
struct = LinearChain
[docs]class SemiMarkovCRF(StructDistribution):
r"""
Represents a semi-markov or segmental CRF with C classes of max width K
Event shape is of the form:
Parameters:
log_potentials : event shape (*N x K x C x C*) e.g.
:math:`\phi(n, k, z_{n+1}, z_{n})`
lengths (long tensor) : batch shape integers for length masking.
Compact representation: N long tensor in [-1, 0, ..., C-1]
Implementation uses linear-scan, forward-pass only.
* Parallel Time: :math:`O(\log(N))` parallel merges.
* Forward Memory: :math:`O(N \log(N) C^2 K^2)`
"""
struct = SemiMarkov
[docs]class DependencyCRF(StructDistribution):
r"""
Represents a projective dependency CRF.
Reference:
* Bilexical grammars and their cubic-time parsing algorithms :cite:`eisner2000bilexical`
Event shape is of the form:
Parameters:
log_potentials (tensor) : event shape (*N x N*) head, child or (*N x N x L*) head,
child, labels with arc scores with root scores on diagonal
e.g. :math:`\phi(i, j)` where :math:`\phi(i, i)` is (root, i).
lengths (long tensor) : batch shape integers for length masking.
Compact representation: N long tensor in [0, .. N] (indexing is +1)
Implementation uses linear-scan, forward-pass only.
* Parallel Time: :math:`O(N)` parallel merges.
* Forward Memory: :math:`O(N \log(N) C^2 K^2)`
"""
def __init__(self, log_potentials, lengths=None, args={}, multiroot=True, validate_args=False):
super(DependencyCRF, self).__init__(log_potentials, lengths, args, validate_args=validate_args)
self.struct = DepTree
setattr(self.struct, "multiroot", multiroot)
[docs]class TreeCRF(StructDistribution):
r"""
Represents a 0th-order span parser with NT nonterminals. Implemented using a
fast CKY algorithm.
For example usage see:
* A Minimal Span-Based Neural Constituency Parser :cite:`stern2017minimal`
Event shape is of the form:
Parameters:
log_potentials (tensor) : event_shape (*N x N x NT*), e.g.
:math:`\phi(i, j, nt)`
lengths (long tensor) : batch shape integers for length masking.
Implementation uses width-batched, forward-pass only
* Parallel Time: :math:`O(N)` parallel merges.
* Forward Memory: :math:`O(N^2)`
Compact representation: *N x N x NT* long tensor (Same)
"""
struct = CKY_CRF
[docs]class SentCFG(StructDistribution):
"""
Represents a full generative context-free grammar with
non-terminals NT and terminals T.
Event shape is of the form:
Parameters:
log_potentials (tuple) : event tuple with event shapes
terms (*N x T*)
rules (*NT x (NT+T) x (NT+T)*)
root (*NT*)
lengths (long tensor) : batch shape integers for length masking.
Implementation uses width-batched, forward-pass only
* Parallel Time: :math:`O(N)` parallel merges.
* Forward Memory: :math:`O(N^2 (NT+T))`
Compact representation: (*N x N x NT*) long tensor
"""
struct = CKY
def __init__(self, log_potentials, lengths=None, validate_args=False):
batch_shape = log_potentials[0].shape[:1]
event_shape = log_potentials[0].shape[1:]
self.log_potentials = log_potentials
self.lengths = lengths
super(StructDistribution, self).__init__(
batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args
)
[docs]class NonProjectiveDependencyCRF(StructDistribution):
r"""
Represents a non-projective dependency CRF.
For references see:
* Non-projective dependency parsing using spanning tree algorithms :cite:`mcdonald2005non`
* Structured prediction models via the matrix-tree theorem :cite:`koo2007structured`
Event shape is of the form:
Parameters:
log_potentials (tensor) : event shape (*N x N*) head, child with
arc scores with root scores on diagonal e.g.
:math:`\phi(i, j)` where :math:`\phi(i, i)` is (root, i).
Compact representation: N long tensor in [0, .. N] (indexing is +1)
Note: Does not currently implement argmax (Chiu-Liu) or sampling.
"""
arg_constraints = {
"log_potentials": constraints.real
}
def __init__(self, log_potentials, lengths=None, args={}, multiroot=False, validate_args=False):
super(NonProjectiveDependencyCRF, self).__init__(log_potentials, lengths, args, validate_args=validate_args)
self.multiroot = multiroot
@lazy_property
def marginals(self):
"""
Compute marginals for distribution :math:`p(z_t)`.
Algorithm is :math:`O(N^3)` but very fast on batched GPU.
Returns:
marginals (*batch_shape x event_shape*)
"""
return deptree_nonproj(self.log_potentials, self.multiroot, self.lengths)
def sample(self, sample_shape=torch.Size()):
raise NotImplementedError()
@lazy_property
def partition(self):
"""
Compute the partition function.
"""
return deptree_part(self.log_potentials, self.multiroot, self.lengths)
@lazy_property
def argmax(self):
"""
Use Chiu-Liu Algorithm. :math:`O(N^2)`
(Currently not implemented)
"""
pass
@lazy_property
def entropy(self):
pass