@@ -40,6 +40,7 @@ def __init__(
40
40
finetuning_model : str | None = None ,
41
41
launch_kwargs : dict [str , Any ] | None = None ,
42
42
train_kwargs : dict [str , Any ] | None = None ,
43
+ use_developer_role : bool = False ,
43
44
** kwargs ,
44
45
):
45
46
"""
@@ -77,6 +78,7 @@ def __init__(
77
78
self .finetuning_model = finetuning_model
78
79
self .launch_kwargs = launch_kwargs or {}
79
80
self .train_kwargs = train_kwargs or {}
81
+ self .use_developer_role = use_developer_role
80
82
self ._warned_zero_temp_rollout = False
81
83
82
84
# Handle model-specific configuration for different model families
@@ -131,6 +133,11 @@ def forward(self, prompt=None, messages=None, **kwargs):
131
133
cache = kwargs .pop ("cache" , self .cache )
132
134
133
135
messages = messages or [{"role" : "user" , "content" : prompt }]
136
+ if self .use_developer_role and self .model_type == "responses" :
137
+ messages = [
138
+ {** m , "role" : "developer" } if m .get ("role" ) == "system" else m
139
+ for m in messages
140
+ ]
134
141
kwargs = {** self .kwargs , ** kwargs }
135
142
self ._warn_zero_temp_rollout (kwargs .get ("temperature" ), kwargs .get ("rollout_id" ))
136
143
if kwargs .get ("rollout_id" ) is None :
@@ -162,6 +169,11 @@ async def aforward(self, prompt=None, messages=None, **kwargs):
162
169
cache = kwargs .pop ("cache" , self .cache )
163
170
164
171
messages = messages or [{"role" : "user" , "content" : prompt }]
172
+ if self .use_developer_role and self .model_type == "responses" :
173
+ messages = [
174
+ {** m , "role" : "developer" } if m .get ("role" ) == "system" else m
175
+ for m in messages
176
+ ]
165
177
kwargs = {** self .kwargs , ** kwargs }
166
178
self ._warn_zero_temp_rollout (kwargs .get ("temperature" ), kwargs .get ("rollout_id" ))
167
179
if kwargs .get ("rollout_id" ) is None :
0 commit comments