Skip to content

Commit c1fc8cc

Browse files
cjyabrahamkyliewd
andauthored
Adds blog post "Enabling Fast Gradient Clipping and Ghost Clipping in Opacus" (pytorch#1711)
* Adds blog post "Enabling Fast Gradient Clipping and Ghost Clipping in Opacus" Signed-off-by: Chris Abraham <[email protected]> * Rename 2024-08-19-clipping-in-opacus.md to 2024-08-20-clipping-in-opacus.md --------- Signed-off-by: Chris Abraham <[email protected]> Co-authored-by: Kylie Wagar-Dirks <[email protected]>
1 parent 567fa17 commit c1fc8cc

File tree

3 files changed

+362
-0
lines changed

3 files changed

+362
-0
lines changed
+362
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
---
2+
layout: blog_detail
3+
title: "Enabling Fast Gradient Clipping and Ghost Clipping in Opacus"
4+
author: Enayat Ullah, Huanyu Zhang, Will Bullock, Ilya Mironov
5+
---
6+
7+
## Introduction and Context
8+
9+
[Differentially Private Stochastic Gradient Descent (DP-SGD)](https://arxiv.org/abs/1607.00133) is the canonical method for training machine learning models with differential privacy. It involves the following two modifications to its non-private counterpart, Stochastic Gradient Descent.
10+
11+
1. **Per-sample gradient clipping**: Clip gradients with respect to every sample in the mini-batch, ensuring that its norm is at most a pre-specified value, “Clipping Norm”, C, in every iteration.
12+
13+
2. **Noise addition**: Add Gaussian noise of pre-specified variance, depending on the clipping norm and privacy parameters, to the average clipped gradient, in every iteration.
14+
15+
The first change, **per-sample gradient clipping**, introduces additional complexities since, in general, it requires instantiating **per-sample** **gradients**.
16+
17+
[Opacus](http://opacus.ai) is a PyTorch implementation of DP-SGD. Opacus addresses the above task by employing [hook functions](https://medium.com/pytorch/differential-privacy-series-part-2-efficient-per-sample-gradient-computation-in-opacus-5bf4031d9e22), which allows intervening on specific events, such as forward and backward passes. For more details about Opacus, we encourage readers to review the previous blog posts: [DP-SGD Algorithm Explained](https://bit.ly/dp-sgd-algorithm-explained), [Efficient Per-Sample Gradient Computation in Opacus](https://medium.com/pytorch/differential-privacy-series-part-2-efficient-per-sample-gradient-computation-in-opacus-5bf4031d9e22) and [Efficient Per-Sample Gradient Computation for More Layers in Opacus](https://pytorch.medium.com/differential-privacy-series-part-3-efficient-per-sample-gradient-computation-for-more-layers-in-39bd25df237).
18+
19+
While Opacus provides substantial efficiency gains compared to the naive approaches, the memory cost of instantiating per-sample gradients is significant. In particular, memory usage is proportional to the batch size times the number of trainable parameters. Consequently, memory limits Opacus to small batch sizes and/or small models, significantly restricting its range of applications.
20+
21+
We introduce [Fast Gradient Clipping](https://arxiv.org/abs/2009.03106) and [Ghost Clipping](https://arxiv.org/abs/2110.05679) to Opacus, which enable developers and researchers to perform gradient clipping without instantiating the per-sample gradients. As an example, this allows for fine-tuning 7M parameters of BERT, on a single 16GB GPU, with a batch size of 1024, with memory comparable to using PyTorch (without applying DP-SGD). In contrast, the previous version of Opacus, supported a maximum batch size of roughly 256 for the same setting. We provide a [tutorial](https://github.com/pytorch/opacus/blob/main/tutorials/building\_text\_classifier.ipynb) on how to use Fast Gradient Clipping in Opacus with the aforementioned task as an example.
22+
23+
## Fast Gradient Clipping and Ghost Clipping
24+
25+
The key idea behind these techniques is based on the following observation: suppose per-sample gradient norms are known, then gradient clipping can be achieved by backpropagation on a re-weighted loss function $ \bar{L} $. This loss function is defined as $ \bar{L} = \sum_{i} R_{i} L_{i} $, where $ R_i = \min\left(\frac{C}{C_i}, 1\right) $ are the clipping coefficients computed from the per-sample gradient norms $ {C_i} $ and $ {L_i} $ are per-sample losses.
26+
27+
The above idea may seem circular at first glance, as it appears to require instantiating per-sample gradients in order to calculate per-sample gradient norms. However, for certain widely-used components of neural network architectures, such as fully connected/linear layers, it is indeed possible to obtain per-sample gradient norms in a single backpropagation pass without the need for per-sample gradients. This suggests a workflow that involves two backpropagation passes: the first to compute per-sample gradient norms, and the second to compute the aggregated (not per-sample) clipped gradient. The second backpropagation is simply the standard batched backpropagation.
28+
29+
![backpropagation diagram](/assets/images/clipping-in-opacus/fg1.jpg){:style="max-width:800px; display:block; margin-left: auto; margin-right: auto; width:100%"}
30+
31+
![backpropagation diagram](/assets/images/clipping-in-opacus/fg2.png){:style="max-width:400px; display:block; margin-left: auto; margin-right: auto; width:100%"}
32+
33+
_Figure 1: Comparison between vanilla **Opacus** (top left), **Fast Gradient Clipping** (top right), and **Ghost clipping** (bottom). We marked in red gradient instantiations that become memory bottlenecks. For vanilla Opacus, it has to instantiate the **per-sample gradients**. **Fast Gradient Clipping** instantiates per-sample gradients for each layer to compute its norm, which is immediately released once the backward pass moves on to the next layer. Ghost Clipping works directly from **per-sample activation gradients** and **per-sample activations**, and avoids the need for gradient instantiation._
34+
35+
[**Fast Gradient Clipping**](https://arxiv.org/abs/2009.03106)
36+
In Fast Gradient Clipping, the per-sample gradient norm is calculated in three steps:
37+
38+
1. For each layer, the per-sample gradient is instantiated and its norm is calculated.
39+
2. The per-sample gradient is then immediately discarded.
40+
3. The (squared) per-sample gradient norms of each layer are summed up to obtain the overall (squared) per-sample gradient norm.
41+
42+
43+
[**Ghost Clipping**](https://arxiv.org/abs/2110.05679)
44+
Extending the approach of Fast Gradient Clipping, Ghost Clipping uses the [fact](https://arxiv.org/abs/1510.01799) that for **linear layers[^1],** per-sample gradient norms can be calculated just from **activation gradients** and **activations**. In particular, let `backprops` and `activations` be per-sample activation gradients and activations, of dimensions `batch_size ✕ output_width` and `batch_size ✕ input_width`, respectively. The per-sample gradient is the outer product of the two, which takes `O(batch_size ✕ input_width ✕ output_width)` time and space.
45+
46+
The [ghost clipping trick](https://arxiv.org/abs/1510.01799) instead calculates the (squared) norm of `backprops` and `activations`, sample-wise, and takes their product, which gives the (squared) norm of the gradient. This takes `O(batch-size ✕ (input_width + output_width))` time and takes `O(batch-size)` space to store. Since **per-sample activation** and **per-sample activation gradients** are already stored, additional memory is needed only for storing the norms.
47+
48+
**Relationship between Fast Gradient Clipping and Ghost Clipping**
49+
50+
1. Fast Gradient Clipping and Ghost Clipping are complementary techniques. Fast Gradient Clipping can be applied to any type of layer, while Ghost Clipping is a strictly better technique for supported layers.
51+
2. Our implementation automatically switches to Fast Gradient Clipping when the layer is not supported by Ghost Clipping.
52+
53+
### How to use Fast Gradient Clipping in Opacus
54+
55+
The training loop is identical to that of the standard PyTorch loop. As in Opacus before, we use the `PrivacyEngine()`, which “sanitizes” the model and optimizer. To enable Ghost Clipping, the argument `grad_sample_mode="ghost"` is used. Additionally, `make_private()` takes the loss criterion as an extra input and sanitizes it. This allows us to hide the two backward passes and the loss rescaling in between in `loss.backward()`.
56+
57+
```py
58+
from opacus import PrivacyEngine
59+
criterion = nn.CrossEntropyLoss() # example loss function
60+
61+
privacy_engine = PrivacyEngine()
62+
model_gc, optimizer_gc, criterion_gc, train_loader, = privacy_engine.make_private(
63+
module=model,
64+
optimizer=optimizer,
65+
data_loader=train_loader,
66+
noise_multiplier=noise_multiplier
67+
max_grad_norm=max_grad_norm,
68+
criterion=criterion,
69+
grad_sample_mode="ghost",
70+
)
71+
72+
# The training loop below is identical to that of PyTorch
73+
74+
for input_data, target_data in train_loader:
75+
output_gc = model_gc(input_data) # Forward pass
76+
optimizer_gc.zero_grad()
77+
loss = criterion_gc(output_gc, target_data)
78+
loss.backward()
79+
optimizer_gc.step() # Add noise and update the model
80+
```
81+
82+
Internally, before the first pass, we enable the *hooks*, which allows us to capture layer-wise values corresponding to forward and backward calls. They are used to compute the per-sample gradient norms. We then compute the clipping coefficients, rescale the loss function and disable hooks, which lets us use the standard PyTorch backward pass.
83+
84+
### Memory Complexity Analysis
85+
86+
Consider a multi-layer neural network with the following properties:
87+
88+
**L**: Number of layers
89+
**d**: Maximum layer width
90+
**B**: Batch size
91+
**K**: Number of non-supported/non-linear layers
92+
93+
The memory overhead of DP-SGD with Ghost Clipping compared to plain (PyTorch) SGD is an additive O(BL), required to store the per-sample gradient norms for all layers. Further, if there is a non-supported layer (if K≥1), then there is an additional O(Bd<sup>2</sup>) memory to instantiate the gradient of that layer.
94+
95+
### Memory Benchmarking
96+
97+
We provide results on the memory usage for a variety of settings.
98+
99+
#### Fine-Tuning BERT
100+
101+
We consider the problem of [privately fine-tuning](https://github.com/pytorch/opacus/blob/main/tutorials/building\_text\_classifier.ipynb) the last three layers of BERT for a text classification task. The base model has over 100M parameters, of which we fine-tune the last three layers, `BertEncoder,` `BertPooler,` and `Classifier`, comprising roughly 7.6M parameters. The experiments are run on a P100 GPU with 16 GB of memory.
102+
103+
The following table reports the maximum memory and time taken per iteration for the various methods:
104+
105+
106+
107+
<table class="table table-bordered">
108+
<tr>
109+
<td rowspan="3" >
110+
</td>
111+
<td colspan="9" style="text-align:center"><strong>Batch size</strong>
112+
</td>
113+
</tr>
114+
<tr>
115+
<td colspan="2" style="text-align:center"><strong>B = 32</strong>
116+
</td>
117+
<td colspan="2" style="text-align:center"><strong>B = 128</strong>
118+
</td>
119+
<td colspan="2" style="text-align:center" ><strong>B = 512</strong>
120+
</td>
121+
<td colspan="2" style="text-align:center"><strong>B = 1024</strong>
122+
</td>
123+
<td><strong>B = 2048</strong>
124+
</td>
125+
</tr>
126+
<tr>
127+
<td><strong>Mem</strong>
128+
</td>
129+
<td><strong>Time</strong>
130+
</td>
131+
<td><strong>Mem</strong>
132+
</td>
133+
<td><strong>Time</strong>
134+
</td>
135+
<td><strong>Mem</strong>
136+
</td>
137+
<td><strong>Time</strong>
138+
</td>
139+
<td><strong>Mem</strong>
140+
</td>
141+
<td><strong>Time</strong>
142+
</td>
143+
<td>
144+
</td>
145+
</tr>
146+
<tr>
147+
<td><strong>PyTorch SGD</strong>
148+
</td>
149+
<td>236 MB
150+
</td>
151+
<td>0.15 s
152+
</td>
153+
<td>1.04 GB
154+
</td>
155+
<td>0.55 s
156+
</td>
157+
<td>5.27 GB
158+
</td>
159+
<td>2.1 s
160+
</td>
161+
<td>12.7 GB
162+
</td>
163+
<td>4.2 s
164+
</td>
165+
<td>OOM
166+
</td>
167+
</tr>
168+
<tr>
169+
<td><strong>DP-SGD</strong>
170+
</td>
171+
<td>1,142 MB
172+
</td>
173+
<td>0.21 s
174+
</td>
175+
<td>4.55 GB
176+
</td>
177+
<td>0.68 s
178+
</td>
179+
<td colspan="2" style="text-align:center">OOM
180+
</td>
181+
<td colspan="2" style="text-align:center">OOM
182+
</td>
183+
<td>OOM
184+
</td>
185+
</tr>
186+
<tr>
187+
<td><strong>FGC DP-SGD</strong>
188+
</td>
189+
<td>908 MB
190+
</td>
191+
<td>0.21 s
192+
</td>
193+
<td>3.6 GB
194+
</td>
195+
<td>0.75 s
196+
</td>
197+
<td colspan="2" style="text-align:center" >OOM
198+
</td>
199+
<td colspan="2" style="text-align:center" >OOM
200+
</td>
201+
<td>OOM
202+
</td>
203+
</tr>
204+
<tr>
205+
<td><strong>GC DP-SGD</strong>
206+
</td>
207+
<td>362 MB
208+
</td>
209+
<td>0.21 s
210+
</td>
211+
<td>1.32 GB
212+
</td>
213+
<td>0.67 s
214+
</td>
215+
<td>5.27 GB
216+
</td>
217+
<td>2.5 s
218+
</td>
219+
<td>12.7 GB
220+
</td>
221+
<td>5 s
222+
</td>
223+
<td>OOM
224+
</td>
225+
</tr>
226+
</table>
227+
228+
229+
230+
In terms of peak memory footprint, DP-SGD \> FGC DP-SGD ≫ GC DP-SGD ≈ PyTorch SGD. Further, the runtimes are similar because most of the parameters are frozen and the forward pass takes up most of the time.
231+
232+
#### Synthetic Setup: Memory Profiling
233+
234+
We consider the following setup to profile the memory used by PyTorch SGD, Vanilla DP-SGD and Ghost Clipping, GC DP-SGD.
235+
236+
* 2-layer fully connected neural network
237+
* Input: 5120
238+
* Hidden: 2560
239+
* Output: 1280
240+
* Total number of model parameters \= 15.6M
241+
* Model size \= 62.5 MB
242+
* Batch size, different values, as seen in the table below.
243+
244+
The table below summarizes the max memory increase (in MB) broken down by stages of the training loop for each of the methods.
245+
246+
247+
248+
<table class="table table-bordered">
249+
<tr>
250+
<td><strong>Batch Size</strong>
251+
</td>
252+
<td><strong>Method</strong>
253+
</td>
254+
<td><strong>Model to GPU</strong>
255+
</td>
256+
<td><strong>Forward</strong>
257+
</td>
258+
<td><strong>First Backward</strong>
259+
</td>
260+
<td><strong>Second Backward</strong>
261+
</td>
262+
<td><strong>Optimizer Step</strong>
263+
</td>
264+
</tr>
265+
<tr>
266+
<td rowspan="3" >32
267+
</td>
268+
<td><strong>PyTorch SGD</strong>
269+
</td>
270+
<td>62.5
271+
</td>
272+
<td>0.5
273+
</td>
274+
<td>62.5
275+
</td>
276+
<td>N/A
277+
</td>
278+
<td>0
279+
</td>
280+
</tr>
281+
<tr>
282+
<td><strong>Vanilla DP-SGD</strong>
283+
</td>
284+
<td>62.5
285+
</td>
286+
<td>0.47
287+
</td>
288+
<td>3,663
289+
</td>
290+
<td>N/A
291+
</td>
292+
<td>162.5
293+
</td>
294+
</tr>
295+
<tr>
296+
<td><strong>GC DP-SGD</strong>
297+
</td>
298+
<td>62.5
299+
</td>
300+
<td>0.47
301+
</td>
302+
<td>63.13
303+
</td>
304+
<td>50
305+
</td>
306+
<td>125
307+
</td>
308+
</tr>
309+
<tr>
310+
<td rowspan="3" >2<sup>17</sup>
311+
</td>
312+
<td><strong>PyTorch SGD</strong>
313+
</td>
314+
<td>62.5
315+
</td>
316+
<td>1920
317+
</td>
318+
<td>1932.5
319+
</td>
320+
<td>N/A
321+
</td>
322+
<td>0
323+
</td>
324+
</tr>
325+
<tr>
326+
<td><strong>Vanilla DP-SGD</strong>
327+
</td>
328+
<td colspan="5" style="text-align:center" >OOM
329+
</td>
330+
</tr>
331+
<tr>
332+
<td><strong>GC DP-SGD</strong>
333+
</td>
334+
<td>62.5
335+
</td>
336+
<td>1920
337+
</td>
338+
<td>2625
339+
</td>
340+
<td>1932.5
341+
</td>
342+
<td>125
343+
</td>
344+
</tr>
345+
</table>
346+
347+
348+
#### Industry use case
349+
350+
We tested Ghost Clipping DP-SGD on an internal Meta use case, consisting of a model of size roughly 100B with 40M trainable parameters. Our initial results show that Ghost Clipping SGD reduces 95% memory of vanilla DP-SGD, and achieves comparable memory usage to PyTorch SGD.
351+
352+
## Conclusion
353+
354+
In this post, we describe implementations of Fast Gradient Clipping and Ghost Clipping in Opacus that enable memory-efficient training of machine learning models with differential privacy. Currently, the Ghost Clipping implementation only applies to linear layers, but, as outlined in [part 3 of the series](https://pytorch.medium.com/differential-privacy-series-part-3-efficient-per-sample-gradient-computation-for-more-layers-in-39bd25df237), it can be extended to “generalized” linear layers such as convolutions and multi-head attention. The current techniques require two explicit backpropagation steps, which increases runtime. We will explore developments on top of Ghost Clipping such as the [Book-Keeping algorithm](https://arxiv.org/abs/2210.00038) for mitigation.
355+
356+
To learn more about Opacus, visit [opacus.ai](https://opacus.ai/) and [github.com/pytorch/opacus](https://github.com/pytorch/opacus).
357+
358+
## Acknowledgements
359+
360+
We thank Iden Kalemaj, Darren Liu, Karthik Prasad, Hao Shi, Igor Shilov, Davide Testuggine, Eli Uriegas, Haicheng Wang, and Richard Zou for valuable feedback and suggestions.
361+
362+
[^1]: There are [ways](https://proceedings.neurips.cc/paper\_files/paper/2023/file/a45d344b28179c8da7646bc38ff50ad8-Paper-Conference.pdf) to extend Ghost Clipping to non-linear layers.
179 KB
Loading
154 KB
Loading

0 commit comments

Comments
 (0)