Advanced: Semirings

[1]:
import torch
import torch_struct
import matplotlib.pyplot as plt
[2]:
# Create two random tensors to experiment with.
v1 = torch.rand(5, 10)
v2 = torch.rand(5, 10)
plt.imshow(v1)
plt.show()
plt.imshow(v2)
_images/semiring_2_0.png
[2]:
<matplotlib.image.AxesImage at 0x7f5b6df54080>
_images/semiring_2_2.png
[37]:
def show(x, title):
    plt.title(title)
    plt.imshow(x.detach())
    plt.show()
def run(sr):
    # Convert semiring form.
    s1, s2 = sr.convert(v1), sr.convert(v2)
    s1.requires_grad_(True)
    s2.requires_grad_(True)

    # Times and sum out last dim.
    s = sr.sum(sr.times(s1, s2))

    # Compute grad.
    s.sum().backward()

    show(s, "Sum")
    # Show the grads
    show(sr.unconvert(s1.grad), "v1 grad")
    show(sr.unconvert(s2.grad), "v2 grad")
    m = torch_struct.Alignment(sr).marginals(torch.rand(1, 10, 10, 3), _raw=True)
    show(m.sum(-1).sum(0).sum(0).transpose(0,1), "Alignment example")

Log

class torch_struct.LogSemiring[source]

Implements the log-space semiring (logsumexp, +, -inf, 0).

Gradients give marginals.

[4]:
run(torch_struct.LogSemiring)
_images/semiring_6_0.png
_images/semiring_6_1.png
_images/semiring_6_2.png
_images/semiring_6_3.png

Max

class torch_struct.MaxSemiring[source]

Implements the max semiring (max, +, -inf, 0).

Gradients give argmax.

[5]:
run(torch_struct.MaxSemiring)
_images/semiring_9_0.png
_images/semiring_9_1.png
_images/semiring_9_2.png
_images/semiring_9_3.png

K-Max

[6]:
run(torch_struct.KMaxSemiring(3))
_images/semiring_11_0.png
_images/semiring_11_1.png
_images/semiring_11_2.png
_images/semiring_11_3.png

Counting

class torch_struct.StdSemiring[source]

Implements the counting semiring (+, *, 0, 1).

[7]:
run(torch_struct.StdSemiring)
_images/semiring_14_0.png
_images/semiring_14_1.png
_images/semiring_14_2.png
_images/semiring_14_3.png

Sampled

class torch_struct.SampledSemiring[source]

Implements a sampling semiring (logsumexp, +, -inf, 0).

“Gradients” give sample.

This is an exact forward-filtering, backward-sampling approach.

[8]:
run(torch_struct.SampledSemiring)
_images/semiring_17_0.png
_images/semiring_17_1.png
_images/semiring_17_2.png
_images/semiring_17_3.png
class torch_struct.MultiSampledSemiring[source]

Implements a multi-sampling semiring (logsumexp, +, -inf, 0).

“Gradients” give up to 16 samples with replacement.

[9]:
run(torch_struct.MultiSampledSemiring)
_images/semiring_19_0.png
_images/semiring_19_1.png
_images/semiring_19_2.png
_images/semiring_19_3.png

Entropy

class torch_struct.EntropySemiring[source]

Implements an entropy expectation semiring.

Computes both the log-values and the running distributional entropy.

Based on descriptions in:

  • Parameter estimation for probabilistic finite-state transducers [Eis02]

  • First-and second-order expectation semirings with applications to minimum-risk training on translation forests [LE09]

  • Sample Selection for Statistical Grammar Induction []

[10]:
run(torch_struct.EntropySemiring)
_images/semiring_22_0.png
_images/semiring_22_1.png
_images/semiring_22_2.png
_images/semiring_22_3.png

Sparsemax

class torch_struct.SparseMaxSemiring[source]

Implements differentiable dynamic programming with a sparsemax semiring (sparsemax, +, -inf, 0).

Sparse-max gradients give a more sparse set of marginal like terms.

  • From softmax to sparsemax- A sparse model of attention and multi-label classification [MA16]

  • Differentiable dynamic programming for structured prediction and attention [MB18]

[13]:
run(torch_struct.SparseMaxSemiring)
_images/semiring_25_0.png
_images/semiring_25_1.png
_images/semiring_25_2.png
_images/semiring_25_3.png
[122]:
all_sr = [("marginals", torch_struct.LogSemiring),
          ("argmax", torch_struct.MaxSemiring),
          ("sample", torch_struct.SampledSemiring),
          ("multi-sample", torch_struct.MultiSampledSemiring),
          ("multi-sample", torch_struct.MultiSampledSemiring),
          ("kmax", torch_struct.KMaxSemiring(4)),
          ("sparsemax", torch_struct.SparseMaxSemiring),
         ]

fig=plt.figure(figsize=(10, 9))
columns = 4
rows = 3

for i in range(2,5):
    fig.add_subplot(rows, columns, i)
    x = torch.arange(15).float()
    y1 = torch.sin(x)
    plt.plot(x, y1, "-x")
    plt.xlim(0,15)
    plt.ylim(-1, 4)

    plt.axis("off")

for i in [5, 9]:
    fig.add_subplot(rows, columns, i)
    x = torch.arange(10).float()
    y2 = torch.sin(1.4 * x)
    plt.plot(y2, x, "-x")
    plt.ylim(-4, 14)
    plt.xlim(-3, 1)
    plt.axis("off")



v = torch.zeros(1, 15, 10, 3)
v[0, :, :, 1] = (y1.unsqueeze(1) - y2.unsqueeze(0)).abs()
v[0, :, :, 0] = 0
v[0, :, :, 2] = 0

for i, (title, sr) in enumerate(all_sr, 2+columns):
    if i == 9:
        continue
    m = torch_struct.Alignment(sr).marginals(v, _raw=True)
    fig.add_subplot(rows, columns, i)
    plt.title(title)
    plt.xticks([])
    plt.yticks([])

    plt.axis("off")
    plt.ylim(-0.5,10.5)
    plt.xlim(-0.5,14.5)

    plt.imshow(m.sum(-1).sum(0).sum(0).transpose(0,1).detach())
plt.savefig("show.png",)
_images/semiring_26_0.png
[ ]: