Skip to content

Commit 33cdd47

Browse files
feat(ark): support partial assistant message (#670)
1 parent d45f69c commit 33cdd47

File tree

4 files changed

+28
-9
lines changed

4 files changed

+28
-9
lines changed

components/model/ark/go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ require (
99
github.com/eino-contrib/jsonschema v1.0.3
1010
github.com/smartystreets/goconvey v1.8.1
1111
github.com/stretchr/testify v1.11.1
12-
github.com/volcengine/volcengine-go-sdk v1.1.49
12+
github.com/volcengine/volcengine-go-sdk v1.2.9
1313

1414
)
1515

components/model/ark/go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ github.com/volcengine/volc-sdk-golang v1.0.23 h1:anOslb2Qp6ywnsbyq9jqR0ljuO63kg9
138138
github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU=
139139
github.com/volcengine/volcengine-go-sdk v1.1.49 h1:jkk3Zt6uFGiZshrVshsdRvadzuHIf4nLkekIZM+wLkY=
140140
github.com/volcengine/volcengine-go-sdk v1.1.49/go.mod h1:oxoVo+A17kvkwPkIeIHPVLjSw7EQAm+l/Vau1YGHN+A=
141+
github.com/volcengine/volcengine-go-sdk v1.2.9 h1:du2gnImtyWXKkQFnJW/GXCs+UBibGGOXIbP1Ams2pB8=
142+
github.com/volcengine/volcengine-go-sdk v1.2.9/go.mod h1:oxoVo+A17kvkwPkIeIHPVLjSw7EQAm+l/Vau1YGHN+A=
141143
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
142144
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
143145
github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg=

components/model/ark/message_extra.go

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ const (
3030
keyOfResponseID = "ark-response-id"
3131
keyOfResponseCacheExpireAt = "ark-response-cache-expire-at"
3232
keyOfServiceTier = "ark-service-tier"
33+
keyOfPartial = "ark-partial"
3334
ImageSizeKey = "seedream-image-size"
3435
)
3536

@@ -162,14 +163,7 @@ func InvalidateMessageCaches(messages []*schema.Message) error {
162163
continue
163164
}
164165

165-
// there may be concurrency
166-
extra := make(map[string]any, len(msg.Extra))
167-
for k, v := range msg.Extra {
168-
extra[k] = v
169-
}
170-
171-
delete(extra, keyOfResponseCacheExpireAt)
172-
msg.Extra = extra
166+
delete(msg.Extra, keyOfResponseCacheExpireAt)
173167
}
174168
return nil
175169
}
@@ -380,3 +374,21 @@ func getImageSize(extra map[string]any) (string, bool) {
380374
}
381375
return size, true
382376
}
377+
378+
// SetPartial marks the message as a partial message to enable continuation (prefill) mode.
379+
// By pre-filling part of the assistant role's content, it guides and controls the model
380+
// to continue generating from existing text fragments and maintain consistency in role-play scenarios.
381+
// To use this, set the role of the last message in the input list to assistant and call SetPartial
382+
// on it. The model will then continue writing based on the message's content.
383+
// Only available for ResponsesAPI.
384+
func SetPartial(msg *schema.Message) {
385+
setMsgExtra(msg, keyOfPartial, true)
386+
}
387+
388+
func getPartial(msg *schema.Message) bool {
389+
v, ok := getMsgExtraValue[bool](msg, keyOfPartial)
390+
if !ok {
391+
return false
392+
}
393+
return v
394+
}

components/model/ark/responses_api.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,11 @@ func (cm *ResponsesAPIChatModel) toArkAssistantRoleItemInputMessage(msg *schema.
793793
Role: responses.MessageRole_assistant,
794794
}
795795

796+
if getPartial(msg) {
797+
b := true
798+
inputItemMessage.Partial = &b
799+
}
800+
796801
if len(msg.UserInputMultiContent) > 0 {
797802
return nil, fmt.Errorf("if assistant role, UserInputMultiContent cannot be set")
798803
}

0 commit comments

Comments
 (0)