Triton Tutorial Practice: 04 Low-Memory Dropout
學習目標
- The limitations of naive implementations of Dropout with PyTorch.
- Parallel pseudo-random number generation in Triton.
Memory Dropout 是什麼
Memory Dropout 是一種防止 Over-fitting 的方法。
隨機地忽略(丟失)神經網絡中的一些神經元及其連接,這樣模型在每次訓練時會學到不同的特徵。這有助於使模型更加 general,從而在未見過的數據上表現更好。
在 SRIVASTAVA2014 中被提出來的方式
基本上的實作就是 vector 中的每一個元素都有一個機率 p 會被設定為 0,
而因為這樣會導致 vector 最後的總數值降低,會在 softmax 等運算中出現延伸的問題 (increase the norm / artificial decrease),所以會把沒有被選為 0 的數值均勻地放大,
做法就是將沒有被選為 0 的數值除以(1-p)
(舉例來說 10個1有3個選起來變為0,原本的總數值從10降為7,所以把7的每一個元素都除以0.7,最後的總量又會是10)
Seed Dropout
上面的做法有以下兩點衍伸的狀況
- need to store the dropout mask for backpropagation (why?)
- dropout state management can get very tricky when using recompute/checkpointing
引用 SALMON 的演算法,有以下特點
(1) has a smaller memory footprint
(2) requires less data movement
(3) simplifies
Q:什麼是 memory footprint?
重點是改寫成
random = tl.rand(seed, offsets)
x_keep = random > p
Exercise
練習
Reference
-
[SALMON2011] John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, "Parallel Random Numbers: As Easy as 1, 2, 3", 2011
-
[SRIVASTAVA2014] Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, "Dropout: A Simple Way to Prevent Neural Networks from Overfitting", JMLR 2014
-
https://isamu-website.medium.com/understanding-triton-tutorials-part-2-f6839ce50ae7