📚 Módulo 7: Configuración de Entrenamiento con TRL (Aprendizaje por Refuerzo de Transformadores)

7.1 ¿Qué Es TRL y Por Qué Usarlo?

TRL (Transformer Reinforcement Learning) es una biblioteca de Hugging Face diseñada específicamente para entrenar modelos de lenguaje con enfoques modernos: desde Ajuste Fino Supervisado (SFT) hasta técnicas avanzadas como RLHF (Aprendizaje por Refuerzo a partir de Retroalimentación Humana) y DPO (Optimización de Preferencia Directa).

Para nuestro caso — ajuste fino supervisado con LoRA/QLoRA — usaremos el SFTTrainer, una clase que extiende el Trainer estándar de Hugging Face pero optimizada para tareas de generación de texto. Ventajas clave:

  • Manejo automático de secuencias de longitud variable.
  • Soporte integrado para conjuntos de datos de instrucciones (formato Alpaca).
  • Compatibilidad nativa con PEFT y cuantización.
  • Integración con herramientas de monitoreo como Weights & Biases (wandb).
  • Optimizaciones de memoria y rendimiento para entrenamiento eficiente.

7.2 Instalación y Configuración Inicial

# Instalar TRL (si no se hizo antes)
!pip install -q trl

# Importar componentes clave
from trl import SFTTrainer
from transformers import TrainingArguments

7.3 Configurar TrainingArguments

TrainingArguments define todos los hiperparámetros de entrenamiento: tamaño de lote, épocas, tasa de aprendizaje, puntos de control, registro, etc.

training_args = TrainingArguments(
    output_dir="./results",              # Directorio para guardar puntos de control y registros
    num_train_epochs=3,                  # Número de épocas completas de entrenamiento
    per_device_train_batch_size=4,       # Tamaño de lote por GPU (ajustar según memoria)
    gradient_accumulation_steps=4,       # Acumular gradientes para simular lotes más grandes
    optim="paged_adamw_8bit",            # Optimizador eficiente en memoria (esencial para QLoRA)
    save_steps=500,                      # Guardar punto de control cada 500 pasos
    logging_steps=100,                   # Registrar métricas cada 100 pasos
    learning_rate=2e-4,                  # Tasa de aprendizaje LoRA (típico: 1e-4 a 3e-4)
    weight_decay=0.01,                   # Regularización L2
    fp16=True,                           # Entrenamiento de precisión mixta (FP16)
    bf16=False,                          # Deshabilitado a menos que la GPU soporte BF16 (A100, H100)
    max_grad_norm=0.3,                   # Recorte de gradientes para estabilidad
    warmup_ratio=0.03,                   # Calentamiento lineal de tasa de aprendizaje
    lr_scheduler_type="cosine",          # Decaimiento de tasa de aprendizaje coseno
    report_to="wandb",                   # Reportar métricas a Weights & Biases (opcional)
    evaluation_strategy="steps",         # Evaluar durante el entrenamiento
    eval_steps=500,                      # Evaluar cada 500 pasos
    save_total_limit=2,                  # Mantener solo los 2 últimos puntos de control
    load_best_model_at_end=True,         # Cargar mejor modelo al final (por métrica de evaluación)
    metric_for_best_model="eval_loss",   # Métrica que define "mejor modelo"
    greater_is_better=False,             # Menor pérdida es mejor
    push_to_hub=False,                   # No subir a Hugging Face Hub (opcional)
)

Notas Clave:

  • per_device_train_batch_size=4 + gradient_accumulation_steps=4 = tamaño de lote efectivo de 16.
  • optim="paged_adamw_8bit" es esencial para evitar OOM en QLoRA.
  • fp16=True acelera el entrenamiento y reduce la memoria. Si tu GPU soporta BF16 (Ampere+), usa bf16=True y fp16=False.
  • report_to="wandb" requiere una cuenta gratuita de Weights & Biases. De lo contrario, usa report_to="none".

7.4 Preparar el Conjunto de Datos para SFTTrainer

El SFTTrainer espera un conjunto de datos con un campo de texto formateado. Usaremos la función format_instruction del Módulo 6.

from datasets import Dataset

# Asumimos que tenemos ejemplos formateados en Alpaca
dataset_dict = {
    "instruction": [
        "Escribe una descripción corta para un producto tecnológico.",
        "Resume el siguiente texto en una oración.",
    ],
    "input": [
        "Producto: Auriculares inalámbricos con cancelación de ruido. Precio: $129.99.",
        "La IA generativa está transformando industrias como la educación, el entretenimiento y la atención médica al permitir la creación automatizada de contenido de alta calidad.",
    ],
    "output": [
        "Disfruta de tu música sin distracciones con estos auriculares inalámbricos de alta fidelidad. Con cancelación activa de ruido y hasta 30 horas de duración de batería, son ideales para viajar, trabajar o simplemente relajarse. Solo $129.99.",
        "La IA generativa está revolucionando sectores clave al automatizar la creación de contenido de alta calidad.",
    ]
}

# Convertir a Dataset de Hugging Face
dataset = Dataset.from_dict(dataset_dict)

# Aplicar formateo
def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    inputs = examples["input"]
    outputs = examples["output"]
    texts = []
    for instruction, input_text, output in zip(instructions, inputs, outputs):
        text = f"### Instrucción:\n{instruction}\n\n### Entrada:\n{input_text}\n\n### Respuesta:\n{output}"
        texts.append(text)
    return texts

# El SFTTrainer usará esta función para formatear ejemplos

7.5 Crear e Iniciar el SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,  # Función para formatear prompts
    max_seq_length=512,                       # Longitud máxima de secuencia
    tokenizer=tokenizer,
    packing=False,                            # No empaquetar secuencias (mejor para ajuste de instrucciones)
    dataset_text_field="text",                # Campo que contiene texto (innecesario si se usa formatting_func)
)

# Iniciar entrenamiento
trainer.train()

Importante: Si el conjunto de datos es grande, divídelo en train_dataset y eval_dataset y pásalos ambos al entrenador. Aquí, por simplicidad, usamos solo entrenamiento.

Course Info

Course: AI-course3

Language: ES

Lesson: Module7