Deep Learning

LSTM (Long Short Term Memory)과 GRU(Gated Recurrent Unit)

ju_young 2022. 11. 14. 16:41
728x90

Long Short Term Memory

Long Short Term Memory는 위에서 얘기한 문제들을 해결하기위해 고안해낸 구조이다. 간단히 말하면 Long Term Memory와 Short Term Meomory를 같이 고려하여 계산함으로서 극복하겠다는 것이다.

위에있는 A라는 셀의 내부를 자세히 살펴보면 다음과 같은 구조를 가진다.

총 4개의 gate로 이루어져있으며 입력값은 이전 cell state, 이전 hidden state, X_t 3개가 있고 출력은 Next cell state, Next hidden state, h_t 3개가 있지만 실질적은 출력은 h_t가 된다. 여기서 LSTM(Long Short Term Memory)의 핵심은 cell state인데 간단하게 말하면 정보가 변하지않고 그대로 흐르게 해주는 역할을 해주는 녀석이라고 이해하면 될 것 같다.

  • Forget Gate

우선 forget gate는 sigmoid layer를 사용하여 입력받은 데이터를 버릴건지 말건지 0과 1로 결정하여 cell state에 넘겨준다. 위에서는 잠재변수 h_t-1과 현재 입력 값 x_t에 각각 가중치를 곱하여 $W_f \cdot [h_{t-1} , x_t] + b_f$로 계산한다. 그 후에 sigmoid 함수를 적용하는 것이다. 수식을 다시 정리하면 다음과 같다.
$$
f_t = \sigma(W_{hf} h_{t-1} + W_{xf}x_t + b_f)
$$

  • Input Gate

input gate는 입력값 x_t와 잠재변수 h_t-1에 가중치 행렬을 곱한후 sigmoid를 적용한 i_t와 tanh를 적용한 $\tilde{C}_t$를 얻는다. 여기서 i_t는 sigmoid를 사용하여 현 시점의 중요도를 출력한 값이며 $\tilde{C}_t$로 정규화 작업을 시키고 cell state로 추가할 candidate를 결정한다.

 

$\tilde{C}_t$에서 tanh를 해주는 이유는 i_t에서 나온 0~1값으로 vanishing/exploding gradient problem이 일어날 수 있기때문에 -1~1값으로 정규화하는 것이라고 한다. 또한 cell state에 추가할 candidate를 결정하기위해 사용한다고 한다.

  • Update Cell

forget gate에서 얻은 f_t값은 이전 타임 스템의 cell state 값 C_t-1과의 곱셈을 한다. 그리고 input gate에서 얻어진 C_t은 이전에 설명한 것처럼 sigmoid의 결과값 i_t과 곱셈을 한다. 그리고 이 두 개의 값을 더하면 cell state가 update된다.

  • Output Gate

output gate는 update gate를 통해 얻은 C_t를 tanh에 적용하고 잠재변수 h_t-1과 입력값 x_t에 가중치 행렬을 곱한 값 o_t를 simoid에 적용한다음 곱한다. 그러면 tanh 함수로 인해 -1~1 값이 출력될테고 이 출력된 값 h_t는 다음 hidden state의 잠재변수로서 쓰이게 된다.

 

정리하자면 이렇다. Forget Gate를 통해 과거의 정보를 얼마나 반영할지를 결정하고 Input Gate를 통해 현 시점이 실제 가지고 있는 정보가 얼마나 중요한지를 반영 한 후에 과거의 정보와 현 시점의 정보 중요도를 반영하여 Update한다. 마지막으로 Output Gate를 통해 update된 cell state(과거의 중요한 모든 정보)를 hidden state(다음 타임 스텝에 당장 필요한 정보)로 출력할 값을 만든다.

Gated Recurrent Unit(GRU)

GRU는 LSTM과 달리 cell state가 없고 오직 hidden state로만 이루어져있으며 reset gate, update gate 두 개의 gate만 존재하는 간략한 구조이다.

여기서 중요한 점은 update gate가 이전 LSTM의 forget gate와 input gate를 합쳤다는 것이다. 일단 reset gate부터 하나씩 살펴보자.

  • Reset Gate

reset gate는 과거의 정보를 적당히 리셋시키는 gate로 입력값 x_t와 이전 hidden state의 잠재변수 h_t-1을 어떻게 합칠 것인지 결정해준다. 위 수식에서는 r_t에 해당된다.

  • Update Gate

update gate는 위에서도 언급했듯이 LSTM의 forget gate와 input gate를 합친 것으로 과거와 현재의 정보를 얼마나 반영할지 결정하는 gate이다. sigmoid를 적용하여 출력된 z_t가 바로 현 시점 정보의 중요도이고 1 - z_t를 통해 과거 정보의 중요도를 얻을 수 있다.

  • Candidate

reset gate의 결과를 tanh에 적용하여 현 시점의 정보 candidate를 얻는다.

  • Update Hidden State

얻은 candidate와 update gate의 결과를 이용하여 현 시점의 hidden state 정보를 update하게 된다. 즉, 과거의 중요한 정보$((1-z_t)*h_{t-1})$와 현 시점의 중요한 정보$(z_t * \tilde{h}_t) $를 더한다는 것이다.

728x90