Skip to content

Commit 635752f

Browse files
XuhuiZhoubugszautofix-ci[bot]astrophie
authored
Feature/sotopia demo UI (sotopia-lab#261)
* initial * initial ui * merge main * add new ui * switch to fastAPI * websocket check * fix render episode error * add page; make a simplified page and still WIP * [autofix.ci] apply automated fixes * fix simplified streaming version * semi-done character page + avatar assets * Fixed character card styling * [autofix.ci] apply automated fixes * unified rendering and chat display * updated chat character icons * add some tags * add typing * temp fix * add characters avatar to simulation * fix episode full avatar * go to modal config * clean up code * add modal streamlit app * clean codebase except websocket * remove repeated local css * clean websocket * fix get name error * fix errors * pre render scenario * add custom eval * change streamlit to dynamic path * new uv * revert to previous install commands * a fix for modal * add customized dimension * [autofix.ci] apply automated fixes * sort scenarios in simulation * for demo video * update deploy instruction * update intro page * update intro page * [autofix.ci] apply automated fixes * update intro page * add customized dimensions * update api link and modal environment * move folder * fix relative import * update modal image build * use uv to build environment * change folder name * change test * fix modal serve * environment change * refactor * fix ui --------- Co-authored-by: Zhe Su <[email protected]> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: astrophie <[email protected]>
1 parent fe8deea commit 635752f

33 files changed

+2921
-805
lines changed

docs/pages/contribution/contribution.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ Please refer to [Dev Containers](https://containers.dev/supporting#editors) to s
133133

134134
You can also set up the development environment without Dev Containers. There are three things you will need to set up manually:
135135

136-
- Python and uv: Please start from an environment supporting Python 3.10+ and install uv using `pip install uv; uv sync --all-extra`.
136+
- Python and uv: Please start from an environment supporting Python 3.10+ and install uv using `pip install uv; uv sync --all-extras`. (Note that this will install all the extra dependencies)
137137
- Redis: Please refer to introduction page for the set up of Redis.
138138
- Local LLM (optional): If you don't have access to model endpoints (e.g. OpenAI, Anthropic or others), you can use a local model. You can use Ollama, Llama.cpp, vLLM or many others which support OpenAI compatible endpoints.
139139

examples/experimental/websocket/websocket_test_client.py

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from sotopia.database import EnvironmentProfile, AgentProfile
77

88
import asyncio
9-
import websockets
9+
import aiohttp
1010
import sys
1111
from pathlib import Path
1212

@@ -28,40 +28,39 @@ async def connect(self) -> None:
2828
url_with_token = f"{self.url}?token=test_token_{self.client_id}"
2929

3030
try:
31-
async with websockets.connect(url_with_token) as websocket:
32-
print(f"Client {self.client_id}: Connected to {self.url}")
33-
34-
# Send initial message
35-
# Note: You'll need to implement the logic to get agent_ids and env_id
36-
# This is just an example structure
37-
agent_ids = [agent.pk for agent in AgentProfile.find().all()[:2]]
38-
env_id = EnvironmentProfile.find().all()[0].pk
39-
start_message = {
40-
"type": "START_SIM",
41-
"data": {
42-
"env_id": env_id, # Replace with actual env_id
43-
"agent_ids": agent_ids, # Replace with actual agent_ids
44-
},
45-
}
46-
await websocket.send(json.dumps(start_message))
47-
print(f"Client {self.client_id}: Sent START_SIM message")
48-
49-
# Receive and process messages
50-
while True:
51-
try:
52-
message = await websocket.recv()
53-
print(
54-
f"\nClient {self.client_id} received message:",
55-
json.dumps(json.loads(message), indent=2),
56-
)
57-
assert isinstance(message, str)
58-
await self.save_message(message)
59-
except websockets.ConnectionClosed:
60-
print(f"Client {self.client_id}: Connection closed")
61-
break
62-
except Exception as e:
63-
print(f"Client {self.client_id} error:", str(e))
64-
break
31+
async with aiohttp.ClientSession() as session:
32+
async with session.ws_connect(url_with_token) as ws:
33+
print(f"Client {self.client_id}: Connected to {self.url}")
34+
35+
# Send initial message
36+
# Note: You'll need to implement the logic to get agent_ids and env_id
37+
# This is just an example structure
38+
agent_ids = [agent.pk for agent in AgentProfile.find().all()[:2]]
39+
env_id = EnvironmentProfile.find().all()[0].pk
40+
start_message = {
41+
"type": "START_SIM",
42+
"data": {
43+
"env_id": env_id, # Replace with actual env_id
44+
"agent_ids": agent_ids, # Replace with actual agent_ids
45+
},
46+
}
47+
await ws.send_json(start_message)
48+
print(f"Client {self.client_id}: Sent START_SIM message")
49+
50+
# Receive and process messages
51+
async for msg in ws:
52+
if msg.type == aiohttp.WSMsgType.TEXT:
53+
print(
54+
f"\nClient {self.client_id} received message:",
55+
json.dumps(json.loads(msg.data), indent=2),
56+
)
57+
await self.save_message(msg.data)
58+
elif msg.type == aiohttp.WSMsgType.CLOSED:
59+
print(f"Client {self.client_id}: Connection closed")
60+
break
61+
elif msg.type == aiohttp.WSMsgType.ERROR:
62+
print(f"Client {self.client_id}: Connection error")
63+
break
6564

6665
except Exception as e:
6766
print(f"Client {self.client_id} connection error:", str(e))

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ examples = ["transformers", "datasets", "scipy", "torch", "pandas"]
4040
api = [
4141
"fastapi[standard]",
4242
"uvicorn",
43+
"streamlit",
44+
"websockets",
45+
"modal"
4346
]
4447
test = ["pytest", "pytest-cov", "pytest-asyncio"]
4548

scripts/modal/modal_api_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55

66
import redis
7-
from sotopia.ui.fastapi_server import SotopiaFastAPI
7+
from sotopia.api.fastapi_server import SotopiaFastAPI
88

99
# Create persistent volume for Redis data
1010
redis_volume = modal.Volume.from_name("sotopia-api", create_if_missing=True)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22
> [!CAUTION]
33
> Work in progress: the API endpoints are being implemented. And will be released in the future major version.
44
5+
## Deploy to Modal
6+
First you need to have a Modal account and logged in with `modal setup`
7+
8+
To deploy the FastAPI server to Modal, run the following command:
9+
```bash
10+
cd sotopia/ui/fastapi_server
11+
modal deploy modal_api_server.py
12+
```
513
## FastAPI Server
614

715
To run the FastAPI server, you can use the following command:
Lines changed: 7 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
NonStreamingSimulationStatus,
1818
CustomEvaluationDimensionList,
1919
CustomEvaluationDimension,
20+
BaseEnvironmentProfile,
21+
BaseAgentProfile,
22+
BaseRelationshipProfile,
2023
)
2124
from sotopia.envs.parallel import ParallelSotopiaEnv
2225
from sotopia.envs.evaluators import (
@@ -37,7 +40,7 @@
3740
from fastapi.middleware.cors import CORSMiddleware
3841
from pydantic import BaseModel, model_validator, field_validator, Field
3942

40-
from sotopia.ui.websocket_utils import (
43+
from sotopia.api.websocket_utils import (
4144
WebSocketSotopiaSimulator,
4245
WSMessageType,
4346
ErrorType,
@@ -68,56 +71,6 @@
6871
] = {} # TODO check whether this is the correct way to store the active simulations
6972

7073

71-
class RelationshipWrapper(BaseModel):
72-
pk: str = ""
73-
agent_1_id: str = ""
74-
agent_2_id: str = ""
75-
relationship: Literal[0, 1, 2, 3, 4, 5] = 0
76-
backstory: str = ""
77-
tag: str = ""
78-
79-
80-
class AgentProfileWrapper(BaseModel):
81-
"""
82-
Wrapper for AgentProfile to avoid pydantic v2 issues
83-
"""
84-
85-
pk: str = ""
86-
first_name: str
87-
last_name: str
88-
age: int = 0
89-
occupation: str = ""
90-
gender: str = ""
91-
gender_pronoun: str = ""
92-
public_info: str = ""
93-
big_five: str = ""
94-
moral_values: list[str] = []
95-
schwartz_personal_values: list[str] = []
96-
personality_and_values: str = ""
97-
decision_making_style: str = ""
98-
secret: str = ""
99-
model_id: str = ""
100-
mbti: str = ""
101-
tag: str = ""
102-
103-
104-
class EnvironmentProfileWrapper(BaseModel):
105-
"""
106-
Wrapper for EnvironmentProfile to avoid pydantic v2 issues
107-
"""
108-
109-
pk: str = ""
110-
codename: str
111-
source: str = ""
112-
scenario: str = ""
113-
agent_goals: list[str] = []
114-
relationship: Literal[0, 1, 2, 3, 4, 5] = 0
115-
age_constraint: str | None = None
116-
occupation_constraint: str | None = None
117-
agent_constraint: list[list[str]] | None = None
118-
tag: str = ""
119-
120-
12174
class CustomEvaluationDimensionsWrapper(BaseModel):
12275
pk: str = ""
12376
name: str = Field(
@@ -484,23 +437,23 @@ def setup_routes(self) -> None:
484437
)(get_evaluation_dimensions)
485438

486439
@self.post("/scenarios/", response_model=str)
487-
async def create_scenario(scenario: EnvironmentProfileWrapper) -> str:
440+
async def create_scenario(scenario: BaseEnvironmentProfile) -> str:
488441
scenario_profile = EnvironmentProfile(**scenario.model_dump())
489442
scenario_profile.save()
490443
pk = scenario_profile.pk
491444
assert pk is not None
492445
return pk
493446

494447
@self.post("/agents/", response_model=str)
495-
async def create_agent(agent: AgentProfileWrapper) -> str:
448+
async def create_agent(agent: BaseAgentProfile) -> str:
496449
agent_profile = AgentProfile(**agent.model_dump())
497450
agent_profile.save()
498451
pk = agent_profile.pk
499452
assert pk is not None
500453
return pk
501454

502455
@self.post("/relationship/", response_model=str)
503-
async def create_relationship(relationship: RelationshipWrapper) -> str:
456+
async def create_relationship(relationship: BaseRelationshipProfile) -> str:
504457
relationship_profile = RelationshipProfile(**relationship.model_dump())
505458
relationship_profile.save()
506459
pk = relationship_profile.pk
Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -160,20 +160,18 @@ def __init__(
160160

161161
async def arun(self) -> AsyncGenerator[dict[str, Any], None]:
162162
# Use sotopia to run the simulation
163-
generator = arun_one_episode(
163+
generator = await arun_one_episode(
164164
env=self.env,
165165
agent_list=list(self.agents.values()),
166166
push_to_db=False,
167167
streaming=True,
168168
)
169169

170-
# assert isinstance(
171-
# generator, AsyncGenerator
172-
# ), "generator should be async generator, but got {}".format(
173-
# type(generator)
174-
# )
170+
assert isinstance(
171+
generator, AsyncGenerator
172+
), "generator should be async generator, but got {}".format(type(generator))
175173

176-
async for messages in await generator: # type: ignore
174+
async for messages in generator:
177175
reasoning, rewards = "", [0.0, 0.0]
178176
if messages[-1][0][0] == "Evaluation":
179177
reasoning = messages[-1][0][2].to_natural_language()
@@ -192,9 +190,6 @@ async def arun(self) -> AsyncGenerator[dict[str, Any], None]:
192190
rewards=rewards,
193191
rewards_prompt="",
194192
)
195-
# agent_profiles, parsed_messages = epilog.render_for_humans()
196-
# if not eval_available:
197-
# parsed_messages = parsed_messages[:-2]
198193

199194
yield {
200195
"type": "messages",

sotopia/database/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,18 @@
22
from redis_om import JsonModel, Migrator
33
from .annotators import Annotator
44
from .env_agent_combo_storage import EnvAgentComboStorage
5-
from .logs import AnnotationForEpisode, EpisodeLog, NonStreamingSimulationStatus
5+
from .logs import (
6+
AnnotationForEpisode,
7+
BaseEpisodeLog,
8+
NonStreamingSimulationStatus,
9+
EpisodeLog,
10+
)
611
from .persistent_profile import (
712
AgentProfile,
13+
BaseAgentProfile,
814
EnvironmentProfile,
15+
BaseEnvironmentProfile,
16+
BaseRelationshipProfile,
917
RelationshipProfile,
1018
RelationshipType,
1119
)
@@ -42,12 +50,16 @@
4250

4351
__all__ = [
4452
"AgentProfile",
53+
"BaseAgentProfile",
4554
"EnvironmentProfile",
55+
"BaseEnvironmentProfile",
4656
"EpisodeLog",
57+
"BaseEpisodeLog",
4758
"NonStreamingSimulationStatus",
4859
"EnvAgentComboStorage",
4960
"AnnotationForEpisode",
5061
"Annotator",
62+
"BaseRelationshipProfile",
5163
"RelationshipProfile",
5264
"RelationshipType",
5365
"RedisCommunicationMixin",

sotopia/database/logs.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
else:
66
from typing_extensions import Self
77

8-
from pydantic import model_validator
8+
from pydantic import model_validator, BaseModel
99
from redis_om import JsonModel
1010
from redis_om.model.model import Field
1111
from typing import Literal
@@ -17,7 +17,7 @@ class NonStreamingSimulationStatus(JsonModel):
1717
status: Literal["Started", "Error", "Completed"]
1818

1919

20-
class EpisodeLog(JsonModel):
20+
class BaseEpisodeLog(BaseModel):
2121
# Note that we did not validate the following constraints:
2222
# 1. The number of turns in messages and rewards should be the same or off by 1
2323
# 2. The agents in the messages are the same as the agetns
@@ -77,6 +77,10 @@ def render_for_humans(self) -> tuple[list[AgentProfile], list[str]]:
7777
return agent_profiles, messages_and_rewards
7878

7979

80+
class EpisodeLog(BaseEpisodeLog, JsonModel):
81+
pass
82+
83+
8084
class AnnotationForEpisode(JsonModel):
8185
episode: str = Field(index=True, description="the pk id of episode log")
8286
annotator_id: str = Field(index=True, full_text_search=True)

0 commit comments

Comments
 (0)