본문 바로가기
새우의 테크/랜덤

[ML] RNN 에서의 gradient 계산

by 오새우 2022. 10. 11.

아래와 같은 RNN formulation에서의 gradient 계산은 어떻게 되는지 정리해 보겠습니다.

Loss 는 Cross Entropy Loss를 가정하겠습니다.

간단한 Chain Rule을 통해 다음과 같이 계산할 수 있습니다. 마지막 외적으로 계산되는 항의 계산은 스킵했습니다. 

MSE loss 로하면 비슷하게 나올 것 같은데, CE loss 로 해도 동일하게 나오는지 나중에 따로 유도한다면 올리겠습니다.

여기서 눈여겨 볼 만한 부분은, V에 대한 gradient는 오로지 현재 state의 값들에만 의존한다는 것입니다.

좀 다르게, W에 대한 gradient는 뒷 state 에도 영향을 받습니다. 여기서 파랑 네모에 있는 부분을 보면, gradient가 이전 state들에서의  gradient의 곱으로 계산되는 부분이 있는데, 흔히 gradient가 0 에서 1 사이의 값을 갖는 것을 생각하면 (tanh, sigmoid) gradient vanishing, 혹은 exploding현상이 쉽게 발생하는 것을 이해할 수 있습니다.

 

Gradient explosion의 경우는 값이 높은 gradient를 clip 해버리는 clipping method를 통해 해결할 수 있고, Gradient vanishing 문제의 경우, identitiy matrix 로 recurrent weight initialization을 하고, ReLU activation을 사용함으로써 해결할 수 있다고 (empirically) 알려져 있습니다. 참고하시면 좋을 것 같습니다.

댓글