|
1 | 1 | # opt_einsum_fx
|
2 | 2 |
|
3 |
| -6ptimizng einsums and functions involving them using [`opt_einsum`](https://optimized-einsum.readthedocs.io/en/stable/) and PyTorch [FX](https://pytorch.org/docs/stable/fx.html) compute graphs. |
4 |
| - |
5 |
| -This library currently supports: |
6 |
| - - Fusing multiple einsums into one |
7 |
| - - Optimizing einsums using the [`opt_einsum`](https://optimized-einsum.readthedocs.io/en/stable/) library |
8 |
| - - Fusing multiplication and division with scalar constants, including fusing _through_ operations, like einsum, that commute with scalar multiplication. |
9 |
| - - Placing multiplication by fused scalar constants onto the smallest intermediate in a chain of operations that commute with scalar multiplication. |
| 3 | +Optimizng einsums and functions involving them using [`opt_einsum`](https://optimized-einsum.readthedocs.io/en/stable/) and PyTorch [FX](https://pytorch.org/docs/stable/fx.html) compute graphs. |
10 | 4 |
|
11 | 5 | Issues, questions, PRs, and any thoughts about further optimizing these kinds of operations are welcome!
|
12 | 6 |
|
| 7 | +For more information please see [the docs](https://opt-einsum-fx.readthedocs.io/en/stable/). |
| 8 | + |
13 | 9 | ## Installation
|
14 | 10 |
|
15 | 11 | ### PyPI
|
@@ -37,11 +33,7 @@ You can run the tests with
|
37 | 33 | $ pytest tests/
|
38 | 34 | ```
|
39 | 35 |
|
40 |
| -## Usage |
41 |
| - |
42 |
| -`opt_einsum_fx` is based on [`torch.fx`](https://pytorch.org/docs/stable/fx.html), a framework for converting between PyTorch Python code and a programatically manipulable compute graph. To use this package, it must be possible to get your function or model as a `torch.fx.Graph`: the limitations of FX's symbolic tracing are discussed [here](https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing). |
43 |
| - |
44 |
| -### Minimal example |
| 36 | +## Minimal example |
45 | 37 |
|
46 | 38 | ```python
|
47 | 39 | import torch
|
@@ -79,12 +71,6 @@ def forward(self, a, b, vec):
|
79 | 71 | einsum_2 = torch.functional.einsum('cb,cab->ca', einsum_1, a); einsum_1 = a = None
|
80 | 72 | return einsum_2
|
81 | 73 | ```
|
82 |
| -The `optimize_einsums_full` function has four passes: |
83 |
| - |
84 |
| - 1. Scalar accumulation --- use the multilinearity of einsum to fuse all constant coefficients and divisors of operands and outputs |
85 |
| - 2. Fusing einsums --- gives greater flexibility to (3) |
86 |
| - 3. Optimized contraction with ``opt_einsum`` |
87 |
| - 4. Moving constant scalar coefficients through operations they commute with in order to place them on the smallest possible intermediate results |
88 | 74 |
|
89 | 75 | We can measure the performance improvement (this is on a CPU):
|
90 | 76 | ```python
|
@@ -113,21 +99,6 @@ f(a, b, vec)
|
113 | 99 | ```
|
114 | 100 | Depending on your function and dimensions you may see even larger improvements.
|
115 | 101 |
|
116 |
| -### JIT |
117 |
| - |
118 |
| -Currently, pure Python and TorchScript have different call signatures for `torch.tensordot` and `torch.permute`, both of which can appear in optimized einsums: |
119 |
| -```python |
120 |
| -graph_script = torch.jit.script(graph_opt) # => RuntimeError: Arguments for call are not valid... |
121 |
| -``` |
122 |
| -A function is provided to convert `torch.fx.GraphModule`s containing these operations from their Python signatures — the default — to a TorchScript compatible form: |
123 |
| -```python |
124 |
| -graph_script = torch.jit.script(opt_einsum_fx.jitable(graph_opt)) |
125 |
| -``` |
126 |
| - |
127 |
| -### More information |
128 |
| - |
129 |
| -More information can be found in docstrings in the source; the tests in [`tests/`](./tests) also serve as usage examples. |
130 |
| - |
131 | 102 | ## License
|
132 | 103 |
|
133 | 104 | `opt_einsum_fx` is distributed under an [MIT license](LICENSE).
|
0 commit comments