Meta learning
Definition
-
“learn to learn”, intends to design models that can learn new skills or adapt to new environment rapidly with a few traing samples, like human learning way. The detail can be posted in Meta-Learning: Learning to Learn Fast
-
Optimization aims:
begin{aligned}
theta^* = argmin_theta mathbb{E}_{mathcal{D}sim p(mathcal{D})} [mathcal{L}_theta(mathcal{D})]
end{aligned}
where $mathcal{D}=langle S, Brangle$, Support set and Batch set.
- Training steps
- sample a subset of labels $Lsubsetmathcal{L}$.
- samples a support set $S^L subset mathcal{D}$ and a training batch $B^L subset mathcal{D}$. Both of them belong to the sampled label set $L$, $y in L, forall (x, y) in S^L, B^L$.
- support set is the input of model.
- the update of model parameters is based on the loss in backpropagation calculated from the mini-batch $B^L$.
each pair of sampled dataset $(S^L, B^L)$ is regarded as one data point, such that trained models can generalize to other datasets. Symbols in red are added for meta-learning in addition to the general supervised learning objective.
begin{aligned}
theta = argmax_theta color{red}{E_{Lsubsetmathcal{L}}[} E_{color{red}{S^L subsetmathcal{D}, }B^L subsetmathcal{D}} [sum_{(x, y)in B^L} P_theta(x, ycolor{red}{, S^L})] color{red}{]}
end{aligned}
- Traning stages
- meta-learner: a optimizer $g_phi$ learns how to update the learner model’s parameters via the support set $S$, $theta’ = g_phi(theta, S)$
- learner: A classifier $f_theta$ is the “learner” model, trained for operating a given task.
- final learning objective is:
begin{aligned}
mathbb{E}_{Lsubsetmathcal{L}}[ mathbb{E}_{S^L subsetmathcal{D}, B^L subsetmathcal{D}} [sum_{(mathbf{x}, y)in B^L} P_{g_phi(theta, S^L)}(y vert mathbf{x})]]
end{aligned}
Common methods
model-based: $f_theta(mathbf{x}, S)$
- use recurrent network with internal (or external) memory.
- rapid parameter update achieved by meta-learner model or internal architecture.
- memory-augmented neural network for meta-learning
- using external memory storage to facilate learning process without forgetting new information in future.
- encoding new information quicly, so adapt to new tasks after only a few samples.
- memory-augmented neural networkc: how to assign weights to attention vector. memory serves as knowledge repository, the controller learns to read and write memory rows. ttention weights generation by addressing mechanism: control-based + location-based.
- MANN for meta-learning: the update of memory for efficient information retrievel and storage, how to read from memory and how to write into memory.
- meta network
- architecture
- loss gradients are used as meta information to populate models to learn fast weights.
metric-based: $sum_{(mathbf{x}_i, y_i) in S} k_theta(mathbf{x}, mathbf{x}_i)y_i$
- learn efficient distance metric.
- similar to nearnest neighbors algorithm (KNN,k-means) and kernel density estimation.
-
the predicted possibility of labeled samples is from is a weighted sum of support set samples, and the weight is generated by a kernel function $ktheta$, which can measure the similarity of twp data samples:
begin{aligned}
Ptheta(y vert mathbf{x}, S) = sum_{(mathbf{x}_i, y_i) in S} k_theta(mathbf{x}, mathbf{x}_i)y_{i}
end{aligned} -
crucial points: good kernel
- solution: learning embedding vector of input data explicitly and use them to design proper kernel functions.
- Siamese networks:
- assumption: the learned embedding can be generalized to be useful for measuring the distance between images of unknown categories.
- verification, images pairs.
- final prediction is the class of the support image with highest probability.
- Matching networks:
- $g_{theta}$ with $k$ classifiers for k classes, while $f_{theta}$ for testing images.
- attention kernel depends on two embedding functions $g$ and $f$, in simple version where $f=g$. in complex version where integrating full contextual embedding (FCE), it achieves improvement on hard tasks, not for simple tasks.
- Relation networks
- image pairs, feature concatenation.
- relation modeled: mse loss function.
- Prototypical networks
- images of each class are embedded into $M$ dimensional feature vector, each class has prototype feature vector from the mean vector of the embedded support data samples in each class.
- squared euclidean distance loss.
optimization-based: $P_{g_phi(theta, S^L)}(y vert mathbf{x})$
- optimize the model parameter explicitly for fast learning.
- modeling optimization algorithm exploitly.
- LSTM meta-learner
- MAML
-
reptile
-
crucial keys: good kernel
- solution: learning embedding vector of input data explicitly and use them to design proper kernel functions.
- Siamese networks:
- assumption: the learned embedding can be generalized to be useful for measuring the distance between images of unknown categories.
- verification, images pairs.
- final prediction is the class of the support image with highest probability.
- Matching networks:
- $g_{theta}$ with $k$ classifiers for k classes, while $f_{theta}$ for testing images.
- attention kernel depends on two embedding functions $g$ and $f$, in simple version where $f=g$. in complex version where integrating full contextual embedding (FCE), it achieves improvement on hard tasks, not for simple tasks.
- Relation network
- image pairs, feature concatenation.
- relation modeled: mse loss function.
- Prototypical
- images of each class are embedded into $M$ dimensional feature vector, each class has prototype feature vector from the mean vector of the embedded support data samples in each class.
- squared euclidean distance.
optimization-based: $P_{g_phi(theta, S^L)}(y vert mathbf{x})$
- optimize the model parameter explicitly for fast learning
近期评论