Skip to content

graphcore-research/pytorch-tensor-tracker

Folders and files

NameName
Last commit message
Last commit date

Latest commit

223f996 · Oct 4, 2023

History

18 Commits
Sep 28, 2023
Oct 4, 2023
Oct 4, 2023
Oct 2, 2023
Oct 2, 2023
Sep 27, 2023
Oct 4, 2023
Oct 2, 2023
Sep 29, 2023
Sep 27, 2023
Sep 27, 2023
Sep 27, 2023

Repository files navigation

Tensor tracker

API documentation | Example

Flexibly track outputs and grad-outputs of torch.nn.Module.

Installation:

pip install git+https://github.com/graphcore-research/pytorch-tensor-tracker

Usage:

Use tensor_tracker.track(module) as a context manager to start capturing tensors from within your module's forward and backward passes:

import tensor_tracker

with tensor_tracker.track(module) as tracker:
    module(inputs).backward()

print(tracker)  # => Tracker(stashes=8, tracking=0)

Now Tracker is filled with stashes, containing copies of fwd/bwd tensors at (sub)module outputs. (Note, this can consume a lot of memory.)

It behaves like a list of Stash objects, with their attached value, usually a tensor or tuple of tensors. We can also use to_frame() to get a Pandas table of summary statistics:

print(list(tracker))
# => [Stash(name="0.linear", type=nn.Linear, grad=False, value=tensor(...)),
#     ...]

display(tracker.to_frame())

tensor tracker to_frame output

See the documentation for more info, or for a more practical example, see our demo of visualising transformer activations & gradients using UMAP. To use on IPU with PopTorch, please see Usage (PopTorch).

License

Copyright (c) 2023 Graphcore Ltd. Licensed under the MIT License (LICENSE).

Our dependencies are (see requirements.txt):

Component About License
torch Machine learning framework BSD 3-Clause

We also use additional Python dependencies for development/testing (see requirements-dev.txt).

About

Flexibly track outputs and grad-outputs of torch.nn.Module.

Resources

License

Stars

Watchers

Forks

Languages