본문 바로가기
논문 리뷰/Quantization

[Paper Review] OverComing Oscillations in Quantization-Aware Training

by hyeon827 2024. 11. 28.

 

논문 원본 : https://arxiv.org/abs/2203.11086

 

Overcoming Oscillations in Quantization-Aware Training

When training neural networks with simulated quantization, we observe that quantized weights can, rather unexpectedly, oscillate between two grid-points. The importance of this effect and its impact on quantization-aware training (QAT) are not well-underst

arxiv.org

 

 

Abstract

  • 딥러닝 모델을 simulated quantization으로 학습할 때, quantization된 weight가 두 개의 graid-point 사이에서 예상치 못하게 oscillation (진동)하는 현상을 관찰
  • 본 논문에서는 이러한 weight oscillation 현상을 깊이 탐구하며, 해당 문제는 모델이 잘못된 batch normalization 통계치를 사용하거나, 학습 과정에서 불필요한 noise가 추가되게 만들어 정확도 저하가 발생할 수 있음을 보임
  • 특히 4비트 이하의 낮은 비트 양자화에서 이 문제가 두드러지며, MobileNet과 EfficientNets 같은 효율적인 네트워크에서 더 큰 영향을 미침
  • 여러 기존 QAT 알고리즘이 이 문제를 해결하려 했지만 대부분 실패함
  • 본 논몬에서는 oscillation dampening(진동 완화)과 iterative weight freezing(반복적 가중치 고정)이라는 새로운 QAT(Quantization-Aware Training) 알고리즘 제안
  • 제안된 방법을 MobileNetV2, MobileNetV3, EfficientNet-lite에 적용했을 때, 3비트와 4비트 양자화에서도 정확도를 크게 향상시킴

 

Introduction

  • Quantization은 신경망 모델의 가중치를 작은 숫자 범위로 압축해서 빠르고 효율적으로 작동하게 만드는 방법
  • 예를 들어 32비트 수를 8비트로 줄이는 것
  • 이렇게 하면 모델 크기가 줄어들고, 엣지 디바이스에 신경망을 배포할 때 전력 소비가 줄며, 추론 속도가 빨라짐
  • Quantization은 전력 절약과 속도 향상에 있어 명확한 이점에도 불구하고, 정밀도가 줄어들기 때문에 추가적인 noise를 발생시킴 
  • 그러나 최근 연구에 따르면, 신경망은 해당 noise에 강하며 PTQ (Post-Training Quantization) 기법을 사용하여  8비트 Quantization으로 정확도의 하락을 최소화할 수 있다는 사실이 밝혀짐
  • 하지만 PTQ는 4비트 이하 Quantization에서 성능이 떨어짐
  • 한편, QAT (Quantization-Aware Training) 는 저비트 Quantization을 달성하면서도 거의 전체 정밀도의 정확도를 유지할 수 있는 기법
  • 학습 또는 fine-tuning 중에 Quantization 작업을 시뮬레이션하여 신경망이 Quantization된 noise에 적응하도록 하여, PTQ보다 더 나은 결과를 얻을 수 있음 
  • 본 논문에서는 QAT (Quantization-Aware Training) 중 발생하는 weight oscillation 현상에 집중
  • QAT에 사용되는 STE (Straight-through estimator)라는 기법이 weight를 무작위로 oscillation하게 만들어 noise 발생
  • 이로 인해 batch normalization에서 잘못된 통계를 계산하게 만들어 검증 정확도가 떨어지게 됨
  • 본 논문에서는 oscillation dampening과 iterative weight freezing이라는 새로운 기법 제안

 

