diff --git a/CHANGELOG.md b/CHANGELOG.md
index 8222f89cf5..14c7b34671 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -22,6 +22,7 @@ All notable changes to this project will be documented in this file.
- Documentation and runtime warning for `dataset_mixer_list` format (float=proportion, int=count) (https://github.com/allenai/open-instruct/pull/1434).
### Changed
+- Refactor Legacy and DRTulu tool parsers to use OpenAI-format `tool_definitions` instead of Ray `tool_actors`. Removes `import ray` from `parsers.py`, fixes DRTulu parser which was broken after the pool refactor, and fixes `--tool_parser_type` typo in dr_tulu debug script (https://github.com/allenai/open-instruct/pull/1491).
- Replaces lambda collators with a "single_example_collator" (https://github.com/allenai/open-instruct/pull/1472).
- Clarified `activation_memory_budget` guidance in DPO utils with a practical default (`0.5`) and memory/speed tradeoff notes (https://github.com/allenai/open-instruct/pull/1460).
- Let TransformerTrainModule handle FSDP parallelism instead of manual application in DPO (https://github.com/allenai/open-instruct/pull/1458).
diff --git a/open_instruct/environments/tools/parsers.py b/open_instruct/environments/tools/parsers.py
index ec000968be..7e6e8f60ed 100644
--- a/open_instruct/environments/tools/parsers.py
+++ b/open_instruct/environments/tools/parsers.py
@@ -18,7 +18,6 @@
from dataclasses import dataclass, field
from typing import Any
-import ray
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.tool_parsers import ToolParser as VllmNativeToolParser
@@ -55,14 +54,11 @@ class OpenInstructLegacyToolParser(ToolParser):
Tools are invoked via content tags.
The content between tags is passed to the tool's first required parameter.
Only works for tools that take a single string parameter.
+
+ Tool names and parameter names are derived from OpenAI-format tool definitions.
"""
- def __init__(
- self,
- tool_actors: list[ray.actor.ActorHandle] | None = None,
- output_wrap_name: str = "output",
- tool_definitions: list[dict[str, Any]] | None = None,
- ):
+ def __init__(self, tool_definitions: list[dict[str, Any]] | None = None, output_wrap_name: str = "output"):
self.output_wrap_name = output_wrap_name
if tool_definitions:
@@ -78,17 +74,6 @@ def __init__(
else:
properties = params.get("properties", {})
self.tool_param_names[name] = next(iter(properties)) if properties else "text"
- elif tool_actors:
- self.tool_names = [ray.get(actor.get_call_name.remote()) for actor in tool_actors]
- self.tool_param_names = {}
- for actor, tool_name in zip(tool_actors, self.tool_names):
- params = ray.get(actor.get_parameters.remote())
- required = params.get("required", [])
- if required:
- self.tool_param_names[tool_name] = required[0]
- else:
- properties = params.get("properties", {})
- self.tool_param_names[tool_name] = next(iter(properties)) if properties else "text"
else:
self.tool_names = []
self.tool_param_names = {}
@@ -288,21 +273,23 @@ class DRTuluToolParser(ToolParser):
"""
Parser for DR Tulu style tool calls. Delegates actual parsing to the tool itself.
Only detects that a tool call occurred (via stop strings) and passes text to the tool.
+
+ Requires exactly one tool (dr_agent_mcp) in tool_definitions.
"""
- def __init__(self, tool_actors: list[ray.actor.ActorHandle]):
- if len(tool_actors) != 1:
- raise ValueError(f"DRTuluToolParser requires exactly one tool (dr_agent_mcp), got {len(tool_actors)}")
+ def __init__(self, tool_definitions: list[dict[str, Any]], stop_sequences: list[str]):
+ if len(tool_definitions) != 1:
+ raise ValueError(f"DRTuluToolParser requires exactly one tool (dr_agent_mcp), got {len(tool_definitions)}")
- actor = tool_actors[0]
- self.tool_call_name = ray.get(actor.get_call_name.remote())
+ self.tool_call_name = tool_definitions[0]["function"]["name"]
if self.tool_call_name != "dr_agent_mcp":
raise ValueError(f"DRTuluToolParser requires dr_agent_mcp tool, got {self.tool_call_name}")
- stop_strings = ray.get(actor.get_stop_strings.remote())
- # Use dict.fromkeys to deduplicate while preserving order
- self.stop_sequences = list(dict.fromkeys(stop_strings)) if stop_strings else []
+ self.stop_sequences = list(dict.fromkeys(stop_sequences)) if stop_sequences else []
+
+ if not self.stop_sequences:
+ logger.warning("DRTuluToolParser initialized with no stop sequences — tool calls will never be detected")
def get_tool_calls(self, text: str) -> list[EnvCall]:
for stop in self.stop_sequences:
@@ -322,8 +309,8 @@ def get_available_parsers() -> list[str]:
def create_tool_parser(
parser_type: str,
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
- tool_actors: list[ray.actor.ActorHandle],
tool_definitions: list[dict[str, Any]] | None = None,
+ stop_sequences: list[str] | None = None,
) -> ToolParser:
"""Create a tool parser instance by type.
@@ -333,8 +320,8 @@ def create_tool_parser(
- "dr_tulu": DRTuluToolParser for content format
- "vllm_*": VllmToolParser variants (vllm_hermes, vllm_llama3_json, vllm_olmo3)
tokenizer: Tokenizer for the model (required for all parser types).
- tool_actors: List of Ray actor handles for the tools.
- tool_definitions: OpenAI-format tool definitions (required for vllm_* parsers).
+ tool_definitions: OpenAI-format tool definitions.
+ stop_sequences: a list of stop sequences to use for stopping generations.
Returns:
A ToolParser instance configured for the specified type.
@@ -343,12 +330,12 @@ def create_tool_parser(
ValueError: If parser_type is unknown.
"""
if parser_type == "legacy":
- return OpenInstructLegacyToolParser(
- tool_actors=tool_actors, output_wrap_name="output", tool_definitions=tool_definitions
- )
+ return OpenInstructLegacyToolParser(tool_definitions=tool_definitions, output_wrap_name="output")
if parser_type == "dr_tulu":
- return DRTuluToolParser(tool_actors)
+ if tool_definitions is None or stop_sequences is None:
+ raise ValueError("dr_tulu parser requires both tool_definitions and stop_sequences")
+ return DRTuluToolParser(tool_definitions=tool_definitions, stop_sequences=stop_sequences)
if parser_type in VLLM_PARSERS:
return create_vllm_parser(parser_type, tokenizer, tool_definitions=tool_definitions)
diff --git a/open_instruct/environments/tools/tests/test_parsers.py b/open_instruct/environments/tools/tests/test_parsers.py
index c65aeb243c..985aa35ca3 100644
--- a/open_instruct/environments/tools/tests/test_parsers.py
+++ b/open_instruct/environments/tools/tests/test_parsers.py
@@ -17,64 +17,31 @@
from open_instruct.utils import import_class_from_string
-class MockTool:
- """Mock tool for testing without ray."""
-
- def __init__(
- self,
- name: str,
- param_name: str = "text",
- required: list[str] | None = None,
- stop_strings: list[str] | None = None,
- ):
- self.call_name = name
- self.param_name = param_name
- self.required = required if required is not None else [param_name]
- self._stop_strings = stop_strings
-
- def get_call_name(self):
- return self.call_name
-
- def get_parameters(self):
- return {"required": self.required, "properties": {self.param_name: {"type": "string"}}}
-
- def get_stop_strings(self):
- if self._stop_strings is not None:
- return self._stop_strings
- raise AttributeError("No stop_strings defined")
-
-
-def create_mock_tool_actor(
- name: str, param_name: str = "text", required: list[str] | None = None, stop_strings: list[str] | None = None
-) -> MagicMock:
- """Create a mock tool actor handle that works with ray.get()."""
- mock_tool = MockTool(name, param_name, required, stop_strings)
- actor_handle = MagicMock()
- actor_handle.get_call_name.remote.return_value = mock_tool.get_call_name()
- actor_handle.get_parameters.remote.return_value = mock_tool.get_parameters()
- actor_handle.get_stop_strings.remote.return_value = stop_strings
-
- return actor_handle
+def make_tool_definition(name: str, param_name: str = "text", required: list[str] | None = None) -> dict:
+ """Create an OpenAI-format tool definition for testing."""
+ if required is None:
+ required = [param_name]
+ return {
+ "type": "function",
+ "function": {
+ "name": name,
+ "description": f"Test tool {name}",
+ "parameters": {
+ "type": "object",
+ "properties": {param_name: {"type": "string", "description": f"The {param_name} parameter"}},
+ "required": required,
+ },
+ },
+ }
class TestOpenInstructLegacyToolParser(unittest.TestCase):
"""Tests for OpenInstructLegacyToolParser."""
- def setUp(self):
- """Set up mock actors for each test."""
- self.patcher = patch("open_instruct.environments.tools.parsers.ray")
- self.mock_ray = self.patcher.start()
- # Make ray.get return the value directly (simulating sync behavior)
- self.mock_ray.get.side_effect = lambda x: x
-
- def tearDown(self):
- """Stop the patcher."""
- self.patcher.stop()
-
def test_single_tool_extraction(self):
"""Test extracting a single tool call."""
- mock_actor = create_mock_tool_actor("search", param_name="query")
- parser = OpenInstructLegacyToolParser([mock_actor])
+ defs = [make_tool_definition("search", param_name="query")]
+ parser = OpenInstructLegacyToolParser(defs)
text = "I need to search for something. python tutorials"
tool_calls = parser.get_tool_calls(text)
@@ -85,9 +52,8 @@ def test_single_tool_extraction(self):
def test_multiple_tools_extraction(self):
"""Test extracting multiple different tool calls."""
- mock_search = create_mock_tool_actor("search", param_name="query")
- mock_code = create_mock_tool_actor("code", param_name="script")
- parser = OpenInstructLegacyToolParser([mock_search, mock_code])
+ defs = [make_tool_definition("search", param_name="query"), make_tool_definition("code", param_name="script")]
+ parser = OpenInstructLegacyToolParser(defs)
text = "First search python then run print('hello')"
tool_calls = parser.get_tool_calls(text)
@@ -98,8 +64,8 @@ def test_multiple_tools_extraction(self):
def test_no_tool_calls(self):
"""Test that no tool calls are returned when none exist."""
- mock_actor = create_mock_tool_actor("search")
- parser = OpenInstructLegacyToolParser([mock_actor])
+ defs = [make_tool_definition("search")]
+ parser = OpenInstructLegacyToolParser(defs)
text = "This is just regular text without any tool calls."
tool_calls = parser.get_tool_calls(text)
@@ -108,8 +74,8 @@ def test_no_tool_calls(self):
def test_multiline_content(self):
"""Test extracting tool calls with multiline content."""
- mock_actor = create_mock_tool_actor("code", param_name="script")
- parser = OpenInstructLegacyToolParser([mock_actor])
+ defs = [make_tool_definition("code", param_name="script")]
+ parser = OpenInstructLegacyToolParser(defs)
code_content = """def hello():
print('Hello, World!')
@@ -124,8 +90,8 @@ def test_multiline_content(self):
def test_partial_tag_not_matched(self):
"""Test that incomplete tags are not matched."""
- mock_actor = create_mock_tool_actor("search")
- parser = OpenInstructLegacyToolParser([mock_actor])
+ defs = [make_tool_definition("search")]
+ parser = OpenInstructLegacyToolParser(defs)
# Missing closing tag
text = "Here's a search query without closing"
@@ -135,8 +101,8 @@ def test_partial_tag_not_matched(self):
def test_nested_content_with_angle_brackets(self):
"""Test content containing angle brackets."""
- mock_actor = create_mock_tool_actor("code", param_name="script")
- parser = OpenInstructLegacyToolParser([mock_actor])
+ defs = [make_tool_definition("code", param_name="script")]
+ parser = OpenInstructLegacyToolParser(defs)
text = "if x > 5 and y < 10: print('yes')"
tool_calls = parser.get_tool_calls(text)
@@ -146,8 +112,8 @@ def test_nested_content_with_angle_brackets(self):
def test_format_tool_outputs_single(self):
"""Test formatting a single tool output."""
- mock_actor = create_mock_tool_actor("search")
- parser = OpenInstructLegacyToolParser([mock_actor])
+ defs = [make_tool_definition("search")]
+ parser = OpenInstructLegacyToolParser(defs)
result = parser.format_tool_outputs(["Search result: Found 5 items"])
expected = "\n"
@@ -155,8 +121,8 @@ def test_format_tool_outputs_single(self):
def test_format_tool_outputs_multiple(self):
"""Test formatting multiple tool outputs."""
- mock_actor = create_mock_tool_actor("search")
- parser = OpenInstructLegacyToolParser([mock_actor])
+ defs = [make_tool_definition("search")]
+ parser = OpenInstructLegacyToolParser(defs)
result = parser.format_tool_outputs(["Result 1", "Result 2"])
expected = "\n\n\n"
@@ -164,8 +130,8 @@ def test_format_tool_outputs_multiple(self):
def test_format_tool_outputs_custom_wrap_name(self):
"""Test formatting with custom output wrap name."""
- mock_actor = create_mock_tool_actor("search")
- parser = OpenInstructLegacyToolParser([mock_actor], output_wrap_name="result")
+ defs = [make_tool_definition("search")]
+ parser = OpenInstructLegacyToolParser(defs, output_wrap_name="result")
result = parser.format_tool_outputs(["Some output"])
expected = "\nSome output\n\n"
@@ -173,9 +139,8 @@ def test_format_tool_outputs_custom_wrap_name(self):
def test_stop_sequences(self):
"""Test that stop sequences are correctly generated."""
- mock_search = create_mock_tool_actor("search")
- mock_code = create_mock_tool_actor("code")
- parser = OpenInstructLegacyToolParser([mock_search, mock_code])
+ defs = [make_tool_definition("search"), make_tool_definition("code")]
+ parser = OpenInstructLegacyToolParser(defs)
stop_seqs = parser.stop_sequences
@@ -185,8 +150,8 @@ def test_stop_sequences(self):
def test_empty_content(self):
"""Test tool call with empty content between tags."""
- mock_actor = create_mock_tool_actor("search", param_name="query")
- parser = OpenInstructLegacyToolParser([mock_actor])
+ defs = [make_tool_definition("search", param_name="query")]
+ parser = OpenInstructLegacyToolParser(defs)
text = "Empty search: "
tool_calls = parser.get_tool_calls(text)
@@ -196,8 +161,8 @@ def test_empty_content(self):
def test_whitespace_only_content(self):
"""Test tool call with whitespace-only content."""
- mock_actor = create_mock_tool_actor("search", param_name="query")
- parser = OpenInstructLegacyToolParser([mock_actor])
+ defs = [make_tool_definition("search", param_name="query")]
+ parser = OpenInstructLegacyToolParser(defs)
text = "Whitespace: \n\t "
tool_calls = parser.get_tool_calls(text)
@@ -207,8 +172,8 @@ def test_whitespace_only_content(self):
def test_tool_without_required_params_uses_first_property(self):
"""Test that tools without required params use first property name."""
- mock_actor = create_mock_tool_actor("search", param_name="query", required=[])
- parser = OpenInstructLegacyToolParser([mock_actor])
+ defs = [make_tool_definition("search", param_name="query", required=[])]
+ parser = OpenInstructLegacyToolParser(defs)
text = "test query"
tool_calls = parser.get_tool_calls(text)
@@ -218,8 +183,8 @@ def test_tool_without_required_params_uses_first_property(self):
def test_multiple_calls_same_tool_extracted(self):
"""Test that all occurrences of the same tool type are extracted."""
- mock_actor = create_mock_tool_actor("search", param_name="query")
- parser = OpenInstructLegacyToolParser([mock_actor])
+ defs = [make_tool_definition("search", param_name="query")]
+ parser = OpenInstructLegacyToolParser(defs)
text = "first query then second query"
tool_calls = parser.get_tool_calls(text)
@@ -230,9 +195,8 @@ def test_multiple_calls_same_tool_extracted(self):
def test_tool_calls_preserve_text_order(self):
"""Test that tool calls are returned in the order they appear in text."""
- mock_search = create_mock_tool_actor("search", param_name="query")
- mock_code = create_mock_tool_actor("code", param_name="script")
- parser = OpenInstructLegacyToolParser([mock_search, mock_code])
+ defs = [make_tool_definition("search", param_name="query"), make_tool_definition("code", param_name="script")]
+ parser = OpenInstructLegacyToolParser(defs)
# Interleaved tool calls: code, search, code
text = "first code then query then second code"
@@ -249,16 +213,23 @@ def test_tool_calls_preserve_text_order(self):
def test_special_regex_characters_in_tool_name(self):
"""Test that tool names with regex special chars are properly escaped."""
# Tool name with characters that have meaning in regex
- mock_actor = create_mock_tool_actor("tool.name", param_name="input")
- parser = OpenInstructLegacyToolParser([mock_actor])
+ defs = [make_tool_definition("tool.name", param_name="input")]
+ parser = OpenInstructLegacyToolParser(defs)
- # Should match literal not
text = "content"
tool_calls = parser.get_tool_calls(text)
self.assertEqual(len(tool_calls), 1)
self.assertEqual(tool_calls[0].args["input"], "content")
+ def test_no_definitions(self):
+ """Test parser with no tool definitions."""
+ parser = OpenInstructLegacyToolParser()
+
+ self.assertEqual(parser.tool_names, [])
+ self.assertEqual(parser.stop_sequences, [])
+ self.assertEqual(parser.get_tool_calls("any text"), [])
+
class TestDRTuluToolParser(unittest.TestCase):
"""Tests for DRTuluToolParser.
@@ -267,20 +238,11 @@ class TestDRTuluToolParser(unittest.TestCase):
It only detects that a tool call occurred (via stop strings) and passes the full text.
"""
- def setUp(self):
- """Set up mock actors for each test."""
- self.patcher = patch("open_instruct.environments.tools.parsers.ray")
- self.mock_ray = self.patcher.start()
- self.mock_ray.get.side_effect = lambda x: x if not isinstance(x, list) else [v for v in x]
-
- def tearDown(self):
- """Stop the patcher."""
- self.patcher.stop()
+ DR_AGENT_DEF = make_tool_definition("dr_agent_mcp")
def test_detects_tool_call_with_stop_string(self):
"""Test that parser detects tool call when stop string is present."""
- mock_actor = create_mock_tool_actor("dr_agent_mcp", stop_strings=[""])
- parser = DRTuluToolParser([mock_actor])
+ parser = DRTuluToolParser([self.DR_AGENT_DEF], stop_sequences=[""])
text = 'python tutorials'
tool_calls = parser.get_tool_calls(text)
@@ -291,8 +253,7 @@ def test_detects_tool_call_with_stop_string(self):
def test_no_tool_call_without_stop_string(self):
"""Test that no tool call is returned when stop string is absent."""
- mock_actor = create_mock_tool_actor("dr_agent_mcp", stop_strings=[""])
- parser = DRTuluToolParser([mock_actor])
+ parser = DRTuluToolParser([self.DR_AGENT_DEF], stop_sequences=[""])
text = "This is just regular text without any tool calls."
tool_calls = parser.get_tool_calls(text)
@@ -301,8 +262,7 @@ def test_no_tool_call_without_stop_string(self):
def test_passes_full_text_to_tool(self):
"""Test that the full text is passed as the argument."""
- mock_actor = create_mock_tool_actor("dr_agent_mcp", stop_strings=[""])
- parser = DRTuluToolParser([mock_actor])
+ parser = DRTuluToolParser([self.DR_AGENT_DEF], stop_sequences=[""])
text = """I need to search
query here"""
@@ -313,8 +273,7 @@ def test_passes_full_text_to_tool(self):
def test_format_tool_outputs_single(self):
"""Test formatting a single tool output."""
- mock_actor = create_mock_tool_actor("dr_agent_mcp", stop_strings=[""])
- parser = DRTuluToolParser([mock_actor])
+ parser = DRTuluToolParser([self.DR_AGENT_DEF], stop_sequences=[""])
result = parser.format_tool_outputs(["Search result: Found 5 items"])
expected = "\nSearch result: Found 5 items\n\n"
@@ -322,40 +281,41 @@ def test_format_tool_outputs_single(self):
def test_format_tool_outputs_multiple(self):
"""Test formatting multiple tool outputs."""
- mock_actor = create_mock_tool_actor("dr_agent_mcp", stop_strings=[""])
- parser = DRTuluToolParser([mock_actor])
+ parser = DRTuluToolParser([self.DR_AGENT_DEF], stop_sequences=[""])
result = parser.format_tool_outputs(["Result 1", "Result 2"])
expected = "\nResult 1\n\n\n\nResult 2\n\n"
self.assertEqual(result, expected)
- def test_stop_sequences_default(self):
- """Test that empty list is used when tools don't provide stop strings."""
- mock_actor = create_mock_tool_actor("dr_agent_mcp")
- parser = DRTuluToolParser([mock_actor])
+ def test_stop_sequences_empty(self):
+ """Test that empty list is used when no stop sequences provided."""
+ parser = DRTuluToolParser([self.DR_AGENT_DEF], stop_sequences=[])
self.assertEqual(parser.stop_sequences, [])
- def test_stop_sequences_from_tools(self):
- """Test that stop sequences are collected from tools that provide them."""
- mock_actor = create_mock_tool_actor("dr_agent_mcp", stop_strings=["", ""])
- parser = DRTuluToolParser([mock_actor])
+ def test_stop_sequences_from_init(self):
+ """Test that stop sequences are set from init parameter."""
+ parser = DRTuluToolParser([self.DR_AGENT_DEF], stop_sequences=["", ""])
+
+ self.assertEqual(parser.stop_sequences, ["", ""])
+
+ def test_stop_sequences_deduplicated(self):
+ """Test that duplicate stop sequences are removed."""
+ parser = DRTuluToolParser([self.DR_AGENT_DEF], stop_sequences=["", "", ""])
self.assertEqual(parser.stop_sequences, ["", ""])
def test_rejects_multiple_tools(self):
"""Test that parser rejects multiple tools."""
- mock_actor1 = create_mock_tool_actor("dr_agent_mcp", stop_strings=[""])
- mock_actor2 = create_mock_tool_actor("dr_agent_mcp", stop_strings=[""])
+ defs = [self.DR_AGENT_DEF, make_tool_definition("other_tool")]
with self.assertRaises(ValueError) as context:
- DRTuluToolParser([mock_actor1, mock_actor2])
+ DRTuluToolParser(defs, stop_sequences=[""])
self.assertIn("exactly one tool", str(context.exception))
def test_uses_tool_call_name(self):
"""Test that parser uses the tool's call name for routing."""
- mock_actor = create_mock_tool_actor("dr_agent_mcp", stop_strings=[""])
- parser = DRTuluToolParser([mock_actor])
+ parser = DRTuluToolParser([self.DR_AGENT_DEF], stop_sequences=[""])
self.assertEqual(parser.tool_call_name, "dr_agent_mcp")
@@ -365,10 +325,10 @@ def test_uses_tool_call_name(self):
def test_rejects_wrong_tool(self):
"""Test that parser rejects tools that aren't dr_agent_mcp."""
- mock_actor = create_mock_tool_actor("python", stop_strings=[""])
+ defs = [make_tool_definition("python")]
with self.assertRaises(ValueError) as context:
- DRTuluToolParser([mock_actor])
+ DRTuluToolParser(defs, stop_sequences=[""])
self.assertIn("dr_agent_mcp", str(context.exception))
@@ -404,20 +364,16 @@ class TestVllmParserRegistry(unittest.TestCase):
@parameterized.expand(VLLM_PARSERS.items())
def test_vllm_parser_config(self, name, config):
"""Test that a registered vLLM parser has valid configuration."""
- # Check config type
self.assertIsInstance(config, VllmParserConfig)
- # Verify import_path resolves to a callable class
self.assertTrue(config.import_path, "missing import_path")
parser_cls = import_class_from_string(config.import_path)
self.assertTrue(callable(parser_cls))
- # Verify output_template is usable with .format()
self.assertTrue(config.output_template, "missing output_template")
formatted = config.output_template.format("test_output")
self.assertIn("test_output", formatted)
- # Check stop_sequences is a sized iterable (list, tuple, set, etc.)
self.assertGreaterEqual(len(config.stop_sequences), 0, "stop_sequences must be a sized iterable")
@@ -426,7 +382,6 @@ class TestVllmToolParser(unittest.TestCase):
def test_format_tool_outputs_single(self):
"""Test formatting a single tool output."""
- # Create a mock native parser (we don't need it for format tests)
mock_native = MagicMock()
parser = VllmToolParser(
tool_parser=mock_native,
@@ -471,50 +426,40 @@ def test_stop_sequences_custom(self):
class TestCreateToolParser(unittest.TestCase):
"""Tests for create_tool_parser factory function."""
- def setUp(self):
- """Set up mock actors for each test."""
- self.patcher = patch("open_instruct.environments.tools.parsers.ray")
- self.mock_ray = self.patcher.start()
- self.mock_ray.get.side_effect = lambda x: x
- self.mock_ray.exceptions.RayActorError = Exception
-
- def tearDown(self):
- """Stop the patcher."""
- self.patcher.stop()
-
def test_create_legacy_parser(self):
"""Test creating legacy parser."""
- mock_actor = create_mock_tool_actor("search")
mock_tokenizer = MagicMock()
+ defs = [make_tool_definition("search")]
- parser = create_tool_parser("legacy", tokenizer=mock_tokenizer, tool_actors=[mock_actor])
+ parser = create_tool_parser("legacy", tokenizer=mock_tokenizer, tool_definitions=defs)
self.assertIsInstance(parser, OpenInstructLegacyToolParser)
def test_create_dr_tulu_parser(self):
"""Test creating dr_tulu parser."""
- mock_actor = create_mock_tool_actor("dr_agent_mcp")
mock_tokenizer = MagicMock()
+ defs = [make_tool_definition("dr_agent_mcp")]
- parser = create_tool_parser("dr_tulu", tokenizer=mock_tokenizer, tool_actors=[mock_actor])
+ parser = create_tool_parser(
+ "dr_tulu", tokenizer=mock_tokenizer, tool_definitions=defs, stop_sequences=[""]
+ )
self.assertIsInstance(parser, DRTuluToolParser)
@parameterized.expand([(p,) for p in VLLM_PARSERS])
def test_create_vllm_parser(self, parser_type):
"""Test creating vLLM parsers."""
- mock_actor = create_mock_tool_actor("search")
mock_tokenizer = MagicMock()
+ defs = [make_tool_definition("search")]
with patch("open_instruct.environments.tools.parsers.import_class_from_string") as mock_import:
mock_import.return_value = MagicMock()
- parser = create_tool_parser(parser_type, tokenizer=mock_tokenizer, tool_actors=[mock_actor])
+ parser = create_tool_parser(parser_type, tokenizer=mock_tokenizer, tool_definitions=defs)
self.assertIsInstance(parser, VllmToolParser)
def test_unknown_parser_raises_error(self):
"""Test that unknown parser types raise an error."""
- mock_actor = create_mock_tool_actor("search")
mock_tokenizer = MagicMock()
with self.assertRaises(ValueError) as context:
- create_tool_parser("unknown_parser", tokenizer=mock_tokenizer, tool_actors=[mock_actor])
+ create_tool_parser("unknown_parser", tokenizer=mock_tokenizer)
self.assertIn("Unknown parser type", str(context.exception))
self.assertIn("Available:", str(context.exception))
diff --git a/open_instruct/environments/tools/utils.py b/open_instruct/environments/tools/utils.py
index 560e477d96..527a69fb86 100644
--- a/open_instruct/environments/tools/utils.py
+++ b/open_instruct/environments/tools/utils.py
@@ -422,5 +422,9 @@ def get_parameters(self) -> dict[str, Any]:
"""Get the tool's parameter schema."""
return self.parameters
+ def get_stop_strings(self) -> list[str]:
+ """Get stop strings for this tool. Override in subclasses that define custom stop sequences."""
+ return []
+
def get_tool_definitions(self) -> list[dict[str, Any]]:
return [get_openai_tool_definitions(self)]
diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py
index 3252b24faf..a6764796e3 100644
--- a/open_instruct/grpo_fast.py
+++ b/open_instruct/grpo_fast.py
@@ -1237,6 +1237,7 @@ def create_model_and_optimizer(
tools_config: EnvsConfig | None = None,
base_env_config: dict | None = None,
pools: dict[str, ray.actor.ActorHandle] | None = None,
+ tool_stop_sequences: list[str] | None = None,
) -> tuple[
ModelGroup, list[vllm_utils.LLMRayActor], int, int, ray.actor.ActorHandle, utils.ModelDims, ray.actor.ActorHandle
]:
@@ -1318,6 +1319,7 @@ def create_model_and_optimizer(
pg=pg if args.single_gpu_mode else None,
tool_parser_type=tools_config.tool_parser_type if tools_config else "legacy",
tool_definitions=tool_definitions,
+ tool_stop_sequences=tool_stop_sequences,
max_steps=tools_config.max_steps if tools_config else 5,
mask_tool_use=streaming_config.mask_tool_use,
pools=pools,
@@ -2093,21 +2095,37 @@ def initialize_tools_and_envs(
tools_config.tool_call_names = tool_call_names
- # Collect tool definitions from all pools
+ # Collect tool definitions and stop strings from all pools (batched for parallelism)
+ acquire_refs = [pool.acquire.remote() for pool in pools.values()]
+ actors = ray.get(acquire_refs)
+
+ def_refs = [actor.get_tool_definitions.remote() for actor in actors]
+ stop_refs = (
+ [actor.get_stop_strings.remote() for actor in actors] if tools_config.tool_parser_type == "dr_tulu" else []
+ )
+ all_results = ray.get(def_refs + stop_refs)
+ def_results = all_results[: len(def_refs)]
+ stop_results = all_results[len(def_refs) :]
+
tool_definitions: list[dict[str, Any]] = []
- for pool in pools.values():
- actor = ray.get(pool.acquire.remote())
- defs = ray.get(actor.get_tool_definitions.remote())
- pool.release.remote(actor)
+ for defs in def_results:
tool_definitions.extend(defs)
+ tool_stop_strings: list[str] = []
+ for stop_strings in stop_results:
+ if stop_strings:
+ tool_stop_strings.extend(stop_strings)
+
+ for pool, actor in zip(pools.values(), actors):
+ pool.release.remote(actor)
+
stop_sequences: list[str] = []
if pools:
stop_sequences = create_tool_parser(
parser_type=tools_config.tool_parser_type,
- tool_actors=[],
tokenizer=tokenizer,
tool_definitions=tool_definitions,
+ stop_sequences=tool_stop_strings,
).stop_sequences
logger.info(
@@ -2152,12 +2170,6 @@ def main(
dataset_mixer_list=streaming_config.dataset_mixer_list,
dataset_mixer_list_splits=streaming_config.dataset_mixer_list_splits,
)
- # TODO: Refactor DRTuluToolParser to work with tool_definitions instead of tool_actors.
- if pools and tools_config.tool_parser_type == "dr_tulu":
- raise ValueError(
- "Parser type 'dr_tulu' requires tool_actors which are no longer created (pools are used instead). "
- "Use --tool_parser_type legacy or --tool_parser_type vllm_hermes."
- )
if tool_stop_sequences:
logger.info(f"Adding tool stop sequences to config: {tool_stop_sequences}")
streaming_config.stop_strings.extend(tool_stop_sequences)
@@ -2247,6 +2259,7 @@ def main(
tools_config,
base_env_config,
pools,
+ tool_stop_sequences,
)
)
diff --git a/open_instruct/vllm_utils.py b/open_instruct/vllm_utils.py
index a87306c2e5..0e12cb72b1 100644
--- a/open_instruct/vllm_utils.py
+++ b/open_instruct/vllm_utils.py
@@ -556,6 +556,7 @@ def __init__(
*args,
tool_parser_type: str = "legacy",
tool_definitions: list[dict] | None = None,
+ tool_stop_sequences: list[str] | None = None,
max_steps: int = 5,
mask_tool_use: bool = True,
pools: dict[str, ray.actor.ActorHandle] | None = None,
@@ -572,6 +573,7 @@ def __init__(
):
assert_threaded_actor(self)
self._tool_definitions = tool_definitions
+ self._tool_stop_sequences = tool_stop_sequences
self._init_config(
max_steps, mask_tool_use, pools, inflight_updates, reward_config, train_dataset, eval_dataset
)
@@ -636,9 +638,9 @@ def _init_executor(self) -> None:
def _init_tool_parser(self, tool_parser_type: str) -> None:
self.tool_parser = create_tool_parser(
parser_type=tool_parser_type,
- tool_actors=[],
tokenizer=self.llm_engine.tokenizer,
tool_definitions=self._tool_definitions,
+ stop_sequences=self._tool_stop_sequences,
)
def _setup_gpu_visibility(self, noset_visible_devices: bool, distributed_executor_backend: str) -> None:
@@ -1084,6 +1086,7 @@ def create_vllm_engines(
pg: PlacementGroup | None = None,
tool_parser_type: str = "legacy",
tool_definitions: list[dict] | None = None,
+ tool_stop_sequences: list[str] | None = None,
max_steps: int = 5,
mask_tool_use: bool = True,
pools: dict[str, ray.actor.ActorHandle] | None = None,
@@ -1171,6 +1174,7 @@ def create_vllm_engines(
actor_manager=actor_manager,
tool_parser_type=tool_parser_type,
tool_definitions=tool_definitions,
+ tool_stop_sequences=tool_stop_sequences,
max_steps=max_steps,
mask_tool_use=mask_tool_use,
pools=pools,
diff --git a/pyproject.toml b/pyproject.toml
index bb166b1871..a6f7d302ab 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -81,6 +81,8 @@ code = [
]
dr-tulu = [
"dr_agent",
+ "authlib",
+ "scipy",
]
[tool.uv]
diff --git a/scripts/train/debug/tools/dr_tulu_parser_debug.sh b/scripts/train/debug/tools/dr_tulu_parser_debug.sh
index c374b3bbbf..278564c1ef 100755
--- a/scripts/train/debug/tools/dr_tulu_parser_debug.sh
+++ b/scripts/train/debug/tools/dr_tulu_parser_debug.sh
@@ -64,7 +64,7 @@ VLLM_ALLOW_INSECURE_SERIALIZATION=1 uv run --extra dr-tulu open_instruct/grpo_fa
--gradient_checkpointing \
--system_prompt_override_file scripts/train/debug/tools/dr_tulu_system_prompt.txt \
--tools dr_agent_mcp \
- --tool_parser dr_tulu \
+ --tool_parser_type dr_tulu \
--tool_configs '{"tool_names": "snippet_search,google_search,browse_webpage", "parser_name": "v20250824", "host": "'"$MCP_HOST"'", "port": '"$MCP_PORT"'}' \
--max_steps 5 \
--pass_tools_to_chat_template false \
diff --git a/uv.lock b/uv.lock
index 1ae8e0598b..1ada6e5270 100644
--- a/uv.lock
+++ b/uv.lock
@@ -2687,7 +2687,9 @@ code = [
{ name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
]
dr-tulu = [
+ { name = "authlib", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
{ name = "dr-agent", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
+ { name = "scipy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
]
[package.dev-dependencies]
@@ -2708,6 +2710,7 @@ requires-dist = [
{ name = "accelerate", specifier = ">=1.10.1" },
{ name = "ai2-olmo-core", specifier = "==2.3.0" },
{ name = "antlr4-python3-runtime", specifier = "==4.11" },
+ { name = "authlib", marker = "extra == 'dr-tulu'" },
{ name = "backoff", specifier = ">=2.2.1" },
{ name = "bitsandbytes", marker = "sys_platform != 'darwin'", specifier = ">=0.44.1" },
{ name = "datasets", specifier = ">=4.0.0" },
@@ -2733,6 +2736,7 @@ requires-dist = [
{ name = "pydantic", marker = "extra == 'code'", specifier = ">=2.0.0" },
{ name = "ray", extras = ["default"], specifier = ">=2.49.2" },
{ name = "requests", marker = "extra == 'code'", specifier = ">=2.28.0" },
+ { name = "scipy", marker = "extra == 'dr-tulu'" },
{ name = "setuptools", specifier = ">=75.6.0,<80.0.0" },
{ name = "tensorboard", specifier = ">=2.18.0" },
{ name = "torch", marker = "sys_platform != 'linux'", specifier = ">=2.9.0,<2.10" },