Skip to content

Commit 696c230

Browse files
authored
feat: Add integration tests
1 parent 677d286 commit 696c230

File tree

13 files changed

+1280
-745
lines changed

13 files changed

+1280
-745
lines changed

examples/art-e.ipynb

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
"outputs": [],
4545
"source": [
4646
"%%capture\n",
47-
"!uv pip install openpipe-art==0.3.11.post5 langchain-core tenacity datasets \"gql<4\" --prerelease allow --no-cache-dir"
47+
"!uv pip install openpipe-art==0.4.7 vllm==0.9.2 langchain-core tenacity datasets \"gql<4\" --prerelease allow --no-cache-dir"
4848
]
4949
},
5050
{
@@ -581,6 +581,8 @@
581581
"metadata": {},
582582
"outputs": [],
583583
"source": [
584+
"import torch\n",
585+
"\n",
584586
"import art\n",
585587
"from art.local import LocalBackend\n",
586588
"\n",
@@ -594,15 +596,16 @@
594596
")\n",
595597
"\n",
596598
"# To run on a T4, we need to override some config defaults.\n",
597-
"model._internal_config = art.dev.InternalModelConfig(\n",
598-
" init_args=art.dev.InitArgs(\n",
599-
" max_seq_length=8192,\n",
600-
" ),\n",
601-
" engine_args=art.dev.EngineArgs(\n",
602-
" enforce_eager=True,\n",
603-
" gpu_memory_utilization=0.8,\n",
604-
" ),\n",
605-
")\n",
599+
"if torch.cuda.get_device_properties(0).major < 8:\n",
600+
" model._internal_config = art.dev.InternalModelConfig(\n",
601+
" init_args=art.dev.InitArgs(\n",
602+
" max_seq_length=8192,\n",
603+
" ),\n",
604+
" engine_args=art.dev.EngineArgs(\n",
605+
" enforce_eager=True,\n",
606+
" gpu_memory_utilization=0.8,\n",
607+
" ),\n",
608+
" )\n",
606609
"\n",
607610
"# Initialize the server\n",
608611
"backend = LocalBackend(\n",

examples/prisoners-dilemma.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"\n",
1919
"BASE_MODEL = \"Qwen/Qwen2.5-7B-Instruct\"\n",
2020
"PRISONERS_DILEMMA_ROUNDS = 10\n",
21+
"TRAINING_STEPS = 1_000\n",
2122
"\n",
2223
"backend = LocalBackend()\n",
2324
"model = art.TrainableModel(\n",
@@ -117,7 +118,7 @@
117118
" return trajectories\n",
118119
"\n",
119120
"\n",
120-
"for _ in range(await model.get_step(), 1_000):\n",
121+
"for _ in range(await model.get_step(), TRAINING_STEPS):\n",
121122
" # Simultaneously rollout self-play games, and games versus the base model.\n",
122123
" self_play_trajectories, base_play_trajectories = await asyncio.gather(\n",
123124
" art.gather_trajectories(\n",

examples/rock-paper-tool-use.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"\n",
4848
"MODEL_NAME = \"001\"\n",
4949
"BASE_MODEL = \"Qwen/Qwen2.5-7B-Instruct\"\n",
50+
"TRAINING_STEPS = 1_000\n",
5051
"\n",
5152
"model = art.TrainableModel(\n",
5253
" name=MODEL_NAME, project=\"rock-paper-tool-use\", base_model=BASE_MODEL\n",
@@ -175,7 +176,7 @@
175176
" return trajectories[0]\n",
176177
"\n",
177178
"\n",
178-
"for i in range(await model.get_step(), 1_000):\n",
179+
"for i in range(await model.get_step(), TRAINING_STEPS):\n",
179180
" trajectories = await art.gather_trajectories(\n",
180181
" (rollout() for _ in range(64)), max_exceptions=64\n",
181182
" )\n",
@@ -202,7 +203,7 @@
202203
"name": "python",
203204
"nbconvert_exporter": "python",
204205
"pygments_lexer": "ipython3",
205-
"version": "3.10.13"
206+
"version": "3.10.16"
206207
}
207208
},
208209
"nbformat": 4,

examples/temporal_clue/temporal-clue.ipynb

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
"outputs": [],
4141
"source": [
4242
"%%capture\n",
43-
"!uv pip install openpipe-art==0.3.11.post3 \"gql<4\" --prerelease allow --no-cache-dir"
43+
"!uv pip install openpipe-art==0.4.7 vllm==0.9.2 \"gql<4\" --prerelease allow --no-cache-dir"
4444
]
4545
},
4646
{
@@ -221,8 +221,12 @@
221221
"metadata": {},
222222
"outputs": [],
223223
"source": [
224-
"stride = 4\n",
225-
"for i in range(await model.get_step(), 1_000):\n",
224+
"STRIDE = 4\n",
225+
"TRAINING_STEPS = 1_000\n",
226+
"ROLLOUTS_PER_STEP = 50\n",
227+
"LEARNING_RATE = 5e-5\n",
228+
"\n",
229+
"for i in range(await model.get_step(), TRAINING_STEPS):\n",
226230
" val_groups, train_groups = await asyncio.gather(\n",
227231
" art.gather_trajectory_groups(\n",
228232
" (\n",
@@ -233,8 +237,10 @@
233237
" ),\n",
234238
" art.gather_trajectory_groups(\n",
235239
" (\n",
236-
" art.TrajectoryGroup(rollout(model, puzzle) for _ in range(50))\n",
237-
" for puzzle in train_puzzles[i * stride : (i + 1) * stride]\n",
240+
" art.TrajectoryGroup(\n",
241+
" rollout(model, puzzle) for _ in range(ROLLOUTS_PER_STEP)\n",
242+
" )\n",
243+
" for puzzle in train_puzzles[i * STRIDE : (i + 1) * STRIDE]\n",
238244
" ),\n",
239245
" pbar_desc=\"train\",\n",
240246
" ),\n",
@@ -243,7 +249,7 @@
243249
" await model.delete_checkpoints()\n",
244250
" await model.train(\n",
245251
" train_groups,\n",
246-
" config=art.TrainConfig(learning_rate=5e-5),\n",
252+
" config=art.TrainConfig(learning_rate=LEARNING_RATE),\n",
247253
" )"
248254
]
249255
},

0 commit comments

Comments
 (0)