2. Oscillations in QAT

 

   2.1 Quantization-aware training

  • 신경망을 양자화하는 가장 효과적인 방법 중 하나는 Simulated Quantization를 활용해 네트워크를 훈련하는 것
  • forward pass 동안, floating-point weight와 활성화 값이 quantization function q()를 통해 Quantization됨

  • ŵ : Quantization된 weight
  • w : 입력된 부동 소수점 weight
  • s : 스케일링 계수
  • n,p : Quantization 하한 및 상한값
  • clip : 값이 n과 p 사이에 유지되도록 제한하는 함수
  • Quantization된 weight ŵ는 추론에 사용되고, 부동 소수점 w는 latent weights 또는 shadow weights로 불리며 최적화에 사용
  • Quantization 과정의 반올림 함수는 미분이 불가능하다는 문제점 존재
  • 따라서 이 문제를 해결하기 위해 STE 기법 사용
  • Quantization 범위 안에 있으면 gradient를 1, 아니면 0
  • 양자화 가능한 영역에서만 학습이 이루어지고, 계산이 간소화

 

 

   2.2 Oscillations

  • STE를 사용하면 latent weight가 Quantization state의 경계 근처에서 weight가 진동하는 현상 발생 (stochastic oscillation) 
  • Quantization state q(w*)의 경계 위에서는 gradient가 weight를 아래로 밀어냄
  • 경계 아래에서는 gradient가 weight를 위로 밀어냄
  • 이는 learning rate에 관계없이 발생하며, learning rate를 줄이면 oscillation 폭은 줄지만 빈도는 변하지 않음
  • 이 현상은 weight가 Quantization state에서 얼마나 가까운지 직접적으로 연관
  • 예를 들어 Quantization state에서 멀수록 oscillation의 빈도는 낮아짐

 

 

   2.3 Oscillations in practice

  • 실제로 대규모 신경망에서도 이러한 Oscillations이 나타나며, 이는 최적화에 중대한 영항을 미침
  • 아래 그림은 MobileNetV2의 첫 번째 Depth-wise separable layer에서  3비트 quantization이 적용된 weight가 학습 후반부 1000번의 iteration동안 어떻게 변화하는지를 보여줌

 

 

  • 아래 그림은 네트워크가 수렴한 후에도, 많은 latent weight가 grid points 사이의 decision boundary에 위치하고 있음을 확인
  • 이는 상당수의 weight가 수렴하지 않고 oscillation 한다는 관찰을 더욱 뒷받침
  • 이러한 진동은 Batch Normalization 통계의 잘못된 추정, 네트워크 최적화에 부정적 영향을 미치는 문제를 야기

 

 

   2.3.1 The EFFECT ON BATCH-NORMALIZATION

  • 훈련 중 batch normalization은 각 layer의 출력에 대해 EMA(지수 이동 평균)을 사용하여 평균과 분산을 추적
  • 이는 추론 시 실제 데이터 분포를 근사하기 위해 사용
  • QAT에서는 weight가 진동
  • weight oscillation으로 인해 inter weight가 iterations 중 빠르게 변화
  • 이러한 빠르고 큰 변화는 layer output distirbution에 중대한 분포 변화를 유발
  • 결과적으로, EMA 통계가 훼손되며, 정확도에 상당한 저하를 초래
  • weight 비트 폭이 작을수록 quantization level 간 간격이 커지므로 weight가 한 level에서 다른 level로 이동할 때 변화 폭이 커짐
  • 출력 채널당 weght 수가 적을수록 한 weight의 oscillation이 EMA에 크게 반영
  • MobileNetV2는 Depth-wise separable convolution 구조 때문에 채널당 weight가 적음
  • 그래서 weight oscillation의 영향을 크게 받으므로, 실제 분포와 EMA 통계 간 차이가 큼 (KL 발산 값이 큼)
  • ResNet은 Full Convolution 구조로 채널당 weight가 많아 oscillation의 영향이 평균적으로 상쇄
  • KL 발산 값이 작아 EMA 통계가 더 잘 유지

 

  • BN Re-estimation을 통한 완화
  • 훈련이 끝난 후, 소규모 데이터셋으로 다시 batch normalize 통계를 계산해 EMA 값을 교체
  • 이를 통해 MobileV2의 정확도 개선
  • 비트 폭이 작을수록 효과가 더 큼

 

 

   2.3.2 The EFFECT ON TRAINING

  • weight oscillation은 BN 통계에 악영향을 미칠 뿐만 아니라 훈련 과정 자체에도 부정적인 영향을 줄 수 있음
  • MobileNetV2(3비트 quantization)를 대상으로 실험
  • oscillation하는 weight의 두 상태(예: w↑ ) 중 각 상태에서 얼마나 오래 머물렀는지를 계산
  • 이 확률에 따라 무작위로 oscillation하는 weight 값을 sampling하여 새로운 네트워크를 만듦
  • 무작위로 sampling한 네트워크들의 평균 훈련 손실은 기존 네트워크와 비슷
  • 가장 잘 sampling된 네트워크는 기존 네트워크보다 훨씬 낮은 훈련 손실 기록
  • 이는 oscillation weight가 최적화에 방해가 되고 있다는 것을 보여줌
  • AdaRound를 사용한 binary optimization
  • oscillation하는 가중치들을 찾아 어떤 값으로 반올림할지 결정
  • 어떤 반올림 조합이 네트워크 손실을 가장 낮게 만드는가를 기준으로 최적화
  • 최적의 반올림 조합을 찾기 위해 다양한 조합을 시도하며 손실 값을 점차 줄여나감
  • oscillation freezing technique
  • 훈련 초기에 weight가 oscillation하지 못하도록 억제하는 기법
  • AdaRound로 최적화한 네트워크보다 더 높은 검증 정확도 기록
  • oscillation은 훈련 후반뿐 아니라 초반에도 큰 문제

 

 

