본문 바로가기
NLP

[TDS] SimpleTOD에 대한 이해 (End-to-end TDS)

by 방구석 몽상가 2021. 9. 1.
SimpleTOD

이 글은 SimpleTOD (A Simple Language Model for Task-Oriented Dialogue)의 방법론 이해를 위한 글이다.

Abstract

Task-Oriented dialogue systems은 주로 세 가지 subtasks(e.g., NLU, DST, NLG)로 분해되어 해결되었었다. 간단하고 통합된 접근으로 SoTA 성능을 달성하기 위해, 모든 sub-tasks를 single sequence prediction 문제로 recast해 훈련된 단일 causal LM인 SimpleTOD를 이용하여 task-oriented dialogue 접근을 간단히 한다. 또한, SimpleTOD는 GPT-2와 같은 pre-trained, open domain, causal LM으로부터 transfer learning을 진행한다. 실험을 통해 noisy annotations에 robustness함을 발견했으며, end-to-end setting에서 주요 metrics 성능을 향상시켰다.

Contributions

1) SimpleTOD - DST를 위한 SoTA 생성 모델

2) SimpleTOD는 end-to-end setting에서 DST, action decisions, response generation metrics에 대해 SoTA 성능을 성취한 첫 모델

3) 분석을 통해, SimpleTOD는 Noisy-labeled annotations에 robust한 dialogue state tracker라는 것을 증명

4) user/systemendof(segment) tokens의 중요성에 대한 ablations

5) SimpleTOD의 larger versions은 end-to-end MultiWOZ에 대해 항상 더 좋은 것은 아니라는 것을 보여주면서 pre-training의 중요성에 대한 ablations

6) MultiWOZ 2.1에서 noisy annotations의 발견하여 목록화

Methods

Task-oriented dialogue (TOD)는 다음 세 가지 sub-tasks에 대해 평가된다.

1) Dialogue state (belief state) tracking

2) Dialogue management (action/decision prediction)

3) Response generation

이러한 분해 없이, single-model, end-to-end approach, SimpleTOD를 사용하여 문제를 해결한다.

SimpleTOD

Turn 에서 user가 input 를 제공하면 시스템은 response 를 생성한다. 추론 동안 하나의 response를 생성하기 위해 SimpleTOD는 context 로써 모든 이전 turns을 읽어낸다. 그리고 Belief state 를 생성한다.

는 특정 도메인 내 slots에 대한 values를 기록한 triplets 목록이다: {(domain, slot_name, value), ...}. 이 belief state는 database를 query하는데 사용된다. DB search는 DB로부터 belief state의 조건들을 만족시키는 rows를 반환한다. 반환된 rows는 추후 응답을 lexicalization (어휘화)하는데 사용된다. 그러나, SimpleTOD는 aggregated DB search 결과 만을 입력으로 취한다. 여기서 는 반환된 rows의 수와 예약 상태 정보 여부를 포함한다. 즉, 검색 결과를 직접 사용하지 않고 얼마나 나왔는지 정도만 사용하겠다는 건데 직관적으로는 이해가 잘 되지 않지만 실험적으로 결정했다고 한다. 그리고 DB search 결과는 delexicalized 응답을 lexicalization 하는데 사용된다고 하는데 예시를 통해 한 번 살펴보자.

일단 Context 는 이전의 context와 현재 user의 utterance로 이루어진 것을 알 수 있으며, user/system 태그로 구분되어진다. 그 다음 context, belief state, action, response를 구분하기 위해 시작 태그 뿐 아니라 endof 태그가 붙는다. Turn 2에서 SimpleTOD에 의해 생성된 응답을 보면, [train_id], [value_time]과 같이 delexicalized reponse를 생성하는 것을 볼 수 있다. 이때, 태그들은 slot 속성에 해당하고 검색된 rows에서 해당 slot의 value를 얻어 response를 lexicalization한다.

다시 돌아와 를 계산한 뒤, SimpleTOD는 action 를 결정하기 위해 단일 sequence로써 함께 결합(concat)된 , , 를 조건화한다.

Action 는 또 다른 triplets (domain, action_type, slot_name)의 목록으로 생성된다. Delexicalized response 는 이전 모든 정보를 단일 sequence로써 결합해 조건화하여 생성된다.

마지막으로 위에 예시에서 이야기했듯, Belief state와 DB 검색 결과의 정보를 결합하여 response을 사람이 읽을 수 있는 text로 lexicalization한다.

Loss function

단일 training sequence는 concatenation 로 구성된다. 이는 sequence 에 대해 joint probability을 모델링할 수 있도록 한다. 다음의 Causal Language modeling을 통해 joint probability 를 학습한다.

확률의 chain rule을 사용하여 해당 분포를 다음과 같이 인수분해할 수 있다. 최종적으로 Dataset 에 대해 negative log-likelihood를 최소화하도록 neural network를 훈련한다.

Architecture

조건부 분포를 학습하기 위해 Transformer의 variant를 훈련시킨다. n개의 tokens을 가진 하나의 sequence는 차원의 n개의 vectors를 가진 하나의 sequence로 임베딩된다. 각 vector는 학습된 token embedding과 sinusoidal positional embedding의 합이다.

