Allocating Constrained Compute Budgets: At Pre-Training or Inference?

Which matters more? Pre-training compute or test-time compute? What are the scaling laws that maximize performance over constrained compute budgets?

Hi everyone 👋

I’ve been interested in better understanding the tradeoffs between compute used at training-time vs inference-time and have been on the lookout for a paper that explores this. There aren’t that many out there and when I asked around multiple friends pointed me to this paper:

Okay, let’s dive in!

Thanks for reading! Hit subscribe to stay updated
on the most interesting news in AI.

📚 Concepts & Learning

Last week, when I was on vacation filling up on French pastries, coffee, and summer sunshine I accidentally made a casual-turned-not-so-casual tweet thread about test-time vs training-time compute budgets. What can I say, sometimes you get nerd-sniped 🤷‍♀️:

I do think this thread summarizes some of the more interesting parts of the paper we’re covering today! Check it out if you want the sparknotes.

Tweet TLDR

  1. Is it better to use your compute budget on model training or on "thinking deeply" when you ask it question?

  2. Pre-training usage: Do you want your model to spend more time learning about the world?

    1. Absorbing facts

    2. Learning new methods to solve problems

  3. Inference-time usage: Do you want your model to spend more time combining its existing knowledge to generate a solution?

    1. Thinking deeply

    2. Revising / editing

  4. tldr: paper conclusions

    1. Hard problems benefit from more pretraining compute (knowledge-building)

    2. Easy problems benefit from more inference-time compute (revisions)

    3. Hard problems benefit from sampling more: like when you start a fresh chat because LLMs get stuck looping / revising solutions with major issues

      1. The top 1, 2, 3... strategies are probably very different on hard problems

    4. Easy problems benefit from sampling less: instead they make revisions to the top strategy

      1. The top 1, 2, 3... strategies are probably very similar

Okay now let’s get to the deeper dive.

The deeper dive

To start, the high-level question we’re looking to answer is: How should we allocate compute? Is it better to allocate it to training budgets, improving capabilities? Or is it better to allocate it to inference budgets, improving processing capabilities after you’ve asked a question?

In the paper, Snell et. al call this question “defining the exchange rate between pretraining and inference-time compute”. I like this idea of exploring the tradeoff because this is realistic in many settings — you have a finite amount of compute and you have to decide how you want to allocate it to maximize accuracy and success rate of outputs. So, what did they find?

“… rather than focusing purely on scaling pretraining, in some settings it is be more effective to pretrain smaller models with less compute, and then apply test-time compute to improve model outputs.”

Instead of just training (called pretraining in the paper) your largest model with as many tokens and as much compute as possible, it may be more effective to train smaller models and then allocate more compute for inference at test time.

When does this work? On easy problems, mostly.

Does this work on harder problems? The findings suggest — not really: increasing inference-time compute did not help. Instead pre-training was the most effective method to improve performance.

Set up for this paper

  1. The models

    1. The models already have pre-training — they have some inherent knowledge and this paper explores using performance by those models (with no additional inference scaffolding or pretraining) as a baseline / floor

    2. The models are LLaMA models: some open-source and some undisclosed and proprietary

  2. The compute

    1. The compute budget is pre-determined and fixed

  3. The tests

    1. Tests are pre-determined and categorized into easy sets vs hard sets

    2. Test sets focus on math reasoning

  4. The inference method

    1. They used beam search as a way of defining how inference compute should be used (more on that later)

  5. Variables tested:

    1. Optimal ratio of inference to training compute allocation

    2. Different compute budgets (small, medium, large)

    3. Different difficulty buckets for tasks (1-5)

    4. Model size (parameter count)

What are some ways to allocate inference compute?
  1. Parallel sampling (aka best-of-N) — test your prompt multiple times. Imagine you use the same prompt across several ChatGPT convosations. LLMs are non-deterministic, so each time the model responds, it will sample from a distribution of plausible responses

    1. Parallel sampling generally plateaus with larger compute budgets. If you have unlimited compute resources, sampling further away from the top-k results won’t improve your final success rate — this intuitively makes sense!

  2. Sequential revisions — test your prompt one at a time and revise to land on the final answer. Think of this like opening up a single ChatGPT chat, submitting your prompt, waiting for a model answer, and following up with more context over multiple turns of conversation to improve the final response

Here’s a quick graphic:

What are the tradeoffs between parallel vs sequential sampling
  1. Imagine parallel sampling is like starting a fresh chat every time. You’ll end up with multiple first drafts, many of which will try different methods to solve a problem — good for testing diverse methods of solutioning

  2. Imagine sequential sampling are like revisions. You have your first draft and you get multiple chances to revise to come to the final correct conclusion

  3. There’s a tradeoff! When is it better to try each method? Well, it depends on the question

  4. The solution proposed in this paper is to combine both sampling methods: there are a fixed number of parallel “first draft” samples, and the rest of compute is use on revisions across those samples. It looks like this:

