Skip to content

Commit 6018c83

Browse files
authored
Merge pull request #16 from qiandl2000/mcp_vlm_r1
add sse for mcp and other changes in a4a
2 parents 74a9eab + be8428a commit 6018c83

File tree

21 files changed

+1913
-295
lines changed

21 files changed

+1913
-295
lines changed

examples/mcp_example/server.py

Lines changed: 300 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,308 @@
1-
from vlm_r1 import run_vlm_r1
2-
from mcp.server.fastmcp import FastMCP
1+
import os
2+
import argparse
3+
import sys
4+
import tempfile
5+
import shutil
6+
import uuid
7+
from typing import Optional, Dict, Any, List
8+
import re
9+
from urllib.parse import urlparse
310

4-
# Create an MCP server
5-
mcp = FastMCP("VLM-R1-Server")
11+
# Ensure the latest fastmcp is installed -----------------------------------------------------------
12+
try:
13+
from fastmcp import FastMCP, Context
14+
except ImportError:
15+
print("fastmcp not found – installing from GitHub ...")
16+
import subprocess
17+
subprocess.check_call([
18+
sys.executable,
19+
"-m",
20+
"pip",
21+
"install",
22+
"git+https://github.com/jlowin/fastmcp.git"
23+
])
24+
from fastmcp import FastMCP, Context
625

26+
# Try to import requests for URL handling
27+
try:
28+
import requests
29+
except ImportError:
30+
print("requests not found – installing...")
31+
import subprocess
32+
subprocess.check_call([
33+
sys.executable,
34+
"-m",
35+
"pip",
36+
"install",
37+
"requests"
38+
])
39+
import requests
740

