Model

[1]:
import torch_struct
import torch
import matplotlib.pyplot as plt
import matplotlib
[2]:
matplotlib.rcParams['figure.figsize'] = (7.0, 7.0)

Chain

class torch_struct.LinearChainCRF(log_potentials, lengths=None, args={}, validate_args=False)[source]

Represents structured linear-chain CRFs with C classes.

For reference see:

  • An introduction to conditional random fields [SM+12]

Example application:

  • Bidirectional LSTM-CRF Models for Sequence Tagging [HXY15]

Event shape is of the form:

Parameters
  • log_potentials (tensor) – event shape ((N-1) x C x C) e.g. \(\phi(n, z_{n+1}, z_{n})\)

  • lengths (long tensor) – batch_shape integers for length masking.

Compact representation: N long tensor in [0, …, C-1]

Implementation uses linear-scan, forward-pass only.

  • Parallel Time: \(O(\log(N))\) parallel merges.

  • Forward Memory: \(O(N \log(N) C^2)\)

[3]:
batch, N, C = 3, 7, 2
def show_chain(chain):
    plt.imshow(chain.detach().sum(-1).transpose(0, 1))

# batch, N, z_n, z_n_1
log_potentials = torch.rand(batch, N, C, C)
dist = torch_struct.LinearChainCRF(log_potentials)
show_chain(dist.argmax[0])
_images/model_6_0.png
[4]:
show_chain(dist.marginals[0])
_images/model_7_0.png
[5]:
event = dist.to_event(torch.tensor([[0, 1, 0, 1, 1, 1, 0, 1]]), 2)
show_chain(event[0])
_images/model_8_0.png

Hidden Markov Model

class torch_struct.HMM(transition, emission, init, observations, lengths=None, validate_args=False)[source]

Represents hidden-markov smoothing with C hidden states.

Event shape is of the form:

Parameters
  • transition (tensor) – log-probabilities (C X C) \(p(z_n| z_n-1)\)

  • emission (tensor) – log-probabilities (V x C) \(p(x_n| z_n)\)

  • init (tensor) – log-probabilities (C) \(p(z_1)\)

  • observations (long tensor) – indices (batch x N) between [0, V-1]

Compact representation: N long tensor in [0, …, C-1]

Implemented as a special case of linear chain CRF.

[6]:
batch, V, N, C = 10, 3, 7, 2

transition = torch.rand(C, C).log_softmax(0)
emission = torch.rand(V, C).log_softmax(0)
init = torch.rand(C).log_softmax(0)
observations = torch.randint(0, V, size=(batch, N))

dist = torch_struct.HMM(transition, emission, init, observations)
show_chain(dist.argmax[0])
_images/model_11_0.png

Semi-Markov

class torch_struct.SemiMarkovCRF(log_potentials, lengths=None, args={}, validate_args=False)[source]

Represents a semi-markov or segmental CRF with C classes of max width K

Event shape is of the form:

Parameters
  • log_potentials – event shape (N x K x C x C) e.g. \(\phi(n, k, z_{n+1}, z_{n})\)

  • lengths (long tensor) – batch shape integers for length masking.

Compact representation: N long tensor in [-1, 0, …, C-1]

Implementation uses linear-scan, forward-pass only.

  • Parallel Time: \(O(\log(N))\) parallel merges.

  • Forward Memory: \(O(N \log(N) C^2 K^2)\)

[7]:
batch, N, C, K = 3, 10, 2, 6
def show_sm(chain):
    plt.imshow(chain.detach().sum(1).sum(-1).transpose(0, 1))

# batch, N, K, z_n, z_n_1
log_potentials = torch.rand(batch, N, K, C, C)
log_potentials[:, :, :3] = -1e9
dist = torch_struct.SemiMarkovCRF(log_potentials)
show_sm(dist.argmax[0])
_images/model_14_0.png
[8]:
show_sm(dist.marginals[0])
_images/model_15_0.png
[9]:
# Use -1 for segments.
event = dist.to_event(torch.tensor([[0, 1, -1, 1, -1, -1, 0, 1,  1, -1, -1]]), (2, 6))
show_sm(event[0])
_images/model_16_0.png

Alignment

class torch_struct.AlignmentCRF(log_potentials, local=False, lengths=None, max_gap=None, validate_args=False)[source]

Represents basic alignment algorithm, i.e. dynamic-time warping, Needleman-Wunsch, and Smith-Waterman.

Event shape is of the form:

Parameters
  • log_potentials (tensor) –

    event_shape (N x M x 3), e.g.

    \(\phi(i, j, op)\)

    Ops are 0 -> j-1, 1->i-1,j-1, and 2->i-1

  • local (bool) – if true computes local alignment (Smith-Waterman), else Needleman-Wunsch

  • max_gap (int or None) – the maximum gap to allow in the dynamic program

  • lengths (long tensor) – batch shape integers for length masking.

