본문 바로가기
논문 리뷰/Quantization

[Paper Review] Self-Supervised Quantization-Aware Knowledge Distillation

by hyeon827 2024. 12. 1.

 

논문 원본 : https://arxiv.org/abs/2403.11106

 

Self-Supervised Quantization-Aware Knowledge Distillation

Quantization-aware training (QAT) and Knowledge Distillation (KD) are combined to achieve competitive performance in creating low-bit deep learning models. However, existing works applying KD to QAT require tedious hyper-parameter tuning to balance the wei

arxiv.org

 

Abstract

  • Quantization-aware training (QAT)와 Knowledge Distillation (KD)를 결합하여 low-bit 딥러닝 모델을 효과적으로 만드는 연구가 이루어지고 있음
  • 하지만 기존의 QAT에 KD를 적용한 연구는 다음과 같은 한계 존재
    • 다양한 손실 항목의 weight를 조절하기 위한 hyperparameter tuning이 번거로움
    • label이 있는 학습 데이터를 사용할 수 있어야 한다는 가정
    • 좋은 성능을 얻기 위해 복잡하고 계산 비용이 많이 드는 학습 절차가 필요
  • 이 문제를 해결하기 위해 본 논문에서는 Self-Supervised Quantization-Aware Knowledge Distillation(SQAKD) 프레임워크 제안
  • SQAKD는 다양한 quantization 함수의 forward 및 backward 계산 과정을 통합하여, 다양한 QAT task를 유연하게 통합할 수 있음
  • 또한, QAT를 공동 최적화 문제로 정의하여, label 없이도 KD를 위해 full-precision 모델과 low bit 모델 간의 KL-Loss를 최소화하고, quantization의 이산화 오류 최소화

 

1. Introduction

  • Deep neural networks (DNNs)는 많은 계산과 메모리를 필요
  • IoT 및 다양한 디바이스에서 딥러닝의 사용이 급증하면서, 자원이 많이 필요한 DNN과 자원이 제한된 디바이스 간의 불균형 문제가 심화되고 있음
  • Quantization는 이러한 문제를 해결하기 위한 주요 모델 압축 기법 중 하나로, 모델의 weight나 activation 값을 낮은 bit precision으로 변환하여 자원을 절약
  • 특히 QAT는 사전 학습된 모델을 기반으로 재학습하며 quantization을 수행하는 방식으로 주목 받고 있음
  • 하지만 기존의 QAT 방법은 아래와 같은 한계가 있음
    • quantization으로 인해 정확도 손실이 큼
    • 다양한 모델 아키텍처 (VGG, ResNet, MobileNet 등)에서 일관된 성능을 내지 못함
    • 모델 간 직관이나 이론이 다르고 일반화가 어려움
    • 1~3 bit와 같은 low bit 네트워크에서는 성능이 좋지 않음
  • KD (Knowledge Distillation)을 QAT에 적용하는 최근 연구들은 고정밀 모델(Teacher)의 Knowledge를 저정밀 모델(Student)에 전달하여 정확도 손실을 줄이려고 함
  • 하지만 이러한 방법들에도 한계가 존재
    • 여러 손실 항목의 weight를 조정하기 위한 hyperparameter tuning이 번거로움
    • label이 있는 데이터가 필요하다는 가정이 있음
    • 높은 성능을 위해 복잡하고 계산 집약적인 학습 절차가 필요
    • 특정 KD 접근법이나 quantization 기법에만 국화되어 일반적으로 일관된 성능을 내지 못함
  • 본 연구에서는 SQAKD (Self-Supervised Quantization-Aware Knowledge Distillation)이라는 간단하지만 효과적인 프레임워크를 제안
  • SQAKD의 주요 특징은 아래와 같음
    • quantization 과정 통합 : 다양한 quantization 함수의 forward 및 backward 동작을 통합하여, 여러 QAT 작업을 유연하게 결합
    • label 없는 학습 : label 데이터 없이 KD와 quantization error를 동시에 최소화하는 공동 최적화 문제로 QAT를 정의
    • 간소화된 학습 과정 : 단일 학습 단계만 필요하여 계산 비용을 줄이고 사용성과 재현성을 향상시킴
    • hyperparameter 불필요 : KL 손실만 사용하여 복잡한 손실 weight 조정 불필요
    • 우수한 성능 : 다양한 모델 (VGG, ResNet, MobileNet 등)에서 기존 최첨단 QAT와 KD 연구보다 높은 정확도와 빠른 수렴 속도를 보임

 