4. Overcoming oscillations in QAT

   

   4.1 Qunatifying Oscillations

  • oscillations 문제를 해결하려면, 학습 중에 oscillations를 감지하고 측정할 수 있는 방법이 필요

   oscillations이 발생하는 조건

  • quantization된 정수 값이 변경되어야 함

  • 변화 방향이 이전과 반대여야 함

   

 

   oscillations의 빈도를 계산하는 방법

  • EMA를 사용해 계산

 

 

  • : 현재 시점의 oscillations 빈도 (새로 계산된 값).
  • ot: 현재 시점의 진동 여부 (oscillations이 발생하면 1, 발생하지 않으면 0).
  • : 이전까지 계산된 oscillations 빈도.
  • m: 현재 값에 얼마나 가중치를 줄지 정하는 비율 (0~1)

 

   4.2 Oscillation Dampening

  • oscillation하는 weight는 항상 quantization bin 경계 근처에서 움직임
  • 이를 완화하기 위해 regularization를 추가하여, weight가 bin의 중심에 더 가깝게 유지되도록 유도

 

  • : bin 중심
  • clip(w,sn,sp) : quantization weight의 범위
  • s,n,p: quantization parameter (scale, lower, upper)

 

   최종 학습 목적 함수

 

  • Ltask: 기존의 학습 손실
  • : 규제 항의 가중치

 

  • 이 규제는 quantization하는 weight 뿐만 아니라 움직이는 않는 weight에도 영향을 줄 수 있다는 단점 존재

 

   4.3 Iterative freezing of oscillating weights

  • oscillations을 방지하는 또 다른 방법으로, oscillations 빈도가 임계값 fth를 초과하면 가중치를 freeze 

   freeze 방법

  1. 각 weight의 oscillations 빈도를 추적
  2. 빈도가 를 초과하면 해당 가중치를 학습에서 제외
    • freeze 시 이전 값(정수 상태)을 유지
    • 정수 값 기록을 위한 EMA 사용
    • 가장 자주 등장한 상태로 가중치를 동결
  • 이 방법은 학습 도중 작은 값을 0으로 동결하는 pruning과 유사

 

 

5. Experiments

 

   5.2 Ablation studies

 

   Oscillation dampening

  • λ를 증가시킬수록 진동하는 weight 비율이 감소
  • 하지만 λ가 너무 크면 최종 정확도가 떨어짐 (유익한 변화까지 억제하기 때문)
  • 학습 초반에는 규제를 약하게, 학습 후반에는 점진적으로 강하게 만드는 cosine annealing 스케줄 적용
  • 이 방식은 초기에 weight가 자유롭게 이동하도록 하여 최적화를 촉진하고, 학습 후반에는 oscillations를 최소화
  • 최적의  λ 스케줄을 사용했을 때 정확도가 post-BN Baseline보다 1% 개선
  • pre-BN Baseline보다 5% 이상 개선

 

 

   Iterative weight freezing

  • 임계값이 작아질수록 oscillations 문제가 많이 해결되고, 네트워크에 남아있는 oscillations의 빈도가 줄어듦
  • BN 재평가 전후의 정확도 차이가 줄어듦 -> oscillations 문제 해결의 간접적 증거
  • 하지만 너무 낮은 임계값은 학습 초기 단계에서 과도한 weight를 freeze -> 최종 정확도 감소
  • 임계값을 점진적으로 낮추는 cosine annealing 적용
  • 학습 초반에는 freeze를 최소화하고, 학습 후반에는 oscillations를 강하게 억제하여 최종 정확도 개선
  • 최적의 freeze 임계값 스케줄을 사용했을 때 정확도가 post-BN Baseline보다 1% 개선
  • pre-BN Baseline보다 5% 이상 개선

 

 

   5.3 Comparison to other QAT methods

 

   MobileNetV2 성능 비교

  • weight와 활성화 값 모두 quantization하여 기존의 QAT 기법과 비교
  • 3비트, 4비트 양자화 모두에서 제안된 알고리즘이 기존 QAT 기법보다 뛰어난 성능을 보임
  • MobileNetV3-Small 와 EfficientNet-Lite에서도 제안된 방법이 최신 성능(SOTA)을 달성
  • 기존 LSQ 방법에 비해 정확도가 1% 이상 개선

 

 

6. Conclusion

  • 본 연구에서는 QAT에서 사용되는 STE 같은 변형 기법들이 weight oscillation을 유발하여 quantization된 네트워크의 성능을 저하시킬 수 있음을 보여줌 (특히 low-bit quantization을 적용한 light-weight networks에서 더욱 심각)
  • 해당 문제는 Batch Normalization 통계를 손상시키고 최적화 과정을 방해
  • 이에 따라 본 논문은 Oscillation Dampening과 Iterative Weight Freezing을 제안 
  • 제안된 두 기법은 ImageNet classification에서 low bit로 quantization된 효율적인 모델에서 SOTA 달성