TADAM: Task dependent adaptive metric for improved few-shot...
Transcript of TADAM: Task dependent adaptive metric for improved few-shot...
1/22
TADAM: Task dependent adaptive metricfor improved few-shot learning (NeurIPS-2018)
B. N. Oreshkin, P. Rodriguez, and A. Lacoste (Element AI)
Jungtaek Kim ([email protected])
Machine Learning Group,Department of Computer Science and Engineering, POSTECH,
77 Cheongam-ro, Nam-gu, Pohang 37673,Gyeongsangbuk-do, Republic of Korea
March 11, 2019
2/22
Table of Contents
Motivation
Contributions
Metric Scaling
3/22
Motivation
I Two recent approaches have attracted significant attention inthe few-shot learning domain: Matching Networks andPrototypical Networks.
I In both approaches, the support set and the query set areembedded with a neural network, and nearest neighborclassification is used given a metric in the embedded space.
I This paper extends the very notion of the metric space bymaking it task dependent via conditioning the featureextractor on the specific task.
I The authors find a solution in exploiting the interactionbetween the conditioned feature extractor and the trainingprocedure based on auxiliary co-training on a simpler task.
4/22
Contributions
I Metric Scaling: This paper proposes metric scaling to improveperformance of few-shot algorithms, and mathematicallyanalyzes its effects on objective function updates.
I Task Conditioning: It uses a task encoding network to extracta task representation based on the support set. This is used toinfluence the behavior of the feature extractor through FiLM.
I Auxiliary task co-training: The authors show that co-trainingthe feature extraction on a conventional classification taskreduces training complexity and provides better generalization.
5/22
Metric Scaling
I Prototypical networks based on Euclidean distance is betterthan matching networks based on cosine distance.
I The authors hypothesize that the improvement could bedirectly attributed to the interaction of the different scaling ofthe metrics with the softmax.
I Moreover, the dimensionality of the output is known to have adirect impact on the output scale even for the Euclideandistance.
I This paper proposes to scale the distance metric by alearnable temperature α.
6/22
Metric Scaling
I Therefore, a class probability can be written as
pφ,α(y = k|x) = softmax(−αd(z, ck)). (1)
I Class-wise cross-entropy function is
Jk(φ, α) =∑xi∈Qk
αd(fφ(xi ), ck) + log∑j
exp(−αd(fφ(xi ), cj)
.(2)
7/22
Metric Scaling
8/22
Metric Scaling
I From Eq. (3), it is clear that for small α values, the first termminimizes the embedding distance between query samples andtheir corresponding prototypes. The second term maximizesthe embedding distance between the samples and theprototypes of the non-belonging categories.
I For large α values (Eq. (4)), the first term is the same as inEq. (3); while the second term maximizes the distance of thesample with the closest wrongly assigned prototype cj∗i .
9/22
Task Conditioning
I Up until now we assumed the feature extractor fφ(·) to betask-independent.
I The authors define a dynamic feature extractor fφ(x, Γ), whereΓ is the set of parameters predicted from a task representationsuch that the performance of fφ(x, Γ) is optimized given thesupport set S .
I This is related to FiLM conditioning layer and conditionalbatch normalization of the form of hl+1 = γ � hl + β.
I The task representation defined as the mean of task classcentroids
c̄ =∑k
ck (3)
(i) reduces the dimensionality of a task embedding network(TEN) input and (ii) replaces expensive RNN/CNN/attentionmodeling.
10/22
Task Conditioning
11/22
Task Conditioning
12/22
Task Conditioning
13/22
Auxiliary Task Co-Training
I The TEN introduces additional complexity into thearchitecture via task conditioning layers inserted after theconvolutional and batch norm blocks.
I The TEN network is difficult to train. Thus, the authors usethe technique, auxiliary task co-training.
I It applies auxiliary co-training with an additional logit head(the normal 64-way classification in mini-Imagenet case).
I The authors anneal it using an exponential decay schedule ofthe form 0.9b20t/Tc.
14/22
Overall Architecture
15/22
Experimental Results
16/22
Experimental Results
17/22
Experimental Results
18/22
Experimental Results
19/22
Feature-wise Linear Modulation (FiLM)
I This paper suggests a general-purpose model that can be usedin learning a visual reason.
I A FiLM layer carries out a simple, feature-wise affinetransformation on a neural network’s intermediate features,conditioned on an arbitrary input.
I The FiLM model consists of a FiLM-generating linguisticpipeline and a FiLM-ed visual pipeline.
20/22
Feature-wise Linear Modulation (FiLM)
21/22
Feature-wise Linear Modulation (FiLM)
22/22
Feature-wise Linear Modulation (FiLM)