2020)
2121from sotopia .envs .evaluators import (
2222 EvaluationForTwoAgents ,
23- ReachGoalLLMEvaluator ,
23+ EpisodeLLMEvaluator ,
2424 RuleBasedTerminatedEvaluator ,
2525 SotopiaDimensions ,
2626)
2727from sotopia .envs .parallel import ParallelSotopiaEnv
28- from sotopia .generation_utils .generate import LLM_Name
2928from sotopia .messages .message_classes import AgentAction , Observation
3029from sotopia .samplers .base_sampler import BaseSampler , EnvAgentCombo
3130from sotopia .server import run_async_server
@@ -92,10 +91,8 @@ def find_combo_pk(
9291def get_combo_model_map (
9392 all_episodes : List [EpisodeLog ],
9493 all_combos_map : Dict [str , EnvAgentComboStorage ],
95- ) -> Dict [str , Counter [tuple [LLM_Name , LLM_Name , LLM_Name ]]]:
96- combo_model_map : Dict [str , Counter [tuple [LLM_Name , LLM_Name , LLM_Name ]]] = (
97- defaultdict (Counter )
98- )
94+ ) -> Dict [str , Counter [tuple [str , str , str ]]]:
95+ combo_model_map : Dict [str , Counter [tuple [str , str , str ]]] = defaultdict (Counter )
9996 bad_combos = []
10097 valid_count = 0
10198 invalid_count = 0
@@ -132,9 +129,7 @@ def get_combo_model_map(
132129 all_combos_map ,
133130 )
134131 if curr_combo_pk :
135- model_pair : tuple [LLM_Name , LLM_Name , LLM_Name ] = cast (
136- tuple [LLM_Name , LLM_Name , LLM_Name ], tuple (curr_ep .models )
137- )
132+ model_pair : tuple [str , str , str ] = tuple (curr_ep .models ) # type: ignore
138133 combo_model_map [curr_combo_pk ][model_pair ] += 1
139134 valid_count += 1
140135 else :
@@ -153,8 +148,8 @@ def get_combo_model_map(
153148
154149
155150def get_all_model_pairs (
156- combo_model_map : Dict [str , Counter [tuple [LLM_Name , LLM_Name , LLM_Name ]]],
157- ) -> Set [tuple [LLM_Name , LLM_Name , LLM_Name ]]:
151+ combo_model_map : Dict [str , Counter [tuple [str , str , str ]]],
152+ ) -> Set [tuple [str , str , str ]]:
158153 all_model_pairs = set ()
159154 for key in combo_model_map :
160155 for combo in combo_model_map [key ]:
@@ -169,12 +164,12 @@ def get_all_model_pairs(
169164
170165
171166def get_all_missing_model_pairs (
172- combo_model_map : Dict [str , Counter [tuple [LLM_Name , LLM_Name , LLM_Name ]]],
173- all_model_pairs : Set [tuple [LLM_Name , LLM_Name , LLM_Name ]],
167+ combo_model_map : Dict [str , Counter [tuple [str , str , str ]]],
168+ all_model_pairs : Set [tuple [str , str , str ]],
174169 num_required : int ,
175- ) -> Dict [str , Counter [tuple [LLM_Name , LLM_Name , LLM_Name ]]]:
176- combo_missing_model_map : Dict [str , Counter [tuple [LLM_Name , LLM_Name , LLM_Name ]]] = (
177- defaultdict ( Counter )
170+ ) -> Dict [str , Counter [tuple [str , str , str ]]]:
171+ combo_missing_model_map : Dict [str , Counter [tuple [str , str , str ]]] = defaultdict (
172+ Counter
178173 )
179174 missing_count = 0
180175 for key in combo_model_map :
@@ -192,9 +187,9 @@ def get_all_missing_model_pairs(
192187# temporally used for making sure unique (env, agents, models) setting; need to change
193188# according to the Counter in the case needing to run multiple experiments for one setting
194189def get_missing_model_combo_map (
195- combo_missing_model_map : Dict [str , Counter [tuple [LLM_Name , LLM_Name , LLM_Name ]]],
190+ combo_missing_model_map : Dict [str , Counter [tuple [str , str , str ]]],
196191 all_combos_map : Dict [str , EnvAgentComboStorage ],
197- ) -> Dict [tuple [LLM_Name , LLM_Name ], List [tuple [str , str , str ]]]:
192+ ) -> Dict [tuple [str , str ], List [tuple [str , str , str ]]]:
198193 missing_model_combo_map = defaultdict (list )
199194 for combo_pk in combo_missing_model_map :
200195 model_counter = combo_missing_model_map [combo_pk ]
@@ -216,7 +211,7 @@ def get_missing_model_combo_map(
216211
217212
218213def yield_env_agent_combo (
219- combo_ids : list [tuple [str , str , str ]], model_names : dict [str , LLM_Name ]
214+ combo_ids : list [tuple [str , str , str ]], model_names : dict [str , str ]
220215) -> Generator [EnvAgentCombo [Observation , AgentAction ], None , None ]:
221216 for combo_id in combo_ids :
222217 env_id , agent_id1 , agent_id2 = combo_id
@@ -229,7 +224,7 @@ def yield_env_agent_combo(
229224 RuleBasedTerminatedEvaluator (max_turn_number = 20 , max_stale_turn = 2 ),
230225 ],
231226 terminal_evaluators = [
232- ReachGoalLLMEvaluator (
227+ EpisodeLLMEvaluator (
233228 model_names ["env" ],
234229 EvaluationForTwoAgents [SotopiaDimensions ],
235230 ),
@@ -249,8 +244,8 @@ def yield_env_agent_combo(
249244
250245@gin .configurable
251246def re_run_missing_episodes (
252- combo_with_models : dict [tuple [LLM_Name , LLM_Name ], list [tuple [str , str , str ]]],
253- model_names : dict [str , LLM_Name ] = {
247+ combo_with_models : dict [tuple [str , str ], list [tuple [str , str , str ]]],
248+ model_names : dict [str , str ] = {
254249 "env" : "gpt-4" ,
255250 "agent1" : "gpt-4o-mini" ,
256251 "agent2" : "gpt-4o-mini" ,
0 commit comments