Skip to content

Commit fe540f6

Browse files
authored
[plugin] Custom model_runner/model support (#3186)
* support custom model&&model_runner * fix merge * add test && update doc * fix codestyle * fix unittest * load model in rl
1 parent 72ef5a9 commit fe540f6

File tree

15 files changed

+150
-13
lines changed

15 files changed

+150
-13
lines changed

docs/features/plugins.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,23 @@ Assuming you have a custom model class `MyModelForCasualLM` and a pretrained cla
2020
# File: fd_add_dummy_model/__init__.py or fd_add_dummy_model/register.py
2121
from fastdeploy.model_registry import ModelRegistry
2222
from my_custom_model import MyModelForCasualLM, MyPretrainedModel
23+
from fastdeploy.config import ErnieArchitectures
2324

2425
def register():
2526
if "MyModelForCasualLM" not in ModelRegistry.get_supported_archs():
27+
if MyModelForCasualLM.name().startswith("Ernie"):
28+
ErnieArchitectures.register_ernie_model_arch(MyModelForCasualLM)
2629
ModelRegistry.register_model_class(MyModelForCasualLM)
2730
ModelRegistry.register_pretrained_model(MyPretrainedModel)
2831
```
32+
Assuming you have a custom model_runner class `MyModelRunner`, you can write the following registration function:
33+
```python
34+
# File: fd_add_dummy_model_runner/__init__.py
35+
from .my_model_runner import MyModelRunner
36+
37+
def get_runner():
38+
return MyModelRunner
39+
```
2940

3041
#### 2. Register Plugin in `setup.py`
3142

@@ -36,11 +47,14 @@ from setuptools import setup
3647
setup(
3748
name="fastdeploy-plugins",
3849
version="0.1",
39-
packages=["fd_add_dummy_model"],
50+
packages=["fd_add_dummy_model", "fd_add_dummy_model_runner"],
4051
entry_points={
4152
"fastdeploy.model_register_plugins": [
4253
"fd_add_dummy_model = fd_add_dummy_model:register",
4354
],
55+
"fastdeploy.model_runner_plugins": [
56+
"model_runner = fd_add_dummy_model:get_runner"
57+
],
4458
},
4559
)
4660
```

fastdeploy/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ class ErnieArchitectures:
6363
"Ernie4_5_VLMoeForConditionalGeneration",
6464
}
6565

66+
@classmethod
67+
def register_ernie_model_arch(cls, model_class):
68+
if model_class.name().startswith("Ernie") and model_class.name() not in cls.ARCHITECTURES:
69+
cls.ARCHITECTURES.add(model_class.name())
70+
6671
@classmethod
6772
def contains_ernie_arch(cls, architectures):
6873
"""Check if any ERNIE architecture is present in the given architectures."""

fastdeploy/entrypoints/llm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from fastdeploy.engine.args_utils import EngineArgs
2929
from fastdeploy.engine.engine import LLMEngine
3030
from fastdeploy.engine.sampling_params import SamplingParams
31+
from fastdeploy.plugins.model_register import load_model_register_plugins
3132
from fastdeploy.utils import (
3233
deprecated_kwargs_warning,
3334
llm_logger,
@@ -76,6 +77,7 @@ def __init__(
7677
):
7778
deprecated_kwargs_warning(**kwargs)
7879

80+
load_model_register_plugins()
7981
model = retrive_model_from_server(model, revision)
8082
engine_args = EngineArgs(
8183
model=model,

fastdeploy/entrypoints/openai/api_server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
main_process_metrics,
4747
)
4848
from fastdeploy.metrics.trace_util import inject_to_metadata, instrument
49+
from fastdeploy.plugins.model_register import load_model_register_plugins
4950
from fastdeploy.utils import (
5051
FlexibleArgumentParser,
5152
api_server_logger,
@@ -393,6 +394,7 @@ def launch_controller_server():
393394
def main():
394395
"""main函数"""
395396

397+
load_model_register_plugins()
396398
if load_engine() is None:
397399
return
398400

fastdeploy/input/preprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
33
#
4-
# Licensed under the Apache License, Version 2.0 (the "License"
4+
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
66
# You may obtain a copy of the License at
77
#

fastdeploy/model_executor/layers/attention/attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
linear_shift: paddle.Tensor = None,
4949
linear_smooth: paddle.Tensor = None,
5050
use_neox_rotary_style: bool = False,
51+
use_qk_norm: bool = False,
5152
) -> None:
5253
"""
5354
Initializes `LMLayer` with the given parameters.

fastdeploy/model_executor/models/model_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ def get_class(cls, name):
6464

6565
@classmethod
6666
def get_supported_archs(cls):
67-
assert len(cls._arch_to_model_cls) == len(
68-
cls._arch_to_model_cls
69-
), "model class / pretrained model registry num is not same"
67+
assert len(cls._arch_to_model_cls) >= len(
68+
cls._arch_to_pretrained_model_cls
69+
), "model class num is more than pretrained model registry num"
7070
return [key for key in cls._arch_to_model_cls.keys()]
7171

7272

fastdeploy/plugins/model_runner/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,5 @@ def load_model_runner_plugins():
2828
plugins_loaded = True
2929

3030
plugins = load_plugins_by_group(group=PLUGINS_GROUP)
31-
assert len(plugins) == 1, "Only one plugin is allowed to be loaded."
32-
return next(iter(plugins.values()))
31+
assert len(plugins) <= 1, "Most one plugin is allowed to be loaded."
32+
return next(iter(plugins.values()))()

fastdeploy/rl/rollout_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ def __init__(self, rollout_model_config: RolloutModelConfig):
5656
def _init_model(self) -> nn.Layer:
5757
"""Load model from loader based on config."""
5858
context = paddle.LazyGuard()
59+
from fastdeploy.plugins.model_register import load_model_register_plugins
60+
61+
load_model_register_plugins()
5962
architectures = f"{self.fd_config.model_config.architectures[0]}RL"
6063
with context:
6164
model_cls = ModelRegistry.get_class(architectures)

fastdeploy/worker/gpu_worker.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,19 @@
2626
from fastdeploy.config import FDConfig
2727
from fastdeploy.engine.request import Request
2828
from fastdeploy.platforms import current_platform
29+
from fastdeploy.plugins.model_runner import load_model_runner_plugins
2930
from fastdeploy.utils import get_logger
30-
from fastdeploy.worker.gpu_model_runner import GPUModelRunner
31+
from fastdeploy.worker.model_runner_base import ModelRunnerBase
3132
from fastdeploy.worker.output import ModelRunnerOutput
3233
from fastdeploy.worker.worker_base import WorkerBase
3334

3435
logger = get_logger("gpu_worker", "gpu_worker.log")
3536

37+
try:
38+
ModelRunner = load_model_runner_plugins()
39+
except:
40+
from fastdeploy.worker.gpu_model_runner import GPUModelRunner as ModelRunner
41+
3642

3743
class GpuWorker(WorkerBase):
3844
def __init__(
@@ -70,7 +76,7 @@ def init_device(self):
7076
raise RuntimeError(f"Not support device type: {self.device_config.device}")
7177

7278
# Construct model runner
73-
self.model_runner: GPUModelRunner = GPUModelRunner(
79+
self.model_runner: ModelRunnerBase = ModelRunner(
7480
fd_config=self.fd_config,
7581
device=self.device,
7682
device_id=self.device_ids[self.local_rank % self.max_chips_per_node],

0 commit comments

Comments
 (0)