8-
# Add an addition tool
41+
# Import our VLM-R1 model ---------------------------------------------------------------------------
42+
try:
43+
from vlm_r1 import VLMR1
44+
except ImportError as e:
45+
print(f"Error importing VLMR1: {e}")
46+
print("Make sure the src/vlm_r1.py file exists and all dependencies are installed.")
47+
print("You may need to install additional packages, e.g.:\n pip install torch transformers pillow flash-attn bitsandbytes")
48+
sys.exit(1)
49+
50+
# -----------------------------------------------------------------------------------------------
51+
# Create the MCP server instance
52+
mcp = FastMCP("VLM-R1 Server – fastmcp 2.x")
53+
54+
# Keep a global handle to the loaded model so that we only pay the load cost once
55+
_model: Optional[VLMR1] = None
56+
57+
# Global temp directory for downloaded images
58+
_temp_dir = None
59+
60+
def get_temp_dir():
61+
"""Get or create a temporary directory for downloaded images."""
62+
global _temp_dir
63+
if _temp_dir is None or not os.path.exists(_temp_dir):
64+
_temp_dir = tempfile.mkdtemp(prefix="vlm_r1_images_")
65+
return _temp_dir
66+
67+
def is_url(path: str) -> bool:
68+
"""Check if the given string is a URL."""
69+
try:
70+
result = urlparse(path)
71+
return all([result.scheme, result.netloc])
72+
except:
73+
return False
74+
75+
def download_image(url: str) -> str:
76+
"""Download an image from URL and return the local path."""
77+
try:
78+
response = requests.get(url, stream=True, timeout=10)
79+
response.raise_for_status()
80+
81+
# Try to get the filename from the URL or generate a random one
82+
content_type = response.headers.get('Content-Type', '')
83+
if 'image' not in content_type:
84+
raise ValueError(f"URL does not point to an image (content-type: {content_type})")
85+
86+
# Determine file extension
87+
if 'image/jpeg' in content_type or 'image/jpg' in content_type:
88+
ext = '.jpg'
89+
elif 'image/png' in content_type:
90+
ext = '.png'
91+
elif 'image/gif' in content_type:
92+
ext = '.gif'
93+
elif 'image/webp' in content_type:
94+
ext = '.webp'
95+
elif 'image/bmp' in content_type:
96+
ext = '.bmp'
97+
else:
98+
ext = '.jpg' # Default to jpg
99+
100+
# Create temporary file
101+
temp_dir = get_temp_dir()
102+
temp_path = os.path.join(temp_dir, f"{uuid.uuid4()}{ext}")
103+
104+
# Save the image
105+
with open(temp_path, 'wb') as f:
106+
shutil.copyfileobj(response.raw, f)
107+
108+
return temp_path
109+
except Exception as e:
110+
raise ValueError(f"Failed to download image from URL: {str(e)}")
111+
112+
def init_model(
113+
model_path: str,
114+
use_flash_attention: bool = True,
115+
low_cpu_mem_usage: bool = True,
116+
load_in_8bit: bool = False,
117+
specific_device: Optional[str] = None,
118+
):
119+
"""Lazy-load the VLM-R1 model (only once per process)."""
120+
global _model
121+
if _model is None:
122+
print(f"[fastmcp-server] Loading VLM-R1 from '{model_path}' … This can take a few minutes.")
123+
# When we pin the device, we must keep low_cpu_mem_usage=True per transformers semantics
124+
if specific_device is not None:
125+
low_cpu_mem_usage = True
126+
127+
_model = VLMR1.load(
128+
model_path=model_path,
129+
use_flash_attention=use_flash_attention,
130+
low_cpu_mem_usage=low_cpu_mem_usage,
131+
load_in_8bit=load_in_8bit,
132+
specific_device=specific_device,
133+
)
134+
print("[fastmcp-server] Model ready! 🚀")
135+
return _model
136+
137+
# -------------------------------------------------------------------------------------------------
138+
# RESOURCE: expose images so that remote clients can fetch binary data if they wish
139+
@mcp.resource("image://{image_path}")
140+
def image_resource(image_path: str) -> bytes: # noqa: D401
141+
"""Return the raw bytes of *image_path* so that clients can embed / inspect it."""
142+
if is_url(image_path):
143+
try:
144+
local_path = download_image(image_path)
145+
with open(local_path, "rb") as fh:
146+
return fh.read()
147+
except Exception as e:
148+
raise ValueError(f"Failed to fetch image from URL: {str(e)}")
149+
else:
150+
if not os.path.exists(image_path):
151+
raise ValueError(f"Image not found at '{image_path}'.")
152+
with open(image_path, "rb") as fh:
153+
return fh.read()
154+
155+
# -------------------------------------------------------------------------------------------------
156+
# TOOL: generic analyse image
9157
@mcp.tool()
10-
def vlm_r1(image_path, question):
11-
"""A VLM to solve question of an image. Please input the image path and question."""
12-
model_name = "omlab/VLM-R1-Qwen2.5VL-3B-OVD-0321"
13-
if image_path is None:
14-
image_path = "/data0/qdl/test/old_women.png"
158+
async def analyze_image(
159+
image_path: str,
160+
question: Optional[str] = None,
161+
max_new_tokens: int = 1024,
162+
max_image_size: int = 448,
163+
resize_mode: str = "shorter",
164+
ctx: Context | None = None,
165+
) -> Dict[str, Any]:
166+
"""Run the multimodal VLM-R1 model on *image_path*.
167+
168+
The *image_path* can be a local file path or a URL to an image.
169+
The default *question* asks for a detailed description. A custom question can be supplied by
170+
callers. The returned dict mirrors the output of :py:meth:`VLMR1.predict`.
171+
"""
172+
global _model
173+
if _model is None:
174+
_model = init_model(DEFAULT_MODEL_PATH)
175+
176+
if _model is None:
177+
raise RuntimeError("Model not initialised – call init_model() first or start the server with --model-path …")
178+
179+
local_image_path = image_path
180+
181+
# If image_path is a URL, download it
182+
if is_url(image_path):
183+
try:
184+
local_image_path = download_image(image_path)
185+
except Exception as e:
186+
return {"error": f"Failed to download image from URL: {str(e)}"}
187+
elif not os.path.exists(image_path):
188+
return {"error": f"Image not found: {image_path}"}
189+
15190
if question is None:
16-
question = "Describe this image."
17-
return run_vlm_r1(model_name, image_path, question)
191+
question = (
192+
"Describe this image in detail. First output the thinking process in <think></think> tags "
193+
"and then output the final answer in <answer></answer> tags."
194+
)
195+
196+
# Run prediction directly
197+
try:
198+
result = _model.predict(
199+
image_path=local_image_path,
200+
question=question,
201+
max_new_tokens=max_new_tokens,
202+
max_image_size=max_image_size,
203+
resize_mode=resize_mode,
204+
)
205+
return result
206+
except Exception as e:
207+
return {"error": f"Error during prediction: {str(e)}"}
208+
209+
# -------------------------------------------------------------------------------------------------
210+
# TOOL: object detection helper
211+
@mcp.tool()
212+
async def detect_objects(
213+
image_path: str,
214+
max_new_tokens: int = 1024,
215+
max_image_size: int = 448,
216+
ctx: Context | None = None,
217+
) -> Dict[str, Any]:
218+
"""Detect objects in *image_path* using VLM-R1. The image_path can be a local file or URL."""
219+
global _model
220+
if _model is None:
221+
_model = init_model(DEFAULT_MODEL_PATH)
222+
223+
if _model is None:
224+
raise RuntimeError("Model not initialised – call init_model() first or start the server with --model-path …")
225+
226+
local_image_path = image_path
227+
228+
# If image_path is a URL, download it
229+
if is_url(image_path):
230+
try:
231+
local_image_path = download_image(image_path)
232+
except Exception as e:
233+
return {"error": f"Failed to download image from URL: {str(e)}"}
234+
elif not os.path.exists(image_path):
235+
return {"error": f"Image not found: {image_path}"}
236+
237+
# Run prediction directly
238+
try:
239+
result = _model.predict(
240+
image_path=local_image_path,
241+
question="Detect all objects in this image. Provide bounding boxes if possible.",
242+
max_new_tokens=max_new_tokens,
243+
max_image_size=max_image_size,
244+
resize_mode="shorter",
245+
)
246+
return result
247+
except Exception as e:
248+
return {"error": f"Error during prediction: {str(e)}"}
249+
250+
# -------------------------------------------------------------------------------------------------
251+
# TOOL: list available images in a directory
252+
@mcp.tool()
253+
def list_images(directory: str = ".") -> List[str]:
254+
"""Return a list of image files (by path) found in *directory*."""
255+
exts = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
256+
if not os.path.exists(directory):
257+
return {"error": f"Directory not found: {directory}"}
258+
return [os.path.join(directory, f) for f in os.listdir(directory) if os.path.splitext(f)[1].lower() in exts]
259+
260+
# -------------------------------------------------------------------------------------------------
261+
# PROMPT helper – illustrates prompt templates
262+
@mcp.prompt()
263+
def image_analysis_prompt(image_path: str) -> str:
264+
"""Generate a prompt to analyze an image (can be a local file or URL)."""
265+
return (
266+
"Please analyse the image at "
267+
f"{image_path}. First describe what you see, then identify key objects or elements in the image."
268+
)
269+
270+
# -------------------------------------------------------------------------------------------------
271+
# Command-line interface so that users can run this file directly
272+
DEFAULT_MODEL_PATH = "omlab/VLM-R1-Qwen2.5VL-3B-OVD-0321"
273+
274+
def _parse_args() -> argparse.Namespace:
275+
p = argparse.ArgumentParser(description="Run a VLM-R1 server powered by fastmcp 2.x")
276+
p.add_argument("--model-path", default=DEFAULT_MODEL_PATH, help="HuggingFace repo or local checkpoint directory")
277+
p.add_argument("--device", default="cuda:0", help="Device to run on (e.g. cuda:0 or cpu)")
278+
p.add_argument("--use-flash-attention", action="store_true", help="Enable flash-attention kernels if available")
279+
p.add_argument("--low-cpu-mem", action="store_true", help="Load with low CPU memory footprint")
280+
p.add_argument("--load-in-8bit", action="store_true", help="Load in 8-bit precision")
281+
return p.parse_args()
282+
283+
284+
def main():
285+
args = _parse_args()
286+
287+
# Pre-load model so that first request is fast (optional but helpful)
288+
init_model(
289+
model_path=args.model_path,
290+
use_flash_attention=args.use_flash_attention,
291+
low_cpu_mem_usage=args.low_cpu_mem,
292+
load_in_8bit=args.load_in_8bit,
293+
specific_device=args.device,
294+
)
295+
296+
# Create temp directory for downloaded images
297+
get_temp_dir()
298+
299+
try:
300+
mcp.run(transport="sse", host="0.0.0.0", port=8008)
301+
finally:
302+
# Clean up temp directory on exit
303+
if _temp_dir and os.path.exists(_temp_dir):
304+
shutil.rmtree(_temp_dir, ignore_errors=True)
18305

