How to Fine-Tune a GPT-2 Model: A Complete Guide
Fine-tuning a pre-trained language model like GPT-2 allows you to adapt its general knowledge to specific tasks without training from scratch. In this guide, we'll walk through the complete process of fine-tuning GPT-2 for spam classification using PyTorch.
Why Fine-Tune Instead of Training From Scratch?
Training a large language model from scratch requires enormous computational resources and massive datasets. Fine-tuning leverages the knowledge already captured in pre-trained weights, allowing you to:
- Use significantly less training data
- Reduce training time dramatically
- Achieve better performance on specialized tasks
- Require less computational power
The Dataset: SMS Spam Classification
For this tutorial, we'll use the SMS Spam Collection dataset, which contains text messages labeled as either "ham" (legitimate) or "spam". This binary classification task is perfect for demonstrating fine-tuning concepts.
Step 1: Data Preparation
First, we download and prepare a balanced dataset:
import pandas as pd
# Load the dataset
df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
# Create balanced dataset (equal spam and ham samples)
def create_balanced_dataset(df):
num_spam = df[df["Label"] == "spam"].shape[0]
ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)
balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])
return balanced_df
balanced_df = create_balanced_dataset(df)
balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})Split the data into training (70%), validation (10%), and test (20%) sets:
def random_split(df, train_frac=0.7, validation_frac=0.1):
df = df.sample(frac=1, random_state=123).reset_index(drop=True)
train_end = int(len(df) * train_frac)
validation_end = train_end + int(len(df) * validation_frac)
train_df = df[:train_end]
validation_df = df[train_end:validation_end]
test_df = df[validation_end:]
return train_df, validation_df, test_dfStep 2: Create a Custom Dataset Class
We need to tokenize the text data and prepare it for the model:
import tiktoken
import torch
from torch.utils.data import Dataset
class SpamDataset(Dataset):
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
self.data = pd.read_csv(csv_file)
# Encode all texts
self.encoded_texts = [
tokenizer.encode(text) for text in self.data["Text"]
]
# Determine max length
if max_length is None:
self.max_length = self._longest_encoded_length()
else:
self.max_length = max_length
self.encoded_texts = [
encoded_text[:self.max_length]
for encoded_text in self.encoded_texts
]
# Pad sequences to max_length
self.encoded_texts = [
encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))
for encoded_text in self.encoded_texts
]
def __getitem__(self, index):
encoded = self.encoded_texts[index]
label = self.data.iloc[index]["Label"]
return (
torch.tensor(encoded, dtype=torch.long),
torch.tensor(label, dtype=torch.long),
)
def __len__(self):
return len(self.data)Step 3: Load Pre-trained GPT-2 Weights
Instead of random initialization, we load pre-trained weights:
from gpt_download import download_and_load_gpt2
from previous_chapters import GPTModel, load_weights_into_gpt
# Configure the model
BASE_CONFIG = {
"vocab_size": 50257,
"context_length": 1024,
"drop_rate": 0.0,
"qkv_bias": True,
"emb_dim": 768, # GPT-2 Small
"n_layers": 12,
"n_heads": 12
}
# Download and load pre-trained weights
settings, params = download_and_load_gpt2(model_size="124M", models_dir="gpt2")
# Initialize model and load weights
model = GPTModel(BASE_CONFIG)
load_weights_into_gpt(model, params)Step 4: Modify the Model for Classification
The original GPT-2 outputs a vocabulary distribution for text generation. For classification, we need to replace the output head:
# Freeze all parameters first
for param in model.parameters():
param.requires_grad = False
# Replace output head for binary classification
num_classes = 2
model.out_head = torch.nn.Linear(BASE_CONFIG["emb_dim"], num_classes)
# Unfreeze the last transformer block and final layer norm
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
for param in model.final_norm.parameters():
param.requires_grad = TrueThis strategy freezes most of the model (keeping the learned language representations) while only training:
- The last transformer block
- The final layer normalization
- The new classification head
Step 5: Define Training Functions
We need functions to calculate loss and accuracy:
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :] # Use last token's output
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0
if num_batches is None:
num_batches = len(data_loader)
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch = input_batch.to(device)
target_batch = target_batch.to(device)
with torch.no_grad():
logits = model(input_batch)[:, -1, :]
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examplesStep 6: Training Loop
Here's the complete training function:
def train_classifier_simple(model, train_loader, val_loader, optimizer,
device, num_epochs, eval_freq, eval_iter):
train_losses, val_losses, train_accs, val_accs = [], [], [], []
examples_seen, global_step = 0, -1
for epoch in range(num_epochs):
model.train()
for input_batch, target_batch in train_loader:
optimizer.zero_grad()
loss = calc_loss_batch(input_batch, target_batch, model, device)
loss.backward()
optimizer.step()
examples_seen += input_batch.shape[0]
global_step += 1
# Periodic evaluation
if global_step % eval_freq == 0:
train_loss, val_loss = evaluate_model(
model, train_loader, val_loader, device, eval_iter
)
train_losses.append(train_loss)
val_losses.append(val_loss)
print(f"Ep {epoch+1} (Step {global_step:06d}): "
f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
# Calculate accuracy after each epoch
train_accuracy = calc_accuracy_loader(
train_loader, model, device, num_batches=eval_iter
)
val_accuracy = calc_accuracy_loader(
val_loader, model, device, num_batches=eval_iter
)
print(f"Training accuracy: {train_accuracy*100:.2f}% | "
f"Validation accuracy: {val_accuracy*100:.2f}%")
train_accs.append(train_accuracy)
val_accs.append(val_accuracy)
return train_losses, val_losses, train_accs, val_accs, examples_seenStep 7: Run the Training
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Configure optimizer with weight decay for regularization
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
# Train
num_epochs = 5
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
model, train_loader, val_loader, optimizer, device,
num_epochs=num_epochs, eval_freq=50, eval_iter=5
)Results
After just 5 epochs, the fine-tuned model achieved:
- Training accuracy: 87.50%
- Validation accuracy: 82.50%
Starting from 53.75% accuracy (essentially random guessing), the model learned to classify spam with high accuracy in minimal training time.
Key Takeaways
-
Leverage pre-trained weights: Starting with GPT-2's language understanding gives you a massive head start
-
Freeze strategically: Only train the layers necessary for your task to prevent overfitting and reduce computational cost
-
Use appropriate learning rates: Small learning rates (like 5e-5) work well for fine-tuning to avoid catastrophic forgetting
-
Add regularization: Weight decay helps prevent overfitting when adapting to smaller datasets
-
Monitor validation metrics: Always track validation performance to detect overfitting early
Next Steps
To further improve your fine-tuned model:
- Experiment with unfreezing more layers
- Try different learning rate schedules
- Augment your training data
- Tune hyperparameters like batch size and weight decay
- Test different GPT-2 model sizes (medium, large, XL)
Fine-tuning opens up powerful possibilities for adapting large language models to specific domains and tasks without the prohibitive costs of training from scratch. With the right approach, you can achieve impressive results even with limited data and computational resources.