Skip to content

Commit 0b496e8

Browse files
committed
initial thinking support for claude
- Updated the Anthropic SDK from v0.32.1 to v0.39.0 - Added support for thinking messages in the language model interface - Added a ThinkingChatResponseContent implementation - Added proper handling for Claude's thinking feature in the Anthropic language model - Added a TextPartRenderer to display thinking content - Various interface updates to support thinking messages throughout the code
1 parent ac4ff54 commit 0b496e8

File tree

19 files changed

+363
-109
lines changed

19 files changed

+363
-109
lines changed

examples/api-samples/src/browser/chat/ask-and-continue-chat-agent-contribution.ts

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717
import {
1818
AbstractStreamParsingChatAgent,
1919
ChatAgent,
20-
ChatMessage,
2120
ChatModel,
2221
MutableChatRequestModel,
2322
lastProgressMessage,
2423
QuestionResponseContentImpl,
2524
unansweredQuestions
2625
} from '@theia/ai-chat';
27-
import { Agent, PromptTemplate } from '@theia/ai-core';
26+
import { Agent, LanguageModelMessage, PromptTemplate } from '@theia/ai-core';
2827
import { injectable, interfaces, postConstruct } from '@theia/core/shared/inversify';
2928

3029
export function bindAskAndContinueChatAgentContribution(bind: interfaces.Bind): void {
@@ -161,15 +160,15 @@ export class AskAndContinueChatAgent extends AbstractStreamParsingChatAgent {
161160
* As the question/answer are handled within the same response, we add an additional user message at the end to indicate to
162161
* the LLM to continue generating.
163162
*/
164-
protected override async getMessages(model: ChatModel): Promise<ChatMessage[]> {
163+
protected override async getMessages(model: ChatModel): Promise<LanguageModelMessage[]> {
165164
const messages = await super.getMessages(model, true);
166165
const requests = model.getRequests();
167166
if (!requests[requests.length - 1].response.isComplete && requests[requests.length - 1].response.response?.content.length > 0) {
168167
return [...messages,
169168
{
170169
type: 'text',
171170
actor: 'user',
172-
query: 'Continue generating based on the user\'s answer or finish the conversation if 5 or more questions were already answered.'
171+
text: 'Continue generating based on the user\'s answer or finish the conversation if 5 or more questions were already answered.'
173172
}];
174173
}
175174
return messages;

packages/ai-anthropic/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"version": "1.59.0",
44
"description": "Theia - Anthropic Integration",
55
"dependencies": {
6-
"@anthropic-ai/sdk": "^0.32.1",
6+
"@anthropic-ai/sdk": "^0.39.0",
77
"@theia/ai-core": "1.59.0",
88
"@theia/core": "1.59.0"
99
},

packages/ai-anthropic/src/node/anthropic-language-model.ts

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717
import {
1818
LanguageModel,
1919
LanguageModelRequest,
20-
LanguageModelRequestMessage,
20+
LanguageModelMessage,
2121
LanguageModelResponse,
2222
LanguageModelStreamResponse,
2323
LanguageModelStreamResponsePart,
2424
LanguageModelTextResponse
2525
} from '@theia/ai-core';
2626
import { CancellationToken, isArray } from '@theia/core';
2727
import { Anthropic } from '@anthropic-ai/sdk';
28-
import { MessageParam } from '@anthropic-ai/sdk/resources';
28+
import { Message, MessageParam } from '@anthropic-ai/sdk/resources';
2929

3030
const DEFAULT_MAX_TOKENS_STREAMING = 4096;
3131
const DEFAULT_MAX_TOKENS_NON_STREAMING = 2048;
@@ -42,23 +42,36 @@ interface ToolCallback {
4242
args: string;
4343
}
4444

45+
const createMessageContent = (message: LanguageModelMessage): MessageParam['content'] => {
46+
if (LanguageModelMessage.isTextMessage(message)) {
47+
return message.text;
48+
} else if (LanguageModelMessage.isThinkingMessage(message)) {
49+
return [{ signature: message.signature, thinking: message.thinking, type: 'thinking' }];
50+
} else if (LanguageModelMessage.isToolUseMessage(message)) {
51+
return [{ id: message.id, input: message.input, name: message.name, type: 'tool_use' }];
52+
} else if (LanguageModelMessage.isToolResultMessage(message)) {
53+
return [{ type: 'tool_result', tool_use_id: message.tool_use_id }];
54+
}
55+
throw new Error(`Unknown message type:'${JSON.stringify(message)}'`);
56+
};
57+
4558
/**
4659
* Transforms Theia language model messages to Anthropic API format
4760
* @param messages Array of LanguageModelRequestMessage to transform
4861
* @returns Object containing transformed messages and optional system message
4962
*/
5063
function transformToAnthropicParams(
51-
messages: readonly LanguageModelRequestMessage[]
64+
messages: readonly LanguageModelMessage[]
5265
): { messages: MessageParam[]; systemMessage?: string } {
5366
// Extract the system message (if any), as it is a separate parameter in the Anthropic API.
5467
const systemMessageObj = messages.find(message => message.actor === 'system');
55-
const systemMessage = systemMessageObj?.query;
68+
const systemMessage = systemMessageObj && LanguageModelMessage.isTextMessage(systemMessageObj) && systemMessageObj.text || '';
5669

5770
const convertedMessages = messages
5871
.filter(message => message.actor !== 'system')
5972
.map(message => ({
6073
role: toAnthropicRole(message),
61-
content: message.query || '',
74+
content: createMessageContent(message)
6275
}));
6376

6477
return {
@@ -74,7 +87,7 @@ export const AnthropicModelIdentifier = Symbol('AnthropicModelIdentifier');
7487
* @param message The message to convert
7588
* @returns Anthropic role ('user' or 'assistant')
7689
*/
77-
function toAnthropicRole(message: LanguageModelRequestMessage): 'user' | 'assistant' {
90+
function toAnthropicRole(message: LanguageModelMessage): 'user' | 'assistant' {
7891
switch (message.actor) {
7992
case 'ai':
8093
return 'assistant';
@@ -151,7 +164,7 @@ export class AnthropicModel implements LanguageModel {
151164
...(systemMessage && { system: systemMessage }),
152165
...settings
153166
};
154-
167+
console.log(JSON.stringify(params));
155168
const stream = anthropic.messages.stream(params);
156169

157170
cancellationToken?.onCancellationRequested(() => {
@@ -164,11 +177,15 @@ export class AnthropicModel implements LanguageModel {
164177

165178
const toolCalls: ToolCallback[] = [];
166179
let toolCall: ToolCallback | undefined;
180+
const currentMessages: Message[] = [];
167181

168182
for await (const event of stream) {
169183
if (event.type === 'content_block_start') {
170184
const contentBlock = event.content_block;
171185

186+
if (contentBlock.type === 'thinking') {
187+
yield { thought: contentBlock.thinking, signature: contentBlock.signature ?? '' };
188+
}
172189
if (contentBlock.type === 'text') {
173190
yield { content: contentBlock.text };
174191
}
@@ -178,7 +195,12 @@ export class AnthropicModel implements LanguageModel {
178195
}
179196
} else if (event.type === 'content_block_delta') {
180197
const delta = event.delta;
181-
198+
if (delta.type === 'thinking_delta') {
199+
yield { thought: delta.thinking, signature: '' };
200+
}
201+
if (delta.type === 'signature_delta') {
202+
yield { thought: '', signature: delta.signature };
203+
}
182204
if (delta.type === 'text_delta') {
183205
yield { content: delta.text };
184206
}
@@ -198,6 +220,8 @@ export class AnthropicModel implements LanguageModel {
198220
}
199221
throw new Error(`The response was stopped because it exceeded the max token limit of ${event.usage.output_tokens}.`);
200222
}
223+
} else if (event.type === 'message_start') {
224+
currentMessages.push(event.message);
201225
}
202226
}
203227
if (toolCalls.length > 0) {
@@ -215,16 +239,16 @@ export class AnthropicModel implements LanguageModel {
215239
});
216240
yield { tool_calls: calls };
217241

218-
const toolRequestMessage: Anthropic.Messages.MessageParam = {
219-
role: 'assistant',
220-
content: toolResult.map(call => ({
242+
// const toolRequestMessage: Anthropic.Messages.MessageParam = {
243+
// role: 'assistant',
244+
// content: toolResult.map(call => ({
221245

222-
type: 'tool_use',
223-
id: call.id,
224-
name: call.name,
225-
input: JSON.parse(call.arguments)
226-
}))
227-
};
246+
// type: 'tool_use',
247+
// id: call.id,
248+
// name: call.name,
249+
// input: JSON.parse(call.arguments)
250+
// }))
251+
// };
228252

229253
const toolResponseMessage: Anthropic.Messages.MessageParam = {
230254
role: 'user',
@@ -234,7 +258,15 @@ export class AnthropicModel implements LanguageModel {
234258
content: that.formatToolCallResult(call.result)
235259
}))
236260
};
237-
const result = await that.handleStreamingRequest(anthropic, request, cancellationToken, [...(toolMessages ?? []), toolRequestMessage, toolResponseMessage]);
261+
const result = await that.handleStreamingRequest(
262+
anthropic,
263+
request,
264+
cancellationToken,
265+
[
266+
...(toolMessages ?? []),
267+
...currentMessages.map(m => ({ role: m.role, content: m.content })),
268+
toolResponseMessage
269+
]);
238270
for await (const nestedEvent of result.stream) {
239271
yield nestedEvent;
240272
}

packages/ai-chat-ui/src/browser/ai-chat-ui-frontend-module.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import {
3333
HorizontalLayoutPartRenderer,
3434
InsertCodeAtCursorButtonAction,
3535
MarkdownPartRenderer,
36+
TextPartRenderer,
3637
ToolCallPartRenderer,
3738
} from './chat-response-renderer';
3839
import {
@@ -79,6 +80,7 @@ export default new ContainerModule((bind, _unbind, _isBound, rebind) => {
7980

8081
bind(ContextVariablePicker).toSelf().inSingletonScope();
8182

83+
bind(ChatResponsePartRenderer).to(TextPartRenderer).inSingletonScope();
8284
bind(ChatResponsePartRenderer).to(HorizontalLayoutPartRenderer).inSingletonScope();
8385
bind(ChatResponsePartRenderer).to(ErrorPartRenderer).inSingletonScope();
8486
bind(ChatResponsePartRenderer).to(MarkdownPartRenderer).inSingletonScope();

packages/ai-chat-ui/src/browser/chat-response-renderer/text-part-renderer.tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ export class TextPartRenderer implements ChatResponsePartRenderer<ChatResponseCo
2828
return 1;
2929
}
3030
render(response: ChatResponseContent): ReactNode {
31-
if (response && ChatResponseContent.hasAsString(response)) {
32-
return <span>{response.asString()}</span>;
31+
if (response && ChatResponseContent.hasDisplayString(response)) {
32+
return <span>{response.asDisplayString()}</span>;
3333
}
3434
return <span>
3535
{nls.localize('theia/ai/chat-ui/text-part-renderer/cantDisplay',

packages/ai-chat/src/common/chat-agents.ts

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,19 @@ import {
2424
AIVariableContext,
2525
CommunicationRecordingService,
2626
getTextOfResponse,
27+
isTextResponsePart,
28+
isThinkingResponsePart,
29+
isToolCallResponsePart,
2730
LanguageModel,
31+
LanguageModelMessage,
2832
LanguageModelRequirement,
2933
LanguageModelResponse,
3034
LanguageModelService,
3135
LanguageModelStreamResponse,
3236
PromptService,
3337
PromptTemplate,
3438
ResolvedPromptTemplate,
39+
TextMessage,
3540
ToolCall,
3641
ToolRequest,
3742
} from '@theia/ai-core';
@@ -40,8 +45,7 @@ import {
4045
isLanguageModelStreamResponse,
4146
isLanguageModelTextResponse,
4247
LanguageModelRegistry,
43-
LanguageModelStreamResponsePart,
44-
MessageActor,
48+
LanguageModelStreamResponsePart
4549
} from '@theia/ai-core/lib/common';
4650
import { ContributionProvider, ILogger, isArray } from '@theia/core';
4751
import { inject, injectable, named, postConstruct } from '@theia/core/shared/inversify';
@@ -53,24 +57,13 @@ import {
5357
ErrorChatResponseContentImpl,
5458
MarkdownChatResponseContentImpl,
5559
ToolCallChatResponseContentImpl,
56-
ChatRequestModel
60+
ChatRequestModel,
61+
ThinkingChatResponseContentImpl
5762
} from './chat-model';
5863
import { findFirstMatch, parseContents } from './parse-contents';
5964
import { DefaultResponseContentFactory, ResponseContentMatcher, ResponseContentMatcherProvider } from './response-content-matcher';
6065
import { ChatToolRequest, ChatToolRequestService } from './chat-tool-request-service';
6166

62-
/**
63-
* A conversation consists of a sequence of ChatMessages.
64-
* Each ChatMessage is either a user message, AI message or a system message.
65-
*
66-
* For now we only support text based messages.
67-
*/
68-
export interface ChatMessage {
69-
actor: MessageActor;
70-
type: 'text';
71-
query: string;
72-
}
73-
7467
/**
7568
* System message content, enriched with function descriptions.
7669
*/
@@ -187,10 +180,10 @@ export abstract class AbstractChatAgent implements ChatAgent {
187180
const messages = await this.getMessages(request.session);
188181

189182
if (systemMessageDescription) {
190-
const systemMsg: ChatMessage = {
183+
const systemMsg: LanguageModelMessage = {
191184
actor: 'system',
192185
type: 'text',
193-
query: systemMessageDescription.text
186+
text: systemMessageDescription.text
194187
};
195188
// insert system message at the beginning of the request messages
196189
messages.unshift(systemMsg);
@@ -252,21 +245,28 @@ export abstract class AbstractChatAgent implements ChatAgent {
252245

253246
protected async getMessages(
254247
model: ChatModel, includeResponseInProgress = false
255-
): Promise<ChatMessage[]> {
248+
): Promise<LanguageModelMessage[]> {
256249
const requestMessages = model.getRequests().flatMap(request => {
257-
const messages: ChatMessage[] = [];
250+
const messages: LanguageModelMessage[] = [];
258251
const text = request.message.parts.map(part => part.promptText).join('');
259252
messages.push({
260253
actor: 'user',
261254
type: 'text',
262-
query: text,
255+
text: text,
263256
});
264257
if (request.response.isComplete || includeResponseInProgress) {
265-
messages.push({
266-
actor: 'ai',
267-
type: 'text',
268-
query: request.response.response.asString(),
258+
const responseMessages: LanguageModelMessage[] = request.response.response.content.flatMap(c => {
259+
if (ChatResponseContent.hasToLanguageModelMessage(c)) {
260+
return c.toLanguageModelMessage();
261+
}
262+
263+
return {
264+
actor: 'ai',
265+
type: 'text',
266+
text: c.asString?.() ?? c.asDisplayString?.() ?? '',
267+
} as TextMessage;
269268
});
269+
messages.push(...responseMessages);
270270
}
271271
return messages;
272272
});
@@ -276,7 +276,7 @@ export abstract class AbstractChatAgent implements ChatAgent {
276276

277277
protected async sendLlmRequest(
278278
request: MutableChatRequestModel,
279-
messages: ChatMessage[],
279+
messages: LanguageModelMessage[],
280280
toolRequests: ChatToolRequest[],
281281
languageModel: LanguageModel
282282
): Promise<LanguageModelResponse> {
@@ -409,17 +409,24 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent {
409409
}
410410

411411
protected parse(token: LanguageModelStreamResponsePart, request: MutableChatRequestModel): ChatResponseContent | ChatResponseContent[] {
412-
const content = token.content;
413-
// eslint-disable-next-line no-null/no-null
414-
if (content !== undefined && content !== null) {
415-
return this.defaultContentFactory.create(content, request);
412+
if (isTextResponsePart(token)) {
413+
const content = token.content;
414+
// eslint-disable-next-line no-null/no-null
415+
if (content !== undefined && content !== null) {
416+
return this.defaultContentFactory.create(content, request);
417+
}
418+
}
419+
if (isToolCallResponsePart(token)) {
420+
const toolCalls = token.tool_calls;
421+
if (toolCalls !== undefined) {
422+
const toolCallContents = toolCalls.map(toolCall =>
423+
this.createToolCallResponseContent(toolCall)
424+
);
425+
return toolCallContents;
426+
}
416427
}
417-
const toolCalls = token.tool_calls;
418-
if (toolCalls !== undefined) {
419-
const toolCallContents = toolCalls.map(toolCall =>
420-
this.createToolCallResponseContent(toolCall)
421-
);
422-
return toolCallContents;
428+
if (isThinkingResponsePart(token)) {
429+
return new ThinkingChatResponseContentImpl(token.thought, token.signature);
423430
}
424431
return this.defaultContentFactory.create('', request);
425432
}

0 commit comments

Comments
 (0)