Learning to Learn with Generative Models
of Neural Network Checkpoints

University of California, Berkeley

We explore a data-driven approach for learning to optimize neural networks. We construct a dataset of neural network checkpoints and train a generative model on the parameters. In particular, our model is a conditional diffusion transformer that, given an initial input parameter vector and a prompted loss, error, or return, predicts the distribution over parameter updates that achieve the desired metric. At test time, it can optimize neural networks with unseen parameters for downstream tasks in just one update. We apply our method to different neural network architectures and tasks in supervised and reinforcement learning.


Generative Models of Checkpoints

Over the last decade, the deep learning community has generated a massive amount of neural network checkpoints. They contain a wealth of information: diverse parameter configurations and rich metrics such as test losses, classification errors and reinforcement learning returns that describe the quality of the checkpoint. We pre-train a generative model on millions of checkpoints. At test time, we use it to generate parameters for a neural network that solves a downstream task.

We refer to our model as G.pt (G and .pt refer to generative models and checkpoint extensions, respectively). G.pt is trained as a transformed-based diffusion model of neural network parameters. It takes as input a starting (potentially randomly-initialized) parameter vector, the starting loss/error/return and a user's desired loss/error/return. Once trained, we can sample a parameter update from the model that matches the user's prompted metric. The backbone of our diffusion model is a transformer that operates on sequences of neural net parameters. Similar to ViTs, our G.pt models leverage relatively minimal domain-specific inductive biases. In this paper, we train a G.pt model for a given dataset, metric and architecture (e.g., a loss-conditional CIFAR-10 CNN model, an error-conditional MNIST MLP model, etc.).

We're also releasing our pre-training dataset containing 23M neural network checkpoints across 100,000+ training runs. The dataset contains trained MLPs and CNNs for vision tasks (MNIST, CIFAR-10) and continuous control tasks (Cartpole). All checkpoints are available for download here.

Prompting for Losses, Errors and Returns

In contrast to hand-designed optimizers (and prior work on learned optimizers), G.pt takes as input a user's desired loss, error, return, etc. Conditioning on these metrics lets us train on checkpoints with good and bad performance alike. After training, you can prompt G.pt with whatever value you want. Below, we prompt it with a large range of test losses (MNIST model) and returns (Cartpole model) over the course of training. By the end of training, our method succesfully generates networks that achieve nearly the full range of prompts.

Training in One Step

By prompting it for minimal loss or maximal return, G.pt can optimize unseen parameter vectors in one step. For example, a Cartpole model with randomly-initialized parameters can be rapidly trained with just one update.

One G.pt update.

Compared to iterative gradient-based optimizers like Adam or Nesterov Momentum, our approach only needs one parameter update to achieve good results. We show examples for two G.pts, one conditioned on MNIST classification error and the other on Cartpole return.

Another advantage of our method over traditional optimizers is that it can directly optimize non-differentiable metrics, like classification errors or RL returns. This is because G.pt is trained as a diffusion model in parameter space, and so we never need to backpropagate through these metrics in order to train our model.

Capturing Parameter Multimodality

Many different neural net parameter vectors can yield the same loss. This creates ambiguity: if we ask G.pt for a small loss, which parameter solution should it choose? Since G.pt is a diffusion model, we can sample different solutions. Below, we show the test error landscape of MNIST MLPs as a function of the top two parameter-space PCA directions. We plot several sampled solutions (prompting for low error) in this landscape. Note that the samples cover distinct local minima in the error landscape. You can zoom and pan to explore the landscape.

A latent walk through parameter space. We can also visualize different parameter samples directly. For example, we can visualize generated conv1 weights for our CIFAR-10 error-conditional model. We prompt for low test error and fix all other inputs except the input latent code. This lets us take a latent walk through parameter space. Below are interpolations for 16 conv1 filters predicted by G.pt. The latents control subtle variations in each filter as well as their ordering.

BibTeX

@article{Peebles2022,
  title={Learning to Learn with Generative Models of Neural Network Checkpoints},
  author={William Peebles and Ilija Radosavovic and Tim Brooks and Alexei Efros and Jitendra Malik},
  year={2022},
  journal={arXiv preprint arXiv:2209.12892},
}