Instance-Conditional Knowledge Distillation for Object Detection (ICD)
Introduction
High performance의 Deep Learning Networks의 성능을 얻기 위해서는, 불가피하게 많은 양의 parameters를 수반하게 되며, 이는 high computational cost와 memory를 필요하게 만든다. 따라서, Resource-limited devices에서 object detection과 같은 실용적인 application을 사용하기 위해, network pruning, quantization, mobile architecture design, 그리고 knowledge distillation (KD)의 방법들이 등장했다. 그 중에서 KD는 추가적인 inference time 부담과 수정없이 network의 효율성과 성능을 개선시킬 수 있기 때문에 많이 적용되고 있다.
지금까지의 많은 KD들은 image classification에 집중되어 연구가 진행되었고, 다음의 이유로 인해 이를 object detection에 그대로 적용하기는 힘들다고 알려져 있다.
- 물체의 localization이 고려되지 않았다.
- 서로 다른 위치에 분포되어 있는 여러개의 물체들이 하나의 이미지에 나타나 있다 (서로 다른 위치로부터의 representation은 서로 다른 contribution를 가지므로 distillation를 어렵게 만든다, i.e., imbalance issue).
이를 해결하기 위해 기존의 논문들은 다음과 같은 방법을 사용했다.
- classification과 localization에 대한 정보를 함께 가진 intermediate representations를 distillation한다.
- imbalance 문제를 해결하기 위해 different feature selection 방법을 사용한다.
- Proposal-based: RPN에 의해 예측된 proposal regions이나 detector를 활용
- Rule-based: pre-design된 rule로 선별된 regions을 활용
; 하지만, 이러한 방법들은 유익한 informative regions를 무시한다. - Attention-based: discriminative area에 대한 hints를 제공하지만, detection을 위한 activation과 knowledge사이의 관계가 clear하지 않다.
이에 따라, 본 논문에서는 Instance-Conditional knowledge Distillation (ICD)를 제안함으로써, knowledge를 feature selection과 연결하는 explicit solution을 제공한다. 이를 위해,
- transformer decoder를 통해 instance-aware attention으로 지식과 각 instance 사이의 correlation을 측정함으로써, 서로 다른 instance와 관련된 지식을 찾도록 하는 conditional decoding module를 설계한다.
- 여기서, 사람이 관찰한 instance를 query로 사용하고, teacher’s representation (decomposed into key and value)과의 사이를 scaled-product attention을 통해 correlation을 계산한다.
결과적으로, decoder를 통해 구분된 features들을 distillation하고, instance-aware attention을 통해 weighted된다.
Related Works
Object Detection
Object detection에서의 detector은 크게 two-stage 또는 one-stage detectors로 구분할 수 있다. Two-stage는 일반적으로 Region Proposal Network (RPN)을 사용하여, 초기의 대략적인 prediction을 얻고, detection heads를 통해 이를 refine하는 과정을 거친다. 대표적으로 Faster F-CNN이 있다. 반면, one-stage detectors의 경우, feature map에 대해 바로 예측을 하기 때문에 더 빠르다고 알려져 있다. 대표적으로 RetinaNet이 있다.
Method
Overview
일반적으로, detection의 유용한 지식들은 intermediate features에 불균등하게 분포되어 있다. 이를 개선하기 위해 본 논문은, 다음과 같은 식을 통해, instance-conditional knowledge를 전달한다.
\[\mathcal{L}_=\sum_{i=1}^N \mathcal{L}_d\left(\kappa_i^{\mathcal{S}}, \kappa_i^{\mathcal{T}}\right)\]여기서, \(\kappa_i^\mathcal{T}=\mathcal{G}(\mathcal{T},\mathrm{y}_i)\)는 condition \((\mathrm{y}_i)\)과 teacher representations \((\mathcal{T})\)에 대한 knowledge를 의미하고, \(\mathcal{G}\)는 instance-conditional decoding module를 나타내며, 이는 auxiliary loss를 통해 최적화된다.
Instance-Conditional Knowledge \((\kappa_i)\)
Instance-conditional decoding module \((\mathcal{G})\)을 통해 instance condition이 주어졌을 때, unconditional knowledge \((\mathcal{T})\)로부터 instance-conditional knowledge \((\kappa^\mathcal{T}_i)\)를 구할 수 있다.
- Unconditional knowledge \((\mathcal{T})\)는 teacher detector로부터 사용가능한 모든 정보를 의미한다. Multi-scale representations로 나타내면 \(\mathcal{T}=\left[X_p\in\mathbb{R}^{D \times H_p \times W_p}\right]_{p\in\mathcal{P}}\)가 되며, 여기서 \(\mathcal{P}\)는 spatial resolutions, \(D\)는 channel dimension을 의미한다. Spatial dimension를 따라 서로 다른 scales의 representations을 concatenation하게 되면 \(A^\mathcal{T}\in\mathbb{R}^{L\times D}\)가 되며, 여기서 \(L=\sum_{p\in\mathcal{P}}H_p\times W_p\)는 scale를 걸쳐서 모든 pixels의 수를 합한 것이다.
- Instance condition은 \(\mathcal{Y}=(\mathrm{y}_i)^N_{i=1}\)로 나타내며, 여기서 \(N\)는 object의 수를, \(\mathrm{y}_i=(c_i,\mathbf{b}_i)\)는 \(i\)-th instance에 대한 annotation를 의미한다. 즉, category \(c_i\)와 box location \(\mathbf{b}_i=(x_i,y_i,w_i,h_i)\)이다.
각 instance에 대해 학습가능한 embedding를 만들기 위해, annotation을 hidden space상에서의 query feature vector \((\mathbf{q}_i)\)로 mapping하게 되며, 이때 아래의 식처럼, 원하는 knowledge를 얻기위해 condition을 지정하게 된다.
\[\mathbf{q}_i=\mathcal{F}_q(\mathcal{E}(\mathrm{y}_i)),\quad\mathbf{q}_i\in\mathbb{R}^D\]여기서 \(\mathcal{E}(\cdot)\)는 instance encoding function을, \(\mathcal{F}_q\)는 MLP를 나타낸다.
\(\mathbf{q}_i\)가 주어졌을 때, \(\mathcal{T}\)로부터의 knowledge를 찾기 위해서 correlation을 측정하는데, 이는 dot-product attention을 사용하여 얻을 수 있으며, 각 head \(j\)는 3개의 선형 layers \((\mathcal{F}^k_j,\mathcal{F}^k_q,\mathcal{F}^k_v)\)에 상응한다.
아래의 식과 같이, key feature \((\mathrm{K}^\mathcal{T}_j)\)는 teacher의 representation \((\mathrm{A}^\mathcal{T})\)를 positional embeddings \((\mathrm{P}\in\mathbb{R}^{L\times d})\)과 projection하여 계산할 수 있다.
\[\mathrm{K}_j^{\mathcal{T}}=\mathcal{F}_j^k\left(\mathrm{A}^{\mathcal{T}}+\mathcal{F}_{p e}(\mathrm{P})\right), \mathrm{K}_j^{\mathcal{T}} \in \mathbb{R}^{L \times d}\]여기서, \(\mathcal{F}_{p e}\)는 position embeddings를 통한 선형 projection을 의미한다. 유사하게, value feature와 query는 아래와 같다.
\[\mathrm{V}_j^{\mathcal{T}}=\mathcal{F}_j^v\left(\mathrm{A}^{\mathcal{T}}\right), \quad \mathrm{V}_j^{\mathcal{T}} \in \mathbb{R}^{L \times d}\] \[\mathbf{q}_{i j}=\mathcal{F}_j^q\left(\mathbf{q}_i\right), \quad \mathbf{q}_{i j} \in \mathbb{R}^d\]마지막으로, \(j\)-th head에 의한 \(i\)-th instance의 instance-aware attention mask \(\mathbf{m}_{ij}\)는 \(\mathrm{K}_j^\mathcal{T}\)와 \(\mathbf{q}_{ij}\)사이의 normalized dot-product를 통해 구할 수 있다.
\[\mathbf{m}_{i j}=\operatorname{softmax}\left(\frac{\mathrm{K}_j^{\mathcal{T}} \mathbf{q}_{i j}}{\sqrt{d}}\right), \mathbf{m}_{i j} \in \mathbb{R}^L\]최종적으로, querying along the key feature \((\mathbf{m}_{i j})\)와 value features \((\mathrm{V}_j^{\mathcal{T}})\)는 representations과 instances간의 correlation을 의미하기 때문에 instance-condition knowledge는 \(\kappa_i^{\mathcal{T}}=\left\{\left(\mathbf{m}_{i j}, \mathrm{V}_j^{\mathcal{T}}\right)\right\}_{j=1}^M\)가 되며, 이는 \(i\)-th instance에 상응하는 knowledge를 encoding한다.
Auxiliary Task
앞서 설명한 decoding module \((\mathcal{G})\)을 최적화하기 위해 auxiliary tasks를 활용한다. 우선, 함수 \(\mathcal{F}_{agg}\)로 instance-conditional knowledge를 통합 (instance-level aggregated information \(\mathbf{g}_i^{\mathcal{T}}\))하여 객체를 식별하고 위치를 파악할 수 있다.
\[\mathbf{g}_i^{\mathcal{T}}=\mathcal{F}_{a g g}\left(\kappa_i^{\mathcal{T}}, \mathbf{q}_i\right), \mathbf{g}_i^{\mathcal{T}} \in \mathbb{R}^D\]Instance-level aggregated information \(\mathbf{g}_i^{\mathcal{T}}\)이 충분한 instance 단서를 유지하도록 하기위해, 아래와 같이 instance-sensitive tasks를 사용하여 최적화한다.
\[\mathcal{L}_{a u x}=\mathcal{L}_{i n s}\left(\mathbf{g}_i^{\mathcal{T}}, \mathcal{H}\left(\mathrm{y}_i\right)\right)\]여기서 \(\mathcal{H}\)는 instance information을 targets으로 encode한다.
다만, 위 식에서 \(i\)-th instance annotation에 해당하는 \(\mathrm{y}_i\)를 \(\mathbf{g}_i^{\mathcal{T}}\)와 \(\mathcal{H}\left(\mathrm{y}_i\right)\)를 통해서 구할수 있기 때문에 위 식을 바로 사용하면 trivial solution이 된다. 이는 결과적으로, teacher representation \(\mathcal{T}\)를 무시하고 parameters를 학습하게 될 가능성이 있다는 것을 의미한다. 이를 해결하기 위해, encoding function \(\mathcal{E}(\cdot)\)에 대한 정보를 drop하여 aggregation function \(\mathcal{F}_{agg}\)이 \(\mathcal{T}\)로부터 hint를 얻도록한다.
이러한 information dropping은 instance condition에 대한 정확한 annotation을 불확실한 것으로 교체하는 것을 의미하는데, bounding box annotations의 경우에는, 아래의 식과 같이 rough box center \(\left(x_i', y_i'\right)\)와 rough scales indicators \([\log_2(w_i)], [\log_2(h_2)]\)를 활용하여 얻을 수 있다.
\[\left\{\begin{array}{l} x_i^{\prime}=x_i+\phi_x w_i, \\ y_i^{\prime}=y_i+\phi_y h_i, \end{array}\right.\]여기서, \(\left(w_i,w_i\right)\)는 bounding box의 width와 height를 나타내며, \(\phi_x, \phi_y\)는 uniform distribution \(\Phi \sim U[-3, 3]\)에서 샘플된 값을 의미한다. 결과적으로 coarse information을 얻어 \(\mathcal{E}\)를 통해 instance encoding을 얻을 수 있다.
Aggregated representation \(\mathbf{g}_i^\mathcal{T}\)는 auxiliary task로 최적화되는 데, 이를 위해 \(\mathcal{F}_{obj}\)와 \(\mathcal{F}_{reg}\)를 각각 도입하여, identification과 localization results를 예측한다. 아래의 식과 같이 real-fake identification을 최적화하기 위해 binary cross entropy loss (BCE)를, regression을 최적화하기 위해 \(L1\) loss를 적용한다.
\[\mathcal{L}_{aux}=\mathcal{L}_{BCE}\left(\mathcal{F}_{obj}\left(\mathbf{g}_i^\mathcal{T}\right),\delta_{obj}(\mathrm{y}_i)\right)+\mathcal{L}_{1}\left(\mathcal{F}_{reg}\left(\mathbf{g}_i^\mathcal{T}\right),\delta_{reg}(\mathrm{y}_i)\right)\]여기서, \(\delta_{obj}(\cdot)\)은 indicator로서, \(\mathrm{y}_i\)가 real이면 1을 fake이면 0을 뱉어낸다.
Instance-Conditional Distillation
Conditional knowledge distillation를 하기 위해, student representations의 projected value features \(\mathrm{V}_j^\mathcal{S}\in\mathbb{R}^{L\times d}\)를 얻고, feature와 각 instance 사이의 correlations을 측정하는 instance-aware attention mask \(\mathbf{m}_{ij}\)를 사용함으로써, 아래와 같이 distillation loss를 설계할 수 있다.
\[\mathcal{L}_{distill} = \frac{1}{MN_r}\sum^{M}_{j=1}\sum^{N}_{i=1}\delta_{obj}(\mathrm{y}_i)\cdot\left<\mathbf{m}_{ij},\mathcal{L}_{MSE}\left(\mathrm{V}^\mathcal{S}_j,\mathrm{V}^\mathcal{T}_j\right)\right>\]여기서, \(N_r=\sum^{N}_{i=1}\delta_{obj}(\mathrm{y}_i), (N_r\le N)\)는 real instance의 수를 나타내고, \(\mathcal{L}_{MSE}\left(\mathrm{V}_j^\mathcal{S},\mathrm{V}_j^\mathcal{T}\right)\in\mathbb{R}^L\)는 pixel-wise mean-square error를 표현하며, \(\left<\cdot,\cdot\right>\)은 inner product를 위한 Dirac notation이다. Supervised learning loss \(\mathcal{L}_{det}\)를 포함하는 전체 loss function은 아래와 같다.
\[\mathcal{L}_{total}=\mathcal{L}_{det} + \mathcal{L}_{aux}+\lambda\mathcal{L}_{distill}\]여기서, \(\mathcal{L}_{det}\)와 \(\mathcal{L}_{distill}\)에 대한 gradient만 student network로 back-propagation되며 (student networks를 학습할 때 활용), \(\mathcal{L}_{aux}\)의 gradient는 instance-conditional decoding function \(\mathcal{G}\)과 auxiliary task와 관련된 modules만을 update한다.
Conclusion
본 논문은 human observed instances와 연관된 knowledge를 찾고 선택하기 위한 instance-feature cross attention를 활용하는 Instance-Conditional knowledge Distillation (ICD)를 제안했다. 본 방법은 instance를 query로, teacher’s representation을 key로 encode하며, knowledge를 찾는 방법을 decoder에게 학습하기 위해, auxiliary task를 설계했다. 제안된 방법은 다양한 detectors에 대해 일관되게 성능을 향상시켰으며, 몇몇 student networks는 teacher networks보다 성능이 뛰어났다.