Triton Tutorial Practice: 04 Low-Memory Dropout

學習目標

  • The limitations of naive implementations of Dropout with PyTorch.
  • Parallel pseudo-random number generation in Triton.

Memory Dropout 是什麼

Memory Dropout 是一種防止 Over-fitting 的方法。
隨機地忽略(丟失)神經網絡中的一些神經元及其連接,這樣模型在每次訓練時會學到不同的特徵。這有助於使模型更加 general,從而在未見過的數據上表現更好。

在 SRIVASTAVA2014 中被提出來的方式

基本上的實作就是 vector 中的每一個元素都有一個機率 p 會被設定為 0,
而因為這樣會導致 vector 最後的總數值降低,會在 softmax 等運算中出現延伸的問題 (increase the norm / artificial decrease),所以會把沒有被選為 0 的數值均勻地放大,
做法就是將沒有被選為 0 的數值除以(1-p)
(舉例來說 10個1有3個選起來變為0,原本的總數值從10降為7,所以把7的每一個元素都除以0.7,最後的總量又會是10)

Seed Dropout

上面的做法有以下兩點衍伸的狀況

  1. need to store the dropout mask for backpropagation
    • 因為 training 時改變梯度的變動要對應這裡被 drop 的結構
  2. dropout state management can get very tricky when using recompute/checkpointing
    • 如果為了避免異常而要把狀態儲存下來,那被 drop 的 state 要怎麼存也會變的非常 tricky

引用 SALMON 的演算法,有以下特點
(1) has a smaller memory footprint
(2) requires less data movement
(3) simplifies

Q:什麼是 memory footprint?

重點是改寫成

    random = tl.rand(seed, offsets)
    x_keep = random > p

使用 rand seed 來指定的話其實只要有 seed number
某種程度上來講位置就是可以用 seed 來復現位置
也就是只要儲存 seed 的話就可以取得對應的位置

Exercise

練習

Reference