GPT2와 BART로 대화 생성 모델을 학습하는 방법에 대해 설명하겠다.
여기서 두 화자 간 대화 모델을 가정하고 설명한다. 특히, user의 말에 system이 응답해주는 식이다. 다음 대화를 각 모델이 어떻게 입력으로 받아들이고 어떤 label을 취하는지 살펴보겠다.
예시 대화
xxxxxxxxxx
User: Hello, I need your help.
System: What do you want?
User: Can you speak Korean?
System: A little bit
GPT-2
GPT는 autoregressive model이다. 따라서 현재 output이 다음 input token이 된다. 그러나 학습 시에는 한 번에 학습하기 위해 teacher forcing 방식을 사용한다. 따라서 다음과 같은 형식으로 학습된다.
노란색으로 칠해진 부분은 dialogue history, 초록색으로 칠해진 부분은 System response다. Input representation을 정하는 방법은 여러 가지지만 대게 이런 형식을 취한다. <ctx>
와 </ctx>
는 dialogue context를 구분지어주는 special tokens이지만 생략되기도 한다. <usr>
과 <sys>
은 각각 user token과 system token이다. Label은 각 입력 token의 다음 token이 되기는 하나 dialogue 생성 모델에선 다음 token을 맞추는게 아니라 응답을 생성하도록 학습하는 것이 목적이므로 응답 이외에는 -100으로 labeling해서 loss 계산 시 빠지도록 한다. 참고로 Pytorch 기준 torch.nn.CrossEntropyLoss
의 argument로 ignore_index=-100을 주는 것이 기본이다.
Token type ids는 각 문장을 구분해주는 역할을 한다. 보통 0, 1 binary로 사용하지만 특별히 ctx 토큰에 대해선 다른 type을 주었다. 이 토큰은 생략 가능하다는 것에 주목하자. Huggingface/transformers의 공식 코드를 보면 token type ids의 임베딩을 따로 학습하기 위해 weights가 추가적으로 또 있는 것이 아니라 word token을 임베딩할 때와 동일한 weights를 사용한다.
BART
BART는 표준 Transformer 구조이다. 따라서 encoder는 bidirectional transformer이기 때문에 pad에 attention 되지 않도록 attention mask가 필요하다. Token type id나 attention mask의 자세한 내용은 논문 혹은 Glossary을 참고바란다.
Decoder에만 loss가 계산되는 것을 주목하자. GPT-2와 어느정도 차이는 있지만 거의 같은 구조라는 것을 확인할 수 있다.
댓글