|
| 1 | +Performance Tips |
| 2 | +================ |
| 3 | + |
| 4 | +:py:mod:`torchkbnufft` is primarily written for the goal of scaling parallelism within |
| 5 | +the PyTorch framework. The performance bottleneck of the package comes from two sources: |
| 6 | +1) advanced indexing and 2) multiplications. Multiplications are handled in a way that |
| 7 | +scales well, but advanced indexing is not due to |
| 8 | +`limitations with PyTorch <https://github.com/pytorch/pytorch/issues/29973>`_. |
| 9 | +As a result, growth in problem size that is independent of the indexing bottleneck is |
| 10 | +handled very well by the package, such as: |
| 11 | + |
| 12 | +1. Scaling the batch dimension. |
| 13 | +2. Scaling the coil dimension. |
| 14 | + |
| 15 | +Generally, you can just add to these dimensions and the package will perform well |
| 16 | +without adding much compute time. If you're chasing more speed, some strategies that |
| 17 | +might be helpful are listed below. |
| 18 | + |
| 19 | +Using Batched K-space Trajectories |
| 20 | +---------------------------------- |
| 21 | + |
| 22 | +As of version ``1.1.0``, :py:mod:`torchkbnufft` can use batched k-space trajectories. |
| 23 | +If you pass in a variable for ``omega`` with dimensions |
| 24 | +``(N, length(im_size), klength)``, the package will parallelize the execution of all |
| 25 | +trajectories in the ``N`` dimension. This is useful when ``N`` is very large, as might |
| 26 | +occur in dynamic imaging settings. The following shows an example: |
| 27 | + |
| 28 | +.. code-block:: python |
| 29 | +
|
| 30 | + import torch |
| 31 | + import torchkbnufft as tkbn |
| 32 | + import numpy as np |
| 33 | + from skimage.data import shepp_logan_phantom |
| 34 | +
|
| 35 | + batch_size = 12 |
| 36 | +
|
| 37 | + x = shepp_logan_phantom().astype(np.complex) |
| 38 | + im_size = x.shape |
| 39 | + # convert to tensor, unsqueeze batch and coil dimension |
| 40 | + # output size: (batch_size, 1, ny, nx) |
| 41 | + x = torch.tensor(x).unsqueeze(0).unsqueeze(0).to(torch.complex64) |
| 42 | + x = x.repeat(batch_size, 1, 1, 1) |
| 43 | +
|
| 44 | + klength = 64 |
| 45 | + ktraj = np.stack( |
| 46 | + (np.zeros(64), np.linspace(-np.pi, np.pi, klength)) |
| 47 | + ) |
| 48 | + # convert to tensor, unsqueeze batch dimension |
| 49 | + # output size: (batch_size, 2, klength) |
| 50 | + ktraj = torch.tensor(ktraj).to(torch.float) |
| 51 | + ktraj = ktraj.unsqueeze(0).repeat(batch_size, 1, 1) |
| 52 | +
|
| 53 | + nufft_ob = tkbn.KbNufft(im_size=im_size) |
| 54 | + # outputs a (batch_size, 1, klength) vector of k-space data |
| 55 | + kdata = nufft_ob(x, ktraj) |
| 56 | +
|
| 57 | +This code will then compute the 12 different radial spokes while parallelizing as much |
| 58 | +as possible. |
| 59 | + |
| 60 | +Lowering the Precision |
| 61 | +---------------------- |
| 62 | + |
| 63 | +A simple way to save both memory and compute time is to decrease the precision. PyTorch |
| 64 | +normally operates at a default 32-bit floating point precision, but if you're converting |
| 65 | +data from NumPy then you might have some data at 64-bit floating precision. To use |
| 66 | +32-bit precision, simply do the following: |
| 67 | + |
| 68 | +.. code-block:: python |
| 69 | +
|
| 70 | + image = image.to(dtype=torch.complex64) |
| 71 | + ktraj = ktraj.to(dtype=torch.float32) |
| 72 | + forw_ob = forw_ob.to(image) |
| 73 | +
|
| 74 | + data = forw_ob(image, ktraj) |
| 75 | +
|
| 76 | +The ``forw_ob.to(image)`` command will automagically determine the type for both real |
| 77 | +and complex tensors registered as buffers under ``forw_ob``, so you should be able to |
| 78 | +do this safely in your code. |
| 79 | + |
| 80 | +In many cases, the tradeoff for going from 64-bit to 32-bit is not severe, so you can |
| 81 | +securely use 32-bit precision. |
| 82 | + |
| 83 | +Lowering the Oversampling Ratio |
| 84 | +------------------------------- |
| 85 | + |
| 86 | +If you create a :py:class:`~torchkbnufft.KbNufft` object using the following code: |
| 87 | + |
| 88 | +.. code-block:: python |
| 89 | +
|
| 90 | + forw_ob = tkbn.KbNufft(im_size=im_size) |
| 91 | +
|
| 92 | +then by default it will use a 2-factor oversampled grid. For some applications, this can |
| 93 | +be overkill. If you can sacrifice some accuracy for your application, you can use a |
| 94 | +smaller grid with 1.25-factor oversampling by altering how you initialize NUFFT objects |
| 95 | +like :py:class:`~torchkbnufft.KbNufft`: |
| 96 | + |
| 97 | +.. code-block:: python |
| 98 | +
|
| 99 | + grid_size = tuple([int(el * 1.25) for el in im_size]) |
| 100 | + forw_ob = tkbn.KbNufft(im_size=im_size, grid_size=grid_size) |
| 101 | +
|
| 102 | +Using Fewer Interpolation Neighbors |
| 103 | +----------------------------------- |
| 104 | + |
| 105 | +Another major speed factor is how many neighbors you use for interpolation. By default, |
| 106 | +:py:mod:`torchkbnufft` uses 6 nearest neighbors in each dimension. If you can sacrifice |
| 107 | +accuracy, you can get more speed by using fewer neighbors by altering how you initialize |
| 108 | +NUFFT objects like :py:class:`~torchkbnufft.KbNufft`: |
| 109 | + |
| 110 | +.. code-block:: python |
| 111 | +
|
| 112 | + forw_ob = tkbn.KbNufft(im_size=im_size, numpoints=4) |
| 113 | +
|
| 114 | +If you know that you can be less accurate in one dimension (e.g., the z-dimension), then |
| 115 | +you can use less neighbors in only that dimension: |
| 116 | + |
| 117 | +.. code-block:: python |
| 118 | +
|
| 119 | + forw_ob = tkbn.KbNufft(im_size=im_size, numpoints=(4, 6, 6)) |
| 120 | +
|
| 121 | +Package Limitations |
| 122 | +------------------- |
| 123 | + |
| 124 | +As mentioned earlier, batches and coils scale well, primarily due to the fact that they |
| 125 | +don't impact the bottlenecks of the package around advanced indexing. Where |
| 126 | +:py:mod:`torchkbnufft` does not scale well is: |
| 127 | + |
| 128 | +1. Very long k-space trajectories. |
| 129 | +2. More imaging dimensions (e.g., 3D). |
| 130 | + |
| 131 | +For these settings, you can first try to use some of the strategies here (lowering |
| 132 | +precision, fewer neighbors, smaller grid). In some cases, lowering the precision a bit |
| 133 | +and using a GPU can still give strong performance. If you're still waiting too long for |
| 134 | +compute after trying all of these, you may be running into the limits of the package. |
0 commit comments