논문 원본 : https://arxiv.org/abs/2403.11106
Self-Supervised Quantization-Aware Knowledge Distillation
Quantization-aware training (QAT) and Knowledge Distillation (KD) are combined to achieve competitive performance in creating low-bit deep learning models. However, existing works applying KD to QAT require tedious hyper-parameter tuning to balance the wei
arxiv.org
Abstract
- Quantization-aware training (QAT)와 Knowledge Distillation (KD)를 결합하여 low-bit 딥러닝 모델을 효과적으로 만드는 연구가 이루어지고 있음
- 하지만 기존의 QAT에 KD를 적용한 연구는 다음과 같은 한계 존재
- 다양한 손실 항목의 weight를 조절하기 위한 hyperparameter tuning이 번거로움
- label이 있는 학습 데이터를 사용할 수 있어야 한다는 가정
- 좋은 성능을 얻기 위해 복잡하고 계산 비용이 많이 드는 학습 절차가 필요
- 이 문제를 해결하기 위해 본 논문에서는 Self-Supervised Quantization-Aware Knowledge Distillation(SQAKD) 프레임워크 제안
- SQAKD는 다양한 quantization 함수의 forward 및 backward 계산 과정을 통합하여, 다양한 QAT task를 유연하게 통합할 수 있음
- 또한, QAT를 공동 최적화 문제로 정의하여, label 없이도 KD를 위해 full-precision 모델과 low bit 모델 간의 KL-Loss를 최소화하고, quantization의 이산화 오류 최소화
1. Introduction
- Deep neural networks (DNNs)는 많은 계산과 메모리를 필요
- IoT 및 다양한 디바이스에서 딥러닝의 사용이 급증하면서, 자원이 많이 필요한 DNN과 자원이 제한된 디바이스 간의 불균형 문제가 심화되고 있음
- Quantization는 이러한 문제를 해결하기 위한 주요 모델 압축 기법 중 하나로, 모델의 weight나 activation 값을 낮은 bit precision으로 변환하여 자원을 절약
- 특히 QAT는 사전 학습된 모델을 기반으로 재학습하며 quantization을 수행하는 방식으로 주목 받고 있음
- 하지만 기존의 QAT 방법은 아래와 같은 한계가 있음
- quantization으로 인해 정확도 손실이 큼
- 다양한 모델 아키텍처 (VGG, ResNet, MobileNet 등)에서 일관된 성능을 내지 못함
- 모델 간 직관이나 이론이 다르고 일반화가 어려움
- 1~3 bit와 같은 low bit 네트워크에서는 성능이 좋지 않음
- KD (Knowledge Distillation)을 QAT에 적용하는 최근 연구들은 고정밀 모델(Teacher)의 Knowledge를 저정밀 모델(Student)에 전달하여 정확도 손실을 줄이려고 함
- 하지만 이러한 방법들에도 한계가 존재
- 여러 손실 항목의 weight를 조정하기 위한 hyperparameter tuning이 번거로움
- label이 있는 데이터가 필요하다는 가정이 있음
- 높은 성능을 위해 복잡하고 계산 집약적인 학습 절차가 필요
- 특정 KD 접근법이나 quantization 기법에만 국화되어 일반적으로 일관된 성능을 내지 못함
- 본 연구에서는 SQAKD (Self-Supervised Quantization-Aware Knowledge Distillation)이라는 간단하지만 효과적인 프레임워크를 제안
- SQAKD의 주요 특징은 아래와 같음
- quantization 과정 통합 : 다양한 quantization 함수의 forward 및 backward 동작을 통합하여, 여러 QAT 작업을 유연하게 결합
- label 없는 학습 : label 데이터 없이 KD와 quantization error를 동시에 최소화하는 공동 최적화 문제로 QAT를 정의
- 간소화된 학습 과정 : 단일 학습 단계만 필요하여 계산 비용을 줄이고 사용성과 재현성을 향상시킴
- hyperparameter 불필요 : KL 손실만 사용하여 복잡한 손실 weight 조정 불필요
- 우수한 성능 : 다양한 모델 (VGG, ResNet, MobileNet 등)에서 기존 최첨단 QAT와 KD 연구보다 높은 정확도와 빠른 수렴 속도를 보임
2. Background and Related Works
Quantization
- PTQ (Post-Training Quantization) : 사전 학습된 모델을 quantization하지만 재학습 없이 수행. 이는 정확도 손실이 크다는 단점
- QAT (Quantization-Aware Training) : quantization를 재학습 과정에 포함하며, 더 나은 성능을 보임
QAT의 발전
- 초기 연구들은 quantization 과정으 forward 및 backward pass를 설계하는데 초점을 맞춤
- BNN과 XNOR-Net : 채널별 스케일링 도입
- DoReFa-Net : 모든 필터에 대해 공통 스케일러 사용
- 최근 연구들은 quantization을 위한 학습 가능한 매개변수를 활용하여 클리핑 범위 및 quantization 간격을 조정
- PACT, LSQ, APoT, DSQ : 입력 값을 제한하는 클리핑 범위를 개선
- QIL, EWGS : quantization 레벨 간 간격 (step size)를 최적화
- 하지만 QAT는 여전히 아래와 같은 문제 존재
- 정확도 손실의 수준이 제각각
- 다양한 직관적인 접근 방식이 동기에 깔려 있지만, 이론적 통합이 부족
- 특정 알고리즘이 모든 모델 아키텍처에 대해 일관적인 성능을 제공하지 못함. 또한, 대분분의 QAT 알고리즘은 4bit 이상의 high bit 네트워크에 초점이 맞춰져 있으며, low bit 네트워크에서는 성는이 떨어짐
Knowledge Distillation (KD)
- KD는 큰 모델 (Teacher)의 knowledge를 작은 모델 (Student)로 전달하여 성능을 향상시키는 기술
- 초기 연구 : Teacher와 Student의 Softmax 출력 간의 KL divergence을 최소화하고 label 기반의 교차 엔트로피 손실을 추가
- 하지만 KD가 quantization 모델 (Student)에서 제대로 이해되고 사용되는 사례는 제한적
Knowledge Distillation + Quantization
- KD를 QAT와 결합하여 low bit 네트워크의 정확도 손실을 줄이는 연구들이 이루어짐
- Apprentice (AP) : 3가지 phases를 통해 4bit 또는 3진법 네트워크 성능 향상
- QKD : quantization과 KD를 3단계 (자기 학습, 공동 학습, 지도 학습)로 조율
- SPEQ : Student의 매개변수를 사용해 Teacher를 구성하고 stochastic bit precision을 도입
- PTG : 4단계 학습 전략을 통해 quantizaion weight와 activation를 최적화하고 점진적으로 bit 수를 줄임
- CMT-KD : 여러 quantization된 Teacher 간 협업 학습과 Teacher-Student 간 상호 학습을 촉진
3. Methodology
3.1 QAT as Constrained Optimization
- 본 연구에서는 QAT의 일반화된 이론적 틀을 제안하며, 다양한 quantization 함수의 forward 및 backward 동작을 통하여 최적화 문제로 정의
Forward Propagation
- quantization 함수 Qaunt()는 full precision 입력 x를 quantization된 출력 xq로 변환
- x는 네트워크의 activation 값 또는 weight일 수 있음
- Clipping 단계
- x를 제한된 범위로 정규화하여 클리핑된 표현 xc를 만듦
- v, m : 클링 범위의 하한과 상한
- {pi} : quantization에 필요한 학습 가능한 파라미터
- Quantization 단계
- 클리핑된 값 xc를 round 함수를 포함한 R(⋅)로 discrete인 값 xq로 변환
- b : bit 너비
- {qi} : 학습 가능한 파라미터 (필요하지 않을 수도 있음)
- 통합 공식
- quantization 함수 Quant(⋅)는 아래와 같이 표현
- 여기서 α는 클리핑과 quantization 함수의 모든 파라미터를 포함
Back Propagation
- quantization 함수 Q(⋅)는 이산적 특성 때문에 미분이 불가능 (round 함수)
- 이를 해결하기 위해 대부분의 QAT 연구는 STE 사용
- 본 연구는 단순한 STE 대신, quantization 오류를 반영하는 새로운 공식을 제안
- 여기서 μ는 양자화된 값과 원래 값 간의 차이를 반영하는 비음수 값
Optimization Objective
- QAT를 "quantization으로 인한 오류"와 "모델 예측과 실제 label 간의 차이"를 동시에 최소화하는 문제로 정의
- Wf / Af : full precision weight / activation 값
- Wq / Aq : quantization weight / activation
3.2 Analysis of KD in QAT
Does KD perform well in QAT?
- KD는 널리 사용되지만, QAT 문제를 해결하는 데 있어 KD의 효과에 대한 철저한 연구는 부족
- 본 연구에서는 사전 학습된 full precision 네트워크를 "Teacher"로 사용하고, 동일한 구조의 low bit 네트워크를 "Student"로 사용하여 KD를 QAT에 적용
- 손실 함수 L은 cross entropy loss Lce와 distillation loss LDistill의 선형 결합으로 정의
- : Student 모델의 예측값과 실제 라벨 간의 차이
- : 두 손실의 중요도를 조절하는 hyperparameter
Are both the cross-entropy loss and distillation loss necessary?
- quantization 네트워크는 full precision 네트워크에 비해 표현 능력이 낮기 때문에, KD가 요구하는 다양한 손실 항목을 최적화하기 어려움
- quantization 과정에서 weight와 activation 값에 추가적인 noise가 발생하여, KD 성능이 저하될 수 있음
- 3가지 실험 시나리오
- KL-Loss만 최소화 (λ=1)
- CE-Loss만 최소화 (λ=0)
- KL-Loss와 CE-Loss를 동시에 최소화 (λ=0.5)
- 결과 분석
- KL-Loss만 최소화 ()
- KL-Loss를 줄이는 것만으로 CE-Loss도 효과적으로 감소
- 이는 Student 모델이 라벨 없이도 Teacher의 분포를 학습하여 정확한 예측을 생성할 수 있음을 의미
- CE-Loss만 최소화 () 또는 CE와 KL을 함께 최소화 ()
- CE-Loss가 포함되면 KL-Loss를 충분히 줄이지 못함
- 이는 CE-Loss와 KL-Loss가 함께 작동할 때 서로 간섭이 생겨 성능이 저하될 수 있음을 나타냄
- 즉, KL-Loss를 단독으로 사용하는 것이 QAT에서 가장 효과적
- KL-Loss만 사용할 경우 hyperparameter λ가 필요 없어 손실 함수가 더 단순해짐
3.3 Optimization via Self-Supervised KD
- CE-Loss를 제외하고 KL-Loss만 사용하는 방식으로 최적화 진행
- KL Divergence : Teacher와 Student 출력 분포 간의 차이를 측정
- S: Softmax 함수
- hT , hS : Teacher와 Student의 penultimate layer(출력 직전 layer)에서의 출력
- ρ: Temperature로, 분포를 부드럽게 만들어 학습이 더 안정적이게 만듦
- Teacher 모델 weight는 freeze되어 학습 과정에서 변경되지 않음
- Teacher는 forward propagation만 수행하며, Student가 학습할 "참조 분포"를 제공
- Student forward pass 시 weight는 quantization되어 사용
- 하지만 내부적으로 full precision 값 유지
- 역전파 과정에서는 quantization된 값이 아닌, full precision 값의 gradient를 계산해 업데이트
- 학습이 완료되면, Student는 quantizatio된 상태에서 최적의 성능을 유지하도록 설계된 가중치 보유
4. Evaluation
4.1 Improvements on SOTA QAT Methods
CIFAR-10과 CIFAR-100에서의 결과
- RestNet과 VGG 모델을 1.2.4 bit로 quantization하여 테스트
- SQAKD (EWGS)는 기존 EWGS보다 모든 bit quantization 상황에서 정확도 크게 향상
Tiny-ImageNet에서의 결과
- PACT, LSQ, DoReFA 같은 QAT 방법과 비교
- bit가 낮을수록 정보 손실이 더 크지만, SQAKD는 KD를 활용하여 이를 완화
경량화 모델 (MobileNet, ShuffleNet 등) quantization 결과
- 이미 가벼운 모델에서도 SQAKD가 성능 향상을 제공
- MobileNet-V2에서는 8 bit quantization 시 full-precision 모델보다 높은 정확도를 기록
4.2 Comparison with SOTA KD Methods
- 기존 KD방법은 EWGS 상황에서 ground-truth label의 supervision이 없으면 수렴하지 못함
- 제안된 unsupervised 방식의 SQAKD는 기존 supervised KD 방법보다 CIFA-10에서 0.36%~3.4%, CIFAR-100에서 0.09%~17.09% 더 높은 성능을 보여 줌
- FSP는 SQAKD 다음으로 높은 정확도를 보인 방법으로, SQAKD와의 비교를 추가로 수행
- SQAKD는 EWGS와 FSP보다 수렴 속도가 훨씬 빠름
- KL-Loss는 SQAKD가 더 빠르게 줄어들고 최종 값도 더 낮음
- 이는 SQAKD가 CE-Loss와 KL-Loss 모두 효과적으로 최소화하여 더 빠르고 정확한 distillation을 가능하게 한다는 것을 보여줌
4.3 Comparison with SOTA Methods Applying both QAT and KD
ResNet (2-bit)
- 정확도 하락 폭이 가장 작았으며, 다른 방법보다 0.04%~3.06% 낮은 하락률 기록
AlexNet (CIFAR-100)
- SQAKD와 PTG만 full precision 모델보다 더 높은 정확도를 기록했으나, SQAKD의 향상폭이 PTG보다 더 큼
4.4 Infernce Speedup
- SQAKD는 모델 bit 너비를 줄여 복잡도를 낮춤으로써, 실제 추론 속도도 빨라짐
- Tiny-ImageNet에서 ResNet-18, MobileNet-V2, ShuffleNet-V2, SqueezeNet 같은 다양한 모델 아키텍처로 8 bit quantization 시 3배의 속도 향상 달성
5. Ablation Study
Analysis of Loss Surface
- SQAKD와 독립형 EWGS를 사용하여 CIFAR-10 데이터셋에서 훈련된 2 bit ResNet-20과 full precision ResNet-20의 3D 손실 곡면과 2D 등고선을 비교
- SQAKD는 quantization된 모델이 더 평평하고 부드러운 손실 곡면을 가지도록 함
- 이는 훈련 과정에서 안정성을 높이고 최적화를 쉽게 만듦
Flexibility for various forward and backward combinations
- SQAKD는 SOTA QAT(PACT, EWGS)의 forward 및 backward 기법을 모듈식으로 통합
- Tiny-ImageNet에서 4 bit ShuffleNet-V2
- STE backward를 사용해 PACT를 14.02% 개선
- EWGS backward를 사용해 추가로 0.77% 더 개선
- CIFAR-10에서 2 bit ResNet-20
- STE backward로 EWGS를 0.29% 개선, EWGS backward로 0.39% 개선
Effect of Temperature
- Knowledge Distillation에서 temperature (ρ)는 모델이 teacher 모델의 'dark knowledge'를 더 잘 학습하도록 분포를 부드럽게 만듦
- VGG-13 (CIFAR-100)과 ResNet-20 (CIRAR-10)에서 temperature 범위 (ρ ∈ [1, 10])를 실험한 결과, ρ = 4에서 가장 좋은 성능을 보임
Effect of Initialization
- Tiny-ImageNet에서 4 bit VGG-11
- SQAKD는 초기화 방식에 따라 PACT, LSQ, DoReFa 정확도를 0.05%~18.97%까지 향상
- 모든 경우에서 full precision teacher 초기화가 랜덤 초기화보다 우수
6. Conclusion
- SQAKD의 특징과 장점
- 비감독 학습(Self-Supervised): 라벨 데이터가 필요 없음
- QAT 연구 접근성 향상: 간단한 학습 절차, 낮은 학습 비용, hyperparameter tuning 필요 없음
- 통합 프레임워크 제공: QAT 방법들의 forward 및 backward 동작을 통합하여 최적화
- SQAKD의 성과
- 기존 KD 및 QAT 방법을 크게 능가
- QAT를 위한 새로운 기준을 세우며, 모델 양자화 연구의 새로운 방향을 제시