2. Background and Related Works

 

   Quantization

  • PTQ (Post-Training Quantization) : 사전 학습된 모델을 quantization하지만 재학습 없이 수행. 이는 정확도 손실이 크다는 단점
  • QAT (Quantization-Aware Training) : quantization를 재학습 과정에 포함하며, 더 나은 성능을 보임

 

   QAT의 발전

  • 초기 연구들은 quantization 과정으 forward 및 backward pass를 설계하는데 초점을 맞춤
    • BNN과 XNOR-Net : 채널별 스케일링 도입
    • DoReFa-Net : 모든 필터에 대해 공통 스케일러 사용
    • 최근 연구들은 quantization을 위한 학습 가능한 매개변수를 활용하여 클리핑 범위 및 quantization 간격을 조정
      • PACT, LSQ, APoT, DSQ : 입력 값을 제한하는 클리핑 범위를 개선
      • QIL, EWGS : quantization 레벨 간 간격 (step size)를 최적화
  • 하지만 QAT는 여전히 아래와 같은 문제 존재
    • 정확도 손실의 수준이 제각각
    • 다양한 직관적인 접근 방식이 동기에 깔려 있지만, 이론적 통합이 부족
    • 특정 알고리즘이 모든 모델 아키텍처에 대해 일관적인 성능을 제공하지 못함. 또한, 대분분의 QAT 알고리즘은 4bit 이상의 high bit 네트워크에 초점이 맞춰져 있으며, low bit 네트워크에서는 성는이 떨어짐

 

   Knowledge Distillation (KD)

  • KD는 큰 모델 (Teacher)의 knowledge를 작은 모델 (Student)로 전달하여 성능을 향상시키는 기술
  • 초기 연구 : Teacher와 Student의 Softmax 출력 간의 KL divergence을 최소화하고 label 기반의 교차 엔트로피 손실을 추가
  • 하지만 KD가 quantization 모델 (Student)에서 제대로 이해되고 사용되는 사례는 제한적

 

   Knowledge Distillation + Quantization

  • KD를 QAT와 결합하여 low bit 네트워크의 정확도 손실을 줄이는 연구들이 이루어짐
    • Apprentice (AP) : 3가지 phases를 통해 4bit 또는 3진법 네트워크 성능 향상
    • QKD : quantization과 KD를 3단계 (자기 학습, 공동 학습, 지도 학습)로 조율
    • SPEQ : Student의 매개변수를 사용해 Teacher를 구성하고 stochastic bit precision을 도입
    • PTG : 4단계 학습 전략을 통해 quantizaion weight와 activation를 최적화하고 점진적으로 bit 수를 줄임
    • CMT-KD : 여러 quantization된 Teacher 간 협업 학습과 Teacher-Student 간 상호 학습을 촉진

 

3.  Methodology

 

   3.1 QAT as Constrained Optimization

  • 본 연구에서는 QAT의 일반화된 이론적 틀을 제안하며, 다양한 quantization 함수의 forward 및 backward 동작을 통하여 최적화 문제로 정의

 

 

   Forward Propagation

  • quantization 함수 Qaunt()는 full precision 입력 x를 quantization된 출력 xq로 변환
  • x는 네트워크의 activation 값 또는 weight일 수 있음

 

  • Clipping 단계
  • x를 제한된 범위로 정규화하여 클리핑된 표현 xc를 만듦
  • v, m : 클링 범위의 하한과 상한
  • {pi} : quantization에 필요한 학습 가능한 파라미터

 

 

  • Quantization 단계
  • 클리핑된 값 xc를 round 함수를 포함한 R()로 discrete인 값 xq로 변환
  • b : bit 너비
  • {qi} : 학습 가능한 파라미터 (필요하지 않을 수도 있음)

 

  • 통합 공식
  • quantization 함수 Quant(⋅)는 아래와 같이 표현
  • 여기서 α는 클리핑과 quantization 함수의 모든 파라미터를 포함

 

 

   Back Propagation

  • quantization 함수 Q(⋅)는 이산적 특성 때문에 미분이 불가능 (round 함수)
  • 이를 해결하기 위해 대부분의 QAT 연구는 STE 사용

 

  • 본 연구는 단순한 STE 대신, quantization 오류를 반영하는 새로운 공식을 제안
  • 여기서 μ는 양자화된 값과 원래 값 간의 차이를 반영하는 비음수 값

 

 

   Optimization Objective

  • QAT를 "quantization으로 인한 오류"와 "모델 예측과 실제 label 간의 차이"를 동시에 최소화하는 문제로 정의
  • Wf / Af : full precision weight / activation 값
  • Wq / Aq : quantization weight / activation 

 

 

 

