跳转至

VeRL 训推交叉打桩功能指南

功能概述

在强化学习训练过程中的两个阶段(推理、训练),精度问题一般需要先定界出问题的阶段,再去采集该阶段精度数据进行分析。本文基于打桩替换掉一个阶段输出的思路进行定界,提供以下功能的使用指南:

功能 说明
Rollout Skip 基础功能 跳过推理阶段,直接加载已保存的推理数据,节省训练时间
Rollout Skip 增强功能 支持多次调用 generate_sequences,避免重复使用相同数据
交叉验证 两个训练过程交叉使用推理数据和检查点,实现精度定界

一、Rollout Skip 基础功能

Rollout Skip是VeRL 提供的基础功能,用于跳过推理阶段,直接加载已保存的推理数据,适用于需要反复运行相同配置实验的场景。

版本支持:VeRL v0.6.0 - v0.7.1

使用方式

在训练启动命令中添加以下参数:

python train.py \
  actor_rollout_ref.rollout.skip_rollout=True \
  actor_rollout_ref.rollout.skip_dump_dir="/tmp/rollout_dump"

落盘文件

文件命名格式:{experiment_name}_{project_name}_GBS{gbs}__N{n}

  • experiment_name:实验名称
  • project_name:项目名称
  • gbs:生成批次大小
  • n:每个 prompt 生成的回复数量

使用效果

  • 首次运行:执行推理并保存数据
  • 后续运行:直接加载已保存的推理数据,跳过推理阶段

二、Rollout Skip 增强功能

当前基础功能仅支持一次 generate_sequences 调用的数据落盘,当 generate_sequences 被多次调用时,基础功能会导致后续调用重复使用第一次的数据,和实际的推理数据不一致。增强功能通过添加索引解决此问题。

代码修改

修改文件 verl/verl/utils/rollout_skip.py

class RolloutSkip__init__ 方法中添加索引变量:

 class RolloutSkip:
     def __init__(self, config, rollout_wg):
         self.rollout_config = config.actor_rollout_ref.rollout
         self.exp_name = config.data.get("experiment_name", "")
         self.project_name = config.data.get("project_name", "")
         self.n = int(self.rollout_config.get("n", 0))
         self.gbs = int(config.data.get("gen_batch_size", config.data.get("train_batch_size", 0)))
+        self.index = 0

class RolloutSkip 中添加索引递增方法:

+    def _add_index(self):
+        self.index += 1

class RolloutSkip 中修改 curr_path_dump 属性,添加索引后缀:

     @property
     def curr_path_dump(self):
-        return self.dumped_dir.joinpath(f"{self.exp_name}_{self.project_name}_GBS{self.gbs}__N{self.n}").absolute()
+        return self.dumped_dir.joinpath(
+            f"{self.exp_name}_{self.project_name}_GBS{self.gbs}_N{self.n}_{self.index}").absolute()

wrap_generate_sequences 函数中添加索引递增调用:

 def wrap_generate_sequences(rolloutskip: RolloutSkip, rollout_wg):
     generate_sequences = rollout_wg.generate_sequences

     def warp_fn(batch, **kwargs):
         gen_batch_output = rolloutskip.try_load()

         if gen_batch_output is None:
             # * 1. Generation
             gen_batch_output = generate_sequences(batch, **kwargs)
             # * 2. Dump
             rolloutskip.dump(gen_batch_output)
+        rolloutskip._add_index()
         return gen_batch_output
     return warp_fn

落盘文件

文件命名格式:{experiment_name}_{project_name}_GBS{gbs}_N{n}_{index}

  • 增加了 {index} 后缀,每次调用自动递增

使用效果

  • 首次运行
  • 第1次调用:执行推理并保存为 ..._N{n}_0
  • 第2次调用:执行推理并保存为 ..._N{n}_1
  • 以此类推,每次调用生成独立文件

  • 后续运行

  • 第1次调用:加载 ..._N{n}_0
  • 第2次调用:加载 ..._N{n}_1
  • 以此类推,每次调用加载对应索引的文件

  • 避免多次调用重复使用相同数据的问题

使用方式

代码修改完成后,使用方式与基础功能相同,无需额外参数,索引自动递增:

python train.py \
  actor_rollout_ref.rollout.skip_rollout=True \
  actor_rollout_ref.rollout.skip_dump_dir="/tmp/rollout_dump"

三、交叉验证功能

通过训推交叉打桩的方式,实现训练和推理数据的共享与复用,帮助定界精度问题发生在训练还是推理阶段。

工作原理

两种工作模式:

模式 推理阶段 训练阶段
ckpt 执行推理 → 保存rollout数据 执行训练 → 加载ckpt
rollout 加载rollout数据(跳过推理) 执行训练 → 保存ckpt

两个训练过程通过共享数据目录进行数据交换,实现训练和推理的分离与交叉验证。

