Source code for torch_struct.cky

import torch
from .helpers import _Struct, Chart

A, B = 0, 1


[docs]class CKY(_Struct): def logpartition(self, scores, lengths=None, force_grad=False): semiring = self.semiring # Checks terms, rules, roots = scores rules.requires_grad_(True) ssize = semiring.size() batch, N, T = terms.shape _, NT, _, _ = rules.shape S = NT + T terms, rules, roots = ( semiring.convert(terms).requires_grad_(True), semiring.convert(rules).requires_grad_(True), semiring.convert(roots).requires_grad_(True), ) if lengths is None: lengths = torch.LongTensor([N] * batch).to(terms.device) # Charts beta = [Chart((batch, N, N, NT), rules, semiring) for _ in range(2)] span = [None for _ in range(N)] v = (ssize, batch) term_use = terms + 0.0 # Split into NT/T groups NTs = slice(0, NT) Ts = slice(NT, S) rules = rules.view(ssize, batch, 1, NT, S, S) def arr(a, b): return rules[..., a, b].contiguous().view(*v + (NT, -1)).transpose(-2, -1) matmul = semiring.matmul times = semiring.times X_Y_Z = arr(NTs, NTs) X_Y1_Z = arr(Ts, NTs) X_Y_Z1 = arr(NTs, Ts) X_Y1_Z1 = arr(Ts, Ts) for w in range(1, N): all_span = [] v2 = v + (N - w, -1) Y = beta[A][: N - w, :w, :] Z = beta[B][w:, N - w :, :] X1 = matmul(matmul(Y.transpose(-2, -1), Z).view(*v2), X_Y_Z) all_span.append(X1) Y_term = term_use[..., : N - w, :, None] Z_term = term_use[..., w:, None, :] Y = Y[..., -1, :].unsqueeze(-1) X2 = matmul(times(Y, Z_term).view(*v2), X_Y_Z1) Z = Z[..., 0, :].unsqueeze(-2) X3 = matmul(times(Y_term, Z).view(*v2), X_Y1_Z) all_span += [X2, X3] if w == 1: X4 = matmul(times(Y_term, Z_term).view(*v2), X_Y1_Z1) all_span.append(X4) span[w] = semiring.sum(torch.stack(all_span, dim=-1)) beta[A][: N - w, w, :] = span[w] beta[B][w:N, N - w - 1, :] = span[w] final = beta[A][0, :, NTs] top = torch.stack([final[:, i, l - 1] for i, l in enumerate(lengths)], dim=1) log_Z = semiring.dot(top, roots) return log_Z, (term_use, rules, roots, span[1:]) def marginals(self, scores, lengths=None, _autograd=True, _raw=False): """ Compute the marginals of a CFG using CKY. Parameters: scores : terms : b x n x T rules : b x NT x (NT+T) x (NT+T) root: b x NT lengths : lengths in batch Returns: v: b tensor of total sum spans: bxNxT terms, (bxNTx(NT+S)x(NT+S)) rules, bxNT roots """ terms, rules, roots = scores batch, N, T = terms.shape _, NT, _, _ = rules.shape v, (term_use, rule_use, root_use, spans) = self.logpartition( scores, lengths=lengths, force_grad=True ) def marginal(obj, inputs): obj = self.semiring.unconvert(obj).sum(dim=0) marg = torch.autograd.grad( obj, inputs, create_graph=True, only_inputs=True, allow_unused=False, ) spans_marg = torch.zeros( batch, N, N, NT, dtype=scores[1].dtype, device=scores[1].device ) span_ls = marg[3:] for w in range(len(span_ls)): x = span_ls[w].sum(dim=0, keepdim=True) spans_marg[:, w, : N - w - 1] = self.semiring.unconvert(x) rule_marg = self.semiring.unconvert(marg[0]).squeeze(1) root_marg = self.semiring.unconvert(marg[1]) term_marg = self.semiring.unconvert(marg[2]) assert term_marg.shape == (batch, N, T) assert root_marg.shape == (batch, NT) assert rule_marg.shape == (batch, NT, NT + T, NT + T) return (term_marg, rule_marg, root_marg, spans_marg) inputs = (rule_use, root_use, term_use) + tuple(spans) if _raw: paths = [] for k in range(v.shape[0]): obj = v[k : k + 1] marg = marginal(obj, inputs) paths.append(marg[-1]) paths = torch.stack(paths, 0) obj = v.sum(dim=0, keepdim=True) term_marg, rule_marg, root_marg, _ = marginal(obj, inputs) return term_marg, rule_marg, root_marg, paths else: return marginal(v, inputs) def score(self, potentials, parts): terms, rules, roots = potentials[:3] m_term, m_rule, m_root = parts[:3] b = m_term.shape[0] return ( m_term.mul(terms).view(b, -1).sum(-1) + m_rule.mul(rules).view(b, -1).sum(-1) + m_root.mul(roots).view(b, -1).sum(-1) ) @staticmethod def to_parts(spans, extra, lengths=None): NT, T = extra batch, N, N, S = spans.shape assert S == NT + T terms = torch.zeros(batch, N, T) rules = torch.zeros(batch, NT, S, S) roots = torch.zeros(batch, NT) for b in range(batch): roots[b, :] = spans[b, 0, lengths[b] - 1, :NT] terms[b, : lengths[b]] = spans[ b, torch.arange(lengths[b]), torch.arange(lengths[b]), NT: ] cover = spans[b].nonzero() left = {i: [] for i in range(N)} right = {i: [] for i in range(N)} for i in range(cover.shape[0]): i, j, A = cover[i].tolist() left[i].append((A, j, j - i + 1)) right[j].append((A, i, j - i + 1)) for i in range(cover.shape[0]): i, j, A = cover[i].tolist() B = None for B_p, k, a_span in left[i]: for C_p, k_2, b_span in right[j]: if k_2 == k + 1 and a_span + b_span == j - i + 1: B, C = B_p, C_p break if j > i: assert B is not None, "%s" % ((i, j, left[i], right[j], cover),) rules[b, A, B, C] += 1 return terms, rules, roots @staticmethod def from_parts(chart): terms, rules, roots = chart batch, N, N, NT, S, S = rules.shape assert terms.shape[1] == N spans = torch.zeros(batch, N, N, S, dtype=rules.dtype, device=rules.device) rules = rules.sum(dim=-1).sum(dim=-1) for n in range(N): spans[:, torch.arange(N - n - 1), torch.arange(n + 1, N), :NT] = rules[ :, n, torch.arange(N - n - 1) ] spans[:, torch.arange(N), torch.arange(N), NT:] = terms return spans, (NT, S - NT) @staticmethod def _intermediary(spans): batch, N = spans.shape[:2] splits = {} cover = spans.nonzero() left, right = {}, {} for k in range(cover.shape[0]): b, i, j, A = cover[k].tolist() left.setdefault((b, i), []) right.setdefault((b, j), []) left[b, i].append((A, j, j - i + 1)) right[b, j].append((A, i, j - i + 1)) for x in range(cover.shape[0]): b, i, j, A = cover[x].tolist() if i == j: continue b_final = None c_final = None k_final = None for B_p, k, a_span in left.get((b, i), []): if k > j: continue for C_p, k_2, b_span in right.get((b, j), []): if k_2 == k + 1 and a_span + b_span == j - i + 1: k_final = k b_final = B_p c_final = C_p break if b_final is not None: break assert k_final is not None, "%s %s %s %s" % (b, i, j, spans[b].nonzero()) splits[(b, i, j)] = k_final, b_final, c_final return splits @classmethod def to_networkx(cls, spans): cur = 0 N = spans.shape[1] n_nodes = int(spans.sum().item()) cover = spans.nonzero().cpu() order = torch.argsort(cover[:, 2] - cover[:, 1]) left = {} right = {} ordered = cover[order] label = ordered[:, 3] a = [] b = [] topo = [[] for _ in range(N)] for n in ordered: batch, i, j, _ = n.tolist() # G.add_node(cur, label=A) if i - j != 0: a.append(left[(batch, i)][0]) a.append(right[(batch, j)][0]) b.append(cur) b.append(cur) order = max(left[(batch, i)][1], right[(batch, j)][1]) + 1 else: order = 0 left[(batch, i)] = (cur, order) right[(batch, j)] = (cur, order) topo[order].append(cur) cur += 1 indices = left return (n_nodes, a, b, label), indices, topo