TRL documentation
GRPO With Replay Buffer
GRPO With Replay Buffer
This experimental trainer, trains a model with GRPO but replaces groups (and corresponding completions) that have 0 standard deviation with groups with high rewards and standard deviation that’ve been used to train a model in prior batches.
Usage
import torch
from trl.experimental.grpo_with_replay_buffer import GRPOWithReplayBufferConfig, GRPOWithReplayBufferTrainer
from datasets import load_dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
# Guarantee that some rewards have 0 std
def custom_reward_func(completions, **kwargs):
if torch.rand(1).item() < 0.25:
return [0] * len(completions) # simulate some None rewards
else:
return torch.rand(len(completions)).tolist()
training_args = GRPOWithReplayBufferConfig(
output_dir="./tmp",
learning_rate=1e-4,
per_device_train_batch_size=4,
num_generations=4,
max_completion_length=8,
replay_buffer_size=8,
report_to="none",
)
trainer = GRPOWithReplayBufferTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=[custom_reward_func],
args=training_args,
train_dataset=dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()GRPOWithReplayBufferTrainer
class trl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferTrainer
< source >( args: trl.experimental.grpo_with_replay_buffer.grpo_with_replay_buffer_config.GRPOWithReplayBufferConfig | None = None **kwargs )
train
< source >( resume_from_checkpoint: str | bool | None = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: list[str] | None = None )
Parameters
- resume_from_checkpoint (
strorbool, optional) — If astr, local path to a saved checkpoint as saved by a previous instance ofTrainer. If abooland equalsTrue, load the last checkpoint in args.output_dir as saved by a previous instance ofTrainer. If present, training will resume from the model/optimizer/scheduler states loaded here. - trial (
optuna.Trialordict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search. - ignore_keys_for_eval (
list[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.
Main training entry point.
Will save the model, so you can reload it using from_pretrained().
Will only save from the main process.
push_to_hub
< source >( commit_message: str | None = 'End of training' blocking: bool = True token: str | None = None revision: str | None = None **kwargs )
Parameters
- commit_message (
str, optional, defaults to"End of training") — Message to commit while pushing. - blocking (
bool, optional, defaults toTrue) — Whether the function should return only when thegit pushhas finished. - token (
str, optional, defaults toNone) — Token with write permission to overwrite Trainer’s original args. - revision (
str, optional) — The git revision to commit from. Defaults to the head of the “main” branch. - kwargs (
dict[str, Any], optional) — Additional keyword arguments passed along to~Trainer.create_model_card.
Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.
GRPOWithReplayBufferConfig
class trl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferConfig
< source >( output_dir: str | None = None do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: transformers.trainer_utils.IntervalStrategy | str = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 gradient_accumulation_steps: int = 1 eval_accumulation_steps: int | None = None eval_delay: float = 0 torch_empty_cache_steps: int | None = None learning_rate: float = 1e-06 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: transformers.trainer_utils.SchedulerType | str = 'linear' lr_scheduler_kwargs: dict | str | None = None warmup_ratio: float | None = None warmup_steps: float = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: str | None = None logging_strategy: transformers.trainer_utils.IntervalStrategy | str = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: transformers.trainer_utils.SaveStrategy | str = 'steps' save_steps: float = 500 save_total_limit: int | None = None enable_jit_checkpoint: bool = False save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False use_cpu: bool = False seed: int = 42 data_seed: int | None = None bf16: bool | None = None fp16: bool = False bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: bool | None = None local_rank: int = -1 ddp_backend: str | None = None debug: str | list[transformers.debug_utils.DebugOption] = '' dataloader_drop_last: bool = False eval_steps: float | None = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: int | None = None run_name: str | None = None disable_tqdm: bool | None = None remove_unused_columns: bool | None = False label_names: list[str] | None = None load_best_model_at_end: bool = False metric_for_best_model: str | None = None greater_is_better: bool | None = None ignore_data_skip: bool = False fsdp: list[transformers.trainer_utils.FSDPOption] | str | None = None fsdp_config: dict[str, typing.Any] | str | None = None accelerator_config: dict | str | None = None parallelism_config: accelerate.parallelism_config.ParallelismConfig | None = None deepspeed: dict | str | None = None label_smoothing_factor: float = 0.0 optim: transformers.training_args.OptimizerNames | str = 'adamw_torch_fused' optim_args: str | None = None group_by_length: bool = False length_column_name: str = 'length' report_to: None | str | list[str] = 'none' project: str = 'huggingface' trackio_space_id: str | None = 'trackio' ddp_find_unused_parameters: bool | None = None ddp_bucket_cap_mb: int | None = None ddp_broadcast_buffers: bool | None = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True push_to_hub: bool = False resume_from_checkpoint: str | None = None hub_model_id: str | None = None hub_strategy: transformers.trainer_utils.HubStrategy | str = 'every_save' hub_token: str | None = None hub_private_repo: bool | None = None hub_always_push: bool = False hub_revision: str | None = None gradient_checkpointing: bool = True gradient_checkpointing_kwargs: dict[str, typing.Any] | str | None = None include_for_metrics: list = <factory> eval_do_concat_batches: bool = True auto_find_batch_size: bool = False full_determinism: bool = False ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: str | None = None torch_compile_mode: str | None = None include_num_input_tokens_seen: str | bool = 'no' neftune_noise_alpha: float | None = None optim_target_modules: None | str | list[str] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: bool = False liger_kernel_config: dict[str, bool] | None = None eval_use_gather_object: bool = False average_tokens_across_devices: bool = True use_cache: bool = False model_init_kwargs: dict | str | None = None disable_dropout: bool = False cast_lm_head_to_fp32: bool = False num_generations: int | None = 8 num_generations_eval: int | None = None max_completion_length: int | None = 256 ds3_gather_for_generation: bool = True shuffle_dataset: bool | None = True generation_batch_size: int | None = None steps_per_generation: int | None = None temperature: float = 1.0 top_p: float = 1.0 top_k: int = 0 min_p: float | None = None generation_kwargs: dict | None = None chat_template_kwargs: dict | None = None repetition_penalty: float = 1.0 use_transformers_paged: bool = False cache_implementation: str | None = None use_vllm: bool = False vllm_mode: str = 'server' vllm_model_impl: str = 'vllm' vllm_enable_sleep_mode: bool = False vllm_structured_outputs_regex: str | None = None vllm_server_base_url: str | None = None vllm_server_host: str = '0.0.0.0' vllm_server_port: int = 8000 vllm_server_timeout: float = 240.0 vllm_group_port: int = 51216 vllm_gpu_memory_utilization: float = 0.3 vllm_max_model_length: int | None = None vllm_tensor_parallel_size: int = 1 beta: float = 0.0 num_iterations: int = 1 epsilon: float = 0.2 delta: float | None = None epsilon_high: float | None = None sapo_temperature_neg: float = 1.05 sapo_temperature_pos: float = 1.0 importance_sampling_level: str = 'token' reward_weights: list[float] | None = None multi_objective_aggregation: str = 'sum_then_normalize' scale_rewards: str = 'group' loss_type: str = 'dapo' mask_truncated_completions: bool = False sync_ref_model: bool = False ref_model_mixup_alpha: float = 0.6 ref_model_sync_steps: int = 512 top_entropy_quantile: float = 1.0 max_tool_calling_iterations: int | None = None vllm_importance_sampling_correction: bool = True vllm_importance_sampling_mode: str = 'sequence_mask' vllm_importance_sampling_cap: float = 3.0 off_policy_mask_threshold: float | None = None use_bias_correction_kl: bool = False log_completions: bool = False num_completions_to_print: int | None = None log_unique_prompts: bool = False replay_buffer_size: int = 64 )
New Parameters:
replay_buffer_size (int, optional, defaults to 0):
A cache that stores the rollouts with the highest advantage scores and variance per group. If a new
group has 0 variance, it is replaced with a group sampled from the replay buffer.
ReplayBuffer
A simple replay buffer to store and sample previously seen rollouts.