ForkAnthropicAnthropicpublished Oct 3, 2025seen 6d

anthropics/torchtyping

forked from patrick-kidger/torchtyping

Open original ↗

Captured source

source ↗
published Oct 3, 2025seen 6dcaptured 16hhttp 200method plain

anthropics/torchtyping

Description: Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.

Language: Python

License: Apache-2.0

Stars: 11

Forks: 6

Open issues: 0

Created: 2025-10-03T15:25:45Z

Pushed: 2025-10-06T17:17:03Z

Default branch: anthropic-0.1.5

Fork: yes

Parent repository: patrick-kidger/torchtyping

Archived: no

README:

Please use jaxtyping instead

*Welcome! For new projects I now strongly recommend using my newer jaxtyping project instead. It supports PyTorch, doesn't actually depend on JAX, and unlike TorchTyping it is compatible with static type checkers. The 'jax' in the name is now historical!*

The original torchtyping README is as follows.

---

torchtyping

Type annotations for a tensor's shape, dtype, names, ...

Turn this:

def batch_outer_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# x has shape (batch, x_channels)
# y has shape (batch, y_channels)
# return has shape (batch, x_channels, y_channels)

return x.unsqueeze(-1) * y.unsqueeze(-2)

into this:

def batch_outer_product(x: TensorType["batch", "x_channels"],
y: TensorType["batch", "y_channels"]
) -> TensorType["batch", "x_channels", "y_channels"]:

return x.unsqueeze(-1) * y.unsqueeze(-2)

with programmatic checking that the shape (dtype, ...) specification is met.

Bye-bye bugs! Say hello to enforced, clear documentation of your code.

If (like me) you find yourself littering your code with comments like # x has shape (batch, hidden_state) or statements like assert x.shape == y.shape , just to keep track of what shape everything is, then this is for you.

---

Installation

pip install torchtyping

Requires Python >=3.7 and PyTorch >=1.7.0.

If using `typeguard` then it must be a version TensorType["batch"]: return x + y

func(rand(3), rand(3)) # works func(rand(3), rand(1))

TypeError: Dimension 'batch' of inconsistent size. Got both 1 and 3.

`typeguard` also has an import hook that can be used to automatically test an entire module, without needing to manually add `@typeguard.typechecked` decorators.

If you're not using `typeguard` then `torchtyping.patch_typeguard()` can be omitted altogether, and `torchtyping` just used for documentation purposes. If you're not already using `typeguard` for your regular Python programming, then strongly consider using it. It's a great way to squash bugs. Both `typeguard` and `torchtyping` also integrate with `pytest`, so if you're concerned about any performance penalty then they can be enabled during tests only.

## API

torchtyping.TensorType[shape, dtype, layout, details]

The core of the library.

Each of `shape`, `dtype`, `layout`, `details` are optional.

- The `shape` argument can be any of:
- An `int`: the dimension must be of exactly this size. If it is `-1` then any size is allowed.
- A `str`: the size of the dimension passed at runtime will be bound to this name, and all tensors checked that the sizes are consistent.
- A `...`: An arbitrary number of dimensions of any sizes.
- A `str: int` pair (technically it's a slice), combining both `str` and `int` behaviour. (Just a `str` on its own is equivalent to `str: -1`.)
- A `str: str` pair, in which case the size of the dimension passed at runtime will be bound to _both_ names, and all dimensions with either name must have the same size. (Some people like to use this as a way to associate multiple names with a dimension, for extra documentation purposes.)
- A `str: ...` pair, in which case the multiple dimensions corresponding to `...` will be bound to the name specified by `str`, and again checked for consistency between arguments.
- `None`, which when used in conjunction with `is_named` below, indicates a dimension that must _not_ have a name in the sense of [named tensors](https://pytorch.org/docs/stable/named_tensor.html).
- A `None: int` pair, combining both `None` and `int` behaviour. (Just a `None` on its own is equivalent to `None: -1`.)
- A `None: str` pair, combining both `None` and `str` behaviour. (That is, it must not have a named dimension, but must be of a size consistent with other uses of the string.)
- A `typing.Any`: Any size is allowed for this dimension (equivalent to `-1`).
- Any tuple of the above. For example.`TensorType["batch": ..., "length": 10, "channels", -1]`. If you just want to specify the number of dimensions then use for example `TensorType[-1, -1, -1]` for a three-dimensional tensor.
- The `dtype` argument can be any of:
- `torch.float32`, `torch.float64` etc.
- `int`, `bool`, `float`, which are converted to their corresponding PyTorch types. `float` is specifically interpreted as `torch.get_default_dtype()`, which is usually `float32`.
- The `layout` argument can be either `torch.strided` or `torch.sparse_coo`, for dense and sparse tensors respectively.
- The `details` argument offers a way to pass an arbitrary number of additional flags that customise and extend `torchtyping`. Two flags are built-in by default. `torchtyping.is_named` causes the [names of tensor dimensions](https://pytorch.org/docs/stable/named_tensor.html) to be checked, and `torchtyping.is_float` can be used to check that arbitrary floating point types are passed in. (Rather than just a specific one as with e.g. `TensorType[torch.float32]`.) For discussion on how to customise `torchtyping` with your own `details`, see the [further documentation](https://github.com/patrick-kidger/torchtyping/blob/master/FURTHER-DOCUMENTATION.md#custom-extensions).
- Check multiple things at once by just putting them all together inside a single `[]`. For example `TensorType["batch": ..., "length", "channels", float, is_named]`.

torchtyping.patch_typeguard()

`torchtyping` integrates with `typeguard` to perform runtime type checking. `torchtyping.patch_typeguard()` should be called at the global level, and will patch `typeguard` to check `TensorType`s.

This function is safe to run multiple times. (It does nothing after the first run).

- If using `@typeguard.typechecked`, then `torchtyping.patch_typeguard()` should be called any time before using `@typeguard.typechecked`. For example you could call it at the start of each file using `torchtyping`.
- If using `typeguard.importhook.install_import_hook`, then `torchtyping.patch_typeguard()` should be called any time before defining the functions…

Excerpt shown — open the source for the full document.

Notability

notability 3.0/10

Low-traction fork, routine event