3.2 Analysis of KD in QAT

 

   Does KD perform well in QAT?

  • KD는 널리 사용되지만, QAT 문제를 해결하는 데 있어 KD의 효과에 대한 철저한 연구는 부족
  • 본 연구에서는 사전 학습된 full precision 네트워크를 "Teacher"로 사용하고, 동일한 구조의 low bit 네트워크를 "Student"로 사용하여 KD를 QAT에 적용
  • 손실 함수 L은 cross entropy loss Lce와 distillation loss LDistill의 선형 결합으로 정의
    • : Student 모델의 예측값과 실제 라벨 간의 차이
    • : 두 손실의 중요도를 조절하는 hyperparameter

 

   Are both the cross-entropy loss and distillation loss necessary?

  • quantization 네트워크는 full precision 네트워크에 비해 표현 능력이 낮기 때문에, KD가 요구하는 다양한 손실 항목을 최적화하기 어려움
  • quantization 과정에서 weight와 activation 값에 추가적인 noise가 발생하여, KD 성능이 저하될 수 있음

 

  • 3가지 실험 시나리오
  • KL-Loss만 최소화 (λ=1)

  • CE-Loss만 최소화 (λ=0)

  • KL-Loss와 CE-Loss를 동시에 최소화  (λ=0.5)

 

  • 결과 분석
  • KL-Loss만 최소화 ()
    • KL-Loss를 줄이는 것만으로 CE-Loss도 효과적으로 감소
    • 이는 Student 모델이 라벨 없이도 Teacher의 분포를 학습하여 정확한 예측을 생성할 수 있음을 의미
  • CE-Loss만 최소화 () 또는 CE와 KL을 함께 최소화 ()
    • CE-Loss가 포함되면 KL-Loss를 충분히 줄이지 못함
    • 이는 CE-Loss와 KL-Loss가 함께 작동할 때 서로 간섭이 생겨 성능이 저하될 수 있음을 나타냄
  • 즉, KL-Loss를 단독으로 사용하는 것이 QAT에서 가장 효과적
  • KL-Loss만 사용할 경우 hyperparameter λ가 필요 없어 손실 함수가 더 단순해짐

 

 

3.3 Optimization via Self-Supervised KD

  • CE-Loss를 제외하고 KL-Loss만 사용하는 방식으로 최적화 진행
  • KL Divergence : Teacher와 Student 출력 분포 간의 차이를 측정
  • S: Softmax 함수
  • hT , hS : Teacher와 Student의 penultimate layer(출력 직전 layer)에서의 출력 
  • ρ: Temperature로, 분포를 부드럽게 만들어 학습이 더 안정적이게 만듦

 

  • Teacher 모델 weight는 freeze되어 학습 과정에서 변경되지 않음
  • Teacher는 forward propagation만 수행하며, Student가 학습할 "참조 분포"를 제공
  • Student forward pass 시 weight는 quantization되어 사용
  • 하지만 내부적으로 full precision 값 유지
  • 역전파 과정에서는 quantization된 값이 아닌, full precision 값의 gradient를 계산해 업데이트
  • 학습이 완료되면, Student는 quantizatio된 상태에서 최적의 성능을 유지하도록 설계된 가중치 보유

 

