Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions frontends/ui/src/adapters/api/websocket-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ export interface NATWebSocketClientOptions {
reconnectAttempts?: number
/** Delay between reconnection attempts in ms (default: 1000) */
reconnectDelay?: number
/** Auth token for backend authentication */
authToken?: string
/** Override WebSocket URL (uses same-origin by default, proxied through UI server) */
websocketUrl?: string
}
Expand Down Expand Up @@ -105,13 +103,7 @@ export class NATWebSocketClient {
this.options.callbacks.onConnectionChange?.('connecting')

try {
const baseWsUrl = this.options.websocketUrl || (await getWebSocketUrl())
let wsUrl = baseWsUrl
if (this.options.authToken) {
const separator = wsUrl.includes('?') ? '&' : '?'
wsUrl = `${wsUrl}${separator}token=${encodeURIComponent(this.options.authToken)}`
}

const wsUrl = this.options.websocketUrl || (await getWebSocketUrl())
this.ws = new WebSocket(wsUrl)
this.setupEventHandlers()
} catch {
Expand Down Expand Up @@ -199,13 +191,6 @@ export class NATWebSocketClient {
return this.ws?.readyState === WebSocket.OPEN
}

/**
* Update the auth token (e.g., after refresh)
*/
updateAuthToken = (authToken: string): void => {
this.options.authToken = authToken
}

/**
* Update conversation ID (e.g., when switching conversations)
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ const mockWsClient = {
sendMessage: vi.fn(),
sendInteractionResponse: vi.fn(),
isConnected: vi.fn(() => false),
updateAuthToken: vi.fn(),
updateConversationId: vi.fn(),
}

Expand Down
31 changes: 4 additions & 27 deletions frontends/ui/src/features/chat/hooks/use-websocket-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import { checkBackendHealthCached, invalidateHealthCache } from '@/shared/hooks/
import { useChatStore } from '../store'
import { useConnectionRecovery } from './use-connection-recovery'
import { useLayoutStore } from '@/features/layout/store'
import { WEB_SEARCH_SOURCE_ID } from '@/features/layout/data-sources'
import { useDocumentsStore } from '@/features/documents/store'
import { useAuth } from '@/adapters/auth'
import type {
Expand Down Expand Up @@ -153,9 +152,7 @@ export const useWebSocketChat = (options: UseWebSocketChatOptions = {}): UseWebS
// Ref to track the current status for detecting status changes
const currentStatusRef = useRef<StatusType | null>(null)

// Auth hook for getting user and token
// Note: idToken is used for backend auth, not accessToken
const { user, idToken } = useAuth()
const { user } = useAuth()

// Chat store
const {
Expand Down Expand Up @@ -562,7 +559,6 @@ export const useWebSocketChat = (options: UseWebSocketChatOptions = {}): UseWebS
wsClientRef.current = createNATWebSocketClient({
conversationId: currentConversation.id,
callbacks: createCallbacks(),
authToken: idToken,
})
wsClientRef.current.connect()
} else {
Expand All @@ -577,16 +573,7 @@ export const useWebSocketChat = (options: UseWebSocketChatOptions = {}): UseWebS
wsClientRef.current = null
}
}
}, [currentConversation?.id, autoConnect, idToken, createCallbacks])

/**
* Update auth token when it changes
*/
useEffect(() => {
if (wsClientRef.current && idToken) {
wsClientRef.current.updateAuthToken(idToken)
}
}, [idToken])
}, [currentConversation?.id, autoConnect, createCallbacks])

