Matching Networks / Prototypical Networks

Setup

Few-shot learning task

  • Abstract formulation:
    • Training data = labeled examples
    • Inputs:
      • Support set $S = \set{(x_i, y_i)}$ of input-output pairs
      • Unlabeled input $\hat x$
    • Output = the label $\hat y$ of $\hat x$, or more generally, a distribution $p(\hat y\mid\hat x, S)$
  • One common instantiation is the $N$-way $K$-shot classification task:
    • In each test instance, choose $N$ classes that do not appear in the training data.
    • Support set $S$ = $K$ labeled examples for each of $N$ classes
    • Input $\hat x$ = an unlabeled example from one of the $N$ classes

Matching Networks

Paper (Vinyals et al., NIPS 2016) Matching Networks for One Shot Learning

Model

  • The model is a network that predicts a linear combination of the support labels:$$\hat y = \sum_i a(\hat x, x_i)\,y_i$$ where $a$ is an "attention mechanism":$$a(\hat x, x_i) = \frac{\exp\crab{q(\hat x, x_i, S)}}{\sum_{i'} \exp\crab{q(\hat x, x_{i'}, S)}}$$ for some network $q$.
    • One simple choice for $q$ is the cosine similarity between the embeddings of $\hat x$ and $x_i$.
    • The final $q$ in this paper treats $S$ as a sequence and then runs a bunch of LSTMs on the embeddings of $S$, $\hat x$, and $x_i$.
  • To be pendantic, we can write$$p(\hat y \mid \hat x, S) = \sum_i a(\hat x, x_i)\,\II[y_i = \hat y]$$ which is a valid distribution since $\sum_i a(\hat x, x_i) = 1$.
  • Train with meta-learning. In each episode:
    • Sample a label set $L$ (e.g., take 5 labels from the training data)
    • Sample a support set $S$ (e.g., take 5 examples for each label in $L$)
    • Sample a training set $B$ (a bunch of examples whose labels are in $L$)
    • Maximize$$\sum_{(x, y)\in B}\log p(y\mid x, S)$$

Relation to other models

  • If $a$ is a kernel on $X \times X$, then the model resembles a kernel density estimator.
  • If $a$ is a constant for only $m$ closest $x_i$'s and 0 otherwise, then the model resembles $m$-nearest neighbor.
  • If we treat $S$ as a memory with $x_i$'s as keys and $y_i$'s as values, then the model is a memory mechanism with an extensible memory.

Prototypical network

Paper (Snell et al., NIPS 2017) Prototypical Networks for Few-shot Learning

Assumption There exists an embedding space where, for each class, examples from that class cluster around a single prototype representation.

Model

  • The model produces a prototype vector $c_k$ for each label $k$.
    • Let $S_k$ be the set of support examples with label $k$.
    • In the standard setup, let $c_k$ = uniform average of the embeddings of examples in $S_k$:$$c_k = \frac{1}{\card{S_k}}\sum_{(x_i, y_i) \in S_k} f(x_i)$$ for some embedder $f$.
    • A prototype vector can also be defined for a zero-shot setup. Given metadata $v_k$ (e.g., text description) for the label $k$ instead of $S_k$, we can do $c_k = g(v_k)$ for some embedder $g$.
    • Prototypical Networks rely on the assumption of "1 prototype per class", and might degrade if this is false. However, a powerful embedder $f$ should be able to produce an embedding space where this assumption is true!
  • The prototype vector can be used to classify new examples $\hat x$:$$p(\hat y = k\mid \hat x, S) = \frac{\exp\crab{-d(f(\hat x), c_k)}}{\sum_{k'}\exp\crab{-d(f(\hat x), c_{k'})}}$$ for some distance function $d$.
    • The paper finds that Euclidean distance works better than cosine distance. (This also applies to Matching Networks.)
  • For training, use meta-learning like in Matching Networks.
    • The paper finds it helpful to sample MORE classes per episode during training than at test time.

Relation to other models

  • Comparison with Matching Networks:
    • Matching Networks compare $\hat x$ to each support example $x_i$, then average the labels.
    • Prototypical Networks average the support examples $x_i$ for all classes, then compare $\hat x$ to them.
    • The two are equivalent when $\card{S_k} = 1$ for all classes $k$.
  • If $d$ is a regular Bregman divergence, then the model is equivalent to mixture density estimation with an exponential family density.
    • Euclidean and Mahalanobis distances are regular Bregman divergences, while cosine isn't. (This might be why Euclidean distance works better than cosine.)
  • If $d$ is Euclidean, the model is a linear model w.r.t. input $f(\hat x)$. The model is still powerful due to the non-linear $f$.
Exported: 2021-01-28T21:00:48.014486