4. Evaluation

 

   4.1 Improvements on SOTA QAT Methods

 

   CIFAR-10과 CIFAR-100에서의 결과

  • RestNet과 VGG 모델을 1.2.4 bit로 quantization하여 테스트
  • SQAKD (EWGS)는 기존 EWGS보다 모든 bit quantization 상황에서 정확도 크게 향상

 

   Tiny-ImageNet에서의 결과

  • PACT, LSQ, DoReFA 같은 QAT 방법과 비교
  • bit가 낮을수록 정보 손실이 더 크지만, SQAKD는 KD를 활용하여 이를 완화

 

   경량화 모델 (MobileNet, ShuffleNet 등) quantization 결과

  • 이미 가벼운 모델에서도 SQAKD가 성능 향상을 제공
  • MobileNet-V2에서는  8 bit quantization 시 full-precision 모델보다 높은 정확도를 기록

 

 

   4.2 Comparison with SOTA KD Methods

  • 기존 KD방법은 EWGS 상황에서 ground-truth label의 supervision이 없으면 수렴하지 못함
  • 제안된 unsupervised 방식의 SQAKD는 기존 supervised KD 방법보다 CIFA-10에서 0.36%~3.4%, CIFAR-100에서 0.09%~17.09% 더 높은 성능을 보여 줌
  • FSP는 SQAKD 다음으로 높은 정확도를 보인 방법으로, SQAKD와의 비교를 추가로 수행
  • SQAKD는 EWGS와 FSP보다 수렴 속도가 훨씬 빠름
  • KL-Loss는 SQAKD가 더 빠르게 줄어들고 최종 값도 더 낮음
  • 이는 SQAKD가 CE-Loss와 KL-Loss 모두 효과적으로 최소화하여 더 빠르고 정확한 distillation을 가능하게 한다는 것을 보여줌

 

 

   4.3 Comparison with SOTA Methods Applying both QAT and KD

 

   ResNet (2-bit)

  • 정확도 하락 폭이 가장 작았으며, 다른 방법보다 0.04%~3.06% 낮은 하락률 기록

   AlexNet (CIFAR-100)

  • SQAKD와 PTG만 full precision 모델보다 더 높은 정확도를 기록했으나, SQAKD의 향상폭이 PTG보다 더 큼

 

   4.4 Infernce Speedup

  • SQAKD는 모델 bit 너비를 줄여 복잡도를 낮춤으로써, 실제 추론 속도도 빨라짐
  • Tiny-ImageNet에서 ResNet-18, MobileNet-V2, ShuffleNet-V2, SqueezeNet 같은 다양한 모델 아키텍처로 8 bit quantization 시 3배의 속도 향상 달성

 

5. Ablation Study

 

   Analysis of Loss Surface

  • SQAKD와 독립형 EWGS를 사용하여 CIFAR-10 데이터셋에서 훈련된 2 bit ResNet-20과 full precision ResNet-20의 3D 손실 곡면과 2D 등고선을 비교
  • SQAKD는 quantization된 모델이 더 평평하고 부드러운 손실 곡면을 가지도록 함
  • 이는 훈련 과정에서 안정성을 높이고 최적화를 쉽게 만듦

 

   Flexibility for various forward and backward combinations

  • SQAKD는 SOTA QAT(PACT, EWGS)의 forward 및 backward 기법을 모듈식으로 통합
  • Tiny-ImageNet에서 4 bit ShuffleNet-V2
    • STE backward를 사용해 PACT를 14.02% 개선
    • EWGS backward를 사용해 추가로 0.77% 더 개선
  • CIFAR-10에서 2 bit ResNet-20
    • STE backward로 EWGS를 0.29% 개선, EWGS backward로 0.39% 개선

 

   Effect of Temperature

  • Knowledge Distillation에서 temperature (ρ)는 모델이 teacher 모델의 'dark knowledge'를 더 잘 학습하도록 분포를 부드럽게 만듦
  • VGG-13 (CIFAR-100)과 ResNet-20 (CIRAR-10)에서 temperature 범위 (ρ ∈ [1, 10])를 실험한 결과, ρ = 4에서 가장 좋은 성능을 보임

 

   Effect of Initialization

  • Tiny-ImageNet에서 4 bit VGG-11
    • SQAKD는 초기화 방식에 따라 PACT, LSQ, DoReFa 정확도를 0.05%~18.97%까지 향상
    • 모든 경우에서 full precision teacher 초기화가 랜덤 초기화보다 우수

 

 

6. Conclusion

 

  • SQAKD의 특징과 장점
    • 비감독 학습(Self-Supervised): 라벨 데이터가 필요 없음
    • QAT 연구 접근성 향상: 간단한 학습 절차, 낮은 학습 비용, hyperparameter tuning 필요 없음
    • 통합 프레임워크 제공: QAT 방법들의 forward 및 backward 동작을 통합하여 최적화
  • SQAKD의 성과
    • 기존 KD 및 QAT 방법을 크게 능가
    • QAT를 위한 새로운 기준을 세우며, 모델 양자화 연구의 새로운 방향을 제시