约束限制

  • 当前支持VeRL强化学习框架
  • 需要两个训练过程使用相同的配置和数据,且两个训练过程能同时拉起
  • 共享数据路径需要可访问且具有足够存储空间

参数说明

参数 选项值 参数说明
share_data "rollout" / "ckpt" 工作模式选择。
• rollout:加载推理数据,保存检查点。
• ckpt:保存推理数据,加载检查点。
share_data_dir str 共享数据存储路径。默认值为:/tmp/verl/share_data_dir

代码修改说明

以VeRL v0.7.1 版本为例,修改如下:

1. 新增工具类

创建文件 verl/verl/utils/share_data.py:

import torch
from pathlib import Path
from verl.protocol import DataProto
import time
from functools import wraps

class ShareDataTest:

    print_mark = "[ShareDataTest]"

    def __init__(self, config):
        self.rollout_config = config

        self.rollout_index = 0  # 需要添加
        self.ckpt_index = 0
        self.share_flag = self.rollout_config.get("share_data", "")

        self.dumped_dir = Path(self.rollout_config.get("share_data_dir", "/tmp/verl/share_data_dir"))
        self.dumped_dir.mkdir(parents=True, exist_ok=True)

        # Check if path is in Ray temporary directory
        if str(self.dumped_dir.absolute()).startswith("/tmp/ray/session"):
            print(
                f"\033[33m{self.print_mark} Warning: \nUsing dump path ",
                f"'{self.dumped_dir.absolute()}' is not recommended ",
                "as it's located in /tmp/ray/session*\033[0m",
                flush=True,
            )

        print(
            f"{self.print_mark} dump path set to: ",
            f"{self.dumped_dir.absolute()}",
            flush=True,
        )

    def _add_rollout_index(self):
        self.rollout_index += 1

    def _add_ckpt_index(self):
        self.ckpt_index += 1

    @property
    def curr_path_rollout(self):
        return self.dumped_dir.joinpath(
            f"rollout_{self.rollout_index}").absolute()

    @property
    def curr_path_ckpt(self):
        return self.dumped_dir.joinpath(
            f"ckpt_{self.ckpt_index}").absolute()

    def try_load(self):
        if not self.curr_path_rollout.exists():
            print(
                f"{self.print_mark} No data dump found at {self.curr_path_rollout}.",
                flush=True,
            )
            return None

        try:
            # * Load
            ret_batch = DataProto.load_from_disk(self.curr_path_rollout)
            print(
                f"\033[32m{self.print_mark} Successfully load pre-generated data from {self.curr_path_rollout}\033[0m",
                flush=True,
            )
            return ret_batch
        except Exception as e:
            print(
                f"\033[31m{self.print_mark} Failed to load pre-generated data from {self.curr_path_rollout}",
                f"Error: {str(e)}\033[0m",
                flush=True,
            )
            return None

    def dump(self, outputs: DataProto):
        try:
            outputs.save_to_disk(self.curr_path_rollout)
            print(
                f"\033[32m{self.print_mark} Successfully dump data in {self.curr_path_rollout}\033[0m",
                flush=True,
            )
        except Exception as e:
            print(
                f"\033[31m{self.print_mark} Failed to dump data in {self.curr_path_rollout}: {e}\033[0m",
                flush=True,
            )


def wrap_generate_sequences(share_data: ShareDataTest, worker):
    """包装generate_sequences方法,支持数据共享"""
    original_generate_sequences = worker.generate_sequences

    @wraps(original_generate_sequences)
    def wrapped_fn(*args, **kwargs):
        if share_data.share_flag == "rollout":
            # rollout模式:从共享文件加载数据
            gen_batch_output = None
            while gen_batch_output is None:
                print(f"\033[32m{share_data.print_mark} Waiting for shared data...\033[0m", flush=True)
                time.sleep(20)
                gen_batch_output = share_data.try_load()

        elif share_data.share_flag == "ckpt":
            # ckpt模式:生成数据并保存到共享文件
            gen_batch_output = original_generate_sequences(*args, **kwargs)
            share_data.dump(gen_batch_output)
        else:
            # 默认模式:直接执行
            gen_batch_output = original_generate_sequences(*args, **kwargs)

        # 增加索引
        share_data._add_rollout_index()
        return gen_batch_output

    return wrapped_fn

