Introduction
Large Language Models (LLMs) are capable of generating impressively human-like text that can be used in various applications, from chatbots to content creation or e-mail generation. However, the quality of the generated text largely depends on the sampling strategy employed during thegeneration process that lies underneath our favorite AI Assistants (ChatGPT, Gemini, Claude etc...). This blog post explores different decoding strategies, with a focus on their practical implementation and impact on the generated output.
Overview
This blogpost is based on the methods discussed in the paper "The Curious Case of Neural Text Degeneration." The primary contribution of this paper is Nucleus Sampling, a method designed to improve the quality and human-likeness of generated text. In this post, we’ll examine several decoding strategies, including Greedy Decoding, Temperature Sampling, Top-K Sampling, and Nucleus Sampling, and compare their effectiveness. You can also explore the full implementation of these methods in my GitHub repository here.
Experimental Setting
1. Model Used
For the experiments, I used the GPT-2 large model, which has approximately 774 million parameters. This model was pre-trained on a vast corpus of English text in a self-supervised manner. The choice of a hobbled GPT-2 that was not further instructed through SFT or RLHF, rather than more advanced models, allows me to focus on the impact of the different sampling strategies without the interference of pre-existing biases towards "human-like" language generation.
2. Prompt and Task
We provided the model with extensive context to better distinguish meaningful generations from senseless ones. Since GPT-2 was trained on Wikipedia, it has the necessary knowledge to generate informative content. The key aspect of our experiments is to evaluate how different decoding methods influence the coherence and relevance of the generated text. The generation process was constrained by the `max_new_token` hyperparameter set to 100, ensuring that the output stops either when this limit is reached or when the model generates an End of Sentence token.
Decoding Methods
In this part I will discuss several sampling strategies and their practical implementation.
Click to see the main function used to generate some text given a particular prompt, a model, its tokenizer and a sampling strategy.
def generate(model, input, tokenizer, sampler, max_new_tokens, context_length=None):
"""
Generates text using a specified sampling strategy.
Parameters:
- model: The pre-trained language model used for generating text.
- input: The initial input text for the model to start generating from.
- tokenizer: The tokenizer associated with the model, used to encode and decode text.
- sampler: The sampling method used to select the next token during generation (e.g., Greedy, Top-K, Nucleus).
- max_new_tokens: The maximum number of new tokens to generate.
- context_length: The maximum length of the context that the model should consider. If None, it defaults to the tokenizer's model max length.
Returns:
- Generated text decoded from the token indices.
"""
# Set context length to the tokenizer's max length if not provided
if context_length is None:
context_length = tokenizer.model_max_length
# Set the model to evaluation mode
model.eval()
# Tokenize the input text
encoded = tokenizer(input, return_tensors="pt").to(device)
EOS = tokenizer.eos_token_id # End of sequence token ID
# Initialize the input token indices and an empty tensor for the generated output
idx = encoded['input_ids']
generated = torch.tensor([], dtype=torch.int32, device=device)
last_token = None
step = 0
while last_token != EOS and step < max_new_tokens:
# Crop the input to the last `context_length` tokens
# The model only considers a context of maximum size `context_length`
B, T = idx.shape
idx_cond = idx[:, -max(T, context_length):]
# Get the model's predictions (logits) for the next token
logits = model(idx_cond).logits
# Focus only on the last time step's logits
logits = logits[:, -1, :]
# Sample the next token using the provided sampler
idx_next = sampler(logits).to(device)
last_token = idx_next.view(-1, 1)
# Append the sampled token to the generated sequence
generated = torch.cat([generated, last_token], dim=-1)
# Update the input sequence by adding the new token
idx = torch.cat([idx[:, 1:], last_token], dim=-1)
step += 1
# Decode the generated token indices into text
return tokenizer.decode(generated.squeeze())
1. Greedy Decoding
Greedy Decoding is the simplest and least efficient decoding strategy. At each step of the generation process, the model selects the token with the highest probability given the context of the previously generated tokens.
The probability of generating the next token \( y_t \) given the context of the previously generated tokens \( y_1, y_2, \ldots, y_{t-1} \) is given by:
$$ P(y_t \mid y_1, y_2, \ldots, y_{t-1}) = \arg\max_{y_t} P(y_t \mid y_1, y_2, \ldots, y_{t-1}) $$
This process is repeated iteratively until the model generates an end-of-sequence token or reaches a predefined maximum sequence length. While simple, Greedy Decoding often leads to repetitive and uncreative outputs due to its focus on selecting only the most likely token at each step.
class GreedySampler(Sampler):
def __init__(self):
self.softmax = nn.Softmax(dim = -1)
def __call__(self, logits: torch.Tensor):
probs = self.softmax(logits)
dist = Categorical(probs)
return dist.sample()
Example of greedy decoding generation:
Prompt: "The capital of France, Paris, is a beautiful city where people are gorgeous and elegant. You get to drink coffee outside..."
Output: "Everyone talks about the food courses. If you think around the block you may find things interesting. I love to try new things. 'Price: Approx $3-4, $15 gratuity. Wednesday of the 12th Sunday: Some dinner and drinks dancing, feel free to come.' 31. DATE: February 17th 1993, VENUE: Basketballrama GAMES: Wins over Memphis, Dallas, Dallas RESULT"
2. Temperature Sampling
Temperature Sampling involves adjusting the softmax function to control the level of randomness in the generated sequences. The softmax function, which converts logits (unscaled log probabilities) into a probability distribution, is modified by a temperature parameter \( T \).
The probability of selecting a token \( y_i \) during Temperature Sampling is given by:
$$ P(y_i) = \frac{e^{\frac{\text{logit}(y_i)}{T}}}{\sum_{j}e^{\frac{\text{logit}(y_j)}{T}}} $$
In this equation, \( \text{logit}(y_i) \) is the logit value associated with token \( y_i \), and \( T \) is the temperature parameter. A higher temperature increases diversity by making the probability distribution flatter, allowing for more exploration. Conversely, a lower temperature sharpens the distribution, focusing the generation on more probable tokens.
Example using Temperature Sampling:
Prompt: "The capital of France, Paris, is a beautiful city where people are gorgeous and elegant. You get to drink coffee outside..."
Output: "Several cafes have a chance to visit the house of the deceased. Only one word can describe the house of the deceased. The people that passed away fell in love and lived in this house. Demons were living in the house and were possessed by a spirit. There are haunted houses, ghost stories, and even the myth of the witch who haunted the bodies of three men who went into a house and when they came out, one of them was dead."
class TemperatureSampler(Sampler):
def __init__(self, temperature: float = 1.0):
self.temperature = temperature
self.softmax = nn.Softmax(dim=-1)
def __call__(self, logits: torch.Tensor):
logits=logits / self.temperature
probs = self.softmax(logits)
dist = Categorical(probs)
return dist.sample()
3. Top-K Sampling
Top-K Sampling restricts the model’s choice of tokens to the top \( K \) most probable tokens at each step. This method limits the model’s exploration to a fixed number of high-probability candidates, which can reduce the likelihood of generating nonsensical text but may also lead to repetitive outputs.
The steps involved in Top-K Sampling are:
- Compute Probabilities: The model computes probabilities for each word in the vocabulary based on the context.
- Sort Probabilities: The probabilities are sorted in descending order.
- Select Top-K Words: The top-K words with the highest probabilities are selected.
- Normalize Probabilities: The selected probabilities are normalized to create a distribution over the top-K words.
- Sample from Distribution: A word is sampled from this distribution to obtain the next predicted word.
Example using Top-K Sampling:
Prompt: "The capital of France, Paris, is a beautiful city where people are gorgeous and elegant. You get to drink coffee outside..."
Output: "The place has a huge number of restaurants, and a lot of different places to get your food from and to get your drink. You could have a beer, and the waiter will bring it to your table and serve you something else. There is no reason why this should not be possible. I am happy to say we have been successful in getting the idea for this idea into the open source community, but as you will see in the following video, it isn't a simple..."
class TopKSampler(Sampler):
def __init__(self, k: int, sampler: Sampler):
self.k = k
self.sampler = sampler
def __call__(self, logits: torch.Tensor):
zeros = logits.new_ones(logits.shape) * float('-inf')
values, indices = torch.topk(logits, self.k, dim=-1)
zeros.scatter_(-1, indices, values)
return self.sampler(zeros)
4. Nucleus Sampling
Nucleus Sampling, also known as Top-P Sampling, addresses the limitations of Top-K Sampling by dynamically adjusting the number of candidate tokens based on their cumulative probability. Instead of fixing the number of candidates, Nucleus Sampling includes all tokens whose cumulative probability mass exceeds a threshold \( P \).
The key intuition behind Nucleus Sampling is to adapt to the changing probabilities of words, allowing for a dynamic vocabulary size. This flexibility enables the generation of more diverse and contextually relevant outputs.
$$ P_{\text{nucleus}}(w_t \mid \text{context}) = \frac{\text{nucleus-words}}{\sum_{i=1}^{P} \text{nucleus-words}_i} $$
Nucleus Sampling is particularly effective in scenarios where creativity and coherence need to be balanced. It can dynamically expand or contract the candidate pool to include both high-probability and lower-probability tokens, providing a more nuanced and human-like generation process.
Example using Nucleus Sampling:
Prompt: "The capital of France, Paris, is a beautiful city where people are gorgeous and elegant. You get to drink coffee outside..."
Output: "If you are going to do this you should have the same clothes and shoes as the person you are meeting. You should also bring the same tools you will be using. I have used this for years and it works great. Here are some other tools I use. I use a basic blender and a food processor. I also use a cheese grater and a cheese cloth. I have been using a food processor for a few years."
class NucleusSampler(Sampler):
def __init__(self, p: float, sampler: Sampler):
self.p = p
self.sampler = sampler
self.softmax = nn.Softmax(dim=-1)
def __call__(self, logits: torch.Tensor):
probs = self.softmax(logits)
sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)
cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)
nucleus = cum_sum_probs < self.p
nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)
sorted_log_probs = torch.log(sorted_probs)
sorted_log_probs[~nucleus] = float('-inf')
sampled_sorted_indexes = self.sampler(sorted_log_probs)
res = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))
return res.squeeze(-1)
Conclusion
Decoding strategies play a crucial role in the quality and coherence of text generated by large language models. While Greedy Decoding is straightforward, it often leads to repetitive outputs. Temperature Sampling allows for greater diversity, while Top-K Sampling provides control over the token selection process. Nucleus Sampling, however, offers the most flexibility, dynamically adjusting to the context to produce more natural and contextually relevant text.
Understanding and implementing these decoding methods can significantly enhance the capabilities of LLMs, making them more effective in generating human-like language across a variety of applications.