Processing math: 100%

Deep Learning

[Paper] A Dual-Stage Attention-Based Recurrent Neural Network (DA-RNN)

ju_young 2022. 2. 16. 21:50
728x90

A Dual-Stage Attention-Based Recurrent Neural Network (DA-RNN)

 

Instroduction

  • encoder-decoder networks의 문제는 sequence의 길이가 길어질수록 성능이 낮아지는 것이다.

 

  • 위 문제를 해결하기 위해 attention based encoder-decoder network를 사용
  • 하지만 time series prediction task에는 맞지 않음

 

  • time series prediction task를 목적으로 하는 모델인 dual-stage attnetion-based recurrent neural network (DA-RNN)을 제안
  • 1 stage: 모든 time steps에서 이전 encoder hidden state를 참고하여 각 time step에서의 연관된 driving series를 adaptively 추출하기위해 새로운 attention mechanism을 개발
  • 2 stage: temporal attention mechanism은 모든 time steps의 연관된 encoder hidden state들을 선택하여 사용

 

  • 가장 많이 연관된 input feature들을 adaptively 선택
  • long-term temporal dependency를 억제

 

Model

input sequence는 n개의 driving series로 이루어진 X=(x1,x2,...,xT)

  • ht: time t에서의 encoder의 hidden state
  • f1: LSTM or GRU (논문에서는 LSTM을 사용)

 

Input attention

  • ht1: encoder LSTM unit에서의 previous hidden state
  • st1: encoder LSTM unit에서의 previous cell state
  • xk: k-th input driving series

 

 

Attention weights (encoder)

  • αkt: softmax(ekt)

 

New input at time t

 

Update encoder hidden state

  • f1: LSTM

 

 

Encoder


Decoder

 

 

Temporal attention

  • dt1: decoder LSTM unit에서의 previous hidden state
  • st1: decoder LSTM unit에서의 previous cell state
  • hi: encoder hidden state

 

Attention weights (decoder)

  • βkt: softmax(ekt)

 

Context vector

 

  • weighted sum of all the encoder hidden states(hi)

 

New decoder input

  • yt1: decoder input (target series)
  • ct1: computed context vector

 

Update decoder hidden state

  • f2: LSTM

 

Final Prediction

  • dT: decoder hidden state
  • cT: context vector

 

Github

728x90