Collecting Data for Verifying Data Consistency Between verl Training and Inference Based on FSDP¶
Overview¶
Before comparing [data consistency between verl training and inference processes] (../accuracy_compare/pytorch_accuracy_compare_instruct.md#verl_training_and_inference_data_consistency_comparison), ensure that the input shapes during training and inference are the same. This ensures that the precision data dumped during training and inference can be matched during comparison.
Input Alignment Analysis for verl Training and Inference¶
Generally, the input shapes during training and inference are different, due to characteristics of training and inference processes themselves.
-
The inference process is divided into two stages: prefill and k decode.
-
In the prefill stage, the inference input is a prompt.
-
In the k decode stage, the final inference output response is generated by adding the KV cache with the output token obtained from the previous decode operation.
-
During training, the input is a prompt plus the inference response, and the final output is logits.
In conclusion, the inference input is a prompt, and the training input is the prompt plus the inference output response.
Conclusion: The training input needs to be adjusted to a single prompt, making it consistent with the inference input.
Preparations¶
To ensure that the training forward shape matches the inference prefill shape, the response must be removed from the training input. This requires meeting the following prerequisites and modifying the training script accordingly.
-
Ensure that the batch size in training is not split.
-
Ensure that the number of mini batches used for gradient update in each training epoch is mini_batch_num = 1.
Formula: mini_batch_num = train_batch_size/train_ppo_mini_batch_size - train_batch_size: total number of samples in training. - train_ppo_mini_batch_size: number of samples in each mini batch.
-
Ensure gac (Gradient Accumulation Steps) = 1.
Formula: gac = train_ppo_mini_batch_size * n_resp_per_prompt/train_ppo_micro_batch_size_per_gpu/DP - train_ppo_mini_batch_size: number of samples in each mini batch. - n_resp_per_prompt: number of responses per prompt. - train_ppo_micro_batch_size_per_gpu: size of the micro batch processed on each GPU. - DP: data parallelism degree.
Modify the preceding parameters in the script as follows:
data.train_batch_size=${train_batch_size}
actor_rollout_ref.actor.ppo_mini_batch_size=${train_ppo_mini_batch_size}
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_ppo_micro_batch_size_per_gpu}
actor_rollout_ref.rollout.n=${n_resp_per_prompt}
-
Ensure that no padding occurs during training.
shell actor_rollout_ref.model.use_remove_padding=True actor_rollout_ref.actor.use_dynamic_bsz=False -
Modify the environment variables in the training script.
- Ensure that the data collected during training and inference corresponds one-to-one on each card. balance_batch can be used to automatically balance and evenly divide batch data.
Modification to verl Code¶
To delete the response from the training input, you need to modify verl/workers/actor/dp_actor.py. The following uses release/v0.6.1 as an example.
...
...
def _forward_micro_batch(
self, micro_batch, temperature, calculate_entropy=False
) -> tuple[torch.Tensor, torch.Tensor]:
"""..."""
+ # Modification to _forward_micro_batch
- response_length = micro_batch["responses"].size(-1)
+ if "responses" in micro_batch and micro_batch["responses"] is not None:
+ response_length = micro_batch["responses"].size(-1)
+ else:
+ response_length = 0
multi_modal_inputs = {}
...
@GPUMemoryLogger(role="dp actor", logger=logger)
def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor:
"""..."""
# set to eval
self.actor_module.eval()
+ # Modification to compute_log_prob
+ compute_prompts_only = int(os.getenv("PROMPTS_ONLY", "0"))
+ if compute_prompts_only:
+ if "responses" in data.batch:
+ responses_len = data.batch["responses"].size(1)
+ data.batch["input_ids"] = data.batch["input_ids"][:, :-responses_len]
+ data.batch["attention_mask"] = data.batch["attention_mask"][:, :-responses_len]
+ if data.batch["position_ids"].dim() == 3:
+ data.batch["position_ids"] = data.batch["position_ids"][:, :, :-responses_len]
+ else:
+ data.batch["position_ids"] = data.batch["position_ids"][:, :-responses_len]
+ # remove responses from batch
+ data.batch["responses"] = None
+ if "rollout_log_probs" in data.batch:
+ data.batch["rollout_log_probs"] = None
+ if "response_mask" in data.batch:
+ data.batch["response_mask"] = None
+
micro_batch_size = data.meta_info["micro_batch_size"]
...
@GPUMemoryLogger(role="dp actor", logger=logger)
def update_policy(self, data: DataProto):
# make sure we are in training mode
self.actor_module.train()
temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error
+ # Modification to update_policy
+ compute_prompts_only = int(os.getenv("PROMPTS_ONLY", "0"))
+ if compute_prompts_only:
+ if "responses" in data.batch:
+ responses_len = data.batch["responses"].size(1)
+ data.batch["input_ids"] = data.batch["input_ids"][:, :-responses_len]
+ data.batch["attention_mask"] = data.batch["attention_mask"][:, :-responses_len]
+ if data.batch["position_ids"].dim() == 3:
+ data.batch["position_ids"] = data.batch["position_ids"][:, :, :-responses_len]
+ else:
+ data.batch["position_ids"] = data.batch["position_ids"][:, :-responses_len]
+ # remove responses from batch
+ data.batch["responses"] = None
+ if "rollout_log_probs" in data.batch:
+ data.batch["rollout_log_probs"] = None
+ if "response_mask" in data.batch:
+ data.batch["response_mask"] = None
+
select_keys = [
"responses",
"response_mask",
"input_ids",
"attention_mask",
"position_ids",
"old_log_probs",
"advantages",
]
...
...
# Extract pre-computed rollout correction weights if present
# Weights are computed centrally in trainer and added when algorithm.rollout_is=True
rollout_is_weights = model_inputs.get("rollout_is_weights", None)
+ # Modification to update_policy
+ if response_mask is None:
+ prompt_mask = torch.ones_like(log_prob, dtype=torch.bool)
+ response_mask = prompt_mask
+
# gpg -> verl.trainer.ppo.core_algos.compute_policy_loss_gpg
# clip_cov -> verl.trainer.ppo.core_algos.compute_policy_loss_clip_cov
policy_loss_fn = get_policy_loss_fn(loss_mode)
# Compute policy loss (any function is expected to return 2 values)
pg_loss, pg_metrics = policy_loss_fn(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
rollout_is_weights=rollout_is_weights,
)
micro_batch_metrics.update(pg_metrics)
...
Data Collection¶
Add the PrecisionDebugger API of msProbe to the verl/workers/fsdp_workers.py file for data dump. For details about PrecisionDebugger, see Precision Data Collection in PyTorch.
Modify code as highlighted in the following example:
...
class ActorRolloutRefWorker(Worker, DistProfilerExtension):
"""
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
or a hybrid engine based on the config.rollout
"""
def __init__(self, config: DictConfig, role: str, **kwargs):
...
# normalize rollout config
if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None:
self.config.rollout.log_prob_micro_batch_size //= (
self.device_mesh.size() // self.ulysses_sequence_parallel_size
)
self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size
# normalize ref config
if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None:
self.config.ref.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size
self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size
+ # Modification to __init__
+ # Instantiate PrecisionDebugger.
+ dump_flag = int(os.environ.get("DUMP_ON", 0)) # Set DUMP_ON to quickly enable or disable the dump function.
+ if dump_flag:
+ from msprobe.pytorch import PrecisionDebugger, seed_all
+ seed_all(mode=True)
+ self.debugger = PrecisionDebugger(task='tensor', level='L0', dump_path='example_dump_path', step=[0])
+ self.dump_path_prefix = self.debugger.config.dump_path
+ else:
+ self.debugger = None
...
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
@DistProfiler.annotate(color="red", role="actor_update")
def update_actor(self, data: DataProto):
...
with self.ulysses_sharding_manager:
data = data.to("cpu") # data will to device with each micro batch on actor.update_policy
+ # Modification to update_actor
+ if self.debugger:
+ self.debugger.service.config.dump_path = os.path.join(self.dump_path_prefix, 'update_actor') # Training results are saved in the update_actor folder.
+ self.debugger.start(model=self.actor.actor_module)
# perform training
with Timer(name="update_policy", logger=None) as timer:
metrics = self.actor.update_policy(data=data)
+ if self.debugger:
+ self.debugger.stop()
+ self.debugger.step()
delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"]
...
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout"))
@DistProfiler.annotate(color="red", role="rollout_generate")
def generate_sequences(self, prompts: DataProto):
...
with simple_timer("generate_sequences", timing_generate):
+ # Modification to generate_sequences
+ if self.debugger:
+ self.debugger.service.config.dump_path = os.path.join(self.dump_path_prefix, 'generate_sequences') # Inference results are saved in the generate_sequences folder.
+ infer_model = self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.get_model()
+ self.debugger.start(model=infer_model, token_range=[0, 0])
output = self.rollout.generate_sequences(prompts=prompts)
+ if self.debugger:
+ self.debugger.stop()
+ self.debugger.service._reset_status()
if self._is_actor:
loop.run_until_complete(self.trainer_mode())
log_gpu_memory_usage("After switch to trainer mode", logger=logger)
...