import torch.nn.functional as F
import torch
from .core import NamedTensorBase, assert_match
from .utils import make_tuple
[docs]class NamedTensor(NamedTensorBase):
def __getitem__(self, index):
if isinstance(index, dict):
cur = self
for k, v in index.items():
if isinstance(v, slice):
cur = cur.narrow(k, v.start, v.stop - v.start)
elif isinstance(v, NamedTensor):
cur = cur.index_select(k, v)
else:
cur = cur.get(k, v)
return cur
elif isinstance(index, NamedTensor):
if (
index.type() == "torch.ByteTensor"
or index.type() == "torch.cuda.ByteTensor"
):
return self.masked_select(index)
raise RuntimeError("Masked namedtensor must be byte tensor.")
else:
raise RuntimeError("Index must be dict or namedtensor.")
def __setitem__(self, index, val):
if isinstance(val, NamedTensor):
copy = True
else:
copy = False
if isinstance(index, dict):
cur = self
for k, v in index.items():
if isinstance(v, slice):
cur = cur.narrow(k, v.start, v.stop - v.start)
elif isinstance(v, NamedTensor):
assert len(index) == 1
if copy:
cur.index_copy_(k, v, val)
else:
cur.index_fill_(k, v, val)
return self
else:
cur = cur.get(k, v)
if copy:
cur.copy_(val)
else:
cur.fill_(val)
elif isinstance(index, NamedTensor):
if (
index.type() == "torch.ByteTensor"
or index.type() == "torch.cuda.ByteTensor"
):
if copy:
return self.masked_scatter_(index, val)
else:
return self.masked_fill_(index, val)
raise RuntimeError("Masked namedtensor must be byte tensor.")
else:
raise RuntimeError("Index must be dict or namedtensor.")
return self
def copy_(self, other):
return self._setter(other, "copy_")
def _setter(self, other, method, vals=[]):
order = other._mask_broadcast_order(self._schema._names)
other = other._force_order(order)
args = [other.values] + vals
getattr(self.values, method)(*args)
return self
[docs] def get(self, name, idx):
"Returns a namedtensor by indexing into dim name"
dim = self._schema.get(name)
return self._new(
self.values.narrow(dim, torch.tensor(idx), 1).squeeze(dim), name
)
[docs] def renorm(self, p, name, maxnorm):
"Apply :py:meth:`torch.Tensor.renorm` over `name`"
results = self._tensor.renorm(p, self.get(name), maxnorm)
return self._new(results)
[docs] def dot(self, names, *others):
"Contract dimension `names` with each of the other tensors"
from .torch_base import ntorch
return ntorch.dot(names, *((self,) + others))
# def access(self, dims):
# term = dims.split() + [d for d in self._schema._names if d not in dims]
# return self.transpose(*term)._tensor
# def debug(self):
# print(self.shape)
# return self
def augment(self, axis_op, add, dim=None, **kwargs):
return self.op(axis_op, dim=dim, _add=add, **kwargs)
def reduce(self, axis_op, reduced, dim=None, **kwargs):
return self.op(axis_op, dim=dim, _drop=reduced, **kwargs)
def reduce2(self, other, axis_op, reduced, dim=None, **kwargs):
return self.op2(other, axis_op, dim=dim, _drop=reduced, **kwargs)
[docs] def op(self, *axis_ops, dim=None, _drop=None, _add=None, **kwargs):
"Apply ops that may change dimensions sizes "
func_args = {}
if dim is not None:
func_args["dim"] = self._schema.get(dim)
_drop = make_tuple(_drop)
for v in _drop:
self._schema.get(v)
cur = self._tensor
for axis_op in axis_ops:
cur = axis_op(cur, **func_args)
for k, vs in kwargs.items():
for v in make_tuple(vs):
self._schema.get(v)
if _add is None and _drop is None:
assert len(cur.shape) == len(
self._tensor.shape
), "In shape %s, Out shape %s" % (cur.shape, self._tensor.shape)
out = self._new(
cur,
drop=_drop,
add=make_tuple(_add),
updates={
(v[0] if isinstance(v, tuple) else v): k
for k, v in kwargs.items()
},
)
# for k, v in self.shape.items():
# assert k not in out.shape or v == out.shape[k], (
# "name needs to change for updated dimensions"
# + str(axis_ops)
# + str(k)
# )
return out
def op2(self, y, axis_op, dim=None, _drop=None, **kwargs):
return self.op(lambda x: axis_op(x, y.values), _drop=_drop, **kwargs)
def __neg__(self):
return self.neg()
def __add__(self, b):
return self.add(b)
def __radd__(self, b):
return self.add(b)
def __sub__(self, b):
return self.sub(b)
def __rsub__(self, b):
return -self.sub(b)
def __mul__(self, b):
return self.mul(b)
def __rmul__(self, b):
return self.mul(b)
def __div__(self, b):
return self.div(b)
def __truediv__(self, b):
return self.div(b)
def __eq__(self, b):
return self.eq(b)
def __ne__(self, b):
return self.ne(b)
def __lt__(self, b):
return self.lt(b)
def __gt__(self, b):
return self.gt(b)
def __le__(self, b):
return self.le(b)
def __ge__(self, b):
return self.ge(b)
def __getattr__(self, methodname):
if methodname in dir(self._tensor):
method = getattr(self._tensor, methodname)
if methodname in self._noshift | self._noshift_args:
def call(*args, **kwargs):
return self._new(method(*args, **kwargs))
call.__doc__ = method.__doc__
elif methodname in self._noshift_nn:
method = getattr(F, methodname)
def call(*args, **kwargs):
return self._new(method(self.values, *args, **kwargs))
call.__doc__ = method.__doc__
elif methodname in self._noshift_dim:
def call(dim, *args, **kwargs):
return self._new(
method(self._schema.get(dim), *args, **kwargs)
)
call.__doc__ = method.__doc__
elif methodname in self._noshift_nn_dim:
method = getattr(F, methodname)
def call(dim, *args, **kwargs):
return self._new(
method(
self.values,
dim=self._schema.get(dim),
*args,
**kwargs
)
)
call.__doc__ = method.__doc__
elif methodname in self._inline:
def call(*args, **kwargs):
method(*args, **kwargs)
return self
call.__doc__ = method.__doc__
elif methodname in self._info:
call = method
elif methodname in self._reduce:
def call(dim=None, *args, **kwargs):
cur = self
method = getattr(cur._tensor, methodname)
if dim is None:
return NamedTensor(method(*args, **kwargs), ())
dim = make_tuple(dim)
method = getattr(self._tensor, methodname)
for d in dim:
cur = cur._new(
method(cur._schema.get(d), *args, **kwargs), d
)
method = getattr(cur._tensor, methodname)
return cur
call.__doc__ = self._reduce_doc + method.__doc__
elif methodname in self._reduce_multi:
def call(dim, *args, **kwargs):
method = getattr(self._tensor, methodname)
results = method(self._schema.get(dim), *args, **kwargs)
return tuple((self._new(r, dim) for r in results))
call.__doc__ = self._reduce_doc + method.__doc__
elif methodname in self._core:
from .torch_base import ntorch
method = getattr(ntorch, methodname)
def call(*args, **kwargs):
return method(self, *args, **kwargs)
call.__doc__ = method.__doc__
elif methodname in self._binop:
def call(other, *args):
if isinstance(other, NamedTensor):
b = other
order = self._broadcast_order(b._schema._names)
a1 = self._force_order(order)
b1 = b._force_order(order)
method = getattr(a1._tensor, methodname)
assert_match(a1, b1)
return a1._new(method(b1._tensor, *args))
else:
method = getattr(self._tensor, methodname)
return self._new(method(other, *args))
call.__doc__ = method.__doc__
else:
raise NotImplementedError(methodname)
return call
raise NotImplementedError(methodname)
def __dir__(self):
return (
set(self.__class__.__dict__.keys())
| self._noshift
| self._noshift_args
| self._noshift_nn
| self._info
| self._reduce
| self._reduce_multi
| self._binop
| self._inline
| self._core
)
# Torch Ops
# Return a tensor of the same dimensions
_noshift = {
"abs",
"acos",
"asin",
"atan",
"byte",
"ceil",
"clone",
"contiguous",
"cos",
"cosh",
"cpu",
"cuda",
"detach",
"double",
"exp",
"expm1",
"float",
"floor",
"frac",
"half",
"int",
"long",
"log",
"reciprocal",
"relu",
"round",
"rsqrt",
"short",
"sigmoid",
"sign",
"sin",
"sinh",
"sqrt",
"neg",
"to",
"tan",
"tanh",
"trunc",
}
_noshift_args = {"pow", "fmod", "clamp"}
_noshift_nn = {"relu"}
_noshift_nn_dim = {"softmax", "log_softmax"}
_noshift_dim = {"cumprod", "cumsum"}
# Return a non-tensor info object
_info = {
"dim",
"is_contigious",
"is_pinned",
"storage",
"storage_offset",
"storage_offset",
"tolist",
"stride",
"all",
"any",
"backward",
"numpy",
"item",
"type",
}
_reduce_doc = """
NamedTensor modifies this method to take a named `dim` as
the argument instead of a dimension index. Otherwise
doc is the same as below.
====================
"""
# Takes a dim arg and reduces it.
_reduce = {
"argmax",
"argmin",
"logsumexp",
"mean",
"prod",
"std",
"sum",
"squeeze",
}
_reduce_multi = {"min", "max", "sort", "unbind", "median"}
_extra = {"masked_fill", "type_as"}
# Broadcast and apply.
_binop = {"add", "sub", "div", "mul", "eq", "ne", "lt", "gt", "le", "ge"}
# Inline.
_inline = {
"fill_",
"random_",
"abs_",
"acos_",
"asin_",
"atan_",
"ceil_",
"clamp_",
"cos_",
"cosh_",
"exp_",
"floor_",
"fmod_",
"log_",
"pow_",
"round_",
"rsqrt_",
"sigmoid_",
"sign_",
"sin_",
"sinh_",
"sqrt_",
"sub_",
"tan_",
"tanh_",
}
_core = {
"gather",
"nonzero",
"scatter_",
"tril",
"triu",
"narrow",
"masked_select",
"masked_scatter",
"masked_fill_",
"index_select",
"index_copy_",
"index_fill_",
"topk",
"equal",
"unique",
"chunk",
}
# def gather(self, dim, index, index_dim):
# """
# Apply gather where `self_dim` is reduced out
# based on `index` from `index_dim`.
# """
# from .torch_base import ntorch
# return ntorch.gather(self, dim, index, index_dim)
# def scatter_(self, dim, index, src, index_dim):
# """
# Apply scatter where `dim` gets the
# scattered values of `src` based in `index` along `index_dim`.
# """
# from .torch_base import ntorch
# ntorch.scatter_(self, dim, index, src, index_dim)
# def narrow(self, name, start, end):
# "Narrow into the `kwargs` dimension and rename it"
# from .torch_base import ntorch
# return ntorch.narrow(self, name, start, end)
# def masked_select(self, mask, name="on"):
# "Applies `mask` and returns a 1D tensor with name `name`"
# from .torch_base import ntorch
# return ntorch.masked_select(self, mask, name)
# def masked_fill_(self, mask, val):
# from .torch_base import ntorch
# return ntorch.masked_fill_(self, mask, val)
# def masked_scatter_(self, mask, source):
# from .torch_base import ntorch
# return ntorch.masked_scatter_(self, mask, source)
# def relu(self):
# "Apply relu"
# return self._new(F.relu(self._tensor))
# def softmax(self, name):
# "Apply softmax over dim `name`"
# return self._new(F.softmax(self._tensor, dim=self._schema.get(name)))
# def log_softmax(self, name):
# "Apply log softmax over dim `name`"
# return self._new(
# F.log_softmax(self._tensor, dim=self._schema.get(name))
# )