Skip to content

[RFC] Support Qwen-Image Flow-GRPO Training based on vLLM-Omni #4639

@chenyingshu

Description

@chenyingshu

Motivation

The goal is to enhance verl’s scalability so that it can support online DPO-like training for state-of-the-art diffusion image and video generation models, including Qwen-Image, Z-Image, Wan2.2, and others. We choose Flow-GRPO as the representative algorithm in this domain, while additional algorithms such as DiffusionNFT and DanceGRPO can be seamlessly integrated following this update. As an initial step, Qwen-Image has been selected as the first supported model for multimodal generation tasks.

At present, verl does not support diffusion-based generation models. To enable this functionality, two major extensions are required: first, the addition of a rollout engine capable of handling image and video generation tasks, incorporating components such as vLLM-Omni; and second, the addition of a training engine for diffusion model training, which will rely on diffusers with an FSDP backend. Consequently, integrating diffusers and vLLM-Omni becomes a necessary change after this modification.

In the following section, we will first briefly present the overall picture of FlowGRPO, and then demonstrate the necessary code changes for enabling Qwen-Image FlowGRPO training.

Overall Structure

Image

Figure 1. Overview of integrating the FlowGRPO algorithm into verl. The left panel shows the entry point and algorithm implementation with a standalone RayFlowGRPOTrainer. The right panel shows the corresponding workers (class names in bold) that need to be implemented. Other miscellaneous changes, such as configs, dataloaders and the logger, are not shown here.

Algorithm Implementation

And here is a brief explanation of the FlowGRPO algorithm (functions) in the left panel:

i. generate_sequence: The sequence generation for the diffusion model produces a sequence of image/video latent samples during the diffusion process, rather than tokens as in LLMs. This includes the final generated image/video, prompt embeddings, timesteps, and other necessary information from the denoising stage.

  • Inputs: Prompts
  • Outputs: Generated Images/Videos; Prompt embeddings; Timesteps; Latent samples, log probabilities, and latent sample means during the sampling stage

ii. compute_rm_score: The reward model calculates the score of the generated images/videos based on the user’s task, such as OCR, GenEval Score, Clip Score, etc. For simplicity and general purposes, we may consider API calls to the vLLM server.

  • Inputs: Generated Images/Videos
  • Outputs: Scores of each image/video

iii. compute_old_log_prob: Similar to the inconsistent behavior between the training and inference engines in LLMs, we also need to add support to calculate the old log probabilities from the training side.

  • Inputs: Latent samples, timesteps, and prompt embeddings from the inference engine
  • Outputs: Updated old log probabilities

iv. compute_ref_old_prob: We need to calculate the latent sample means from the reference model, used for KL divergence computation.

  • Inputs: Latent samples, timesteps, and prompt embeddings from the inference engine
  • Outputs: Reference latent sample means

v. compute_advantage: The advantage calculation is basically the same as in the GRPO algorithm.

  • Inputs: Scores of each image/video, UIDs to specify the group the sample belongs to
  • Outputs: Advantage score of each sample

vi. update_actor: The overall workflow of updating the actor is very similar to PPO/GRPO, with a slightly different loss function from GRPO incorporated.

  • Inputs: Latent samples, prompt embeddings; timesteps; Updated old log probabilities, advantages, timesteps, Reference latent sample means
  • Outputs: Loss value

New Components

The right panel shows the classes we need to add to support the FlowGRPO algorithm. Other diffusion-based RL algorithms may also use this to speed up the development process.

Rollout

We add DiffusionAgentLoopWorker, DiffusionSingleTurnAgentLoop for diffusion-based agent loop and async rollout.
DiffusionAgentLoopWorker runs DiffusionSingleTurnAgentLoop to call server manager to generate sequence.

We apply the vllm-omni API calling for image generation, which relies on PRs:
vllm-project/vllm-omni#355
vllm-project/vllm-omni#376
vllm-project/vllm-omni#371

1a) DiffusionAgentLoopWorker

The agent loop worker for asynchronous rollout, supports generation in generate_sequence.

class DiffusionAgentLoopWorker:
    """Diffusion Agent loop worker takes a batch of messages and run each message in an agent loop.
    """
    def __init__(
        self,
        config: DictConfig,
        server_handles: list[ray.actor.ActorHandle],
        reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
    ):
        """Initialize agent loop manager.
        Args:
            config (DictConfig): whole config for main entrypoint.
            server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.
            reward_loop_worker_handles (List[ray.actor.ActorHandle]): Actor handles for streaming reward computation.
        """
        ...

    async def generate_sequences(self, batch: DataProto) -> DataProto:
        """Generate sequences () from agent loop.
        Args:
            batch (DataProto): Input batch.
        Returns:
            DataProto: Output batch.
            - prompts: [bsz, prompt_length], prompt token ids from dataset.
            - responses: [bsz, channel, height, width],  output images from diffusion generation.
            ...
        """
        ...
        tasks = []
        for i in range(len(batch)):
            tasks.append(
                asyncio.create_task(
                    self._run_agent_loop(...)
                )
            )
        outputs = await asyncio.gather(*tasks)
        ...
    
    async def _run_agent_loop(
        self,
        sampling_params: dict[str, Any],
        trajectory: dict[str, Any],
        *,
        agent_name: str,
        trace: bool = True,
        **kwargs,
    ) -> _InternalDiffusionAgentLoopOutput:
        """
        Returns:
            _InternalDiffusionAgentLoopOutput: Internal agent loop output, e.g., response_image, logprobs, etc.
        """
        ...

    async def _compute_score(self, output, prompts, responses, attention_mask, input_ids, kwargs):
        """
            Call reward loop work to compute reward score for single sample.
            Input includes prompts and reponse images, etc.
        """
        ...

