Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ All notable changes to this project will be documented in this file.
- Documentation and runtime warning for `dataset_mixer_list` format (float=proportion, int=count) (https://github.com/allenai/open-instruct/pull/1434).

### Changed
- Refactor Legacy and DRTulu tool parsers to use OpenAI-format `tool_definitions` instead of Ray `tool_actors`. Removes `import ray` from `parsers.py`, fixes DRTulu parser which was broken after the pool refactor, and fixes `--tool_parser_type` typo in dr_tulu debug script (https://github.com/allenai/open-instruct/pull/1491).
- Replaces lambda collators with a "single_example_collator" (https://github.com/allenai/open-instruct/pull/1472).
- Clarified `activation_memory_budget` guidance in DPO utils with a practical default (`0.5`) and memory/speed tradeoff notes (https://github.com/allenai/open-instruct/pull/1460).
- Let TransformerTrainModule handle FSDP parallelism instead of manual application in DPO (https://github.com/allenai/open-instruct/pull/1458).
Expand Down
53 changes: 20 additions & 33 deletions open_instruct/environments/tools/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from dataclasses import dataclass, field
from typing import Any

import ray
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.tool_parsers import ToolParser as VllmNativeToolParser
Expand Down Expand Up @@ -55,14 +54,11 @@ class OpenInstructLegacyToolParser(ToolParser):
Tools are invoked via <tool_name>content</tool_name> tags.
The content between tags is passed to the tool's first required parameter.
Only works for tools that take a single string parameter.

Tool names and parameter names are derived from OpenAI-format tool definitions.
"""

def __init__(
self,
tool_actors: list[ray.actor.ActorHandle] | None = None,
output_wrap_name: str = "output",
tool_definitions: list[dict[str, Any]] | None = None,
):
def __init__(self, tool_definitions: list[dict[str, Any]] | None = None, output_wrap_name: str = "output"):
self.output_wrap_name = output_wrap_name

if tool_definitions:
Expand All @@ -78,17 +74,6 @@ def __init__(
else:
properties = params.get("properties", {})
self.tool_param_names[name] = next(iter(properties)) if properties else "text"
elif tool_actors:
self.tool_names = [ray.get(actor.get_call_name.remote()) for actor in tool_actors]
self.tool_param_names = {}
for actor, tool_name in zip(tool_actors, self.tool_names):
params = ray.get(actor.get_parameters.remote())
required = params.get("required", [])
if required:
self.tool_param_names[tool_name] = required[0]
else:
properties = params.get("properties", {})
self.tool_param_names[tool_name] = next(iter(properties)) if properties else "text"
else:
self.tool_names = []
self.tool_param_names = {}
Expand Down Expand Up @@ -288,21 +273,23 @@ class DRTuluToolParser(ToolParser):
"""
Parser for DR Tulu style tool calls. Delegates actual parsing to the tool itself.
Only detects that a tool call occurred (via stop strings) and passes text to the tool.

Requires exactly one tool (dr_agent_mcp) in tool_definitions.
"""

def __init__(self, tool_actors: list[ray.actor.ActorHandle]):
if len(tool_actors) != 1:
raise ValueError(f"DRTuluToolParser requires exactly one tool (dr_agent_mcp), got {len(tool_actors)}")
def __init__(self, tool_definitions: list[dict[str, Any]], stop_sequences: list[str]):
if len(tool_definitions) != 1:
raise ValueError(f"DRTuluToolParser requires exactly one tool (dr_agent_mcp), got {len(tool_definitions)}")

actor = tool_actors[0]
self.tool_call_name = ray.get(actor.get_call_name.remote())
self.tool_call_name = tool_definitions[0]["function"]["name"]

if self.tool_call_name != "dr_agent_mcp":
raise ValueError(f"DRTuluToolParser requires dr_agent_mcp tool, got {self.tool_call_name}")

stop_strings = ray.get(actor.get_stop_strings.remote())
# Use dict.fromkeys to deduplicate while preserving order
self.stop_sequences = list(dict.fromkeys(stop_strings)) if stop_strings else []
self.stop_sequences = list(dict.fromkeys(stop_sequences)) if stop_sequences else []

if not self.stop_sequences:
logger.warning("DRTuluToolParser initialized with no stop sequences — tool calls will never be detected")

def get_tool_calls(self, text: str) -> list[EnvCall]:
for stop in self.stop_sequences:
Expand All @@ -322,8 +309,8 @@ def get_available_parsers() -> list[str]:
def create_tool_parser(
parser_type: str,
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
tool_actors: list[ray.actor.ActorHandle],
tool_definitions: list[dict[str, Any]] | None = None,
stop_sequences: list[str] | None = None,
) -> ToolParser:
"""Create a tool parser instance by type.

Expand All @@ -333,8 +320,8 @@ def create_tool_parser(
- "dr_tulu": DRTuluToolParser for <call_tool name="...">content</call_tool> format
- "vllm_*": VllmToolParser variants (vllm_hermes, vllm_llama3_json, vllm_olmo3)
tokenizer: Tokenizer for the model (required for all parser types).
tool_actors: List of Ray actor handles for the tools.
tool_definitions: OpenAI-format tool definitions (required for vllm_* parsers).
tool_definitions: OpenAI-format tool definitions.
stop_sequences: a list of stop sequences to use for stopping generations.

Returns:
A ToolParser instance configured for the specified type.
Expand All @@ -343,12 +330,12 @@ def create_tool_parser(
ValueError: If parser_type is unknown.
"""
if parser_type == "legacy":
return OpenInstructLegacyToolParser(
tool_actors=tool_actors, output_wrap_name="output", tool_definitions=tool_definitions
)
return OpenInstructLegacyToolParser(tool_definitions=tool_definitions, output_wrap_name="output")

if parser_type == "dr_tulu":
return DRTuluToolParser(tool_actors)
if tool_definitions is None or stop_sequences is None:
raise ValueError("dr_tulu parser requires both tool_definitions and stop_sequences")
return DRTuluToolParser(tool_definitions=tool_definitions, stop_sequences=stop_sequences)

if parser_type in VLLM_PARSERS:
return create_vllm_parser(parser_type, tokenizer, tool_definitions=tool_definitions)
Expand Down
Loading