본문 바로가기
논문 리뷰/Quantization

[Paper Review] A Comprehensive Overhaul of Feature Distillation

by hyeon827 2025. 2. 1.

 

논문 원본 : https://arxiv.org/abs/1904.01866

 

A Comprehensive Overhaul of Feature Distillation

We investigate the design aspects of feature distillation methods achieving network compression and propose a novel feature distillation method in which the distillation loss is designed to make a synergy among various aspects: teacher transform, student t

arxiv.org

 

3. Approach

 

 

   3.1. Distillation position

  • Neural Network에서 Activation Function은 매우 중요한 요소
  • Neural Network의 Non-linearity는 Activation Function 덕분에 생김
  • Computer Vision 분야에서 가장 많이 사용되는 Activation Function은 ReLU
  • ReLU는 입력 값이 양수이면 그대로 유지하고, 음수이면 0으로 고정하는 간단한 방식으로 작동
  • 이 과정에서 음수 값을 제거하여 불필요한 정보가 역전파되는 것을 막아줌
  • ReLU의 이러한 특징을 고려하여 knowledge distillation을 설계하면 필요한 정보만 효과적으로 전달할 수 있음
  • 그러나 기존 연구들 대부분은 ReLU의 역할을 깊이 고려하지 않음

   Strategy

  • 기존 대부분의 distillation 방법들은 Layer Block이 끝나는 지점 (ex. ResNet의 Residual Block) 에서 distillation을 수행
  • 이 과정에서 ReLU를 고려하지 않음
    • ReLU는 음수 값을 0으로 바꿔서 중요한 정보가 일부 손실될 수 있음
    • 즉, 학생 모델(Student)은 ReLU를 거친 후 손실된 정보만 받게 됨
  • 본 논문은 ReLU를 통과하기 전에 distillation을 수행
  • 이렇게 하면 ReLU를 통과하기 전에 교사(Teacher) 모델의 정보를 학생(Student) 모델이 받을 수 있어 더 많은 정보를 보존할 수 있음

 

 

 

   3.2. Loss function

  • Teacher의 featrue value는 ReLU 적용 이전 값이므로
    • positive value : 중요한 정보 -> Student도 같은 값을 출력해야 함
    • negative value : 불필요한 값 -> Student도 음수 값이어야 함
  • Student Network가 Teacher의 음수 값보다 더 작은 값을 출력해야 함
  • 이를 위해 Margin ReLU 사용
    • 여기서 m은 0보다 작은 margin 값
    • 즉, Teacehr의 음수 값을 보완하여 Student가 따라가기 쉽게 조정
    • 이전 연구에서는 m을 고정된 값을 사용했지만, 이는 네트워크의 가중치 정보를 반영하지 않음
    • 따라서, 본 연구는 각 채널별로 Teacher의 음수 값 평균을 margin 값으로 설정하는 방식 제안
    • 이 기대값은 훈련 과정에서 직접 계산할 수도 있고, 이전 Batch Normalization Layer의 파라미터를 활용하여 계산할 수 도 있음

Margin ReLU

 

 

  • Student Network 변환
    • 1x1 Convolution + Batch Normalization을 사용해 교사 네트워크의 출력을 학습하는 회귀 모델을 만듦
      • 교수님께서 언급하신 adapter
  • 정리하자면 Student Network는 Network가 제공하는 "Margin이 적용된 Feature Map"을 따라가도록 학습하게 됨

 

   Distance Function

  • distillation function도 ReLU를 고려한 방식으로 변경해야 함
  • Teacher의 음수 값보다 Student의 값이 크면 오차를 계산하고, 음수 값보다 작다면 오차를 0으로 설정하여 학습하지 않도록 함
  • 이를 수식으로 표현하면 Partial L2 Distance는 아래와 같음

Partial L2 Distance

 

   Final Loss Function

  • Teacher Transform : Margin ReLU
  • Student Transform : 1×1 Convolution 기반 회귀 모델
  • Distance Function : Partial L2 Distance dp

  • 최종 네트워크의 loss function은 기본적인 task loss와 distillation loss의 합
    • CIFAR 같은 32×32 입력 이미지 네트워크에서는 3개 층을 증류
    • ImageNet 같은 대형 네트워크에서는 4개 층을 증류

 

 

 

   3.3. Batch normalization

  • Student Network는 배치 단위로 정규화된 데이터를 받으며 학습함
  • 그런데 Teacher Network의 BN 레이어를 평가 모드로 설정하면 문제 발생
    • Teacher는 학습 중의 배치 단위 평균과 분산이 아닌, 미리 저장된 값을 사용하게 됨
    • 즉, Student가 배치 단위로 정규화된 feature를 학습하는데, Teacher는 평가 모드로 고정된 feature를 제공하면 정규화 방식이 달라져서 효과가 떨어짐
  • Teacher Network의 BN을 학습 모드로 설정해야 함→ Teahcer의 특징값이 학생과 동일한 방식으로 정규화됨
  • Student Network의 변환 과정에서도 BN을 추가해야 함→ Student도 BN를 적용하여 Teacher의 출력을 더 잘 따라가도록 도움