1b) DiffusionSingleTurnAgentLoop

The agent loop supports single-turn response generation from the server.

@register("diffusion_single_turn_agent")
class DiffusionSingleTurnAgentLoop(AgentLoopBase):
    """Agent loop for diffusion model serving."""

    async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
        """
        Run agent loop to interact with vLLM-Omni server and environment.
        """
        ...

1c) vLLMOmniReplica

rollout class to launch API servers (i.e., vLLMOmniHttpServer) for async rollout calling.

class vLLMOmniReplica(RolloutReplica):
    def __init__(...):
        ...
        self.server_class = ray.remote(vLLMOmniHttpServer)
    ...

1d) vLLMOmniHttpServer

The vLLM http server in a single node and supports vLLM-Omni server calling.

class vLLMOmniHttpServer:
    """vLLM-Omni http server in single node, this is equivalent to launch server with command line:
    ```
    vllm serve --omni --tensor-parallel-size=8 ...
    ```
    """

    def __init__(
        self,
        config: DiffusionRolloutConfig,
        model_config: DiffusersModelConfig,
        rollout_mode: RolloutMode,
        workers: list[ActorHandle],
        replica_rank: int,
        node_rank: int,
        gpus_per_node: int,
        nnodes: int,
        cuda_visible_devices: str,
    ):
        """
        Args:
            config (DiffusionRolloutConfig): full config.
            model_config (HFModelConfig): model config.
            rollout_mode (RolloutMode): rollout mode.
            replica_rank (int): replica rank, a replica may contain multiple nodes.
            node_rank (int): node rank.
            gpus_per_node (int): number of gpus per node.
            nnodes (int): number of nodes.
            cuda_visible_devices (str): cuda visible devices.
        """
        ...
    
    async def launch_server(self, master_address: str = None, master_port: int = None, dp_rpc_port: int = None):
        ...

    async def generate(
        self,
        prompt_ids: list[int],
        sampling_params: dict[str, Any],
        request_id: str,
        image_data: Optional[list[Any]] = None,
        video_data: Optional[list[Any]] = None,
        negative_prompt_ids: Optional[list[int]] = None,
        priority: int = 0,
    ) -> ImageOutput:
        """Generate sequence with token-in-image-out."""
        ...

Reward

RewardLoopManager (Modified)

manages reward loop workers, and generate rewards by calling workers.

class RewardLoopManager:
    def compute_rm_score(self, data: DataProto) -> DataProto:
        """ Add a condition for image reward score processing """
        ...
        # compute rm score
        if self.config.reward.reward_manager.name == "image":
            rm_scores = torch.tensor(scores, dtype=torch.float32).unsqueeze(-1)
            # do not need to handle valid reponse length 
        else:
            ... # unchanged
        ...

RewardLoopWorker (modfied)

a loop worker to compute rewards for different logics.

@ray.remote
class RewardLoopWorker:
    
    def _init_reward_fn(self):
        """ Handle diffusion and non-diffusion input separately """
        ...
        # extract response and prepare input
        response = data_item.batch["responses"]
        if response.ndim == 3: 
            # handling multi-modal response
            response_image = response
            image_base64 = ...
            chat.append({"role": "assistant", "content": query})
        else: # unchanged
            ...
            # decode
            ...
            # remove bos and eos
            ...
            chat.append({"role": "assistant", "content": rollout_response})
        ...

ImageRewardManager

manages image-related reward computing.

@register("image")
class ImageRewardManager(RewardManagerBase):
    """The reward manager for image response."""

    def __init__(self, config, tokenizer, compute_score, reward_router_address=None, reward_model_tokenizer=None):
        super().__init__(config, tokenizer, compute_score)
        ...

    async def run_single(self, data: DataProto) -> dict:
        """ Run sync or async reward computing for a single sample. """
        ...

default_compute_score_image

A default rule-based reward for image.

def default_compute_score_image():
    ...

Actor

2) DiffusersFSDPEngine

The default base training engine for diffusers models, supporting diffusion pipeline instantiation, forward and backward steps.

class DiffusersFSDPEngine(FSDPEngine):
    """
    Concrete Engine implementation using PyTorch FullyShardedDataParallel (FSDP).
    Supports model sharding, activation/optimizer offloading, LoRA.
    """
    def __init__(
        self,
        model_config: DiffusersModelConfig,
        engine_config: FSDPEngineConfig,
        optimizer_config: FSDPOptimizerConfig,
        checkpoint_config: CheckpointConfig,
    ):
        """
        Initialize the DiffusersFSDPEngine.
        Set up distributed device meshes, LoRA, and offload policies based on config.
        Args:
            config: Configuration object with FSDP and model settings.
        """
        ...

    def initialize(self):
        """
        Build the model, optimizer, and learning rate scheduler under FSDP.

        Applies device, dtype, and precision configurations, including mixed precision.
        Sets up checkpoint manager and FLOPs counter.
        """
        ...

    def forward_step(self, micro_batch: TensorDict, loss_function, forward_only):
        ...

Development Plan

See latest progress in PR #5297.

Rollout

Trainer (left panel)

Actor

Reward

FlowGRPO Algorithm Support

Dataloader and Logger

Future Plan

Acceleration

Algorithm

  • Support other reinforcement learning paradigms for diffusion models, e.g., DiffusionNFT, DanceGRPO, etc.

More Supported Models

  • Support more visual generation and editing diffusion models, e.g., Wan2.2 (video), Qwen-Image-Edit (edit).
  • Support MoE models for RL, e.g., HunyuanImage.
  • Support Unified Multimodal Understanding and Generation Models, e.g., Qwen3-Omni.

Training Engine

  • Support VeOmni training engine.

Rollout Engine: vLLM-Omni

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions