ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators
Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning
https://openreview.net/forum?id=r1xMH1BtvB
ELECTRA 논문에 대해 간단하게 요약한 리뷰입니다.
3줄 요약
- 기존의 MLM에서 전체 데이터의 15%만 loss를 계산하고 학습하는 등의 방식이 비효율적이기 때문에 전체 데이터를 학습하도록 하면 훨씬 효율적으로 학습할 수 있음
- MLM과 달리 MASK 토큰 자리에 대체 단어를 생성하고 전체 단어가 original인지 corrupted token인지 예측하는 replaced token detection을 통해 학습이 진행
- 대체 단어를 generator로 생성하고 이를 discriminator로 진짜인지 가짜인지 판별하는 방식으로 학습. Generator의 파라미터는 discriminator보다 작은 것이 좋고 둘 사이의 가중치를 공유하는 방식이 성능이 잘나옴
0. ABSTRACT
- 많은 계산량이 필요한 BERT 같은 MLM과 달리 간단한 generator network를 통해 [MASK] 대신 대체 토큰을 생성하여 해당 토큰이 생성된 토큰인지 original인지를 예측하는 discriminative model을 학습하는데, 이를 replaced token detection task라고 함.
- 모든 input token을 판별하기 때문에 masking된 token만 예측하는 MLM 보다 더 효율적이고 같은 model size, data, compute를 가진 BERT보다 더 좋은 성능을 냄
- 작은데 쎔. RoBERTa, XLNet의 1/4 만큼의 계산량으로 성능 비슷하고 같은 계산하면 더 좋음.
1. INTRODUCTION
- 기존의 SOTA 모델은 denoising autoencoders 방식으로 학습한다고 볼 수 있는데, MLM은 MASK를 씌운 15%의 토큰에 대해서만 학습하기 때문에 계산 비용을 낭비하게 됨.
- replaced token detection을 추구함. MASK 토큰 대신 생성된 대체 단어를 넣어 줌으로써 input을 corrupts 해주고 모든 토큰에 대해 진짠지 대체 단어인지 판별하는 discriminator 방식으로 사전학습 함. 이 방식으로 MASK 토큰은 사전학습에서만 사용되고 finetuning에선 사용되지 않는 현상을 해결함.
- Discriminator 방식은 모든 input token에 대해 학습하기 때문에 MLM보다 효율적임. GAN이랑 비슷해 보일 순 있는데 텍스트에 GAN을 적용하기 어렵기 때문에 maximum likelihood로 학습한다는 점에서 ‘adversarial’ 하지는 않음.
- BERT보다 훨씬 학습 속도도 빠르고 효율적인데 성능도 더 좋음. Large model은 RoBERTa, XLNet이랑 성능이 비슷한데 파라미터도 적고 계산량은 1/4 수준임. 암튼 좋음
2. METHOD
- [MASK] 토큰 자리에 generator를 이용해 단어를 생성하고 그 단어가 대체인지 진짜인지를 맞추는 과정
- 각 network는 transforemr encoder로 구성. position t에서 generator는 softmax layer 를 거쳐 t시점의 x를 generation할 확률을 출력.
- GAN과 다른 점으로 generator가 정답과 유사한 토큰을 생성해 냈다면 fake가 아니라 real로 사용된다는 점과, discriminator를 속이기 위해 adversarially하게 학습하지 않고 maximum likelihood로 학습한다는 것이 있음
- Generator는 MLM loss, discriminator는 아래에 있는 loss 사용, generator loss와 discriminator loss의 합을 최소화하도록 학습.
3. EXPERIMENTS
3.2. Model Extensions
- Weight Sharing: Generator와 discriminator의 크기가 같으면 모든 transformer weight 공유 가능. 그리고 가중치를 공유한게 성능이 더 잘나옴.공유했을 때 성능이 잘 나온 이유: Discriminator는 입력 토큰만 학습하는데, generator는 출력 레이어에서 softmax를 통해 사전에 있는 모든 토큰에 대해서 밀도 있게 학습. ELECTRA는 결국 discriminator만을 취해서 사용하는데, generator의 weight를 다 공유하면 입력 출력을 다 학습하게 되므로 훨씬 효과적으로 학습이 됨.
- 근데 small generator를 쓰는게 더 효율적이어서, embedding 값만 공유함. 여기서 임베딩 크기는 discriminator hidden state.
- Smaller Generators: generator의 크기가 discriminator만큼이 되면 MLM보다 계산량이 2배이고, generator의 크기가 1/4~1/2 수준일 때 가장 잘 작동함. Generator가 쎄면 학습에 악영향
- Training Algorithms: two-stage electra, adversarial electra 전부 해 봤는데, generator와 discriminator를 jointly 학습시키는 방식 성능이 가장 높았음
3.3. Small Models, Large Models
- single GPU에서도 빠르게 학습 되도록 BERT보다 가볍게 모델 만듬
sequence length: 512->128
batch size: 256 -> 128
hidden size: 768 -> 256
token embedding: 768 -> 128 - 성능 보면 작은데도 불구하고 우수함. 큰걸로 쓰면 더 좋음, Large모델도 역시 가장 좋음
3.5. Efficiency Analysis
ELECTRA가 성능이 잘나오는 이유를 파악하기 위해 다음과 같은 실험 실행
- ELECTRA 15%: discriminator loss를 input에서 masking된 15%만 계산
- Replace MLM: Discriminator를 MLM 학습을 하되, [MASK]로 는 안쓰고 generator가 만든 토큰 씀. 사전학습때만 MASK 쓰고 finetuning에는 안쓰는 문제인 discrepancy 문제를 해결했을 때 얼마나 좋아지는지 보여주기 위함.
- ALL-Tokens MLM: Replace MLM처럼 하는데, 15% 토큰만 치환하는 게 아니고 모든 토큰을 generator가 생성한 토큰으로 치환. 15%만 사용했을 때 보다 성능이 더 올라간 것을 확인. 다 쓰는게 좋다.
5. Conclusion
Replaced Token Detection을 쓰면 MLM보다 더 효율적이고 downstream tasks에도 결과가 더 잘나옴. 계산량도 적은데 결과도 잘나왔음. 무조건 더 큰 파라미터와 계산량을 추구하기보다 적은 컴퓨팅 파워로도 학습할 수 있도록 더 효율적인 방법도 생각해 봅시다.