From the paper, we learn that
  • Harder tasks prompts benefit from more parallel sampling (option 1) 

  • Easier task prompts benefit from lots of sequential revisions (option 2)

This is kinda surprising!

Wouldn't hard problems benefit from thinking deeply? Wouldn't it help to reflect over the few plausible strategies to achieve the right answer?

Wouldn't easier problems benefit from lots of parallel sampling? Then, from those response options, selecting the majority best answer?

What are other ways to optimize the inference-time compute usage?

Another thing that Snell et. al cover in their paper is the search method during inference-time usage. In this paper, they landed on beam search being the optimal method, shown on the right below:

“We find that the efficacy of any given verifier search method depends critically on both the compute budget and the question at hand. Specifically, beam-search is more effective on harder questions and at lower compute budgets, whereas best-of-N is more effective on easier questions and at higher budgets. Moreover, by selecting the best search setting for a given question difficulty and test-time compute budget, we can nearly outperform best-of-N using up to 4x less test-time compute.” 

How do different inference methods perform on different tasks?

And how does that compare that to 1x sampling on larger base models?

  1. Each verifier search method perform better and worse across different task prompt sets

  2. Beam-search is effective on harder questions with lower compute budgets

  3. Best-of-N is more effective on easier questions and higher compute budgets

  4. In many cases (see graph below), the right search method will result in 4x more efficient use of compute to achieve the same performance

  5. This is great! Now we can figure out a way of efficiently routing to the right search method, to make efficient use of our inference-compute budgets. Great when we have constrained budgets

Training vs inference

So how should we think about allocating train-time vs. inference-time compute on different problem sets?

There’s a lot to unpack here:

  • What are the stars? They are models that are sampled once per question (greedy sampling), but those models are 14x larger (14x more parameters) that those mapped by the solid lines

  • If the star is below the associated line, it’s more effective to invest in inference vs. training compute

    • Small models with efficient inference-time compute usage perform better on the same tasks

  • If the star is above the associated line, it’s more effective to invest in training vs. inference compute

    • Larger models perform better than smaller models with optimized inference-time methods over the same compute budget

The upshot

So, what’s the optimal ratio? Well… it depends. In general, it seems that harder problems benefit from the latent knowledge in larger pre-trained models and from the option to test out multiple solution methods through parallel sampling. It seems that easier problems are more efficiently solved by smaller models when paired with more inference-time compute and the ability to make revisions to their first few samples.

Other questions

Some other questions that are related to work in this paper include:

  1. How does this training vs inference “exchange rate” differ conditioned on the size of the base model (aka. number of parameters)?

    • Larger models are going to both more expensive to train and to pass inference over

    • How do you take that into consideration when you hold compute budgets as fixed?

    • Maybe larger models already have a ton of latent knowledge from previous pre-training (they’re smart), so they will continue to have that advantage even if downstream post-training and inference are comparatively more expensive than for smaller models

  2. How does this differ based on the task prompt? 

    • Do harder task prompts benefit from more training or inference-time compute? 

    • What does hard mean? 

      • Are stacked questions considered hard? 

        • For example, would a complex prompt with many easy sub-questions be considered difficult?

        • An example of this: Tell me about Apple corporation. Give me the full name of the current CEO. Give me the EDITDA rate from Q4 2024. Give me the location of the US headquarters… etc. Each question individually is very easy, but if you stack them into one prompt, does that make this prompt hard? You can imagine this is a trivial example and more complex examples of “stacked questions” could be steps in a coding sequence or math proof

        • If these tasks are considered hard, maybe one capability to focus on is improving memory of a large context window. Or improving recursive review of prompts

      • Are underspecified task prompts considered hard (or impossible)?

This is not the Chinchilla paper

Note: This paper is not about compute-optimal scaling in pre-training. Research in that realm is covered in paper like the Chinchilla paper — if you’re interested, I wrote about this exactly two years ago to the day! (see here)

Some things we didn’t cover in this post

Beam search vs. tree search vs. naive sampling (best-of-N)

Snell et. al land on beam search as the most effective method of inference to elicit successful responses. We didn’t cover that research on PRMs. More to be covered in future posts!

🗞️ On My Reading List

That’s it! Have a great day and see you in two weeks! 👋

What did you think about today’s newsletter? Send me a DM on Twitter @barralexandra or reply to this email!

Thanks for reading Superslow AI.
If you enjoyed this post, feel free to
share it with any AI-curious friends. Cheers!