19306

20307
if __name__ == "__main__":
21-
# Initialize and run the server
22-
mcp.run()
308+
main()

examples/mcp_example/test_mcp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
os.environ["OMAGENT_MODE"] = "lite"
2+
os.environ["OMAGENT_MODE"] = "pro"
33
from typing import List
44
from pydantic import Field
55
from omagent_core.models.llms.base import BaseLLMBackend
@@ -25,9 +25,9 @@
2525
class LLMTest(BaseLLMBackend):
2626
llm: OpenaiGPTLLM ={
2727
"name": "OpenaiGPTLLM",
28-
"model_id": "gpt-4o-mini",
28+
"model_id": "gpt-4o",
2929
"api_key": os.getenv("custom_openai_key"),
30-
"endpoint": "https://api.openai.com/v1",
30+
"endpoint": os.getenv("custom_openai_endpoint"),
3131
"vision": False,
3232
"response_format": "text",
3333
"use_default_sys_prompt": False,
@@ -40,6 +40,6 @@ class LLMTest(BaseLLMBackend):
4040
llm_test = LLMTest(workflow_instance_id="temp")
4141

4242
tool_manager = llm_test.tool_manager
43-
x = tool_manager.execute_task("command ls -l for the current directory")
43+
x = tool_manager.execute_task("describe /data0/qdl/test/old_women.png",)
4444
print(x)
4545

examples/mcp_example/test_mcp_vlm-r1.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
os.environ["OMAGENT_MODE"] = "lite"
2+
os.environ["OMAGENT_MODE"] = "pro"
33
from typing import List
44
from pydantic import Field
55
from omagent_core.models.llms.base import BaseLLMBackend
@@ -14,12 +14,8 @@
1414
from omagent_core.utils.general import encode_image, read_image
1515
from omagent_core.tool_system.manager import ToolManager
1616
import asyncio
17-
import os
1817

19-
# 设置 CUDA_VISIBLE_DEVICES 环境变量
20-
os.environ["CUDA_VISIBLE_DEVICES"] = "1" # 只使用第一个 GPU
2118
# Set current working directory path
22-
2319
CURRENT_PATH = Path(__file__).parents[0]
2420

2521
# Import registered modules
@@ -44,7 +40,6 @@ class LLMTest(BaseLLMBackend):
4440
llm_test = LLMTest(workflow_instance_id="temp")
4541

4642
tool_manager = llm_test.tool_manager
47-
#x = tool_manager.execute_task("请描述一下这张图片'/data0/qdl/test/old_women.png'",)
48-
x = tool_manager.execute_task("你能使用哪些工具?",)
43+
x = tool_manager.execute_task("describe /data0/qdl/test/old_women.png",)
4944
print(x)
5045

0 commit comments

Comments
 (0)