Nudging employs a small aligned model to generate nudging tokens to steer the large base model's output toward desired directions when the base model's uncertainty is high.
Large language models (LLMs) require alignment—such as instruction-tuning or reinforcement learning from human feedback—to effectively and safely follow user instructions. The conventional training pipelines require separate alignment tuning for every model size within each model family. This practice leads to substantial computational overhead, hindering the rapid iteration and development of new model families.
In this work, we propose nudging, a simple, plug-and-play, and training-free algorithm that aligns any base model at inference time using a small aligned model. Nudging is motivated by recent findings that alignment primarily alters the model's behavior on a small subset of stylistic tokens, such as ``Sure'' or ``Thank''. We find that base models are significantly more uncertain when generating these alignment-related tokens. Leveraging this observation, nudging employs a small aligned model to generate nudging tokens to steer the large base model's output toward desired directions when the base model's uncertainty is high.
We evaluate the effectiveness of nudging across 3 model families and 13 tasks, covering reasoning, general knowledge, instruction following, and safety benchmarks. Without any additional training, nudging a large base model with a 7x-14x smaller aligned model achieves zero-shot performance comparable to, and sometimes surpassing, that of large aligned models. For example, nudging OLMo-7b with OLMo-1b-instruct—affecting less than 9% of tokens—achieves a 10% absolute improvement on GSM8K over OLMo-7b-instruct. Unlike prior inference-time tuning methods, nudging enables off-the-shelf collaboration between model families. For instance, nudging Gemma-2-27b with Llama-2-7b-chat outperforms Llama-2-70b-chat on various tasks. Overall, this work introduces a simple yet powerful approach to token-level model collaboration, offering a modular solution to LLM alignment.
Previous work finds that the token distributions of base models shift significantly after alignment only on a small set of output positions. By identifying (1) where the base and aligned model would disagree (alignment-related positions) and (2) what the aligned model would generate for these positions (alignment tokens), we can insert these tokens during decoding to nudge a base model to behave like an aligned model. We study these two questions: where to nudge and what to nudge. Specifically, we find:
To answer where to nudge, we compare the token distribution of the base model and the aligned model. We aim to show that the base model uncertainty can help identify positions where the token distribution of the base model significantly differs from the aligned model (the red box).
To answer what to nudge, we analyze how aligned models of different sizes within the same family behave at alignment-related positions. We find that the small and large aligned models usually have similar token distribution at these positions.
The counts and ratios of different token position types for three model pairs on three tasks. Agree, weakly disagree, and disagree refer to positions where the top-1 token from the aligned model has a rank equal to 1, between 2 and 3, and larger than 3, in the base model's token distribution. Base models are mostly certain and agree with their aligned counterparts at the token level. When the base models get more uncertain, they are more likely to disagree with the aligned models.
Model | GSM8K | MMLU | Just-eval |
---|---|---|---|
Llama-2 | 82.6 | 65.4 | 69.7 |
Gemma-2 | 87.6 | 58.2 | 59.6 |
OLMo | 38.0 | 42.7 | 46.9 |
The ratio of token positions where the top-1 token from the large aligned model is among the top-3 of the small aligned model when the large base and aligned models disagree. At alignment-related positions, the small and large aligned model usually have similar token distribution.
To demonstrate the effectiveness of nudging, we evaluate it across three different model families: Llama-2, Gemma-2, and OLMo. We consider 13 tasks, covering reasoning, general knowledge, instruction following, and safety benchmarks. We find that
Model | GSM | SVP | MA | MM | Arc | CS | ST | date | SP | CF | LLC | Avg. |
---|---|---|---|---|---|---|---|---|---|---|---|---|
Llama 2 | ||||||||||||
70b | 10.0 | 11.7 | 4.4 | 26.6 | 78.3 | 42.2 | 62.7 | 44.7 | 42.1 | 47.7 | 1.3 | 33.8 |
7b-chat | 25.5 | 43.3 | 62.8 | 40.9 | 54.1 | 52.2 | 50.4 | 33.9 | 51.7 | 45.0 | 7.3 | 42.5 |
70b-chat | 48.5 | 64.0 | 63.9 | 57.4 | 77.6 | 70.3 | 58.9 | 48.8 | 64.9 | 38.3 | 31.3 | 56.7 |
Nudging | 46.2 | 63.3 | 71.1 | 57.4 | 75.9 | 59.2 | 60.0 | 47.7 | 59.5 | 57.4 | 38.7 | 57.9 |
Gemma 2 | ||||||||||||
27b | 6.7 | 8.3 | 7.0 | 17.7 | 24.2 | 16.0 | 21.3 | 12.5 | 7.9 | 7.6 | 6.7 | 12.4 |
2b-it | 63.8 | 72.3 | 92.2 | 57.5 | 78.6 | 70.0 | 53.4 | 30.4 | 56.2 | 33.9 | 4.7 | 55.7 |
27b-it | 85.4 | 86.7 | 99.4 | 75.1 | 92.7 | 71.7 | 70.6 | 69.6 | 74.3 | 11.3 | 82.0 | 74.4 |
Nudging | 74.6 | 77.0 | 92.2 | 66.8 | 88.9 | 69.8 | 62.3 | 49.9 | 63.0 | 42.7 | 86.0 | 70.3 |
OLMo | ||||||||||||
7b | 18.8 | 16.7 | 35.0 | 22.5 | 37.1 | 71.0 | 40.6 | 15.2 | 38.5 | 22.4 | 0.0 | 28.9 |
1b-it | 10.2 | 12.0 | 34.4 | 31.6 | 37.1 | 56.6 | 64.2 | 4.3 | 44.5 | 49.1 | 0.0 | 31.3 |
7b-it | 14.1 | 22.7 | 32.8 | 49.8 | 60.3 | 70.9 | 61.3 | 9.8 | 64.6 | 44.9 | 0.0 | 39.2 |
Nudging | 24.2 | 30.7 | 71.1 | 41.3 | 47.0 | 68.5 | 62.9 | 6.0 | 49.9 | 47.1 | 0.0 | 40.8 |
Zero-shot performances on standard benchmarks, where nudging uses the smaller aligned models (2nd rows) to nudge the large base models (1st rows). We bold and underline the best and the second-best results for each setting. Nudging a large base model with a much smaller aligned model performs on par with the aligned version of the large base model. Also, nudging is particularly effective on math: GSM (GSM8K), SVP (SVAMP), MA (MultiArith) and symbolic reasoning tasks: CF (Coin Flip), LLC (LastLetterConcat).
Model | GSM8K | MMLU |
---|---|---|
Gemma-2-27b | 7.0 | 16.5 |
OLMo-7b-it | 13.0 | 49.0 |
Nudging (Gemma-2-27b + OLMo-7b-it) | 41.0 | 62.5 |
Llama-2-7b-chat | 22.0 | 39.5 |
Nudging (Gemma-2-27b + Llama-2-7b-chat) | 65.0 | 67.0 |
Llama-2-70b-chat | 52.0 | 53.0 |
Zero-shot performance of different models. Nudging works effectively for models from different families.
Explore the base models' output and the nudging output on different tasks and model families.
@misc{fei2024nudginginferencetimealignmentmodel,
title={Nudging: Inference-time Alignment via Model Collaboration},
author={Yu Fei and Yasaman Razeghi and Sameer Singh},
year={2024},
eprint={2410.09300},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2410.09300},
}