논문 원본 : https://arxiv.org/abs/2203.11086
Overcoming Oscillations in Quantization-Aware Training
When training neural networks with simulated quantization, we observe that quantized weights can, rather unexpectedly, oscillate between two grid-points. The importance of this effect and its impact on quantization-aware training (QAT) are not well-underst
arxiv.org
Abstract
- 딥러닝 모델을 simulated quantization으로 학습할 때, quantization된 weight가 두 개의 graid-point 사이에서 예상치 못하게 oscillation (진동)하는 현상을 관찰
- 본 논문에서는 이러한 weight oscillation 현상을 깊이 탐구하며, 해당 문제는 모델이 잘못된 batch normalization 통계치를 사용하거나, 학습 과정에서 불필요한 noise가 추가되게 만들어 정확도 저하가 발생할 수 있음을 보임
- 특히 4비트 이하의 낮은 비트 양자화에서 이 문제가 두드러지며, MobileNet과 EfficientNets 같은 효율적인 네트워크에서 더 큰 영향을 미침
- 여러 기존 QAT 알고리즘이 이 문제를 해결하려 했지만 대부분 실패함
- 본 논몬에서는 oscillation dampening(진동 완화)과 iterative weight freezing(반복적 가중치 고정)이라는 새로운 QAT(Quantization-Aware Training) 알고리즘 제안
- 제안된 방법을 MobileNetV2, MobileNetV3, EfficientNet-lite에 적용했을 때, 3비트와 4비트 양자화에서도 정확도를 크게 향상시킴
Introduction
- Quantization은 신경망 모델의 가중치를 작은 숫자 범위로 압축해서 빠르고 효율적으로 작동하게 만드는 방법
- 예를 들어 32비트 수를 8비트로 줄이는 것
- 이렇게 하면 모델 크기가 줄어들고, 엣지 디바이스에 신경망을 배포할 때 전력 소비가 줄며, 추론 속도가 빨라짐
- Quantization은 전력 절약과 속도 향상에 있어 명확한 이점에도 불구하고, 정밀도가 줄어들기 때문에 추가적인 noise를 발생시킴
- 그러나 최근 연구에 따르면, 신경망은 해당 noise에 강하며 PTQ (Post-Training Quantization) 기법을 사용하여 8비트 Quantization으로 정확도의 하락을 최소화할 수 있다는 사실이 밝혀짐
- 하지만 PTQ는 4비트 이하 Quantization에서 성능이 떨어짐
- 한편, QAT (Quantization-Aware Training) 는 저비트 Quantization을 달성하면서도 거의 전체 정밀도의 정확도를 유지할 수 있는 기법
- 학습 또는 fine-tuning 중에 Quantization 작업을 시뮬레이션하여 신경망이 Quantization된 noise에 적응하도록 하여, PTQ보다 더 나은 결과를 얻을 수 있음
- 본 논문에서는 QAT (Quantization-Aware Training) 중 발생하는 weight oscillation 현상에 집중
- QAT에 사용되는 STE (Straight-through estimator)라는 기법이 weight를 무작위로 oscillation하게 만들어 noise 발생
- 이로 인해 batch normalization에서 잘못된 통계를 계산하게 만들어 검증 정확도가 떨어지게 됨
- 본 논문에서는 oscillation dampening과 iterative weight freezing이라는 새로운 기법 제안
2. Oscillations in QAT
2.1 Quantization-aware training
- 신경망을 양자화하는 가장 효과적인 방법 중 하나는 Simulated Quantization를 활용해 네트워크를 훈련하는 것
- forward pass 동안, floating-point weight와 활성화 값이 quantization function q(⋅)를 통해 Quantization됨
- ŵ : Quantization된 weight
- w : 입력된 부동 소수점 weight
- s : 스케일링 계수
- n,p : Quantization 하한 및 상한값
- clip : 값이 n과 p 사이에 유지되도록 제한하는 함수
- Quantization된 weight ŵ는 추론에 사용되고, 부동 소수점 w는 latent weights 또는 shadow weights로 불리며 최적화에 사용
- Quantization 과정의 반올림 함수는 미분이 불가능하다는 문제점 존재
- 따라서 이 문제를 해결하기 위해 STE 기법 사용
- Quantization 범위 안에 있으면 gradient를 1, 아니면 0
- 양자화 가능한 영역에서만 학습이 이루어지고, 계산이 간소화
2.2 Oscillations
- STE를 사용하면 latent weight가 Quantization state의 경계 근처에서 weight가 진동하는 현상 발생 (stochastic oscillation)
- Quantization state q(w*)의 경계 위에서는 gradient가 weight를 아래로 밀어냄
- 경계 아래에서는 gradient가 weight를 위로 밀어냄
- 이는 learning rate에 관계없이 발생하며, learning rate를 줄이면 oscillation 폭은 줄지만 빈도는 변하지 않음
- 이 현상은 weight가 Quantization state에서 얼마나 가까운지 직접적으로 연관
- 예를 들어 Quantization state에서 멀수록 oscillation의 빈도는 낮아짐
2.3 Oscillations in practice
- 실제로 대규모 신경망에서도 이러한 Oscillations이 나타나며, 이는 최적화에 중대한 영항을 미침
- 아래 그림은 MobileNetV2의 첫 번째 Depth-wise separable layer에서 3비트 quantization이 적용된 weight가 학습 후반부 1000번의 iteration동안 어떻게 변화하는지를 보여줌
- 아래 그림은 네트워크가 수렴한 후에도, 많은 latent weight가 grid points 사이의 decision boundary에 위치하고 있음을 확인
- 이는 상당수의 weight가 수렴하지 않고 oscillation 한다는 관찰을 더욱 뒷받침
- 이러한 진동은 Batch Normalization 통계의 잘못된 추정, 네트워크 최적화에 부정적 영향을 미치는 문제를 야기
2.3.1 The EFFECT ON BATCH-NORMALIZATION
- 훈련 중 batch normalization은 각 layer의 출력에 대해 EMA(지수 이동 평균)을 사용하여 평균과 분산을 추적
- 이는 추론 시 실제 데이터 분포를 근사하기 위해 사용
- QAT에서는 weight가 진동
- weight oscillation으로 인해 inter weight가 iterations 중 빠르게 변화
- 이러한 빠르고 큰 변화는 layer output distirbution에 중대한 분포 변화를 유발
- 결과적으로, EMA 통계가 훼손되며, 정확도에 상당한 저하를 초래
- weight 비트 폭이 작을수록 quantization level 간 간격이 커지므로 weight가 한 level에서 다른 level로 이동할 때 변화 폭이 커짐
- 출력 채널당 weght 수가 적을수록 한 weight의 oscillation이 EMA에 크게 반영
- MobileNetV2는 Depth-wise separable convolution 구조 때문에 채널당 weight가 적음
- 그래서 weight oscillation의 영향을 크게 받으므로, 실제 분포와 EMA 통계 간 차이가 큼 (KL 발산 값이 큼)
- ResNet은 Full Convolution 구조로 채널당 weight가 많아 oscillation의 영향이 평균적으로 상쇄
- KL 발산 값이 작아 EMA 통계가 더 잘 유지
- BN Re-estimation을 통한 완화
- 훈련이 끝난 후, 소규모 데이터셋으로 다시 batch normalize 통계를 계산해 EMA 값을 교체
- 이를 통해 MobileV2의 정확도 개선
- 비트 폭이 작을수록 효과가 더 큼
2.3.2 The EFFECT ON TRAINING
- weight oscillation은 BN 통계에 악영향을 미칠 뿐만 아니라 훈련 과정 자체에도 부정적인 영향을 줄 수 있음
- MobileNetV2(3비트 quantization)를 대상으로 실험
- oscillation하는 weight의 두 상태(예: w↑와 ) 중 각 상태에서 얼마나 오래 머물렀는지를 계산
- 이 확률에 따라 무작위로 oscillation하는 weight 값을 sampling하여 새로운 네트워크를 만듦
- 무작위로 sampling한 네트워크들의 평균 훈련 손실은 기존 네트워크와 비슷
- 가장 잘 sampling된 네트워크는 기존 네트워크보다 훨씬 낮은 훈련 손실 기록
- 이는 oscillation weight가 최적화에 방해가 되고 있다는 것을 보여줌
- AdaRound를 사용한 binary optimization
- oscillation하는 가중치들을 찾아 어떤 값으로 반올림할지 결정
- 어떤 반올림 조합이 네트워크 손실을 가장 낮게 만드는가를 기준으로 최적화
- 최적의 반올림 조합을 찾기 위해 다양한 조합을 시도하며 손실 값을 점차 줄여나감
- oscillation freezing technique
- 훈련 초기에 weight가 oscillation하지 못하도록 억제하는 기법
- AdaRound로 최적화한 네트워크보다 더 높은 검증 정확도 기록
- oscillation은 훈련 후반뿐 아니라 초반에도 큰 문제
4. Overcoming oscillations in QAT
4.1 Qunatifying Oscillations
- oscillations 문제를 해결하려면, 학습 중에 oscillations를 감지하고 측정할 수 있는 방법이 필요
oscillations이 발생하는 조건
- quantization된 정수 값이 변경되어야 함
- 변화 방향이 이전과 반대여야 함
oscillations의 빈도를 계산하는 방법
- EMA를 사용해 계산
- : 현재 시점의 oscillations 빈도 (새로 계산된 값).
- ot: 현재 시점의 진동 여부 (oscillations이 발생하면 1, 발생하지 않으면 0).
- : 이전까지 계산된 oscillations 빈도.
- m: 현재 값에 얼마나 가중치를 줄지 정하는 비율 (0~1)
4.2 Oscillation Dampening
- oscillation하는 weight는 항상 quantization bin 경계 근처에서 움직임
- 이를 완화하기 위해 regularization를 추가하여, weight가 bin의 중심에 더 가깝게 유지되도록 유도
- ŵ: bin 중심
- clip(w,sn,sp) : quantization weight의 범위
- s,n,p: quantization parameter (scale, lower, upper)
최종 학습 목적 함수
- Ltask: 기존의 학습 손실
- : 규제 항의 가중치
- 이 규제는 quantization하는 weight 뿐만 아니라 움직이는 않는 weight에도 영향을 줄 수 있다는 단점 존재
4.3 Iterative freezing of oscillating weights
- oscillations을 방지하는 또 다른 방법으로, oscillations 빈도가 임계값 fth를 초과하면 가중치를 freeze
freeze 방법
- 각 weight의 oscillations 빈도를 추적
- 빈도가 를 초과하면 해당 가중치를 학습에서 제외
- freeze 시 이전 값(정수 상태)을 유지
- 정수 값 기록을 위한 EMA 사용
- 가장 자주 등장한 상태로 가중치를 동결
- 이 방법은 학습 도중 작은 값을 0으로 동결하는 pruning과 유사
5. Experiments
5.2 Ablation studies
Oscillation dampening
- λ를 증가시킬수록 진동하는 weight 비율이 감소
- 하지만 λ가 너무 크면 최종 정확도가 떨어짐 (유익한 변화까지 억제하기 때문)
- 학습 초반에는 규제를 약하게, 학습 후반에는 점진적으로 강하게 만드는 cosine annealing 스케줄 적용
- 이 방식은 초기에 weight가 자유롭게 이동하도록 하여 최적화를 촉진하고, 학습 후반에는 oscillations를 최소화
- 최적의 λ 스케줄을 사용했을 때 정확도가 post-BN Baseline보다 1% 개선
- pre-BN Baseline보다 5% 이상 개선
Iterative weight freezing
- 임계값이 작아질수록 oscillations 문제가 많이 해결되고, 네트워크에 남아있는 oscillations의 빈도가 줄어듦
- BN 재평가 전후의 정확도 차이가 줄어듦 -> oscillations 문제 해결의 간접적 증거
- 하지만 너무 낮은 임계값은 학습 초기 단계에서 과도한 weight를 freeze -> 최종 정확도 감소
- 임계값을 점진적으로 낮추는 cosine annealing 적용
- 학습 초반에는 freeze를 최소화하고, 학습 후반에는 oscillations를 강하게 억제하여 최종 정확도 개선
- 최적의 freeze 임계값 스케줄을 사용했을 때 정확도가 post-BN Baseline보다 1% 개선
- pre-BN Baseline보다 5% 이상 개선
5.3 Comparison to other QAT methods
MobileNetV2 성능 비교
- weight와 활성화 값 모두 quantization하여 기존의 QAT 기법과 비교
- 3비트, 4비트 양자화 모두에서 제안된 알고리즘이 기존 QAT 기법보다 뛰어난 성능을 보임
- MobileNetV3-Small 와 EfficientNet-Lite에서도 제안된 방법이 최신 성능(SOTA)을 달성
- 기존 LSQ 방법에 비해 정확도가 1% 이상 개선
6. Conclusion
- 본 연구에서는 QAT에서 사용되는 STE 같은 변형 기법들이 weight oscillation을 유발하여 quantization된 네트워크의 성능을 저하시킬 수 있음을 보여줌 (특히 low-bit quantization을 적용한 light-weight networks에서 더욱 심각)
- 해당 문제는 Batch Normalization 통계를 손상시키고 최적화 과정을 방해
- 이에 따라 본 논문은 Oscillation Dampening과 Iterative Weight Freezing을 제안
- 제안된 두 기법은 ImageNet classification에서 low bit로 quantization된 효율적인 모델에서 SOTA 달성