Skip to content

Commit 03543df

Browse files
authored
Raise temperature when doing multiple rollouts (and warn otherwise) (#8748)
* warn once per LM instance for zero-temp rollout * Remove duplicate warnings import in test_lm.py
1 parent 8c167dc commit 03543df

File tree

14 files changed

+101
-37
lines changed

14 files changed

+101
-37
lines changed

docs/docs/cheatsheet.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@ This page will contain snippets for frequent usage patterns.
1010

1111
### Forcing fresh LM outputs
1212

13-
DSPy caches LM calls. Provide a unique ``rollout_id`` to bypass an existing
14-
cache entry while still caching the new result:
13+
DSPy caches LM calls. Provide a unique ``rollout_id`` and set a non-zero
14+
``temperature`` (e.g., 1.0) to bypass an existing cache entry while still caching
15+
the new result:
1516

1617
```python
1718
predict = dspy.Predict("question -> answer")
18-
predict(question="1+1", config={"rollout_id": 1})
19+
predict(question="1+1", config={"rollout_id": 1, "temperature": 1.0})
1920
```
2021

2122
### dspy.Signature

docs/docs/learn/programming/language_models.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,28 +167,29 @@ gpt_4o_mini = dspy.LM('openai/gpt-4o-mini', temperature=0.9, max_tokens=3000, st
167167
By default LMs in DSPy are cached. If you repeat the same call, you will get the same outputs. But you can turn off caching by setting `cache=False`.
168168

169169
If you want to keep caching enabled but force a new request (for example, to obtain diverse outputs),
170-
pass a unique `rollout_id` in your call. DSPy hashes both the inputs and the `rollout_id` when
171-
looking up a cache entry, so different values force a new LM request while
170+
pass a unique `rollout_id` and set a non-zero `temperature` in your call. DSPy hashes both the inputs
171+
and the `rollout_id` when looking up a cache entry, so different values force a new LM request while
172172
still caching future calls with the same inputs and `rollout_id`. The ID is also recorded in
173-
`lm.history`, which makes it easy to track or compare different rollouts during experiments.
173+
`lm.history`, which makes it easy to track or compare different rollouts during experiments. Changing
174+
only the `rollout_id` while keeping `temperature=0` will not affect the LM's output.
174175

175176
```python linenums="1"
176-
lm("Say this is a test!", rollout_id=1)
177+
lm("Say this is a test!", rollout_id=1, temperature=1.0)
177178
```
178179

179180
You can pass these LM kwargs directly to DSPy modules as well. Supplying them at
180181
initialization sets the defaults for every call:
181182

182183
```python linenums="1"
183-
predict = dspy.Predict("question -> answer", rollout_id=1)
184+
predict = dspy.Predict("question -> answer", rollout_id=1, temperature=1.0)
184185
```
185186

186187
To override them for a single invocation, provide a ``config`` dictionary when
187188
calling the module:
188189

189190
```python linenums="1"
190191
predict = dspy.Predict("question -> answer")
191-
predict(question="What is 1 + 52?", config={"rollout_id": 5})
192+
predict(question="What is 1 + 52?", config={"rollout_id": 5, "temperature": 1.0})
192193
```
193194

194195
In both cases, ``rollout_id`` is forwarded to the underlying LM, affects

dspy/clients/base_lm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def copy(self, **kwargs):
113113
"""Returns a copy of the language model with possibly updated parameters.
114114
115115
Any provided keyword arguments update the corresponding attributes or LM kwargs of
116-
the copy. For example, ``lm.copy(rollout_id=1)`` returns an LM whose requests use a
117-
different rollout ID to bypass cache collisions.
116+
the copy. For example, ``lm.copy(rollout_id=1, temperature=1.0)`` returns an LM whose
117+
requests use a different rollout ID at non-zero temperature to bypass cache collisions.
118118
"""
119119

120120
import copy
@@ -130,6 +130,8 @@ def copy(self, **kwargs):
130130
new_instance.kwargs.pop(key, None)
131131
else:
132132
new_instance.kwargs[key] = value
133+
if hasattr(new_instance, "_warned_zero_temp_rollout"):
134+
new_instance._warned_zero_temp_rollout = False
133135

134136
return new_instance
135137

dspy/clients/lm.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import re
44
import threading
5+
import warnings
56
from typing import Any, Literal, cast
67

78
import litellm
@@ -61,8 +62,9 @@ def __init__(
6162
from the models available for inference.
6263
rollout_id: Optional integer used to differentiate cache entries for otherwise
6364
identical requests. Different values bypass DSPy's caches while still caching
64-
future calls with the same inputs and rollout ID. This argument is stripped
65-
before sending requests to the provider.
65+
future calls with the same inputs and rollout ID. Note that `rollout_id`
66+
only affects generation when `temperature` is non-zero. This argument is
67+
stripped before sending requests to the provider.
6668
"""
6769
# Remember to update LM.copy() if you modify the constructor!
6870
self.model = model
@@ -75,6 +77,7 @@ def __init__(
7577
self.finetuning_model = finetuning_model
7678
self.launch_kwargs = launch_kwargs or {}
7779
self.train_kwargs = train_kwargs or {}
80+
self._warned_zero_temp_rollout = False
7881

7982
# Handle model-specific configuration for different model families
8083
model_family = model.split("/")[-1].lower() if "/" in model else model.lower()
@@ -96,6 +99,20 @@ def __init__(
9699
if self.kwargs.get("rollout_id") is None:
97100
self.kwargs.pop("rollout_id", None)
98101

102+
self._warn_zero_temp_rollout(self.kwargs.get("temperature"), self.kwargs.get("rollout_id"))
103+
104+
def _warn_zero_temp_rollout(self, temperature: float | None, rollout_id):
105+
if (
106+
not self._warned_zero_temp_rollout
107+
and rollout_id is not None
108+
and (temperature is None or temperature == 0)
109+
):
110+
warnings.warn(
111+
"rollout_id has no effect when temperature=0; set temperature>0 to bypass the cache.",
112+
stacklevel=3,
113+
)
114+
self._warned_zero_temp_rollout = True
115+
99116
def _get_cached_completion_fn(self, completion_fn, cache):
100117
ignored_args_for_cache_key = ["api_key", "api_base", "base_url"]
101118
if cache:
@@ -115,6 +132,7 @@ def forward(self, prompt=None, messages=None, **kwargs):
115132

116133
messages = messages or [{"role": "user", "content": prompt}]
117134
kwargs = {**self.kwargs, **kwargs}
135+
self._warn_zero_temp_rollout(kwargs.get("temperature"), kwargs.get("rollout_id"))
118136
if kwargs.get("rollout_id") is None:
119137
kwargs.pop("rollout_id", None)
120138

@@ -145,6 +163,7 @@ async def aforward(self, prompt=None, messages=None, **kwargs):
145163

146164
messages = messages or [{"role": "user", "content": prompt}]
147165
kwargs = {**self.kwargs, **kwargs}
166+
self._warn_zero_temp_rollout(kwargs.get("temperature"), kwargs.get("rollout_id"))
148167
if kwargs.get("rollout_id") is None:
149168
kwargs.pop("rollout_id", None)
150169

dspy/predict/best_of_n.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ def __init__(
1414
fail_count: int | None = None,
1515
):
1616
"""
17-
Runs a module up to `N` times with different rollout IDs and returns the best prediction
18-
out of `N` attempts or the first prediction that passes the `threshold`.
17+
Runs a module up to `N` times with different rollout IDs at `temperature=1.0` and
18+
returns the best prediction out of `N` attempts or the first prediction that passes the
19+
`threshold`.
1920
2021
Args:
2122
module (Module): The module to run.
@@ -53,14 +54,12 @@ def one_word_answer(args, pred):
5354

5455
def forward(self, **kwargs):
5556
lm = self.module.get_lm() or dspy.settings.lm
56-
base_rollout = lm.kwargs.get("rollout_id")
57-
start = 0 if base_rollout is None else base_rollout
57+
start = lm.kwargs.get("rollout_id", 0)
5858
rollout_ids = [start + i for i in range(self.N)]
59-
rollout_ids = list(dict.fromkeys(rollout_ids))[: self.N]
6059
best_pred, best_trace, best_reward = None, None, -float("inf")
6160

6261
for idx, rid in enumerate(rollout_ids):
63-
lm_ = lm.copy(rollout_id=rid)
62+
lm_ = lm.copy(rollout_id=rid, temperature=1.0)
6463
mod = self.module.deepcopy()
6564
mod.set_lm(lm_)
6665

dspy/predict/predict.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ class Predict(Module, Parameter):
2727
invocation by passing a ``config`` dictionary when calling the
2828
module. For example::
2929
30-
predict = dspy.Predict("q -> a", rollout_id=1)
31-
predict(q="What is 1 + 52?", config={"rollout_id": 2})
30+
predict = dspy.Predict("q -> a", rollout_id=1, temperature=1.0)
31+
predict(q="What is 1 + 52?", config={"rollout_id": 2, "temperature": 1.0})
3232
"""
3333

3434
def __init__(self, signature: str | type[Signature], callbacks: list[BaseCallback] | None = None, **config):

dspy/predict/refine.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def __init__(
4848
fail_count: int | None = None,
4949
):
5050
"""
51-
Refines a module by running it up to N times with different rollout IDs and returns the best prediction.
51+
Refines a module by running it up to N times with different rollout IDs at `temperature=1.0`
52+
and returns the best prediction.
5253
5354
This module runs the provided module multiple times with varying rollout identifiers and selects
5455
either the first prediction that exceeds the specified threshold or the one with the highest reward.
@@ -96,16 +97,14 @@ def one_word_answer(args, pred):
9697

9798
def forward(self, **kwargs):
9899
lm = self.module.get_lm() or dspy.settings.lm
99-
base_rollout = lm.kwargs.get("rollout_id")
100-
start = 0 if base_rollout is None else base_rollout
100+
start = lm.kwargs.get("rollout_id", 0)
101101
rollout_ids = [start + i for i in range(self.N)]
102-
rollout_ids = list(dict.fromkeys(rollout_ids))[: self.N]
103102
best_pred, best_trace, best_reward = None, None, -float("inf")
104103
advice = None
105104
adapter = dspy.settings.adapter or dspy.ChatAdapter()
106105

107106
for idx, rid in enumerate(rollout_ids):
108-
lm_ = lm.copy(rollout_id=rid)
107+
lm_ = lm.copy(rollout_id=rid, temperature=1.0)
109108
mod = self.module.deepcopy()
110109
mod.set_lm(lm_)
111110

dspy/propose/grounded_proposer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,8 @@ def __init__(
284284
set_tip_randomly=True,
285285
set_history_randomly=True,
286286
verbose=False,
287-
rng=None
287+
rng=None,
288+
init_temperature: float = 1.0,
288289
):
289290
super().__init__()
290291
self.program_aware = program_aware
@@ -299,6 +300,7 @@ def __init__(
299300
self.rng = rng or random
300301

301302
self.prompt_model = get_prompt_model(prompt_model)
303+
self.init_temperature = init_temperature
302304

303305
self.program_code_string = None
304306
if self.program_aware:
@@ -412,7 +414,10 @@ def propose_instruction_for_predictor(
412414
)
413415

414416
# Generate a new instruction for our predictor using a unique rollout id to bypass cache
415-
rollout_lm = self.prompt_model.copy(rollout_id=self.rng.randint(0, 10**9))
417+
rollout_lm = self.prompt_model.copy(
418+
rollout_id=self.rng.randint(0, 10**9),
419+
temperature=self.init_temperature,
420+
)
416421

417422
with dspy.settings.context(lm=rollout_lm):
418423
proposed_instruction = instruction_generator(

dspy/teleprompt/bootstrap.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def __init__(
4747
"""A Teleprompter class that composes a set of demos/examples to go into a predictor's prompt.
4848
These demos come from a combination of labeled examples in the training set, and bootstrapped demos.
4949
50+
Each bootstrap round copies the LM with a new ``rollout_id`` at ``temperature=1.0`` to
51+
bypass caches and gather diverse traces.
52+
5053
Args:
5154
metric (Callable): A function that compares an expected value and predicted value,
5255
outputting the result of that comparison.
@@ -181,7 +184,8 @@ def _bootstrap_one_example(self, example, round_idx=0):
181184
try:
182185
with dspy.settings.context(trace=[], **self.teacher_settings):
183186
lm = dspy.settings.lm
184-
lm = lm.copy(rollout_id=round_idx) if round_idx > 0 else lm
187+
# Use a fresh rollout with temperature=1.0 to bypass caches.
188+
lm = lm.copy(rollout_id=round_idx, temperature=1.0) if round_idx > 0 else lm
185189
new_settings = {"lm": lm} if round_idx > 0 else {}
186190

187191
with dspy.settings.context(**new_settings):

dspy/teleprompt/infer_rules.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,10 @@ class CustomRulesInduction(dspy.Signature):
143143

144144
def forward(self, examples_text):
145145
with dspy.settings.context(**self.teacher_settings):
146-
lm = dspy.settings.lm.copy(rollout_id=self.rng.randint(0, 10**9))
146+
# Generate rules with a fresh rollout and non-zero temperature.
147+
lm = dspy.settings.lm.copy(
148+
rollout_id=self.rng.randint(0, 10**9), temperature=1.0
149+
)
147150
with dspy.settings.context(lm=lm):
148151
rules = self.rules_induction(examples_text=examples_text).natural_language_rules
149152

0 commit comments

Comments
 (0)