anthropics/torchtyping
forked from patrick-kidger/torchtyping
Captured source
source ↗anthropics/torchtyping
Description: Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
Language: Python
License: Apache-2.0
Stars: 11
Forks: 6
Open issues: 0
Created: 2025-10-03T15:25:45Z
Pushed: 2025-10-06T17:17:03Z
Default branch: anthropic-0.1.5
Fork: yes
Parent repository: patrick-kidger/torchtyping
Archived: no
README:
Please use jaxtyping instead
*Welcome! For new projects I now strongly recommend using my newer jaxtyping project instead. It supports PyTorch, doesn't actually depend on JAX, and unlike TorchTyping it is compatible with static type checkers. The 'jax' in the name is now historical!*
The original torchtyping README is as follows.
---
torchtyping
Type annotations for a tensor's shape, dtype, names, ...
Turn this:
def batch_outer_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # x has shape (batch, x_channels) # y has shape (batch, y_channels) # return has shape (batch, x_channels, y_channels) return x.unsqueeze(-1) * y.unsqueeze(-2)
into this:
def batch_outer_product(x: TensorType["batch", "x_channels"], y: TensorType["batch", "y_channels"] ) -> TensorType["batch", "x_channels", "y_channels"]: return x.unsqueeze(-1) * y.unsqueeze(-2)
with programmatic checking that the shape (dtype, ...) specification is met.
Bye-bye bugs! Say hello to enforced, clear documentation of your code.
If (like me) you find yourself littering your code with comments like # x has shape (batch, hidden_state) or statements like assert x.shape == y.shape , just to keep track of what shape everything is, then this is for you.
---
Installation
pip install torchtyping
Requires Python >=3.7 and PyTorch >=1.7.0.
If using `typeguard` then it must be a version TensorType["batch"]: return x + y
func(rand(3), rand(3)) # works func(rand(3), rand(1))
TypeError: Dimension 'batch' of inconsistent size. Got both 1 and 3.
`typeguard` also has an import hook that can be used to automatically test an entire module, without needing to manually add `@typeguard.typechecked` decorators. If you're not using `typeguard` then `torchtyping.patch_typeguard()` can be omitted altogether, and `torchtyping` just used for documentation purposes. If you're not already using `typeguard` for your regular Python programming, then strongly consider using it. It's a great way to squash bugs. Both `typeguard` and `torchtyping` also integrate with `pytest`, so if you're concerned about any performance penalty then they can be enabled during tests only. ## API
torchtyping.TensorType[shape, dtype, layout, details]
The core of the library. Each of `shape`, `dtype`, `layout`, `details` are optional. - The `shape` argument can be any of: - An `int`: the dimension must be of exactly this size. If it is `-1` then any size is allowed. - A `str`: the size of the dimension passed at runtime will be bound to this name, and all tensors checked that the sizes are consistent. - A `...`: An arbitrary number of dimensions of any sizes. - A `str: int` pair (technically it's a slice), combining both `str` and `int` behaviour. (Just a `str` on its own is equivalent to `str: -1`.) - A `str: str` pair, in which case the size of the dimension passed at runtime will be bound to _both_ names, and all dimensions with either name must have the same size. (Some people like to use this as a way to associate multiple names with a dimension, for extra documentation purposes.) - A `str: ...` pair, in which case the multiple dimensions corresponding to `...` will be bound to the name specified by `str`, and again checked for consistency between arguments. - `None`, which when used in conjunction with `is_named` below, indicates a dimension that must _not_ have a name in the sense of [named tensors](https://pytorch.org/docs/stable/named_tensor.html). - A `None: int` pair, combining both `None` and `int` behaviour. (Just a `None` on its own is equivalent to `None: -1`.) - A `None: str` pair, combining both `None` and `str` behaviour. (That is, it must not have a named dimension, but must be of a size consistent with other uses of the string.) - A `typing.Any`: Any size is allowed for this dimension (equivalent to `-1`). - Any tuple of the above. For example.`TensorType["batch": ..., "length": 10, "channels", -1]`. If you just want to specify the number of dimensions then use for example `TensorType[-1, -1, -1]` for a three-dimensional tensor. - The `dtype` argument can be any of: - `torch.float32`, `torch.float64` etc. - `int`, `bool`, `float`, which are converted to their corresponding PyTorch types. `float` is specifically interpreted as `torch.get_default_dtype()`, which is usually `float32`. - The `layout` argument can be either `torch.strided` or `torch.sparse_coo`, for dense and sparse tensors respectively. - The `details` argument offers a way to pass an arbitrary number of additional flags that customise and extend `torchtyping`. Two flags are built-in by default. `torchtyping.is_named` causes the [names of tensor dimensions](https://pytorch.org/docs/stable/named_tensor.html) to be checked, and `torchtyping.is_float` can be used to check that arbitrary floating point types are passed in. (Rather than just a specific one as with e.g. `TensorType[torch.float32]`.) For discussion on how to customise `torchtyping` with your own `details`, see the [further documentation](https://github.com/patrick-kidger/torchtyping/blob/master/FURTHER-DOCUMENTATION.md#custom-extensions). - Check multiple things at once by just putting them all together inside a single `[]`. For example `TensorType["batch": ..., "length", "channels", float, is_named]`.
torchtyping.patch_typeguard()
`torchtyping` integrates with `typeguard` to perform runtime type checking. `torchtyping.patch_typeguard()` should be called at the global level, and will patch `typeguard` to check `TensorType`s. This function is safe to run multiple times. (It does nothing after the first run). - If using `@typeguard.typechecked`, then `torchtyping.patch_typeguard()` should be called any time before using `@typeguard.typechecked`. For example you could call it at the start of each file using `torchtyping`. - If using `typeguard.importhook.install_import_hook`, then `torchtyping.patch_typeguard()` should be called any time before defining the functions…
Excerpt shown — open the source for the full document.
Notability
notability 3.0/10Low-traction fork, routine event