Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions src/extension/agents/claude/node/claudeCodeAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ export class ClaudeAgentManager extends Disposable {
super();
}

public async handleRequest(claudeSessionId: string | undefined, request: vscode.ChatRequest, _context: vscode.ChatContext, stream: vscode.ChatResponseStream, token: vscode.CancellationToken, modelId?: string, permissionMode?: PermissionMode): Promise<vscode.ChatResult & { claudeSessionId?: string }> {
public async handleRequest(claudeSessionId: string | undefined, request: vscode.ChatRequest, _context: vscode.ChatContext, stream: vscode.ChatResponseStream, token: vscode.CancellationToken, modelId: string, permissionMode?: PermissionMode): Promise<vscode.ChatResult & { claudeSessionId?: string }> {
try {
// Get server config, start server if needed
const serverConfig = (await this.getLangModelServer()).getConfig();
const langModelServer = await this.getLangModelServer();
const serverConfig = langModelServer.getConfig();

const sessionIdForLog = claudeSessionId ?? 'new';
this.logService.trace(`[ClaudeAgentManager] Handling request for sessionId=${sessionIdForLog}, modelId=${modelId}, permissionMode=${permissionMode}.`);
Expand All @@ -64,7 +65,7 @@ export class ClaudeAgentManager extends Disposable {
session = this._sessions.get(claudeSessionId)!;
} else {
this.logService.trace(`[ClaudeAgentManager] Creating Claude session for sessionId=${sessionIdForLog}.`);
const newSession = this.instantiationService.createInstance(ClaudeCodeSession, serverConfig, claudeSessionId, modelId, permissionMode);
const newSession = this.instantiationService.createInstance(ClaudeCodeSession, serverConfig, langModelServer, claudeSessionId, modelId, permissionMode);
if (newSession.sessionId) {
this._sessions.set(newSession.sessionId, newSession);
}
Expand Down Expand Up @@ -173,7 +174,7 @@ export class ClaudeCodeSession extends Disposable {
private _abortController = new AbortController();
private _editTracker = new ExternalEditTracker();
private _settingsChangeTracker: ClaudeSettingsChangeTracker;
private _currentModelId: string | undefined;
private _currentModelId: string;
private _currentPermissionMode: PermissionMode;

/**
Expand Down Expand Up @@ -202,8 +203,9 @@ export class ClaudeCodeSession extends Disposable {

constructor(
private readonly serverConfig: IClaudeLanguageModelServerConfig,
private readonly langModelServer: ClaudeLanguageModelServer,
public sessionId: string | undefined,
initialModelId: string | undefined,
initialModelId: string,
initialPermissionMode: PermissionMode | undefined,
@ILogService private readonly logService: ILogService,
@IWorkspaceService private readonly workspaceService: IWorkspaceService,
Expand Down Expand Up @@ -284,14 +286,14 @@ export class ClaudeCodeSession extends Disposable {
* @param toolInvocationToken Token for invoking tools
* @param stream Response stream for sending results back to VS Code
* @param token Cancellation token for request cancellation
* @param modelId Optional model ID to use for this request
* @param modelId Model ID to use for this request
*/
public async invoke(
prompt: string,
toolInvocationToken: vscode.ChatParticipantToolToken,
stream: vscode.ChatResponseStream,
token: vscode.CancellationToken,
modelId?: string,
modelId: string,
permissionMode?: PermissionMode
): Promise<void> {
if (this._store.isDisposed) {
Expand All @@ -309,9 +311,8 @@ export class ClaudeCodeSession extends Disposable {
}

// Update model and permission mode on active session if they changed
if (modelId !== undefined) {
await this._setModel(modelId);
}
await this._setModel(modelId);

if (permissionMode !== undefined) {
await this._setPermissionMode(permissionMode);
}
Expand Down Expand Up @@ -381,7 +382,7 @@ export class ClaudeCodeSession extends Disposable {
},
resume: this.sessionId,
// Pass the model selection to the SDK
...(this._currentModelId !== undefined ? { model: this._currentModelId } : {}),
model: this._currentModelId,
// Pass the permission mode to the SDK
...(this._currentPermissionMode !== undefined ? { permissionMode: this._currentPermissionMode } : {}),
hooks: this._buildHooks(token),
Expand Down Expand Up @@ -478,6 +479,10 @@ export class ClaudeCodeSession extends Disposable {
token: request.token
};

// Increment user-initiated message count for this model
// This is used by the language model server to track which requests are user-initiated
this.langModelServer.incrementUserInitiatedMessageCount(this._currentModelId);

yield {
type: 'user',
message: {
Expand Down
53 changes: 16 additions & 37 deletions src/extension/agents/claude/node/claudeLanguageModelServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@ import { SSEParser } from '../../../../util/vs/base/common/sseParser';
import { generateUuid } from '../../../../util/vs/base/common/uuid';
import { IInstantiationService } from '../../../../util/vs/platform/instantiation/common/instantiation';

/**
* Marker string to identify user-initiated messages from VS Code in the Messages API.
*/
export const VSCODE_USER_INITIATED_MESSAGE_MARKER = '__vscode_user_initiated_message__';

export interface IClaudeLanguageModelServerConfig {
readonly port: number;
readonly nonce: string;
Expand Down Expand Up @@ -64,6 +59,7 @@ const DEFAULT_MAX_OUTPUT_TOKENS = 64_000;
export class ClaudeLanguageModelServer extends Disposable {
private server: http.Server;
private config: IClaudeLanguageModelServerConfig;
private readonly _userInitiatedMessageCounts = new Map<string, number>();

constructor(
@ILogService private readonly logService: ILogService,
Expand Down Expand Up @@ -164,38 +160,6 @@ export class ClaudeLanguageModelServer extends Disposable {
try {
const requestBody: AnthropicMessagesRequest = JSON.parse(bodyString);

// Determine if this is a user-initiated message
const lastMessage = requestBody.messages?.at(-1);
const lastContentItems = !lastMessage || typeof lastMessage.content === 'string'
? []
: lastMessage.content;

// Find the index of the marker content item if it exists
const markerIndex = lastContentItems.findIndex(
c => c.type === 'text' &&
// Our marker
c.text.includes(VSCODE_USER_INITIATED_MESSAGE_MARKER) &&
// The name of the hook we are using
c.text.includes('UserPromptSubmit')
);

const isUserInitiatedMessage =
// A user initiated message would only be of role 'user'
lastMessage?.role === 'user' &&
// We expect our marker AND the user's actual message so there will be multiple content items
lastContentItems.length > 1 &&
// The marker must be in a preceding content item, not the last one (which is the actual user message)
markerIndex !== -1 &&
markerIndex !== lastContentItems.length - 1;

// Remove the marker content item and the one before it (which just provides the status of our hook)
// so they don't influence the request
if (isUserInitiatedMessage) {
// Remove marker and its preceding item (if it exists)
const indicesToRemove = markerIndex > 0 ? [markerIndex - 1, markerIndex] : [markerIndex];
lastMessage.content = lastContentItems.filter((_, i) => !indicesToRemove.includes(i));
}

const allEndpoints = await this.endpointProvider.getAllChatEndpoints();
// Filter to only endpoints that support the Messages API
const endpoints = allEndpoints.filter(e => e.apiType === 'messages');
Expand All @@ -212,6 +176,12 @@ export class ClaudeLanguageModelServer extends Disposable {
return;
}
requestBody.model = selectedEndpoint.model;
// Determine if this is a user-initiated message using counter-based approach
const count = this._userInitiatedMessageCounts.get(selectedEndpoint.model) ?? 0;
const isUserInitiatedMessage = count > 0;
if (isUserInitiatedMessage) {
this._userInitiatedMessageCounts.set(selectedEndpoint.model, count - 1);
}

// Set up streaming response
res.writeHead(200, {
Expand Down Expand Up @@ -355,6 +325,15 @@ export class ClaudeLanguageModelServer extends Disposable {
return { ...this.config };
}

/**
* Increments the user-initiated message count for a given model.
* Called when a user sends a new message in a Claude session.
*/
public incrementUserInitiatedMessageCount(modelId: string): void {
const current = this._userInitiatedMessageCounts.get(modelId) ?? 0;
this._userInitiatedMessageCounts.set(modelId, current + 1);
}

private info(message: string): void {
const messageWithClassName = `[ClaudeLanguageModelServer] ${message}`;
this.logService.info(messageWithClassName);
Expand Down
11 changes: 1 addition & 10 deletions src/extension/agents/claude/node/hooks/loggingHooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import { ILogService } from '../../../../../platform/log/common/logService';
import { CapturingToken } from '../../../../../platform/requestLogger/common/capturingToken';
import { IClaudeSessionStateService } from '../claudeSessionStateService';
import { registerClaudeHook } from './claudeHookRegistry';
import { VSCODE_USER_INITIATED_MESSAGE_MARKER } from '../claudeLanguageModelServer';

/**
* Logging hook for Notification events.
Expand Down Expand Up @@ -61,15 +60,7 @@ export class UserPromptSubmitLoggingHook implements HookCallbackMatcher {
// Create a capturing token for this request to group tool calls under the request
const capturingToken = new CapturingToken(hookInput.prompt, 'sparkle', false);
this.sessionStateService.setCapturingTokenForSession(hookInput.session_id, capturingToken);

return {
continue: true,
// Mark this message as user-initiated for downstream processing of PRUs
hookSpecificOutput: {
hookEventName: 'UserPromptSubmit',
additionalContext: VSCODE_USER_INITIATED_MESSAGE_MARKER
}
};
return { continue: true };
}
}
registerClaudeHook('UserPromptSubmit', UserPromptSubmitLoggingHook);
Expand Down
58 changes: 38 additions & 20 deletions src/extension/agents/claude/node/test/claudeCodeAgent.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,17 @@ import { createExtensionUnitTestingServices } from '../../../../test/node/servic
import { MockChatResponseStream, TestChatRequest } from '../../../../test/node/testHelpers';
import { ClaudeAgentManager, ClaudeCodeSession } from '../claudeCodeAgent';
import { IClaudeCodeSdkService } from '../claudeCodeSdkService';
import { ClaudeLanguageModelServer } from '../claudeLanguageModelServer';
import { MockClaudeCodeSdkService } from './mockClaudeCodeSdkService';

function createMockLangModelServer(): ClaudeLanguageModelServer {
return {
incrementUserInitiatedMessageCount: vi.fn()
} as unknown as ClaudeLanguageModelServer;
}

const TEST_MODEL_ID = 'claude-3-sonnet';

describe('ClaudeAgentManager', () => {
const store = new DisposableStore();
let instantiationService: IInstantiationService;
Expand Down Expand Up @@ -41,7 +50,7 @@ describe('ClaudeAgentManager', () => {
const stream1 = new MockChatResponseStream();

const req1 = new TestChatRequest('Hi');
const res1 = await manager.handleRequest(undefined, req1, { history: [] } as any, stream1, CancellationToken.None);
const res1 = await manager.handleRequest(undefined, req1, { history: [] } as any, stream1, CancellationToken.None, TEST_MODEL_ID);

expect(stream1.output.join('\n')).toContain('Hello from mock!');
expect(res1.claudeSessionId).toBe('sess-1');
Expand All @@ -50,7 +59,7 @@ describe('ClaudeAgentManager', () => {
const stream2 = new MockChatResponseStream();

const req2 = new TestChatRequest('Again');
const res2 = await manager.handleRequest(res1.claudeSessionId, req2, { history: [] } as any, stream2, CancellationToken.None);
const res2 = await manager.handleRequest(res1.claudeSessionId, req2, { history: [] } as any, stream2, CancellationToken.None, TEST_MODEL_ID);

expect(stream2.output.join('\n')).toContain('Hello from mock!');
expect(res2.claudeSessionId).toBe('sess-1');
Expand Down Expand Up @@ -80,24 +89,26 @@ describe('ClaudeCodeSession', () => {

it('processes a single request correctly', async () => {
const serverConfig = { port: 8080, nonce: 'test-nonce' };
const session = store.add(instantiationService.createInstance(ClaudeCodeSession, serverConfig, 'test-session', undefined, undefined));
const mockServer = createMockLangModelServer();
const session = store.add(instantiationService.createInstance(ClaudeCodeSession, serverConfig, mockServer, 'test-session', TEST_MODEL_ID, undefined));
const stream = new MockChatResponseStream();

await session.invoke('Hello', {} as vscode.ChatParticipantToolToken, stream, CancellationToken.None);
await session.invoke('Hello', {} as vscode.ChatParticipantToolToken, stream, CancellationToken.None, TEST_MODEL_ID);

expect(stream.output.join('\n')).toContain('Hello from mock!');
});

it('queues multiple requests and processes them sequentially', async () => {
const serverConfig = { port: 8080, nonce: 'test-nonce' };
const session = store.add(instantiationService.createInstance(ClaudeCodeSession, serverConfig, 'test-session', undefined, undefined));
const mockServer = createMockLangModelServer();
const session = store.add(instantiationService.createInstance(ClaudeCodeSession, serverConfig, mockServer, 'test-session', TEST_MODEL_ID, undefined));

const stream1 = new MockChatResponseStream();
const stream2 = new MockChatResponseStream();

// Start both requests simultaneously
const promise1 = session.invoke('First', {} as vscode.ChatParticipantToolToken, stream1, CancellationToken.None);
const promise2 = session.invoke('Second', {} as vscode.ChatParticipantToolToken, stream2, CancellationToken.None);
const promise1 = session.invoke('First', {} as vscode.ChatParticipantToolToken, stream1, CancellationToken.None, TEST_MODEL_ID);
const promise2 = session.invoke('Second', {} as vscode.ChatParticipantToolToken, stream2, CancellationToken.None, TEST_MODEL_ID);

// Wait for both to complete
await Promise.all([promise1, promise2]);
Expand All @@ -109,31 +120,35 @@ describe('ClaudeCodeSession', () => {

it('cancels pending requests when cancelled', async () => {
const serverConfig = { port: 8080, nonce: 'test-nonce' };
const session = store.add(instantiationService.createInstance(ClaudeCodeSession, serverConfig, 'test-session', undefined, undefined));
const mockServer = createMockLangModelServer();
const session = store.add(instantiationService.createInstance(ClaudeCodeSession, serverConfig, mockServer, 'test-session', TEST_MODEL_ID, undefined));
const stream = new MockChatResponseStream();
const source = new CancellationTokenSource();
source.cancel();

await expect(session.invoke('Hello', {} as vscode.ChatParticipantToolToken, stream, source.token)).rejects.toThrow();
await expect(session.invoke('Hello', {} as vscode.ChatParticipantToolToken, stream, source.token, TEST_MODEL_ID)).rejects.toThrow();
});

it('cleans up resources when disposed', async () => {
const serverConfig = { port: 8080, nonce: 'test-nonce' };
const session = instantiationService.createInstance(ClaudeCodeSession, serverConfig, 'test-session', undefined, undefined);
const mockServer = createMockLangModelServer();
const session = instantiationService.createInstance(ClaudeCodeSession, serverConfig, mockServer, 'test-session', TEST_MODEL_ID, undefined);

// Dispose the session immediately
session.dispose();

// Any new requests should be rejected
const stream = new MockChatResponseStream();
await expect(session.invoke('Hello', {} as vscode.ChatParticipantToolToken, stream, CancellationToken.None))
await expect(session.invoke('Hello', {} as vscode.ChatParticipantToolToken, stream, CancellationToken.None, TEST_MODEL_ID))
.rejects.toThrow('Session disposed');
});

it('handles multiple sessions with different session IDs', async () => {
const serverConfig = { port: 8080, nonce: 'test-nonce' };
const session1 = store.add(instantiationService.createInstance(ClaudeCodeSession, serverConfig, 'session-1', undefined, undefined));
const session2 = store.add(instantiationService.createInstance(ClaudeCodeSession, serverConfig, 'session-2', undefined, undefined));
const mockServer1 = createMockLangModelServer();
const mockServer2 = createMockLangModelServer();
const session1 = store.add(instantiationService.createInstance(ClaudeCodeSession, serverConfig, mockServer1, 'session-1', TEST_MODEL_ID, undefined));
const session2 = store.add(instantiationService.createInstance(ClaudeCodeSession, serverConfig, mockServer2, 'session-2', TEST_MODEL_ID, undefined));

expect(session1.sessionId).toBe('session-1');
expect(session2.sessionId).toBe('session-2');
Expand All @@ -143,8 +158,8 @@ describe('ClaudeCodeSession', () => {

// Both sessions should work independently
await Promise.all([
session1.invoke('Hello from session 1', {} as vscode.ChatParticipantToolToken, stream1, CancellationToken.None),
session2.invoke('Hello from session 2', {} as vscode.ChatParticipantToolToken, stream2, CancellationToken.None)
session1.invoke('Hello from session 1', {} as vscode.ChatParticipantToolToken, stream1, CancellationToken.None, TEST_MODEL_ID),
session2.invoke('Hello from session 2', {} as vscode.ChatParticipantToolToken, stream2, CancellationToken.None, TEST_MODEL_ID)
]);

expect(stream1.output.join('\n')).toContain('Hello from mock!');
Expand All @@ -153,25 +168,27 @@ describe('ClaudeCodeSession', () => {

it('initializes with model ID from constructor', async () => {
const serverConfig = { port: 8080, nonce: 'test-nonce' };
const session = store.add(instantiationService.createInstance(ClaudeCodeSession, serverConfig, 'test-session', 'claude-3-opus', undefined));
const mockServer = createMockLangModelServer();
const session = store.add(instantiationService.createInstance(ClaudeCodeSession, serverConfig, mockServer, 'test-session', 'claude-3-opus', undefined));
const stream = new MockChatResponseStream();

await session.invoke('Hello', {} as vscode.ChatParticipantToolToken, stream, CancellationToken.None);
await session.invoke('Hello', {} as vscode.ChatParticipantToolToken, stream, CancellationToken.None, 'claude-3-opus');

expect(stream.output.join('\n')).toContain('Hello from mock!');
});

it('calls setModel when model changes instead of restarting session', async () => {
const serverConfig = { port: 8080, nonce: 'test-nonce' };
const mockServer = createMockLangModelServer();
const mockService = instantiationService.invokeFunction(accessor => accessor.get(IClaudeCodeSdkService)) as MockClaudeCodeSdkService;
mockService.queryCallCount = 0;
mockService.setModelCallCount = 0;

const session = store.add(instantiationService.createInstance(ClaudeCodeSession, serverConfig, 'test-session', 'claude-3-sonnet', undefined));
const session = store.add(instantiationService.createInstance(ClaudeCodeSession, serverConfig, mockServer, 'test-session', 'claude-3-sonnet', undefined));

// First request with initial model
const stream1 = new MockChatResponseStream();
await session.invoke('Hello', {} as vscode.ChatParticipantToolToken, stream1, CancellationToken.None);
await session.invoke('Hello', {} as vscode.ChatParticipantToolToken, stream1, CancellationToken.None, 'claude-3-sonnet');
expect(mockService.queryCallCount).toBe(1);

// Second request with different model should call setModel on existing session
Expand All @@ -184,10 +201,11 @@ describe('ClaudeCodeSession', () => {

it('does not restart session when same model is used', async () => {
const serverConfig = { port: 8080, nonce: 'test-nonce' };
const mockServer = createMockLangModelServer();
const mockService = instantiationService.invokeFunction(accessor => accessor.get(IClaudeCodeSdkService)) as MockClaudeCodeSdkService;
mockService.queryCallCount = 0;

const session = store.add(instantiationService.createInstance(ClaudeCodeSession, serverConfig, 'test-session', 'claude-3-sonnet', undefined));
const session = store.add(instantiationService.createInstance(ClaudeCodeSession, serverConfig, mockServer, 'test-session', 'claude-3-sonnet', undefined));

// First request
const stream1 = new MockChatResponseStream();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export class ClaudeChatSessionParticipant {
return slashResult.result ?? {};
}

const create = async (modelId?: string, permissionMode?: PermissionMode) => {
const create = async (modelId: string, permissionMode?: PermissionMode) => {
const result = await this.claudeAgentManager.handleRequest(undefined, request, context, stream, token, modelId, permissionMode);
if (!result.claudeSessionId) {
// Only show generic warning if we didn't already show a specific error
Expand Down