Deep Learning

Generate Text Decoding Methods

ju_young 2022. 9. 21. 18:03
728x90

Greedy Search

Greedy Search는 다음 단어로 올 확률이 높은 단어를 선택하는 간단한 알고리즘이다.

하지만 이 알고리즘은 위 그림처럼 낮은 확률 뒤에 존재하는 높은 확률을 가지는 단어를 놓친다는 것이다.

 

Beam Search

Beam Search는 뒷 부분에 높은 확률을 가지는 단어를 놓칠 risk를 줄일 수 있다.

위 처럼 전체를 보았을때 가장 높은 확률을 선택하는 알고리즘으로 다음과 같이 진행된다.

  • time step 1. 가장 가능성이 높은 ("The", "nice")과 두 번째로 가능성이 높은 ("The", "dog")도 선택한다.
  • time step 2. ("The", "dog", "has")가 0.4x0.9=0.36으로 ("The", "nice". "woman") 0.5x0.4=0.2보다 높은 확률을 가진다. 따라서 "The dog has"라고 예측한다.

Beam Search는 Greedy보다 더 높은 확률의 sequence를 찾겠지만 항상 그렇다고 보장할 수는 없다.

 

다음 예시를 확인해보자.

Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.

I'm not sure if I'll ever be able to walk with him again. I'm not sure if I'll

위 결과에서 반복되는 단어가 존재하게된다. 이 문제를 해결하기위해 n-grams penalty라는 것이 생기게된다.

n-grams penalty은 가장 많이 존재하는 n-gram이 두 번 나오지 않도록 다음 단어의 확률 0으로 수정해버린다. 하지만 n-grams penalty는 사용할때 주의해야한다. 만약 반복으로 나와야하는 태스크라면 사용하지 말아야한다.

 

Sampling

일반적으로 sampling은 조건부 확률 분포에따라 다음 단어 $w_t$를 랜덤으로 선택하는 것을 의미한다.

위에서 사용한 예시를 시각화해보면 다음과 같다.

이것은 sampling을 사용했을때 전혀 deterministic하지않다는 것을 보여준다.

이렇게 이상하게 출력하는 것을 개선하기위해 softmax에 temperature라고 불리는 것을 곱하여 조건부 확률 분포를 더 sharp하게 만든다.

 

Top-K Sampling

다음 단어로 올 가능성이 높은 K개의 단어를 필터링하고 이 K개의 단어들로 다시 확률을 다시 계산해가는 sampling 이다. GPT2에서 바로 이 sampling을 사용했다고 한다.

한 가지 Top-K sampling에 대한 아쉬운 점이 있다면 다음 단어로 필터링될 단어의 개수가 dynamic하지 않다는 것이다. 이것은 확률 분포가 아주 sharp하거나 flat할때 문제가 될 수 있다.

 

Top-p (nucleus) sampling

그냥 가장 확률 높은 K의 단어를 sampling하는 대신에 Top-p sampling에서는 누적 확률이 확률 p를 초과하는 단어들을 선택한다. 이 방법으로 선택되는 단어의 수는 다음 단어의 확률 분포에 따라 dynamic하게 감소하거나 증가할 수 있다.

 

예를 들어서 p=0.92라고 한다면 다음과 같이 단어가 선택된다.

  • "nice" + "dog" + "car" + "woman" + "guy" + "man" + "people" + "big" + "house" > 0.92
  • "drives" + "is" + "turns" > 0.92

(흠... 문서에서는 초과하는 단어를 선택한다고 했는데 p 미만의 단어들을 선택한다는게 맞는거 아닌가...?)

 

[Reference]

https://huggingface.co/blog/how-to-generate

728x90