Setup
Few-shot learning task
- Abstract formulation:
- Training data = labeled examples
- Inputs:
- Support set $S = \set{(x_i, y_i)}$ of input-output pairs
- Unlabeled input $\hat x$
- Output = the label $\hat y$ of $\hat x$, or more generally, a distribution $p(\hat y\mid\hat x, S)$
- One common instantiation is the $N$-way $K$-shot classification task:
- In each test instance, choose $N$ classes that do not appear in the training data.
- Support set $S$ = $K$ labeled examples for each of $N$ classes
- Input $\hat x$ = an unlabeled example from one of the $N$ classes
Matching Networks
Paper (Vinyals et al., NIPS 2016) Matching Networks for One Shot Learning
Model
- The model is a network that predicts a linear combination of the support labels:$$\hat y = \sum_i a(\hat x, x_i)\,y_i$$ where $a$ is an "attention mechanism":$$a(\hat x, x_i) = \frac{\exp\crab{q(\hat x, x_i, S)}}{\sum_{i'} \exp\crab{q(\hat x, x_{i'}, S)}}$$ for some network $q$.
- One simple choice for $q$ is the cosine similarity between the embeddings of $\hat x$ and $x_i$.
- The final $q$ in this paper treats $S$ as a sequence and then runs a bunch of LSTMs on the embeddings of $S$, $\hat x$, and $x_i$.
- To be pendantic, we can write$$p(\hat y \mid \hat x, S) = \sum_i a(\hat x, x_i)\,\II[y_i = \hat y]$$ which is a valid distribution since $\sum_i a(\hat x, x_i) = 1$.
- Train with meta-learning. In each episode:
- Sample a label set $L$ (e.g., take 5 labels from the training data)
- Sample a support set $S$ (e.g., take 5 examples for each label in $L$)
- Sample a training set $B$ (a bunch of examples whose labels are in $L$)
- Maximize$$\sum_{(x, y)\in B}\log p(y\mid x, S)$$
Relation to other models
- If $a$ is a kernel on $X \times X$, then the model resembles a kernel density estimator.
- If $a$ is a constant for only $m$ closest $x_i$'s and 0 otherwise, then the model resembles $m$-nearest neighbor.
- If we treat $S$ as a memory with $x_i$'s as keys and $y_i$'s as values, then the model is a memory mechanism with an extensible memory.
Prototypical network
Paper (Snell et al., NIPS 2017) Prototypical Networks for Few-shot Learning
Assumption There exists an embedding space where, for each class, examples from that class cluster around a single prototype representation.
Model
- The model produces a prototype vector $c_k$ for each label $k$.
- Let $S_k$ be the set of support examples with label $k$.
- In the standard setup, let $c_k$ = uniform average of the embeddings of examples in $S_k$:$$c_k = \frac{1}{\card{S_k}}\sum_{(x_i, y_i) \in S_k} f(x_i)$$ for some embedder $f$.
- A prototype vector can also be defined for a zero-shot setup. Given metadata $v_k$ (e.g., text description) for the label $k$ instead of $S_k$, we can do $c_k = g(v_k)$ for some embedder $g$.
- Prototypical Networks rely on the assumption of "1 prototype per class", and might degrade if this is false. However, a powerful embedder $f$ should be able to produce an embedding space where this assumption is true!
- The prototype vector can be used to classify new examples $\hat x$:$$p(\hat y = k\mid \hat x, S) = \frac{\exp\crab{-d(f(\hat x), c_k)}}{\sum_{k'}\exp\crab{-d(f(\hat x), c_{k'})}}$$ for some distance function $d$.
- The paper finds that Euclidean distance works better than cosine distance. (This also applies to Matching Networks.)
- For training, use meta-learning like in Matching Networks.
- The paper finds it helpful to sample MORE classes per episode during training than at test time.
Relation to other models
- Comparison with Matching Networks:
- Matching Networks compare $\hat x$ to each support example $x_i$, then average the labels.
- Prototypical Networks average the support examples $x_i$ for all classes, then compare $\hat x$ to them.
- The two are equivalent when $\card{S_k} = 1$ for all classes $k$.
- If $d$ is a regular Bregman divergence, then the model is equivalent to mixture density estimation with an exponential family density.
- Euclidean and Mahalanobis distances are regular Bregman divergences, while cosine isn't. (This might be why Euclidean distance works better than cosine.)
- If $d$ is Euclidean, the model is a linear model w.r.t. input $f(\hat x)$. The model is still powerful due to the non-linear $f$.