Implementation uses convolution and linear-scan. Use max_gap for long sequences.

  • Parallel Time: \(O(\log (M + N))\) parallel merges.

  • Forward Memory: \(O((M+N)^2)\)

[10]:
batch, N, M = 3, 15, 20
def show_deps(tree):
    plt.imshow(tree.detach())

log_potentials = torch.rand(batch, N, M, 3)
dist = torch_struct.AlignmentCRF(log_potentials)
show_deps(dist.argmax[0])
_images/model_19_0.png
[11]:
show_deps(dist.marginals[0])
_images/model_20_0.png

Dependency Tree

class torch_struct.DependencyCRF(log_potentials, lengths=None, args={}, multiroot=True, validate_args=False)[source]

Represents a projective dependency CRF.

Reference:

  • Bilexical grammars and their cubic-time parsing algorithms [Eis00]

Event shape is of the form:

Parameters
  • log_potentials (tensor) – event shape (N x N) head, child or (N x N x L) head, child, labels with arc scores with root scores on diagonal e.g. \(\phi(i, j)\) where \(\phi(i, i)\) is (root, i).

  • lengths (long tensor) – batch shape integers for length masking.

Compact representation: N long tensor in [0, .. N] (indexing is +1)

Implementation uses linear-scan, forward-pass only.

  • Parallel Time: \(O(N)\) parallel merges.

  • Forward Memory: \(O(N \log(N) C^2 K^2)\)

[12]:
batch, N, N = 3, 10, 10
def show_deps(tree):
    plt.imshow(tree.detach())

log_potentials = torch.rand(batch, N, N)
dist = torch_struct.DependencyCRF(log_potentials)
show_deps(dist.argmax[0])
_images/model_23_0.png
[13]:
show_deps(dist.marginals[0])
_images/model_24_0.png
[14]:
# Convert from 1-index standard format. (Head is 0)
event = dist.to_event(torch.tensor([[2, 3, 4, 1, 0, 4]]), None)
show_deps(event[0])
_images/model_25_0.png

Non-Projective Dependency Tree

class torch_struct.NonProjectiveDependencyCRF(log_potentials, lengths=None, args={}, multiroot=False, validate_args=False)[source]

Represents a non-projective dependency CRF.

For references see:

  • Non-projective dependency parsing using spanning tree algorithms [MPRHajivc05]

  • Structured prediction models via the matrix-tree theorem [KGCPerezC07]

Event shape is of the form:

Parameters

log_potentials (tensor) – event shape (N x N) head, child with arc scores with root scores on diagonal e.g. \(\phi(i, j)\) where \(\phi(i, i)\) is (root, i).

Compact representation: N long tensor in [0, .. N] (indexing is +1)

Note: Does not currently implement argmax (Chiu-Liu) or sampling.

[15]:
batch, N, N = 3, 10, 10
def show_deps(tree):
    plt.imshow(tree.detach())

log_potentials = torch.rand(batch, N, N)
dist = torch_struct.NonProjectiveDependencyCRF(log_potentials)
show_deps(dist.marginals[0])
_images/model_28_0.png

Binary Labeled Tree

class torch_struct.TreeCRF(log_potentials, lengths=None, args={}, validate_args=False)[source]

Represents a 0th-order span parser with NT nonterminals. Implemented using a fast CKY algorithm.

For example usage see:

  • A Minimal Span-Based Neural Constituency Parser [SAK17]

Event shape is of the form:

Parameters
  • log_potentials (tensor) – event_shape (N x N x NT), e.g. \(\phi(i, j, nt)\)

  • lengths (long tensor) – batch shape integers for length masking.

Implementation uses width-batched, forward-pass only

  • Parallel Time: \(O(N)\) parallel merges.

  • Forward Memory: \(O(N^2)\)

Compact representation: N x N x NT long tensor (Same)

[16]:
batch, N, NT = 3, 20, 3
def show_tree(tree):
    t = tree.detach()
    plt.imshow(t[ :, : , 0] +
               2 * t[ :,:, 1] +
               3 * t[ :,:, 2])

log_potentials = torch.rand(batch, N, N, NT)
dist = torch_struct.TreeCRF(log_potentials)
show_tree(dist.argmax[0])
_images/model_31_0.png
[17]:
show_tree(dist.marginals[0])
_images/model_32_0.png

Probabilistic Context-Free Grammar

class torch_struct.SentCFG(log_potentials, lengths=None, validate_args=False)[source]

Represents a full generative context-free grammar with non-terminals NT and terminals T.

Event shape is of the form:

Parameters
  • log_potentials (tuple) – event tuple with event shapes terms (N x T) rules (NT x (NT+T) x (NT+T)) root (NT)

  • lengths (long tensor) – batch shape integers for length masking.

Implementation uses width-batched, forward-pass only

  • Parallel Time: \(O(N)\) parallel merges.

  • Forward Memory: \(O(N^2 (NT+T))\)

Compact representation: (N x N x NT) long tensor

