Tensor Considered Harmful
Alexander Rush - @harvardnlp
TL;DR: Despite its ubiquity in deep learning, Tensor is broken. It forces bad habits such as exposing private dimensions, broadcasting based on absolute position, and keeping type information in documentation. This post presents a proof-of-concept of an alternative approach, named tensors, with named dimensions. This change eliminates the need for indexing, dim arguments, einsum- style unpacking, and documentation-based coding. The prototype PyTorch library accompanying this blog post is available as namedtensor.
Changelog
- See Part 2 as well.
- Updated the syntax of the prototype to be a subest of xarray whereever possible.
- Dropped the einops style string DSL notation to be more explicit.
Implementations
- Jon Malmaud points out that the xarray project has very similar goals as this note with the addition of extensive Pandas and scientific computing support.
- Tongfei Chen’s Nexus project proposes statically type-safe tensors in Scala.
- Stephan Hoyer and Eric Christiansen have a labeled tensor library for Tensorflow that is the same as this appraoch. Labed Tensor
- Nishant Sinha has a TSA library that uses type annotations to define dimension names.
Tensor Traps
This post is about the tensor class, a multi-dimensional array object that is the central object of deep learning frameworks such as Torch, TensorFlow and Chainer, as well as numpy. Tensors carry around a blob of storage and expose a tuple of dimension information to users.
torch.Size([6, 96, 96, 3])
Here there are 4 dimensions, corresponding to batch_size, height, width, and channels. Most of the time you can figure this out by some comment in the code that looks like this:
This approch is concise and pseudo-mathy. However from a programming point of view it is not a great way to build complex software.
Trap 1: Privacy by Convention
Code that manipulates tensors does so by dimension identifiers in the tuple. If you want to rotate the image you read the comment, decide what dimensions need to be changed and alter them.
This code is simple and in theory well documented. However, it does not reflect the semantics of the target function. The property of rotation is independent of the batch, or for that matter, the channels. The function should not have to account for these dimensions in determining the dimensions to alter.
This leads to two problems. FIrst, it’s quite worrisome that if we pass in a singleton image this function runs fine but fails to work.
torch.Size([96, 3, 96])
However, even more worrisome is that the function may actually use the batch dimensions by mistake and mix together properties of different images. This can lead to nasty bugs that would be easy to avoid if this dimension was hidden from the code.
Trap 2: Broadcasting by Alignment
The most useful aspect of Tensors is that they can quickly do array operations without directly requiring for loops. For this to work dimensions need to be directly aligned so that they can be broadcasts. Again this is done by convention and code documentation that makes it “easy” to line up dimensions. For instance, let’s assume we want to apply a mask to the above image.
'Broadcasting fail torch.Size([96, 96]) torch.Size([6, 96, 96, 3])'
This fails because even though we knew that we were building a height and
width shaped mask, the rules of broadcasting do not have the correct
semantics. To make this work, you are encouraged to use either view
or
squeeze
my least favorite functions.
Note we do not need to do this for the left-most dimensions so there is a bit of
abstraction here. However reading through real code, dozens of right side
view
s and squeeze
s become completely unreadable.
Trap 3: Access by Comments
It is possible that you look at the top two issues and think that as long as you are careful, these issues will be caught by run time errors. However, even well used the combination of broadcasting and indexing can lead to problems that are very tough to catch.
Here we assume that the coder is trying to combine two tensor using both reduction operations and dimension indexing. (Honestly at this point I have forgotten what the dimensions stand for).
The main point though is that this code will run fine for whatever value dim is given. The comment here might descibe what is happening but the code itself doesn’t throw a run time error.
Named Tensor: A Prototype
Based on these issues, I think deep learning code should move to a better central object. There are several of these proposed. Here for fun, I will develop a new prototype. I have the following goals.
1) Dimensions should have human-readable names.
2) No function should have a dim argument.
3) Broadcast should be by name matching.
4) Transposition should be explicit.
5) Ban dimension based indexing.
6) Private dimensions should be protected.
To experiment with these ideas I have built a library known as NamedTensor
.
Currently it is PyTorch specific, but in theory a similar idea could be used in
other frameworks. The code is available at
github.com/harvardnlp/namedtensor.
Proposal 1: Assigning Names
The core of the library is an object that wraps a tensor and provides names for each dimension. Here we simply wrap a given torch tensor with dimension names.
OrderedDict([('batch', 6), ('height', 96), ('width', 96), ('channels', 3)])
Alternatively the library has wrappers for the pytorch constructors to turn them into named tensors.
Most simple operations simply keep around the named tensor properties.
Proposal 2: Accessors and Reduction
The first benefit of names comes from the ability to replace the need for dim and axis style arguments entirely. For example, lets say we wanted to sort each column.
Another common operation is a reduction where one or more dimensions is pooled out.
Proposal 3: Broadcasting and Contraction
The names that are provided also provide the basis for broadcasting operations. When there is a binary operations between two named tensors they first ensure that all dimension are matched in name and then apply standard broadcasting. To demonstrate let’s return to the masking example above. Here we simply declare the names of the dimensions of our mask, and ask the library to figure out the broadcasting.
Similar operations can be used for standard matrix operations such as addition and multiplication.
A more general feature is the dot
method for tensor contraction between name
tensors. Tensor contraction, the machinery behind einsum
, is an elegant way of
thinking about generalizations of dot-products, matrix-vector products, matrix-
matrix products, etc.
OrderedDict([('width', 96), ('channels', 3)])
OrderedDict([('height', 96), ('channels', 3)])
OrderedDict([('channels', 3)])
Similar notation can be used for sparse indexing (inspired by the einindex library). This is useful for embedding lookups and other sparse operations.
Proposal 4: Shifting Dimensions
Behind the scenes all of the named tensors are acting as tensor objects. As such
thing like order and stride of dimensions does matter. Operations like
transpose
and view
are crucial for maintaining this, but are unfortunately
quite error-prone.
Instead consider a domain specific langauge shift
that borrows heavily from
the Alex Rogozhnikov’s excellent
einops package.
Standard calls to transpose dimensions.
Calls for splitting and stacking together dimensions.
OrderedDict([('height', 8), ('q', 12), ('w', 96), ('c', 3)])
OrderedDict([('bh', 576), ('w', 96), ('c', 3)])
Ops can be chained.
Just for fun, here are some of the crazier examples from einops in this notation.
Proposal 5: Ban Indexing
Generally indexing is discouraged in this named tensor paradigm. Instead use
functions like index_select
above.
There are some useful named alternative functions pulled over from torch. For
example unbind
pulls apart a dimension to a tuple.
The function get
directly selects a slice of from a named dimension.
Finally narrow
can be used to replace fancy indexing. However you must give a
new dim name (since it can no longer broadcast).
Proposal 6: Private Dimensions
Finally named tensor attempts to let you directly hide dimensions that should
not be accessed by internal functions. The function mask_to
will keep around a
left side mask that protects any earlier dimensions from manipulations by
functions. The simplest example uses a mask to drop the batch
dimension.
'Error received: Dimension batch is masked'
This is weak dynamic check and can be turned off by internal functions. In future versions, perhaps we can add function annotations to lift non-named functions to respect these properties.
Example: Neural Attention
To demonstrate why these choices lead to better encapsulation properties, let’s consider a real-world deep learning example.
This example was proposed by my colleague Tim Rocktashel in the blog post describing einsum (https://rockt.github.io/2018/04/30/einsum). Tim’s code was proposed as a better alternative to raw PyTorch. While I agree that einsum is a step forward, it still falls into many of the traps described above.
Consider the problem of neural attention, which requires computing,
\[\begin{align*} \mathbf{M}_t &= \tanh(\mathbf{W}^y\mathbf{Y}+(\mathbf{W}^h\mathbf{h}_t+\mathbf{W }^r\mathbf{r}_{t-1})\otimes \mathbf{e}_L) & \mathbf{M}_t &\in\mathbb{R}^{k\times L}\\ \alpha_t &= \text{softmax}(\mathbf{w}^T\mathbf{M}_t)&\alpha_t&\in\mathbb{R}^L\\ \mathbf{r}_t &= \mathbf{Y}\alpha^T_t + \tanh(\mathbf{W}^t\mathbf{r}_{t-1})&\mathbf{r}_t&\in\mathbb{R}^k \end{align*}\]First we setup the parameters.
Now consider the tensor-based einsum implementation of this function.
This implementation is an improvement over the naive PyTorch implementation. It
removes many of the
views and transposes that would be necessary to make this work. However, it
still uses squeeze
, references the private batch dim, and usees comments that
are not enforced.
Consider instead the namedtensor
version:
This code avoids all three traps.
(Trap 1) The code never mentions the batch
dim.
(Trap 2) All broadcasting is done directly with contractions, there are no views.
(Trap 3) Operations across dims are explicit. For instance, the softmax is clearly over the seqlen.
Conclusion / Request for Help
Tools for deep learning help researchers implement standard models, but they also impact what researchers try. Current models can be built fine with the tools we have, but the programming practices are not going to scale to new models.
(For instance, one space we have been working on recently is discrete latent variable models which often have many problem specific variables each with their own variable dimension. This setting breaks the current tensor paradigm almost immediately. )
This blog post is just a prototype of where this approach could go. If you are interested, I would love contributors to the build out this library properly. Some ideas if you want to send a PR to namedtensor. Some ideas:
1) Extending beyond PyTorch: Can we generalize this approach in a way that supports NumPy and Tensorflow?
2) Interacting with PyTorch Modules: Can we “lift” PyTorch modules with type annotations, so that we know how they change inputs?
3) Error Checking: Can we add annotations to functions giving pre- and post -conditions so that dimensions are automatically checked.
Sorry if there are tacky ads down here :(. Disqus seems to do it automatically.