본문 바로가기
NLP

[TDS] A Tailored Pre-Training Model for Task-Oriented Dialog Generation (PRAL)에 대한 이해

by 방구석 몽상가 2021. 9. 5.
PRAL

이 글은 ACL 2021에 게재된 A Tailored Pre-Training Model for Task-Oriented Dialog Generation (PRAL)의 방법론 이해를 위한 글이다.

Abstract

Motivation : TDS를 build하는 현재 접근들은 여전히 상당한 양의 annotations을 요구해 노동력을 많이 요구한다. 그래서 더 적은 supervision을 위해 large-scale language models로 대화 시스템을 개발하는 쪽으로 많은 연구들이 있었다. 그러나 LM을 대화 시스템에 적용하기엔 많은 한계점들이 있다.

1) 대화 시스템을 위한 LM pre-training은 엄청난 양의 훈련 corpora를 요구하지만 고품질의 다양한 대화 데이터를 얻는 건 항상 어렵다.

2) 대화는 다수의 인원으로 구성되고 각 인원은 다른 언어 스타일을 지닌다.
그렇지만 이전의 대화 시스템은 모든 인원에 대해 하나의 단일 언어 모델을 사용하여 dialog generation을 수행한다.

3) 대화는 항상 가변 길이이므로 GPT에서의 고정된 길이의 position embedding은 최적의 결과를 낼 수 없다.

4) 대화는 방대한 양의 commonsense knowledge를 포함하므로 작은 크기의 LM은 이를 놓칠 수 있다.
더욱이, 자연스러운 대화는 context의 좋은 이해를 요구하지만 LM 내 contextual 정보를 보존하는게 어렵다.

이러한 이슈들을 해결하기 위해, 대화 생성에 특화해 설겨한 language model인 Pre-trained Role Alternating Language model (PRAL)를 제안한다.

Contributions

1) They process and present a collection of high-quality dialog datasets suitable for pre-training large-scale language models on dialog systems.

2) They propose PRAL and design several effective techniques to improve the dialog model pretraining.

3) Their pretrained model leads to an increase on success rate on CamRest676 and MultiWOZ dataset, and an improvement on the coherence and diversity scores by 50% on PersuasionForGood.

Methods

PretrainDial Dataset for Pre-training

다음의 13개 dialog corpora의 형식을 통합해 pre-training dataset으로 사용한다. Corpora 종류는 chit-chat부터 task-oriented dialogs까지 다양하게 선택되었다.

데이터셋에서 총 dialogues 갯수는 142,298개, 대화 당 평균 턴 수는 12.66턴이다.

PRAL (Pre-trained Role Alternating Language model)

PRAL은 두 개의 language model을 사용하는 ARDM (Alternating Roles Dialog Model) 아키텍쳐를 채택한다. 여기서 두 language model은 small GPT-2로 초기화된다. ARDM에 대한 자세한 설명은 여기에 설명해놓았다. ARDM과 동일하게 전체 대화 분포는 다음으로 정의된다.

Start Position Randomization

GPT-2는 각 token의 위치 정보를 인코딩하기 위한 position embedding을 사용한다. 최대 position은 1024이고 position index는 항상 0부터 시작한다. 이 position embedding을 dialogue에 적용할 때 발생하는 문제점은 다음과 같다.

1) 대부분의 대화는 1024개 이하의 tokens을 포함하기 때문에 position embedding 내 대부분의 vectors는 0이고 pre-training 동안 업데이트 되지 않는다.

2) Position embedding은 각 token에 대한 위치 정보만 제공하기 때문에 시작 위치를 0으로 고정하는 것은 특정 position index로 특정 text를 잇는 것이다.
예를 들자면, "hi"는 항상 초반에 등장하니까 "hi"의 positional embeddings은 시작 근처로 overfitting 될 가능성이 있다.

이러한 점들을 해결하기 위해, 다음 Start Position Randomization (SPR)을 적용한다.

하나의 Dialog 내 전체 tokens 수를 이라 할때, 최대 start position index은 이고 start position은 0에서 사이의 임의의 숫자로 랜덤하게 설정된다. 이렇게 하면 textual meaning에서 positional 정보가 분리되고 모든 positional embeddings가 업데이트되도록 한다.

Teacher GPT

새로운 corpus에 훈련할 경우, 이전의 지식을 잊어버리는 문제를 완화시키기 위해 continual learning의 한 방식으로써 distillation loss를 사용한다. Distillation loss를 위해 teacher network로 고정된 GPT-2 large를 사용해 학습시키는 모델과의 KL divergence 를 계산한다.

History Discount

각 대화에서 후반 utterances는 더 복잡한 contextual 정보를 집약하고 있기에 더 중요하며, 모델이 context의 consistency를 학습하도록 도울 수 있다. 그러므로 턴 수를 기반으로 각 utterance의 중요성을 re-weight하는 discount factor 를 도입한다.

전체 개의 utterances로 이루어진 대화와 현재 utterance index 에 대해 language model loss로 가중치가 부여된다. 즉, 대화 후반부로 갈 수록 더욱 가중치를 부여함으로써 복잡한 context를 예측하고 일관성 있는 응답을 생성하도록 한다.

Optimization

모델을 최적화하는 language modeling 대한 loss는 다음과 같다.

여기서 번째 utterance의 전체 tokens 수이고, 는 output 확률 분포, 는 ground truth다. 여기에 distillation loss를 더해 최종적인 loss는 다음과 같이 정의된다.

Factor 는 iteration 증가에 따라 지수적으로 감소한다, i.e. .

Experiments

Pre-training details는 ARDM과 동일하게 따라간다. 자세한 내용은 여기에 적어놓았다.

MultiWOZ dataset의 combined score는 102.45로 현재 내가 찾아본 논문들 중에는 SoTA 성능이다. 참고로, SimpleTOD는 92.98, SOLOIST는 95.74, ARDM은 100.7이다.

 

댓글