논문 원본 : https://arxiv.org/abs/1904.01866
A Comprehensive Overhaul of Feature Distillation
We investigate the design aspects of feature distillation methods achieving network compression and propose a novel feature distillation method in which the distillation loss is designed to make a synergy among various aspects: teacher transform, student t
arxiv.org
3. Approach
3.1. Distillation position
- Neural Network에서 Activation Function은 매우 중요한 요소
- Neural Network의 Non-linearity는 Activation Function 덕분에 생김
- Computer Vision 분야에서 가장 많이 사용되는 Activation Function은 ReLU
- ReLU는 입력 값이 양수이면 그대로 유지하고, 음수이면 0으로 고정하는 간단한 방식으로 작동
- 이 과정에서 음수 값을 제거하여 불필요한 정보가 역전파되는 것을 막아줌
- ReLU의 이러한 특징을 고려하여 knowledge distillation을 설계하면 필요한 정보만 효과적으로 전달할 수 있음
- 그러나 기존 연구들 대부분은 ReLU의 역할을 깊이 고려하지 않음
Strategy
- 기존 대부분의 distillation 방법들은 Layer Block이 끝나는 지점 (ex. ResNet의 Residual Block) 에서 distillation을 수행
- 이 과정에서 ReLU를 고려하지 않음
- ReLU는 음수 값을 0으로 바꿔서 중요한 정보가 일부 손실될 수 있음
- 즉, 학생 모델(Student)은 ReLU를 거친 후 손실된 정보만 받게 됨
- 본 논문은 ReLU를 통과하기 전에 distillation을 수행
- 이렇게 하면 ReLU를 통과하기 전에 교사(Teacher) 모델의 정보를 학생(Student) 모델이 받을 수 있어 더 많은 정보를 보존할 수 있음
3.2. Loss function
- Teacher의 featrue value는 ReLU 적용 이전 값이므로
- positive value : 중요한 정보 -> Student도 같은 값을 출력해야 함
- negative value : 불필요한 값 -> Student도 음수 값이어야 함
- Student Network가 Teacher의 음수 값보다 더 작은 값을 출력해야 함
- 이를 위해 Margin ReLU 사용
- 여기서 m은 0보다 작은 margin 값
- 즉, Teacehr의 음수 값을 보완하여 Student가 따라가기 쉽게 조정
- 이전 연구에서는 m을 고정된 값을 사용했지만, 이는 네트워크의 가중치 정보를 반영하지 않음
- 따라서, 본 연구는 각 채널별로 Teacher의 음수 값 평균을 margin 값으로 설정하는 방식 제안
- 이 기대값은 훈련 과정에서 직접 계산할 수도 있고, 이전 Batch Normalization Layer의 파라미터를 활용하여 계산할 수 도 있음
- Student Network 변환
- 1x1 Convolution + Batch Normalization을 사용해 교사 네트워크의 출력을 학습하는 회귀 모델을 만듦
- 교수님께서 언급하신 adapter
- 1x1 Convolution + Batch Normalization을 사용해 교사 네트워크의 출력을 학습하는 회귀 모델을 만듦
- 정리하자면 Student Network는 Network가 제공하는 "Margin이 적용된 Feature Map"을 따라가도록 학습하게 됨
Distance Function
- distillation function도 ReLU를 고려한 방식으로 변경해야 함
- Teacher의 음수 값보다 Student의 값이 크면 오차를 계산하고, 음수 값보다 작다면 오차를 0으로 설정하여 학습하지 않도록 함
- 이를 수식으로 표현하면 Partial L2 Distance는 아래와 같음
Final Loss Function
- Teacher Transform : Margin ReLU
- Student Transform : 1×1 Convolution 기반 회귀 모델
- Distance Function : Partial L2 Distance dp
- 최종 네트워크의 loss function은 기본적인 task loss와 distillation loss의 합
- CIFAR 같은 32×32 입력 이미지 네트워크에서는 3개 층을 증류
- ImageNet 같은 대형 네트워크에서는 4개 층을 증류
3.3. Batch normalization
- Student Network는 배치 단위로 정규화된 데이터를 받으며 학습함
- 그런데 Teacher Network의 BN 레이어를 평가 모드로 설정하면 문제 발생
- Teacher는 학습 중의 배치 단위 평균과 분산이 아닌, 미리 저장된 값을 사용하게 됨
- 즉, Student가 배치 단위로 정규화된 feature를 학습하는데, Teacher는 평가 모드로 고정된 feature를 제공하면 정규화 방식이 달라져서 효과가 떨어짐
- Teacher Network의 BN을 학습 모드로 설정해야 함→ Teahcer의 특징값이 학생과 동일한 방식으로 정규화됨
- Student Network의 변환 과정에서도 BN을 추가해야 함→ Student도 BN를 적용하여 Teacher의 출력을 더 잘 따라가도록 도움