Advanced: Semirings¶
[1]:
import torch
import torch_struct
import matplotlib.pyplot as plt
[2]:
# Create two random tensors to experiment with.
v1 = torch.rand(5, 10)
v2 = torch.rand(5, 10)
plt.imshow(v1)
plt.show()
plt.imshow(v2)
[2]:
<matplotlib.image.AxesImage at 0x7f5b6df54080>
[37]:
def show(x, title):
plt.title(title)
plt.imshow(x.detach())
plt.show()
def run(sr):
# Convert semiring form.
s1, s2 = sr.convert(v1), sr.convert(v2)
s1.requires_grad_(True)
s2.requires_grad_(True)
# Times and sum out last dim.
s = sr.sum(sr.times(s1, s2))
# Compute grad.
s.sum().backward()
show(s, "Sum")
# Show the grads
show(sr.unconvert(s1.grad), "v1 grad")
show(sr.unconvert(s2.grad), "v2 grad")
m = torch_struct.Alignment(sr).marginals(torch.rand(1, 10, 10, 3), _raw=True)
show(m.sum(-1).sum(0).sum(0).transpose(0,1), "Alignment example")
Log¶
- class torch_struct.LogSemiring[source]¶
Implements the log-space semiring (logsumexp, +, -inf, 0).
Gradients give marginals.
[4]:
run(torch_struct.LogSemiring)
Max¶
- class torch_struct.MaxSemiring[source]¶
Implements the max semiring (max, +, -inf, 0).
Gradients give argmax.
[5]:
run(torch_struct.MaxSemiring)
Sampled¶
- class torch_struct.SampledSemiring[source]¶
Implements a sampling semiring (logsumexp, +, -inf, 0).
“Gradients” give sample.
This is an exact forward-filtering, backward-sampling approach.
[8]:
run(torch_struct.SampledSemiring)
- class torch_struct.MultiSampledSemiring[source]¶
Implements a multi-sampling semiring (logsumexp, +, -inf, 0).
“Gradients” give up to 16 samples with replacement.
[9]:
run(torch_struct.MultiSampledSemiring)
Entropy¶
- class torch_struct.EntropySemiring[source]¶
Implements an entropy expectation semiring.
Computes both the log-values and the running distributional entropy.
Based on descriptions in:
[10]:
run(torch_struct.EntropySemiring)
Sparsemax¶
- class torch_struct.SparseMaxSemiring[source]¶
Implements differentiable dynamic programming with a sparsemax semiring (sparsemax, +, -inf, 0).
Sparse-max gradients give a more sparse set of marginal like terms.
[13]:
run(torch_struct.SparseMaxSemiring)
[122]:
all_sr = [("marginals", torch_struct.LogSemiring),
("argmax", torch_struct.MaxSemiring),
("sample", torch_struct.SampledSemiring),
("multi-sample", torch_struct.MultiSampledSemiring),
("multi-sample", torch_struct.MultiSampledSemiring),
("kmax", torch_struct.KMaxSemiring(4)),
("sparsemax", torch_struct.SparseMaxSemiring),
]
fig=plt.figure(figsize=(10, 9))
columns = 4
rows = 3
for i in range(2,5):
fig.add_subplot(rows, columns, i)
x = torch.arange(15).float()
y1 = torch.sin(x)
plt.plot(x, y1, "-x")
plt.xlim(0,15)
plt.ylim(-1, 4)
plt.axis("off")
for i in [5, 9]:
fig.add_subplot(rows, columns, i)
x = torch.arange(10).float()
y2 = torch.sin(1.4 * x)
plt.plot(y2, x, "-x")
plt.ylim(-4, 14)
plt.xlim(-3, 1)
plt.axis("off")
v = torch.zeros(1, 15, 10, 3)
v[0, :, :, 1] = (y1.unsqueeze(1) - y2.unsqueeze(0)).abs()
v[0, :, :, 0] = 0
v[0, :, :, 2] = 0
for i, (title, sr) in enumerate(all_sr, 2+columns):
if i == 9:
continue
m = torch_struct.Alignment(sr).marginals(v, _raw=True)
fig.add_subplot(rows, columns, i)
plt.title(title)
plt.xticks([])
plt.yticks([])
plt.axis("off")
plt.ylim(-0.5,10.5)
plt.xlim(-0.5,14.5)
plt.imshow(m.sum(-1).sum(0).sum(0).transpose(0,1).detach())
plt.savefig("show.png",)
[ ]: