← Back to OpenAI updates ← Terug naar OpenAI-updates
OpenAI ARTICLE ARTIKEL 7 March 2018 7 maart 2018

Reptile: A scalable meta-learning algorithm Reptile: A scalable meta-learning algorithm

Read paper(opens in a new window)View code(opens in a new window) Read paper(opens in a new window)View code(opens in a new window)

Article details Artikelgegevens
AI maker AI-maker OpenAI Type Type Article Artikel Published Gepubliceerd 7 March 2018 7 maart 2018 Updates Updates Videos Video's View original article Bekijk origineel artikel

Listen to article

We’ve developed a simple meta-learning algorithm called Reptile which works by repeatedly sampling a task, performing stochastic gradient descent on it, and updating the initial parameters towards the final parameters learned on that task. Reptile is the application of the Shortest Descent algorithm to the meta-learning setting, and is mathematically similar to first-order MAML (which is a version of the well-known MAML algorithm) that only needs black-box access to an optimizer such as SGD or Adam, with similar computational efficiency and performance.

Meta-learning is the process of learning how to learn. A meta-learning algorithm takes in a distribution of tasks, where each task is a learning problem, and it produces a quick learner—a learner that can generalize from a small number of examples. One well-studied meta-learning problem is few-shot classification, where each task is a classification problem where the learner only sees 1–5 input-output examples from each class, and then it must classify new inputs. Below, you can try out our interactive demo of 1-shot classification, which uses Reptile.

Training Data

Draw Here

0.0%

99.5%

0.4%

Edit All

Input

Edit

Try clicking the “Edit All” button, drawing three distinct shapes or symbols, then drawing one of them again in the input field on the right, and see how well Reptile can classify it. The first three drawings are the labelled examples: each drawing defines one of the classes. The final drawing represents the unknown example, and Reptile outputs the probabilities of it belonging to each of the classes.

How Reptile works

Like MAML, Reptile seeks an initialization for the parameters of a neural network, such that the network can be fine-tuned using a small amount of data from a new task. But while MAML unrolls and differentiates through the computation graph of the gradient descent algorithm, Reptile simply performsstochastic gradient descent (SGD)⁠(opens in a new window)on each task in a standard way—it does not unroll a computation graph or calculate any second derivatives. This makes Reptile take less computation and memory than MAML. The pseudocode is as follows:

Initialize $\Phi$Φ, the initial parameter vector

for iteration $1 , 2 , 3 , \ldots$1,2,3,… do

Randomly sample a task $T$

Perform $k > 1$k>1 steps of SGD on task $T$, starting with parameters $\Phi$Φ, resulting in parameters $W$W

Update: $\Phi \leftarrow \Phi + \epsilon \left(\right. W - \Phi \left.\right)$Φ←Φ+ϵ(W−Φ)

end for

Return $\Phi$Φ

As an alternative to the last step, we can treat$\Phi - W$Φ−W as a gradient and plug it into a more sophisticated optimizer likeAdam⁠(opens in a new window).

It is at first surprising that this method works at all. If$k = 1$k=1, this algorithm would correspond to “joint training”—performing SGD on the mixture of all tasks. While joint training can learn a useful initialization in some cases, it learns very little when zero-shot learning is not possible (e.g. when the output labels are randomly permuted). Reptile requires$k > 1$k>1, where the update depends on the higher-order derivatives of the loss function; as we show in the paper, this behaves very differently from$k = 1$k=1(joint training).

To analyze why Reptile works, we approximate the update using aTaylor series⁠(opens in a new window). We show that the Reptile update maximizes the inner product between gradients of different minibatches from the same task, corresponding to improved generalization. This finding may have implications outside of the meta-learning setting for explaining the generalization properties of SGD. Our analysis suggests that Reptile and MAML perform a very similar update, including the same two terms with different weights.

In our experiments, we show that Reptile and MAML yield similar performance on theOmniglot⁠(opens in a new window)andMini-ImageNet⁠(opens in a new window)benchmarks for few-shot classification. Reptile also converges to the solution faster, since the update has lower variance.

Our analysis of Reptile suggests a plethora of different algorithms that we can obtain using different combinations of the SGD gradients. In the figure below, assume that we perform k steps of SGD on each task using different minibatches, yielding gradients $\textrm{ } g_{1} , g_{2} , \ldots , g_{k}$g 1​,g 2​,…,g k​. The figure below shows the learning curves on Omniglot obtained by using each sum as the meta-gradient. $\textrm{ } g_{2}$g 2​​corresponds to first-order MAML, an algorithm proposed in the original MAML paper. Including more gradients yields faster learning, due to variance reduction. Note that simply using $\textrm{ } g_{1}$g 1​​(which corresponds to$k = 1$k=1) yields no progress as predicted for this task since zero-shot performance cannot be improved.

Implementations

Our implementation of Reptile isavailable on GitHub⁠(opens in a new window). It uses TensorFlow for the computations involved, and includes code for replicating the experiments on Omniglot and Mini-ImageNet. We’re also releasinga smaller JavaScript implementation⁠(opens in a new window)that fine-tunes a model pre-trained with TensorFlow—we used this to create the above demo.

Finally, here’s a minimal example of few-shot regression, predicting a random sine wave from 10$\left(\right. x , y \left.\right)$(x,y) pairs. This one uses PyTorch and fits in a gist:

This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters

Show hidden characters

import numpy as np

import torch

from torch import nn, autograd as ag

import matplotlib.pyplot as plt

from copy import deepcopy

seed=0

plot=True

innerstepsize=0.02# stepsize in inner SGD

innerepochs=1# number of epochs of each inner SGD

outerstepsize0=0.1# stepsize of outer optimization, i.e., meta-optimization

niterations=30000# number of outer updates; each iteration we sample one task and update on it

rng=np.random.RandomState(seed)

torch.manual_seed(seed)

Define task distribution

x_all=np.linspace(-5, 5, 50)[:,None] # All of the x points

ntrain=10# Size of training minibatches

def gen_task():

"Generate classification problem"

phase=rng.uniform(low=0, high=2*np.pi)

ampl=rng.uniform(0.1, 5)

f_randomsine=lambda x : np.sin(x+phase) *ampl

return f_randomsine

Define model. Reptile paper uses ReLU, but Tanh gives slightly better results

model=nn.Sequential(

nn.Linear(1, 64),

nn.Tanh(),

nn.Linear(64, 64),

nn.Linear(64, 1),

)

def totorch(x):

return ag.Variable(torch.Tensor(x))

def train_on_batch(x, y):

x=totorch(x)

y=totorch(y)

model.zero_grad()

ypred=model(x)

loss= (ypred-y).pow(2).mean()

loss.backward()

for param in model.parameters():

param.data-=innerstepsize*param.grad.data

def predict(x):

return model(x).data.numpy()

Choose a fixed task and minibatch for visualization

f_plot=gen_task()

xtrain_plot=x_all[rng.choice(len(x_all), size=ntrain)]

Reptile training loop

for iteration in range(niterations):

weights_before=deepcopy(model.state_dict())

Generate task

f=gen_task()

y_all=f(x_all)

Do SGD on this task

inds=rng.permutation(len(x_all))

for _ in range(innerepochs):

for start in range(0, len(x_all), ntrain):

mbinds=inds[start:start+ntrain]

train_on_batch(x_all[mbinds], y_all[mbinds])

Interpolate between current weights and trained weights from this task

I.e. (weights_before - weights_after) is the meta-gradient

weights_after=model.state_dict()

outerstepsize=outerstepsize0* (1-iteration/niterations) # linear schedule

model.load_state_dict({name :

weights_before[name] + (weights_after[name] -weights_before[name]) *outerstepsize

for name in weights_before})

Periodically plot the results on a particular task and minibatch

if plot and iteration==0 or (iteration+1) %1000==0:

plt.cla()

f=f_plot

weights_before=deepcopy(model.state_dict()) # save snapshot before evaluation

plt.plot(x_all, predict(x_all), label="pred after 0", color=(0,0,1))

for inneriter in range(32):

train_on_batch(xtrain_plot, f(xtrain_plot))

if (inneriter+1) %8==0:

frac= (inneriter+1) /32

plt.plot(x_all, predict(x_all), label="pred after %i"%(inneriter+1), color=(frac, 0, 1-frac))

plt.plot(x_all, f(x_all), label="true", color=(0,1,0))

lossval=np.square(predict(x_all) -f(x_all)).mean()

plt.plot(xtrain_plot, f(xtrain_plot), "x", label="train", color="k")

plt.ylim(-4,4)

plt.legend(loc="lower right")

plt.pause(0.01)

model.load_state_dict(weights_before) # restore from snapshot

print(f"-----------------------------")

print(f"iteration {iteration+1}")

print(f"loss on plotted curve {lossval:.3f}") # would be better to average loss over a set of examples, but this is optimized for brevity

view rawreptile-sinewaves-demo.py hosted with ❤ by GitHub

Several people have pointed out to us that first-order MAML and Reptile are more closely related than MAML and Reptile. These algorithms take different perspectives on the problem, but end up computing similar updates—and specifically, Reptile’s contribution builds on the history of both Shortest Descent and avoiding second derivativesin⁠(opens in a new window)meta⁠(opens in a new window)-learning⁠(opens in a new window). We’ve since updated the first paragraph to reflect this.

print(f"-----------------------------")

print(f"iteration {iteration+1}")

print(f"loss on plotted curve {lossval:.3f}") # would be better to average loss over a set of examples, but this is optimized for brevity

Authors

Alex Nichol, John Schulman

Related articles

View all

Solving Rubik’s Cube with a robot hand Milestone Oct 15, 2019

Learning dexterity Milestone Jul 30, 2018

Evolved Policy Gradients Milestone Apr 18, 2018

Related articles

View all

Solving Rubik’s Cube with a robot hand Milestone Oct 15, 2019

Learning dexterity Milestone Jul 30, 2018

Evolved Policy Gradients Milestone Apr 18, 2018

More from OpenAI Meer van OpenAI

All updates Alle updates

Gemini komt eraan