/**
* Send a message via WebSocket
Expand All @@ -597,14 +584,7 @@ export const useWebSocketChat = (options: UseWebSocketChatOptions = {}): UseWebS

// Collect metadata about data sources and files before adding user message
const layoutState = useLayoutStore.getState()
let enabledDataSources = layoutState.enabledDataSourceIds

// Filter out authenticated sources if user doesn't have a valid idToken
if (!idToken) {
enabledDataSources = enabledDataSources.filter(
(sourceId) => sourceId === WEB_SEARCH_SOURCE_ID
)
}
const enabledDataSources = layoutState.enabledDataSourceIds

// Get session files
const sessionId = useChatStore.getState().currentConversation?.id
Expand Down Expand Up @@ -683,7 +663,6 @@ export const useWebSocketChat = (options: UseWebSocketChatOptions = {}): UseWebS
wsClientRef.current = createNATWebSocketClient({
conversationId,
callbacks,
authToken: idToken,
})
wsClientRef.current.connect()
} else {
Expand All @@ -702,7 +681,6 @@ export const useWebSocketChat = (options: UseWebSocketChatOptions = {}): UseWebS
setCurrentStatus,
setStreaming,
setLoading,
idToken,
createCallbacks,
]
)
Expand Down Expand Up @@ -771,11 +749,10 @@ export const useWebSocketChat = (options: UseWebSocketChatOptions = {}): UseWebS
wsClientRef.current = createNATWebSocketClient({
conversationId: currentConversation.id,
callbacks: createCallbacks(),
authToken: idToken,
})
wsClientRef.current.connect()
}
}, [currentConversation, idToken, createCallbacks])
}, [currentConversation, createCallbacks])

// Activate recovery polling when connection error cards are visible
useConnectionRecovery(connect)
Expand Down
4 changes: 4 additions & 0 deletions frontends/ui/src/lib/pdf/ReactPdfDocument.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ Font.register({
fonts: [{ src: 'Helvetica' }, { src: 'Helvetica-Bold', fontWeight: 'bold' }],
})

// Disable hyphenation — react-pdf's default English hyphenation inserts
// hyphens into long words (including URLs), making them unusable when copied.
Font.registerHyphenationCallback((word) => [word])

const styles = StyleSheet.create({
page: {
padding: 56.69,
Expand Down
85 changes: 72 additions & 13 deletions src/aiq_agent/agents/deep_researcher/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import logging
import re
from collections.abc import Sequence
from datetime import datetime
from pathlib import Path
Expand All @@ -14,6 +15,7 @@
from deepagents.backends import CompositeBackend
from deepagents.backends import StateBackend
from langchain.agents.middleware import ModelRetryMiddleware
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from langchain_core.tools import BaseTool
from langchain_core.tools import tool
Expand All @@ -36,6 +38,11 @@

logger = logging.getLogger(__name__)

# Minimum character count for a report to be considered substantive.
# Used by both _extract_report_content (to decide if write_file fallback is needed)
# and _is_report_complete (to reject too-short reports).
_MIN_REPORT_LENGTH = 1500

# Path to this agent's directory (for loading prompts)
AGENT_DIR = Path(__file__).parent

Expand Down Expand Up @@ -255,6 +262,29 @@ def backend(runtime):
)
return agent.with_config({"recursion_limit": 1000})

@staticmethod
def _extract_report_content(messages: list) -> str:
"""Extract report content from the last message, falling back to write_file tool calls if text is too short."""
if not messages:
return ""
last_msg = messages[-1]
raw = last_msg.content or ""
if isinstance(raw, list):
content = " ".join(p.get("text", "") for p in raw if isinstance(p, dict) and p.get("type") == "text")
else:
content = raw if isinstance(raw, str) else str(raw)
if len(content) >= _MIN_REPORT_LENGTH:
return content
# If the last message is an AIMessage with a write_file tool call,
# the LLM may have written the report via tool instead of text output.
if isinstance(last_msg, AIMessage) and getattr(last_msg, "tool_calls", None):
for tc in last_msg.tool_calls:
if tc.get("name") == "write_file":
file_content = tc.get("args", {}).get("content", "")
if isinstance(file_content, str) and len(file_content) > len(content):
content = file_content
return content

def _is_report_complete(self, result: dict | Any) -> tuple[bool, str]:
"""
Check if the agent produced a complete report using tool calls or heuristics.
Expand All @@ -266,10 +296,9 @@ def _is_report_complete(self, result: dict | Any) -> tuple[bool, str]:
if not messages:
return False, "no_messages"

last_msg = messages[-1]
content = last_msg.content or ""
content = self._extract_report_content(messages)

if len(content) < 1500:
if len(content) < _MIN_REPORT_LENGTH:
return False, f"too_short ({len(content)} chars)"

if content.count("## ") < 2:
Expand Down Expand Up @@ -299,7 +328,7 @@ def _is_report_complete(self, result: dict | Any) -> tuple[bool, str]:
url_match = _URL_IN_LINE_RE.search(ref_text)
if url_match:
url = url_match.group(0).rstrip(".,;)")
if registry.has_url(url):
if registry.resolve_url(url):
has_any_valid = True
break
continue
Expand Down Expand Up @@ -391,7 +420,12 @@ async def run(self, state: DeepResearchAgentState) -> DeepResearchAgentState:
if source_list:
feedback_msg += "\n\n" + source_list

feedback_msg += " Please fix this immediately and call 'submit_final_report' when done."
feedback_msg += (
" IMPORTANT: Do NOT restart the research from scratch."
" First check if /report.md already exists using read_file."
" If it does, use that content as your report — just fix the specific issue above"
" and return the corrected report in your final message."
)

if isinstance(result, dict):
next_state = {**result}
Expand All @@ -400,27 +434,52 @@ async def run(self, state: DeepResearchAgentState) -> DeepResearchAgentState:
next_state = result.model_dump() if hasattr(result, "model_dump") else dict(result)
messages = getattr(result, "messages", next_state.get("messages", []))
next_state["messages"] = list(messages) + [HumanMessage(content=feedback_msg)]
result = await agent.ainvoke(
next_state,
config={"callbacks": self.callbacks} if self.callbacks else None,
)

try:
result = await agent.ainvoke(
next_state,
config={"callbacks": self.callbacks} if self.callbacks else None,
)
last_error = None
except Exception as ex:
logger.error("Deep Research feedback retry %d failed: %s", attempt + 1, ex, exc_info=True)
last_error = ex
if "recursion" in str(ex).lower() or "reuse already awaited" in str(ex):
raise ex
# Non-fatal: ainvoke raised before producing a result, so
# `result` still holds the previous iteration's value.
# The next loop iteration will rebuild next_state from it.
continue

# Evaluate the feedback-retry result before the next iteration
is_complete, reason = self._is_report_complete(result)
if is_complete:
logger.info(f"Report completed after feedback retry. Reason: {reason}")
break

# Update state so next iteration builds on progress, not the original state
state = result

if result is None and last_error is not None:
raise last_error

final_message = "Research failed to produce a report."
if result and result.get("messages"):
final_content = result["messages"][-1].content
final_message = final_content if isinstance(final_content, str) else str(final_content)
final_message = self._extract_report_content(result["messages"])

# Post-process: verify citations against source registry
if self.source_registry_middleware._get_registry().all_sources():
verification = verify_citations(final_message, self.source_registry_middleware._get_registry())
if verification.removed_citations:
removed_details = []
for c in verification.removed_citations:
url_match = re.search(r"https?://\S+", c.get("line", ""))
url_str = url_match.group(0).rstrip(".,;)") if url_match else "(no url)"
removed_details.append(f"[{c['number']}] {c['reason']}: {url_str}")
logger.info(
"Citation verification removed %d invalid citations: %s",
"Citation verification removed %d invalid citation(s):\n %s",
len(verification.removed_citations),
[c["reason"] for c in verification.removed_citations],
"\n ".join(removed_details),
)
final_message = verification.verified_report
else:
Expand Down
Loading
Loading