Triton Tutorial Practice: 04 Low-Memory Dropout

學習目標

  • The limitations of naive implementations of Dropout with PyTorch.
  • Parallel pseudo-random number generation in Triton.

Memory Dropout 是什麼

Memory Dropout 是一種防止 Over-fitting 的方法。
隨機地忽略(丟失)神經網絡中的一些神經元及其連接,這樣模型在每次訓練時會學到不同的特徵。這有助於使模型更加 general,從而在未見過的數據上表現更好。

在 SRIVASTAVA2014 中被提出來的方式

基本上的實作就是 vector 中的每一個元素都有一個機率 p 會被設定為 0,
而因為這樣會導致 vector 最後的總數值降低,會在 softmax 等運算中出現延伸的問題 (increase the norm / artificial decrease),所以會把沒有被選為 0 的數值均勻地放大,
做法就是將沒有被選為 0 的數值除以(1-p)
(舉例來說 10個1有3個選起來變為0,原本的總數值從10降為7,所以把7的每一個元素都除以0.7,最後的總量又會是10)

Seed Dropout

上面的做法有以下兩點衍伸的狀況

  1. need to store the dropout mask for backpropagation (why?)
  2. dropout state management can get very tricky when using recompute/checkpointing

引用 SALMON 的演算法,有以下特點
(1) has a smaller memory footprint
(2) requires less data movement
(3) simplifies

Q:什麼是 memory footprint?

重點是改寫成

    random = tl.rand(seed, offsets)
    x_keep = random > p

Exercise

練習

Reference

關於

AI Computing / 武術 / 登山 / IT / - 貪多而正努力咀嚼的人生小吃貨