diff --git a/CHANGELOG.md b/CHANGELOG.md index 8222f89cf5..14c7b34671 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ All notable changes to this project will be documented in this file. - Documentation and runtime warning for `dataset_mixer_list` format (float=proportion, int=count) (https://github.com/allenai/open-instruct/pull/1434). ### Changed +- Refactor Legacy and DRTulu tool parsers to use OpenAI-format `tool_definitions` instead of Ray `tool_actors`. Removes `import ray` from `parsers.py`, fixes DRTulu parser which was broken after the pool refactor, and fixes `--tool_parser_type` typo in dr_tulu debug script (https://github.com/allenai/open-instruct/pull/1491). - Replaces lambda collators with a "single_example_collator" (https://github.com/allenai/open-instruct/pull/1472). - Clarified `activation_memory_budget` guidance in DPO utils with a practical default (`0.5`) and memory/speed tradeoff notes (https://github.com/allenai/open-instruct/pull/1460). - Let TransformerTrainModule handle FSDP parallelism instead of manual application in DPO (https://github.com/allenai/open-instruct/pull/1458). diff --git a/open_instruct/environments/tools/parsers.py b/open_instruct/environments/tools/parsers.py index ec000968be..7e6e8f60ed 100644 --- a/open_instruct/environments/tools/parsers.py +++ b/open_instruct/environments/tools/parsers.py @@ -18,7 +18,6 @@ from dataclasses import dataclass, field from typing import Any -import ray from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.tool_parsers import ToolParser as VllmNativeToolParser @@ -55,14 +54,11 @@ class OpenInstructLegacyToolParser(ToolParser): Tools are invoked via content tags. The content between tags is passed to the tool's first required parameter. Only works for tools that take a single string parameter. + + Tool names and parameter names are derived from OpenAI-format tool definitions. """ - def __init__( - self, - tool_actors: list[ray.actor.ActorHandle] | None = None, - output_wrap_name: str = "output", - tool_definitions: list[dict[str, Any]] | None = None, - ): + def __init__(self, tool_definitions: list[dict[str, Any]] | None = None, output_wrap_name: str = "output"): self.output_wrap_name = output_wrap_name if tool_definitions: @@ -78,17 +74,6 @@ def __init__( else: properties = params.get("properties", {}) self.tool_param_names[name] = next(iter(properties)) if properties else "text" - elif tool_actors: - self.tool_names = [ray.get(actor.get_call_name.remote()) for actor in tool_actors] - self.tool_param_names = {} - for actor, tool_name in zip(tool_actors, self.tool_names): - params = ray.get(actor.get_parameters.remote()) - required = params.get("required", []) - if required: - self.tool_param_names[tool_name] = required[0] - else: - properties = params.get("properties", {}) - self.tool_param_names[tool_name] = next(iter(properties)) if properties else "text" else: self.tool_names = [] self.tool_param_names = {} @@ -288,21 +273,23 @@ class DRTuluToolParser(ToolParser): """ Parser for DR Tulu style tool calls. Delegates actual parsing to the tool itself. Only detects that a tool call occurred (via stop strings) and passes text to the tool. + + Requires exactly one tool (dr_agent_mcp) in tool_definitions. """ - def __init__(self, tool_actors: list[ray.actor.ActorHandle]): - if len(tool_actors) != 1: - raise ValueError(f"DRTuluToolParser requires exactly one tool (dr_agent_mcp), got {len(tool_actors)}") + def __init__(self, tool_definitions: list[dict[str, Any]], stop_sequences: list[str]): + if len(tool_definitions) != 1: + raise ValueError(f"DRTuluToolParser requires exactly one tool (dr_agent_mcp), got {len(tool_definitions)}") - actor = tool_actors[0] - self.tool_call_name = ray.get(actor.get_call_name.remote()) + self.tool_call_name = tool_definitions[0]["function"]["name"] if self.tool_call_name != "dr_agent_mcp": raise ValueError(f"DRTuluToolParser requires dr_agent_mcp tool, got {self.tool_call_name}") - stop_strings = ray.get(actor.get_stop_strings.remote()) - # Use dict.fromkeys to deduplicate while preserving order - self.stop_sequences = list(dict.fromkeys(stop_strings)) if stop_strings else [] + self.stop_sequences = list(dict.fromkeys(stop_sequences)) if stop_sequences else [] + + if not self.stop_sequences: + logger.warning("DRTuluToolParser initialized with no stop sequences — tool calls will never be detected") def get_tool_calls(self, text: str) -> list[EnvCall]: for stop in self.stop_sequences: @@ -322,8 +309,8 @@ def get_available_parsers() -> list[str]: def create_tool_parser( parser_type: str, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, - tool_actors: list[ray.actor.ActorHandle], tool_definitions: list[dict[str, Any]] | None = None, + stop_sequences: list[str] | None = None, ) -> ToolParser: """Create a tool parser instance by type. @@ -333,8 +320,8 @@ def create_tool_parser( - "dr_tulu": DRTuluToolParser for content format - "vllm_*": VllmToolParser variants (vllm_hermes, vllm_llama3_json, vllm_olmo3) tokenizer: Tokenizer for the model (required for all parser types). - tool_actors: List of Ray actor handles for the tools. - tool_definitions: OpenAI-format tool definitions (required for vllm_* parsers). + tool_definitions: OpenAI-format tool definitions. + stop_sequences: a list of stop sequences to use for stopping generations. Returns: A ToolParser instance configured for the specified type. @@ -343,12 +330,12 @@ def create_tool_parser( ValueError: If parser_type is unknown. """ if parser_type == "legacy": - return OpenInstructLegacyToolParser( - tool_actors=tool_actors, output_wrap_name="output", tool_definitions=tool_definitions - ) + return OpenInstructLegacyToolParser(tool_definitions=tool_definitions, output_wrap_name="output") if parser_type == "dr_tulu": - return DRTuluToolParser(tool_actors) + if tool_definitions is None or stop_sequences is None: + raise ValueError("dr_tulu parser requires both tool_definitions and stop_sequences") + return DRTuluToolParser(tool_definitions=tool_definitions, stop_sequences=stop_sequences) if parser_type in VLLM_PARSERS: return create_vllm_parser(parser_type, tokenizer, tool_definitions=tool_definitions) diff --git a/open_instruct/environments/tools/tests/test_parsers.py b/open_instruct/environments/tools/tests/test_parsers.py index c65aeb243c..985aa35ca3 100644 --- a/open_instruct/environments/tools/tests/test_parsers.py +++ b/open_instruct/environments/tools/tests/test_parsers.py @@ -17,64 +17,31 @@ from open_instruct.utils import import_class_from_string -class MockTool: - """Mock tool for testing without ray.""" - - def __init__( - self, - name: str, - param_name: str = "text", - required: list[str] | None = None, - stop_strings: list[str] | None = None, - ): - self.call_name = name - self.param_name = param_name - self.required = required if required is not None else [param_name] - self._stop_strings = stop_strings - - def get_call_name(self): - return self.call_name - - def get_parameters(self): - return {"required": self.required, "properties": {self.param_name: {"type": "string"}}} - - def get_stop_strings(self): - if self._stop_strings is not None: - return self._stop_strings - raise AttributeError("No stop_strings defined") - - -def create_mock_tool_actor( - name: str, param_name: str = "text", required: list[str] | None = None, stop_strings: list[str] | None = None -) -> MagicMock: - """Create a mock tool actor handle that works with ray.get().""" - mock_tool = MockTool(name, param_name, required, stop_strings) - actor_handle = MagicMock() - actor_handle.get_call_name.remote.return_value = mock_tool.get_call_name() - actor_handle.get_parameters.remote.return_value = mock_tool.get_parameters() - actor_handle.get_stop_strings.remote.return_value = stop_strings - - return actor_handle +def make_tool_definition(name: str, param_name: str = "text", required: list[str] | None = None) -> dict: + """Create an OpenAI-format tool definition for testing.""" + if required is None: + required = [param_name] + return { + "type": "function", + "function": { + "name": name, + "description": f"Test tool {name}", + "parameters": { + "type": "object", + "properties": {param_name: {"type": "string", "description": f"The {param_name} parameter"}}, + "required": required, + }, + }, + } class TestOpenInstructLegacyToolParser(unittest.TestCase): """Tests for OpenInstructLegacyToolParser.""" - def setUp(self): - """Set up mock actors for each test.""" - self.patcher = patch("open_instruct.environments.tools.parsers.ray") - self.mock_ray = self.patcher.start() - # Make ray.get return the value directly (simulating sync behavior) - self.mock_ray.get.side_effect = lambda x: x - - def tearDown(self): - """Stop the patcher.""" - self.patcher.stop() - def test_single_tool_extraction(self): """Test extracting a single tool call.""" - mock_actor = create_mock_tool_actor("search", param_name="query") - parser = OpenInstructLegacyToolParser([mock_actor]) + defs = [make_tool_definition("search", param_name="query")] + parser = OpenInstructLegacyToolParser(defs) text = "I need to search for something. python tutorials" tool_calls = parser.get_tool_calls(text) @@ -85,9 +52,8 @@ def test_single_tool_extraction(self): def test_multiple_tools_extraction(self): """Test extracting multiple different tool calls.""" - mock_search = create_mock_tool_actor("search", param_name="query") - mock_code = create_mock_tool_actor("code", param_name="script") - parser = OpenInstructLegacyToolParser([mock_search, mock_code]) + defs = [make_tool_definition("search", param_name="query"), make_tool_definition("code", param_name="script")] + parser = OpenInstructLegacyToolParser(defs) text = "First search python then run print('hello')" tool_calls = parser.get_tool_calls(text) @@ -98,8 +64,8 @@ def test_multiple_tools_extraction(self): def test_no_tool_calls(self): """Test that no tool calls are returned when none exist.""" - mock_actor = create_mock_tool_actor("search") - parser = OpenInstructLegacyToolParser([mock_actor]) + defs = [make_tool_definition("search")] + parser = OpenInstructLegacyToolParser(defs) text = "This is just regular text without any tool calls." tool_calls = parser.get_tool_calls(text) @@ -108,8 +74,8 @@ def test_no_tool_calls(self): def test_multiline_content(self): """Test extracting tool calls with multiline content.""" - mock_actor = create_mock_tool_actor("code", param_name="script") - parser = OpenInstructLegacyToolParser([mock_actor]) + defs = [make_tool_definition("code", param_name="script")] + parser = OpenInstructLegacyToolParser(defs) code_content = """def hello(): print('Hello, World!') @@ -124,8 +90,8 @@ def test_multiline_content(self): def test_partial_tag_not_matched(self): """Test that incomplete tags are not matched.""" - mock_actor = create_mock_tool_actor("search") - parser = OpenInstructLegacyToolParser([mock_actor]) + defs = [make_tool_definition("search")] + parser = OpenInstructLegacyToolParser(defs) # Missing closing tag text = "Here's a search query without closing" @@ -135,8 +101,8 @@ def test_partial_tag_not_matched(self): def test_nested_content_with_angle_brackets(self): """Test content containing angle brackets.""" - mock_actor = create_mock_tool_actor("code", param_name="script") - parser = OpenInstructLegacyToolParser([mock_actor]) + defs = [make_tool_definition("code", param_name="script")] + parser = OpenInstructLegacyToolParser(defs) text = "if x > 5 and y < 10: print('yes')" tool_calls = parser.get_tool_calls(text) @@ -146,8 +112,8 @@ def test_nested_content_with_angle_brackets(self): def test_format_tool_outputs_single(self): """Test formatting a single tool output.""" - mock_actor = create_mock_tool_actor("search") - parser = OpenInstructLegacyToolParser([mock_actor]) + defs = [make_tool_definition("search")] + parser = OpenInstructLegacyToolParser(defs) result = parser.format_tool_outputs(["Search result: Found 5 items"]) expected = "\nSearch result: Found 5 items\n\n" @@ -155,8 +121,8 @@ def test_format_tool_outputs_single(self): def test_format_tool_outputs_multiple(self): """Test formatting multiple tool outputs.""" - mock_actor = create_mock_tool_actor("search") - parser = OpenInstructLegacyToolParser([mock_actor]) + defs = [make_tool_definition("search")] + parser = OpenInstructLegacyToolParser(defs) result = parser.format_tool_outputs(["Result 1", "Result 2"]) expected = "\nResult 1\n\n\n\nResult 2\n\n" @@ -164,8 +130,8 @@ def test_format_tool_outputs_multiple(self): def test_format_tool_outputs_custom_wrap_name(self): """Test formatting with custom output wrap name.""" - mock_actor = create_mock_tool_actor("search") - parser = OpenInstructLegacyToolParser([mock_actor], output_wrap_name="result") + defs = [make_tool_definition("search")] + parser = OpenInstructLegacyToolParser(defs, output_wrap_name="result") result = parser.format_tool_outputs(["Some output"]) expected = "\nSome output\n\n" @@ -173,9 +139,8 @@ def test_format_tool_outputs_custom_wrap_name(self): def test_stop_sequences(self): """Test that stop sequences are correctly generated.""" - mock_search = create_mock_tool_actor("search") - mock_code = create_mock_tool_actor("code") - parser = OpenInstructLegacyToolParser([mock_search, mock_code]) + defs = [make_tool_definition("search"), make_tool_definition("code")] + parser = OpenInstructLegacyToolParser(defs) stop_seqs = parser.stop_sequences @@ -185,8 +150,8 @@ def test_stop_sequences(self): def test_empty_content(self): """Test tool call with empty content between tags.""" - mock_actor = create_mock_tool_actor("search", param_name="query") - parser = OpenInstructLegacyToolParser([mock_actor]) + defs = [make_tool_definition("search", param_name="query")] + parser = OpenInstructLegacyToolParser(defs) text = "Empty search: " tool_calls = parser.get_tool_calls(text) @@ -196,8 +161,8 @@ def test_empty_content(self): def test_whitespace_only_content(self): """Test tool call with whitespace-only content.""" - mock_actor = create_mock_tool_actor("search", param_name="query") - parser = OpenInstructLegacyToolParser([mock_actor]) + defs = [make_tool_definition("search", param_name="query")] + parser = OpenInstructLegacyToolParser(defs) text = "Whitespace: \n\t " tool_calls = parser.get_tool_calls(text) @@ -207,8 +172,8 @@ def test_whitespace_only_content(self): def test_tool_without_required_params_uses_first_property(self): """Test that tools without required params use first property name.""" - mock_actor = create_mock_tool_actor("search", param_name="query", required=[]) - parser = OpenInstructLegacyToolParser([mock_actor]) + defs = [make_tool_definition("search", param_name="query", required=[])] + parser = OpenInstructLegacyToolParser(defs) text = "test query" tool_calls = parser.get_tool_calls(text) @@ -218,8 +183,8 @@ def test_tool_without_required_params_uses_first_property(self): def test_multiple_calls_same_tool_extracted(self): """Test that all occurrences of the same tool type are extracted.""" - mock_actor = create_mock_tool_actor("search", param_name="query") - parser = OpenInstructLegacyToolParser([mock_actor]) + defs = [make_tool_definition("search", param_name="query")] + parser = OpenInstructLegacyToolParser(defs) text = "first query then second query" tool_calls = parser.get_tool_calls(text) @@ -230,9 +195,8 @@ def test_multiple_calls_same_tool_extracted(self): def test_tool_calls_preserve_text_order(self): """Test that tool calls are returned in the order they appear in text.""" - mock_search = create_mock_tool_actor("search", param_name="query") - mock_code = create_mock_tool_actor("code", param_name="script") - parser = OpenInstructLegacyToolParser([mock_search, mock_code]) + defs = [make_tool_definition("search", param_name="query"), make_tool_definition("code", param_name="script")] + parser = OpenInstructLegacyToolParser(defs) # Interleaved tool calls: code, search, code text = "first code then query then second code" @@ -249,16 +213,23 @@ def test_tool_calls_preserve_text_order(self): def test_special_regex_characters_in_tool_name(self): """Test that tool names with regex special chars are properly escaped.""" # Tool name with characters that have meaning in regex - mock_actor = create_mock_tool_actor("tool.name", param_name="input") - parser = OpenInstructLegacyToolParser([mock_actor]) + defs = [make_tool_definition("tool.name", param_name="input")] + parser = OpenInstructLegacyToolParser(defs) - # Should match literal not text = "content" tool_calls = parser.get_tool_calls(text) self.assertEqual(len(tool_calls), 1) self.assertEqual(tool_calls[0].args["input"], "content") + def test_no_definitions(self): + """Test parser with no tool definitions.""" + parser = OpenInstructLegacyToolParser() + + self.assertEqual(parser.tool_names, []) + self.assertEqual(parser.stop_sequences, []) + self.assertEqual(parser.get_tool_calls("any text"), []) + class TestDRTuluToolParser(unittest.TestCase): """Tests for DRTuluToolParser. @@ -267,20 +238,11 @@ class TestDRTuluToolParser(unittest.TestCase): It only detects that a tool call occurred (via stop strings) and passes the full text. """ - def setUp(self): - """Set up mock actors for each test.""" - self.patcher = patch("open_instruct.environments.tools.parsers.ray") - self.mock_ray = self.patcher.start() - self.mock_ray.get.side_effect = lambda x: x if not isinstance(x, list) else [v for v in x] - - def tearDown(self): - """Stop the patcher.""" - self.patcher.stop() + DR_AGENT_DEF = make_tool_definition("dr_agent_mcp") def test_detects_tool_call_with_stop_string(self): """Test that parser detects tool call when stop string is present.""" - mock_actor = create_mock_tool_actor("dr_agent_mcp", stop_strings=[""]) - parser = DRTuluToolParser([mock_actor]) + parser = DRTuluToolParser([self.DR_AGENT_DEF], stop_sequences=[""]) text = 'python tutorials' tool_calls = parser.get_tool_calls(text) @@ -291,8 +253,7 @@ def test_detects_tool_call_with_stop_string(self): def test_no_tool_call_without_stop_string(self): """Test that no tool call is returned when stop string is absent.""" - mock_actor = create_mock_tool_actor("dr_agent_mcp", stop_strings=[""]) - parser = DRTuluToolParser([mock_actor]) + parser = DRTuluToolParser([self.DR_AGENT_DEF], stop_sequences=[""]) text = "This is just regular text without any tool calls." tool_calls = parser.get_tool_calls(text) @@ -301,8 +262,7 @@ def test_no_tool_call_without_stop_string(self): def test_passes_full_text_to_tool(self): """Test that the full text is passed as the argument.""" - mock_actor = create_mock_tool_actor("dr_agent_mcp", stop_strings=[""]) - parser = DRTuluToolParser([mock_actor]) + parser = DRTuluToolParser([self.DR_AGENT_DEF], stop_sequences=[""]) text = """I need to search query here""" @@ -313,8 +273,7 @@ def test_passes_full_text_to_tool(self): def test_format_tool_outputs_single(self): """Test formatting a single tool output.""" - mock_actor = create_mock_tool_actor("dr_agent_mcp", stop_strings=[""]) - parser = DRTuluToolParser([mock_actor]) + parser = DRTuluToolParser([self.DR_AGENT_DEF], stop_sequences=[""]) result = parser.format_tool_outputs(["Search result: Found 5 items"]) expected = "\nSearch result: Found 5 items\n\n" @@ -322,40 +281,41 @@ def test_format_tool_outputs_single(self): def test_format_tool_outputs_multiple(self): """Test formatting multiple tool outputs.""" - mock_actor = create_mock_tool_actor("dr_agent_mcp", stop_strings=[""]) - parser = DRTuluToolParser([mock_actor]) + parser = DRTuluToolParser([self.DR_AGENT_DEF], stop_sequences=[""]) result = parser.format_tool_outputs(["Result 1", "Result 2"]) expected = "\nResult 1\n\n\n\nResult 2\n\n" self.assertEqual(result, expected) - def test_stop_sequences_default(self): - """Test that empty list is used when tools don't provide stop strings.""" - mock_actor = create_mock_tool_actor("dr_agent_mcp") - parser = DRTuluToolParser([mock_actor]) + def test_stop_sequences_empty(self): + """Test that empty list is used when no stop sequences provided.""" + parser = DRTuluToolParser([self.DR_AGENT_DEF], stop_sequences=[]) self.assertEqual(parser.stop_sequences, []) - def test_stop_sequences_from_tools(self): - """Test that stop sequences are collected from tools that provide them.""" - mock_actor = create_mock_tool_actor("dr_agent_mcp", stop_strings=["", ""]) - parser = DRTuluToolParser([mock_actor]) + def test_stop_sequences_from_init(self): + """Test that stop sequences are set from init parameter.""" + parser = DRTuluToolParser([self.DR_AGENT_DEF], stop_sequences=["", ""]) + + self.assertEqual(parser.stop_sequences, ["", ""]) + + def test_stop_sequences_deduplicated(self): + """Test that duplicate stop sequences are removed.""" + parser = DRTuluToolParser([self.DR_AGENT_DEF], stop_sequences=["", "", ""]) self.assertEqual(parser.stop_sequences, ["", ""]) def test_rejects_multiple_tools(self): """Test that parser rejects multiple tools.""" - mock_actor1 = create_mock_tool_actor("dr_agent_mcp", stop_strings=[""]) - mock_actor2 = create_mock_tool_actor("dr_agent_mcp", stop_strings=[""]) + defs = [self.DR_AGENT_DEF, make_tool_definition("other_tool")] with self.assertRaises(ValueError) as context: - DRTuluToolParser([mock_actor1, mock_actor2]) + DRTuluToolParser(defs, stop_sequences=[""]) self.assertIn("exactly one tool", str(context.exception)) def test_uses_tool_call_name(self): """Test that parser uses the tool's call name for routing.""" - mock_actor = create_mock_tool_actor("dr_agent_mcp", stop_strings=[""]) - parser = DRTuluToolParser([mock_actor]) + parser = DRTuluToolParser([self.DR_AGENT_DEF], stop_sequences=[""]) self.assertEqual(parser.tool_call_name, "dr_agent_mcp") @@ -365,10 +325,10 @@ def test_uses_tool_call_name(self): def test_rejects_wrong_tool(self): """Test that parser rejects tools that aren't dr_agent_mcp.""" - mock_actor = create_mock_tool_actor("python", stop_strings=[""]) + defs = [make_tool_definition("python")] with self.assertRaises(ValueError) as context: - DRTuluToolParser([mock_actor]) + DRTuluToolParser(defs, stop_sequences=[""]) self.assertIn("dr_agent_mcp", str(context.exception)) @@ -404,20 +364,16 @@ class TestVllmParserRegistry(unittest.TestCase): @parameterized.expand(VLLM_PARSERS.items()) def test_vllm_parser_config(self, name, config): """Test that a registered vLLM parser has valid configuration.""" - # Check config type self.assertIsInstance(config, VllmParserConfig) - # Verify import_path resolves to a callable class self.assertTrue(config.import_path, "missing import_path") parser_cls = import_class_from_string(config.import_path) self.assertTrue(callable(parser_cls)) - # Verify output_template is usable with .format() self.assertTrue(config.output_template, "missing output_template") formatted = config.output_template.format("test_output") self.assertIn("test_output", formatted) - # Check stop_sequences is a sized iterable (list, tuple, set, etc.) self.assertGreaterEqual(len(config.stop_sequences), 0, "stop_sequences must be a sized iterable") @@ -426,7 +382,6 @@ class TestVllmToolParser(unittest.TestCase): def test_format_tool_outputs_single(self): """Test formatting a single tool output.""" - # Create a mock native parser (we don't need it for format tests) mock_native = MagicMock() parser = VllmToolParser( tool_parser=mock_native, @@ -471,50 +426,40 @@ def test_stop_sequences_custom(self): class TestCreateToolParser(unittest.TestCase): """Tests for create_tool_parser factory function.""" - def setUp(self): - """Set up mock actors for each test.""" - self.patcher = patch("open_instruct.environments.tools.parsers.ray") - self.mock_ray = self.patcher.start() - self.mock_ray.get.side_effect = lambda x: x - self.mock_ray.exceptions.RayActorError = Exception - - def tearDown(self): - """Stop the patcher.""" - self.patcher.stop() - def test_create_legacy_parser(self): """Test creating legacy parser.""" - mock_actor = create_mock_tool_actor("search") mock_tokenizer = MagicMock() + defs = [make_tool_definition("search")] - parser = create_tool_parser("legacy", tokenizer=mock_tokenizer, tool_actors=[mock_actor]) + parser = create_tool_parser("legacy", tokenizer=mock_tokenizer, tool_definitions=defs) self.assertIsInstance(parser, OpenInstructLegacyToolParser) def test_create_dr_tulu_parser(self): """Test creating dr_tulu parser.""" - mock_actor = create_mock_tool_actor("dr_agent_mcp") mock_tokenizer = MagicMock() + defs = [make_tool_definition("dr_agent_mcp")] - parser = create_tool_parser("dr_tulu", tokenizer=mock_tokenizer, tool_actors=[mock_actor]) + parser = create_tool_parser( + "dr_tulu", tokenizer=mock_tokenizer, tool_definitions=defs, stop_sequences=[""] + ) self.assertIsInstance(parser, DRTuluToolParser) @parameterized.expand([(p,) for p in VLLM_PARSERS]) def test_create_vllm_parser(self, parser_type): """Test creating vLLM parsers.""" - mock_actor = create_mock_tool_actor("search") mock_tokenizer = MagicMock() + defs = [make_tool_definition("search")] with patch("open_instruct.environments.tools.parsers.import_class_from_string") as mock_import: mock_import.return_value = MagicMock() - parser = create_tool_parser(parser_type, tokenizer=mock_tokenizer, tool_actors=[mock_actor]) + parser = create_tool_parser(parser_type, tokenizer=mock_tokenizer, tool_definitions=defs) self.assertIsInstance(parser, VllmToolParser) def test_unknown_parser_raises_error(self): """Test that unknown parser types raise an error.""" - mock_actor = create_mock_tool_actor("search") mock_tokenizer = MagicMock() with self.assertRaises(ValueError) as context: - create_tool_parser("unknown_parser", tokenizer=mock_tokenizer, tool_actors=[mock_actor]) + create_tool_parser("unknown_parser", tokenizer=mock_tokenizer) self.assertIn("Unknown parser type", str(context.exception)) self.assertIn("Available:", str(context.exception)) diff --git a/open_instruct/environments/tools/utils.py b/open_instruct/environments/tools/utils.py index 560e477d96..527a69fb86 100644 --- a/open_instruct/environments/tools/utils.py +++ b/open_instruct/environments/tools/utils.py @@ -422,5 +422,9 @@ def get_parameters(self) -> dict[str, Any]: """Get the tool's parameter schema.""" return self.parameters + def get_stop_strings(self) -> list[str]: + """Get stop strings for this tool. Override in subclasses that define custom stop sequences.""" + return [] + def get_tool_definitions(self) -> list[dict[str, Any]]: return [get_openai_tool_definitions(self)] diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 3252b24faf..a6764796e3 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1237,6 +1237,7 @@ def create_model_and_optimizer( tools_config: EnvsConfig | None = None, base_env_config: dict | None = None, pools: dict[str, ray.actor.ActorHandle] | None = None, + tool_stop_sequences: list[str] | None = None, ) -> tuple[ ModelGroup, list[vllm_utils.LLMRayActor], int, int, ray.actor.ActorHandle, utils.ModelDims, ray.actor.ActorHandle ]: @@ -1318,6 +1319,7 @@ def create_model_and_optimizer( pg=pg if args.single_gpu_mode else None, tool_parser_type=tools_config.tool_parser_type if tools_config else "legacy", tool_definitions=tool_definitions, + tool_stop_sequences=tool_stop_sequences, max_steps=tools_config.max_steps if tools_config else 5, mask_tool_use=streaming_config.mask_tool_use, pools=pools, @@ -2093,21 +2095,37 @@ def initialize_tools_and_envs( tools_config.tool_call_names = tool_call_names - # Collect tool definitions from all pools + # Collect tool definitions and stop strings from all pools (batched for parallelism) + acquire_refs = [pool.acquire.remote() for pool in pools.values()] + actors = ray.get(acquire_refs) + + def_refs = [actor.get_tool_definitions.remote() for actor in actors] + stop_refs = ( + [actor.get_stop_strings.remote() for actor in actors] if tools_config.tool_parser_type == "dr_tulu" else [] + ) + all_results = ray.get(def_refs + stop_refs) + def_results = all_results[: len(def_refs)] + stop_results = all_results[len(def_refs) :] + tool_definitions: list[dict[str, Any]] = [] - for pool in pools.values(): - actor = ray.get(pool.acquire.remote()) - defs = ray.get(actor.get_tool_definitions.remote()) - pool.release.remote(actor) + for defs in def_results: tool_definitions.extend(defs) + tool_stop_strings: list[str] = [] + for stop_strings in stop_results: + if stop_strings: + tool_stop_strings.extend(stop_strings) + + for pool, actor in zip(pools.values(), actors): + pool.release.remote(actor) + stop_sequences: list[str] = [] if pools: stop_sequences = create_tool_parser( parser_type=tools_config.tool_parser_type, - tool_actors=[], tokenizer=tokenizer, tool_definitions=tool_definitions, + stop_sequences=tool_stop_strings, ).stop_sequences logger.info( @@ -2152,12 +2170,6 @@ def main( dataset_mixer_list=streaming_config.dataset_mixer_list, dataset_mixer_list_splits=streaming_config.dataset_mixer_list_splits, ) - # TODO: Refactor DRTuluToolParser to work with tool_definitions instead of tool_actors. - if pools and tools_config.tool_parser_type == "dr_tulu": - raise ValueError( - "Parser type 'dr_tulu' requires tool_actors which are no longer created (pools are used instead). " - "Use --tool_parser_type legacy or --tool_parser_type vllm_hermes." - ) if tool_stop_sequences: logger.info(f"Adding tool stop sequences to config: {tool_stop_sequences}") streaming_config.stop_strings.extend(tool_stop_sequences) @@ -2247,6 +2259,7 @@ def main( tools_config, base_env_config, pools, + tool_stop_sequences, ) ) diff --git a/open_instruct/vllm_utils.py b/open_instruct/vllm_utils.py index a87306c2e5..0e12cb72b1 100644 --- a/open_instruct/vllm_utils.py +++ b/open_instruct/vllm_utils.py @@ -556,6 +556,7 @@ def __init__( *args, tool_parser_type: str = "legacy", tool_definitions: list[dict] | None = None, + tool_stop_sequences: list[str] | None = None, max_steps: int = 5, mask_tool_use: bool = True, pools: dict[str, ray.actor.ActorHandle] | None = None, @@ -572,6 +573,7 @@ def __init__( ): assert_threaded_actor(self) self._tool_definitions = tool_definitions + self._tool_stop_sequences = tool_stop_sequences self._init_config( max_steps, mask_tool_use, pools, inflight_updates, reward_config, train_dataset, eval_dataset ) @@ -636,9 +638,9 @@ def _init_executor(self) -> None: def _init_tool_parser(self, tool_parser_type: str) -> None: self.tool_parser = create_tool_parser( parser_type=tool_parser_type, - tool_actors=[], tokenizer=self.llm_engine.tokenizer, tool_definitions=self._tool_definitions, + stop_sequences=self._tool_stop_sequences, ) def _setup_gpu_visibility(self, noset_visible_devices: bool, distributed_executor_backend: str) -> None: @@ -1084,6 +1086,7 @@ def create_vllm_engines( pg: PlacementGroup | None = None, tool_parser_type: str = "legacy", tool_definitions: list[dict] | None = None, + tool_stop_sequences: list[str] | None = None, max_steps: int = 5, mask_tool_use: bool = True, pools: dict[str, ray.actor.ActorHandle] | None = None, @@ -1171,6 +1174,7 @@ def create_vllm_engines( actor_manager=actor_manager, tool_parser_type=tool_parser_type, tool_definitions=tool_definitions, + tool_stop_sequences=tool_stop_sequences, max_steps=max_steps, mask_tool_use=mask_tool_use, pools=pools, diff --git a/pyproject.toml b/pyproject.toml index bb166b1871..a6f7d302ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,8 @@ code = [ ] dr-tulu = [ "dr_agent", + "authlib", + "scipy", ] [tool.uv] diff --git a/scripts/train/debug/tools/dr_tulu_parser_debug.sh b/scripts/train/debug/tools/dr_tulu_parser_debug.sh index c374b3bbbf..278564c1ef 100755 --- a/scripts/train/debug/tools/dr_tulu_parser_debug.sh +++ b/scripts/train/debug/tools/dr_tulu_parser_debug.sh @@ -64,7 +64,7 @@ VLLM_ALLOW_INSECURE_SERIALIZATION=1 uv run --extra dr-tulu open_instruct/grpo_fa --gradient_checkpointing \ --system_prompt_override_file scripts/train/debug/tools/dr_tulu_system_prompt.txt \ --tools dr_agent_mcp \ - --tool_parser dr_tulu \ + --tool_parser_type dr_tulu \ --tool_configs '{"tool_names": "snippet_search,google_search,browse_webpage", "parser_name": "v20250824", "host": "'"$MCP_HOST"'", "port": '"$MCP_PORT"'}' \ --max_steps 5 \ --pass_tools_to_chat_template false \ diff --git a/uv.lock b/uv.lock index 1ae8e0598b..1ada6e5270 100644 --- a/uv.lock +++ b/uv.lock @@ -2687,7 +2687,9 @@ code = [ { name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, ] dr-tulu = [ + { name = "authlib", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "dr-agent", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "scipy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, ] [package.dev-dependencies] @@ -2708,6 +2710,7 @@ requires-dist = [ { name = "accelerate", specifier = ">=1.10.1" }, { name = "ai2-olmo-core", specifier = "==2.3.0" }, { name = "antlr4-python3-runtime", specifier = "==4.11" }, + { name = "authlib", marker = "extra == 'dr-tulu'" }, { name = "backoff", specifier = ">=2.2.1" }, { name = "bitsandbytes", marker = "sys_platform != 'darwin'", specifier = ">=0.44.1" }, { name = "datasets", specifier = ">=4.0.0" }, @@ -2733,6 +2736,7 @@ requires-dist = [ { name = "pydantic", marker = "extra == 'code'", specifier = ">=2.0.0" }, { name = "ray", extras = ["default"], specifier = ">=2.49.2" }, { name = "requests", marker = "extra == 'code'", specifier = ">=2.28.0" }, + { name = "scipy", marker = "extra == 'dr-tulu'" }, { name = "setuptools", specifier = ">=75.6.0,<80.0.0" }, { name = "tensorboard", specifier = ">=2.18.0" }, { name = "torch", marker = "sys_platform != 'linux'", specifier = ">=2.9.0,<2.10" },