Vectors의 sequence는 matrix 로 stack되고 개의 attention layers에 의해 처리된다. 번째 layer는 두 개 blocks으로 구성된다. 첫 번째 block은 heads의 multi-head attention을 사용한다. Causal mask는 future tokens이 attention되는 것을 방지하기 위해 사용된다. 두 번째 block은 inputs을 inner dimension 로 project하는 feedforward network를 사용한다. 이때, activation 함수는 ReLU를 사용한다. Parameter and 에 의해 다음과 같이 연산된다.

기존 transformer와 같이 LayerNorm과 residual connection을 추가하여 최종적으로 각 block의 연산은 다음과 같다.

Block 1

Block 2

마지막 layer는 를 연산한다. 자세한 내용은 여기를 참고하자.

Training Details

모델의 input은 DistilGPT2에서의 pretrained BPE codes를 사용하여 token화 된다. 여기서 BPE (Byte Pair Encoding)[참고]는 OOV 문제를 해결하기 위해 훈련 데이터의 모든 단어를 character 단위 혹은 유니코드 단위로 모두 나눈 다음 연속적으로 가장 많이 등장한 단어(글자) 쌍을 찾아 하나의 단어로 만드는 것을 반복한다. 최대 sequence 길이는 1024 tokens으로 제한했다.

Input Representation

위 table은 SimpleTOD 입력의 schematic representation이다. 단일 훈련 시퀀스는 context, belief state, DB search results, action decisions, system response의 결합으로 구성된다. 각 input token과 연관된 output state는 다음 token을 예측하는데 사용된다.

추론동안, SimpleTOD는 토큰마다 해당 시퀀스 토큰을 생성하지만 DB에 query하기 위해 belief states가 생성된 후 멈춘다. 그 다음 DB outputs은 위와 같이 요약되어 input sequence 끝에 결합된 뒤 토큰마다 생성을 계속한다. 최종적으로 delexicalized reponses는 slots과 values를 DB 결과 내 정보로 대체함으로써 lexicalization 된다.

Experiments

Dataset Details

평가를 위해 MultiWOZ 데이터셋을 사용한다. Police와 hospital 도메인은 valid/test splits이 없기 때문에 평가로부터 제외된다. SimpleTOD는 케임브리지 대학의 기술 보고서에서 설명된 pre-processing을 따라 delexicalized 시스템 응답에 대해 훈련된다. 최근 noisy state values가 제거된 MultiWOZ 2.1이 릴리즈됐지만 모든 이전의 작업들이 2.0에서 평가됐기 때문에 2.0과 2.1에 대해 둘 다 평가한다.

Evaluation Details

모든 metrics은 original MultiWOZ의 guidance를 따르며, Structured Fusion Networks for Dialog의 combined score를 계산한다.

Joint goal accuracy

Dialogue state tracking (i.e. belief state tracking)의 성능을 평가하기 위해 사용된다. oracle belief states외 비교함으로써 생성된 belief states의 정확도를 측정한다. 모든 예측된 values가 정확하게 oracle values와 매치될 때만 정답으로 계산된다.

Inform and Success rates

Action prediction의 성능을 평가하기 위해 사용된다. Inform rate는 시스템에서 제공된 entities가 얼마나 정확한가를 의미하며, success rate는 시스템이 사용자가 요청한 모든 속성에 응답할 수 있는가를 의미한다.

BLEU score

Response generation을 평가하기 위해 사용되며, 생성된 응답의 fluency를 측정한다.

Combined score

Action과 response generation의 총 평가로, 로써 계산된다.

Dialogue State Tracking

위에서 비교된 baselines은 dialogue context의 더 나은 representation을 학습하기 위해 bidirectional encoder를 사용한다. 그러나 SimpleTOD는 unidirectional decoder를 사용하고 추가적인 양방향 encoder를 사용하지 않는다. 또한, extra supervision을 이용하지 않았다. 그렇지만 다른 모델과 비교해 SoTA를 달성했다.

Action and Response Generation

이전 작업은 oracle DB Search 결과들을 훈련동안 supervision으로, 추론동안 입력으로 사용한다. 비교를 위해 DB 결과를 사용한 결과도 나타내지만 DB search 정보 없이 실험한 결과는 놀라운 결과를 보여준다. 그리고 Dynamic DB search 결과를 사용하는 setting에서의 결과도 보여준다. 이 setting은 matched DB entries의 수로 훈련하고, 추론할 때 생성된 belief states로부터 동적으로 이를 계산한다.

DB Search 결과를 완전히 무시하는게 최고의 결과를 보여주는데 왜 그럴까? 저자는 생성된 belief states가 DB 내 정보와 몇몇 경우에서 충돌하는 것을 발견했다고 한다. 예를 들어, 식당 이름 중 불일치가 존재하는 경우가 있다. Target belief states는 'pizza hut fenditton'이지만 DB에선 'pizza hut fen ditton' 인 경우가 이에 해당한다.

Decoding and Multi-turns

Pre-trained weights로 초기화되기 때문에 cost가 높은 decoding 전략을 사용할 필요가 없다고 한다. 그래서 SimpleTOD는 간단한 greedy decoding을 채택한다.

초기의 belief state errors는 추가적인 turns이 진행되면서 증가된 context에 의해 추후 제거된다.

댓글