#!/usr/bin/env python3
"""
DPO (Direct Preference Optimization) Training for Domain Name Model
Uses preference pairs generated by hybrid judges to improve model quality.
Usage (on RunPod):
python train_dpo.py \
--model_path /workspace/training/output-full \
--preferences /workspace/training/rlhf/preference_pairs.jsonl \
--output /workspace/training/output-dpo
"""
import argparse
import json
from pathlib import Path
import torch
from datasets import Dataset
from peft import LoraConfig, get_peft_model, PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from trl import DPOTrainer, DPOConfig
def load_preferences(preferences_file: str) -> Dataset:
"""Load preference pairs into HuggingFace Dataset"""
data = {"prompt": [], "chosen": [], "rejected": []}
with open(preferences_file, "r") as f:
for line in f:
item = json.loads(line)
data["prompt"].append(item["prompt"])
data["chosen"].append(item["chosen"])
data["rejected"].append(item["rejected"])
return Dataset.from_dict(data)
def format_for_dpo(example, tokenizer):
"""Format examples for DPO training"""
# Create full response with domain context
chosen_response = f"Here's a great domain name suggestion:\n\n{example['chosen']}"
rejected_response = f"Here's a domain name suggestion:\n\n{example['rejected']}"
return {
"prompt": example["prompt"],
"chosen": chosen_response,
"rejected": rejected_response,
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", required=True, help="Path to SFT LoRA adapter")
parser.add_argument("--base_model", default="Qwen/Qwen2.5-7B-Instruct")
parser.add_argument("--preferences", required=True, help="Preference pairs JSONL")
parser.add_argument("--output", required=True, help="Output directory")
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--learning_rate", type=float, default=5e-6)
parser.add_argument("--beta", type=float, default=0.1, help="DPO beta parameter")
args = parser.parse_args()
print("Loading model and tokenizer...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" # Important for DPO
model = AutoModelForCausalLM.from_pretrained(
args.base_model,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
# Load SFT adapter first
print(f"Loading SFT adapter from {args.model_path}...")
model = PeftModel.from_pretrained(model, args.model_path)
# Merge SFT adapter into base model for DPO
print("Merging SFT adapter...")
model = model.merge_and_unload()
# Create new LoRA config for DPO training
peft_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.05,
target_modules=["q_proj", "v_proj"], # Simpler target for DPO
bias="none",
task_type="CAUSAL_LM",
)
# Load preference dataset
print(f"Loading preferences from {args.preferences}...")
dataset = load_preferences(args.preferences)
print(f" Loaded {len(dataset)} preference pairs")
# Format for DPO
dataset = dataset.map(lambda x: format_for_dpo(x, tokenizer))
# Split train/test
dataset = dataset.train_test_split(test_size=0.1)
# DPO training config
training_args = DPOConfig(
output_dir=args.output,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=4,
learning_rate=args.learning_rate,
beta=args.beta, # DPO beta - controls preference strength
logging_steps=10,
save_steps=100,
save_total_limit=2,
bf16=True,
gradient_checkpointing=True,
report_to="none",
remove_unused_columns=False,
max_length=512,
max_prompt_length=256,
)
# Create DPO trainer
# Note: TRL 0.12+ renamed 'tokenizer' to 'processing_class'
trainer = DPOTrainer(
model=model,
ref_model=None, # Will use implicit reference
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
processing_class=tokenizer, # TRL 0.12+ API
peft_config=peft_config,
)
print("Starting DPO training...")
trainer.train()
print(f"Saving model to {args.output}...")
trainer.save_model(args.output)
tokenizer.save_pretrained(args.output)
print("DPO training complete!")
if __name__ == "__main__":
main()