[18]:
batch, N, NT, T = 3, 20, 3, 3
def show_prob_tree(tree):
    t = tree.detach().sum(-1).sum(-1)
    plt.imshow(t[ :, : , 0] +
               2 * t[ :,:, 1] +
               3 * t[ :,:, 2])

terminals = torch.rand(batch, N, T)
rules = torch.rand(batch, NT, NT+T,  NT+T)
init = torch.rand(batch, NT).log_softmax(-1)

dist = torch_struct.SentCFG((terminals, rules, init))
term, rules, init = dist.argmax
[19]:
# Rules
show_prob_tree(rules[0])
_images/model_36_0.png
[20]:
# Terminals
plt.imshow(term[:1])
[20]:
<matplotlib.image.AxesImage at 0x7f1877a1fb70>
_images/model_37_1.png

Base Class

class torch_struct.StructDistribution(log_potentials, lengths=None, args={}, validate_args=False)[source]

Base structured distribution class.

Dynamic distribution for length N of structures \(p(z)\).

Implemented based on gradient identities from:

  • Inside-outside and forward-backward algorithms are just backprop [Eis16]

  • Semiring Parsing [Goo99]

  • First-and second-order expectation semirings with applications to minimum-risk training on translation forests [LE09]

Parameters
  • log_potentials (tensor, batch_shape x event_shape) – log-potentials \(\phi\)

  • lengths (long tensor, batch_shape) – integers for length masking

property argmax

Compute an argmax for distribution \(\arg\max p(z)\).

Returns

argmax (batch_shape x event_shape)

property count

Compute the log-partition function.

cross_entropy(other)[source]

Compute cross-entropy for distribution p(self) and q(other) \(H[p, q]\).

Parameters

other – Comparison distribution

Returns

cross entropy (batch_shape)

property entropy

Compute entropy for distribution \(H[z]\).

Returns

entropy (batch_shape)

from_event(event)[source]

Convert event to simple representation.

kl(other)[source]

Compute KL-divergence for distribution p(self) and q(other) \(KL[p || q] = H[p, q] - H[p]\).

Parameters

other – Comparison distribution

Returns

cross entropy (batch_shape)

kmax(k)[source]

Compute the k-max for distribution \(k\max p(z)\).

Parameters :

k : Number of solutions to return

Returns

kmax (k x batch_shape)

log_prob(value)[source]

Compute log probability over values \(p(z)\).

Parameters

value (tensor) – One-hot events (sample_shape x batch_shape x event_shape)

Returns

log_probs (sample_shape x batch_shape)

property marginals

Compute marginals for distribution \(p(z_t)\).

Can be used in higher-order calculations, i.e.

Returns

marginals (batch_shape x event_shape)

property max

Compute an max for distribution \(\max p(z)\).

Returns

max (batch_shape)

property partition

Compute the log-partition function.

sample(sample_shape=torch.Size([]))[source]

Compute structured samples from the distribution \(z \sim p(z)\).

Parameters

sample_shape (int) – number of samples

Returns

samples (sample_shape x batch_shape x event_shape)

to_event(sequence, extra, lengths=None)[source]

Convert simple representation to event.

topk(k)[source]

Compute the k-argmax for distribution \(k\max p(z)\).

Parameters :

k : Number of solutions to return

Returns

kmax (k x batch_shape x event_shape)

[28]:
batch, N, C = 3, 7, 2

# batch, N, z_n, z_n_1
log_potentials = torch.rand(batch, N, C, C)
dist = torch_struct.LinearChainCRF(log_potentials, lengths=torch.tensor([N-1, N, N+1]))
show_chain(dist.argmax[0])
plt.show()
show_chain(dist.argmax[1])
_images/model_49_0.png
_images/model_49_1.png
[29]:
show_chain(dist.marginals[0])
plt.show()
show_chain(dist.marginals[1])
_images/model_50_0.png
_images/model_50_1.png
[30]:
def show_samples(samples):
    show_chain(samples[0, 0])
    plt.show()
    show_chain(samples[1, 0])
    plt.show()
    show_chain(samples[0, 1])
[31]:
show_samples(dist.sample((10,)))
_images/model_52_0.png
_images/model_52_1.png
_images/model_52_2.png
[32]:
show_samples(dist.topk(10))
_images/model_53_0.png
_images/model_53_1.png
_images/model_53_2.png
[33]:
# Enumerate
x,_ = dist.enumerate_support()
print(x.shape)
for i in range(10):
    show_chain(x[i][0])
    plt.show()
torch.Size([256, 3, 7, 2, 2])
_images/model_54_1.png
_images/model_54_2.png
_images/model_54_3.png
_images/model_54_4.png
_images/model_54_5.png
_images/model_54_6.png
_images/model_54_7.png
_images/model_54_8.png
_images/model_54_9.png
_images/model_54_10.png
[34]:
plt.imshow(dist.entropy.detach().unsqueeze(0))
[34]:
<matplotlib.image.AxesImage at 0x7f1877c52c18>
_images/model_55_1.png