def after_update_policy(share_data, load_func, dump_func):
    """
    根据配置选这加载ckpt还是保存ckpt
    """
    # 如果需要读ckpt
    if share_data.share_flag == "ckpt":
        while not share_data.curr_path_ckpt.exists():
            print(f"\033[32m{share_data.print_mark} waiting for {share_data.curr_path_ckpt}\033[0m", flush=True)
            time.sleep(20)
        time.sleep(60)
        succ = False
        while not succ:
            try:
                load_func(share_data.curr_path_ckpt)
                succ = True
                print(f"\033[32m{share_data.print_mark} Successfully load ckpt from {share_data.curr_path_ckpt}\033[0m", flush=True)
            except Exception as e:
                print(f"\033[31m{share_data.print_mark} Load ckpt failed from {share_data.curr_path_ckpt} because of {e} \033[0m", flush=True)
                time.sleep(20)
    # 如果需要存ckpt
    if share_data.share_flag == "rollout":
        try:
            dump_func(share_data.curr_path_ckpt)
            print(f"\033[32m{share_data.print_mark} Successfully save ckpt to {share_data.curr_path_ckpt}\033[0m", flush=True)
        except Exception as e:
            print(f"\033[31m{share_data.print_mark} Failed to save ckpt to {share_data.curr_path_ckpt}: {e}\033[0m", flush=True)
    share_data._add_ckpt_index()

2. 修改Trainer类

修改文件 verl/verl/trainer/ppo/ray_trainer.py:

fit() 函数中添加初始化逻辑,并包装 generate_sequences 方法:

 if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):
     rollout_skip = RolloutSkip(self.config, self.async_rollout_manager)
     rollout_skip.wrap_generate_sequences()
+if self.config.trainer.get("share_data", None):
+    from verl.utils.share_data import ShareDataTest, wrap_generate_sequences, after_update_policy
+    share_data_manager = ShareDataTest(self.config.trainer)
+    self.async_rollout_manager.generate_sequences = \
+        wrap_generate_sequences(share_data_manager, self.async_rollout_manager)

fit() 函数的训练循环中,_update_actor 调用之后添加检查点同步:

 if self.config.trainer.critic_warmup <= self.global_steps:
     # update actor
     with marked_timer("update_actor", timing_raw, color="red"):
         actor_output = self._update_actor(batch)
+    if self.config.trainer.get("share_data", None):
+        after_update_policy(
+            share_data_manager, 
+            self.actor_rollout_wg.load_checkpoint, 
+            self.actor_rollout_wg.save_checkpoint
+        )

使用步骤

结果分析

交叉验证的定界思路:

背景:训练过程A出现崩溃,但无法确定是训练阶段还是推理阶段的问题。需要借助一个正常的训练过程B(作为标杆),在相同实验配置下进行交叉验证。

定界方法

场景 训练过程A 训练过程B 结果判断
场景1 自身推理 + 加载ckpt 加载rollout + 自身训练 如果训练A崩溃,说明A的推理阶段有问题
场景2 加载rollout + 自身训练 自身推理 + 加载ckpt 如果训练A崩溃,说明A的训练阶段有问题

场景1测试(定界推理阶段):

训练过程A(ckpt 模式):

python train.py \
  ++trainer.share_data="ckpt" \
  ++trainer.share_data_dir="/root/autodl-tmp/share_data"
  • 推理阶段:执行推理并保存rollout数据
  • 训练阶段:加载训练过程B的检查点

训练过程B(rollout 模式):

python train.py \
  ++trainer.share_data="rollout" \
  ++trainer.share_data_dir="/root/autodl-tmp/share_data"
  • 推理阶段:加载训练过程A的rollout数据
  • 训练阶段:执行训练并保存检查点

场景2测试(定界训练阶段):

训练过程A(rollout 模式):

python train.py \
  ++trainer.share_data="rollout" \
  ++trainer.share_data_dir="/root/autodl-tmp/share_data"
  • 推理阶段:加载训练过程B的rollout数据
  • 训练阶段:执行训练并保存检查点

训练过程B(ckpt 模式):

python train.py \
  ++trainer.share_data="ckpt" \
  ++trainer.share_data_dir="/root/autodl-tmp/share_data"
  • 推理阶段:执行推理并保存rollout数据
  • 训练阶段:加载训练过程A的检查点

通过这种方式,可以快速定位问题发生在训练过程A的哪个阶段。

日志说明

运行过程中会打印以下日志信息,用于监控数据共享状态:

日志信息 含义
dump path set to: {path} 初始化时打印共享数据目录路径
Warning: Using dump path... 警告:使用了 Ray 临时目录,不推荐
Waiting for shared data... rollout 模式下等待共享数据
No data dump found at {path} 未找到共享数据文件
Successfully load pre-generated data from {path} 成功加载推理数据(绿色)
Failed to load pre-generated data from {path} 加载推理数据失败(红色)
Successfully dump data in {path} 成功保存推理数据(绿色)
Failed to dump data in {path} 保存推理数据失败(红色)
waiting for {ckpt_path} ckpt 模式下等待检查点文件
Successfully load ckpt from {path} 成功加载检查点(绿色)
Load ckpt failed from {path} 加载检查点失败(红色)
Successfully save ckpt to {path} 成功保存检查点(绿色)
Failed to save ckpt to {path} 保存检查点失败(红色)

数据说明

共享数据目录下的文件命名格式:

  • 推理数据rollout_{index},index 从 0 开始递增
  • 检查点数据ckpt_{index},index 从 0 开始递增

两个训练过程通过相同的 index 保持数据同步。