Skip to content

Commit 68fda25

Browse files
authored
8688 csv output (#8725)
* adding save_as_csv and save_as_json to evaluate * reformat * add documentation and type hings * removed reforamatting * more whitespace fixes * whitespace changes again
1 parent 060c391 commit 68fda25

File tree

1 file changed

+55
-8
lines changed

1 file changed

+55
-8
lines changed

dspy/evaluate/evaluate.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import csv
12
import importlib
3+
import json
24
import logging
35
import types
46
from typing import TYPE_CHECKING, Any, Callable
@@ -77,6 +79,8 @@ def __init__(
7779
max_errors: int | None = None,
7880
provide_traceback: bool | None = None,
7981
failure_score: float = 0.0,
82+
save_as_csv: str | None = None,
83+
save_as_json: str | None = None,
8084
**kwargs,
8185
):
8286
"""
@@ -91,6 +95,9 @@ def __init__(
9195
stopping evaluation. If ``None``, inherits from ``dspy.settings.max_errors``.
9296
provide_traceback (Optional[bool]): Whether to provide traceback information during evaluation.
9397
failure_score (float): The default score to use if evaluation fails due to an exception.
98+
save_as_csv (Optional[str]): The file name where the csv will be saved.
99+
save_as_json (Optional[str]): The file name where the json will be saved.
100+
94101
"""
95102
self.devset = devset
96103
self.metric = metric
@@ -100,6 +107,8 @@ def __init__(
100107
self.max_errors = max_errors
101108
self.provide_traceback = provide_traceback
102109
self.failure_score = failure_score
110+
self.save_as_csv = save_as_csv
111+
self.save_as_json = save_as_json
103112

104113
if "return_outputs" in kwargs:
105114
raise ValueError("`return_outputs` is no longer supported. Results are always returned inside the `results` field of the `EvaluationResult` object.")
@@ -114,6 +123,8 @@ def __call__(
114123
display_progress: bool | None = None,
115124
display_table: bool | int | None = None,
116125
callback_metadata: dict[str, Any] | None = None,
126+
save_as_csv: str | None = None,
127+
save_as_json: str | None = None,
117128
) -> EvaluationResult:
118129
"""
119130
Args:
@@ -140,6 +151,8 @@ def __call__(
140151
num_threads = num_threads if num_threads is not None else self.num_threads
141152
display_progress = display_progress if display_progress is not None else self.display_progress
142153
display_table = display_table if display_table is not None else self.display_table
154+
save_as_csv = save_as_csv if save_as_csv is not None else self.save_as_csv
155+
save_as_json = save_as_json if save_as_json is not None else self.save_as_json
143156

144157
if callback_metadata:
145158
logger.debug(f"Evaluate is called with callback metadata: {callback_metadata}")
@@ -179,11 +192,52 @@ def process_item(example):
179192
else:
180193
logger.warning("Skipping table display since `pandas` is not installed.")
181194

195+
if save_as_csv:
196+
metric_name = (
197+
metric.__name__
198+
if isinstance(metric, types.FunctionType)
199+
else metric.__class__.__name__
200+
)
201+
data = self._prepare_results_output(results, metric_name)
202+
203+
with open(save_as_csv, "w", newline="") as csvfile:
204+
fieldnames = data[0].keys()
205+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
206+
207+
writer.writeheader()
208+
for row in data:
209+
writer.writerow(row)
210+
if save_as_json:
211+
metric_name = (
212+
metric.__name__
213+
if isinstance(metric, types.FunctionType)
214+
else metric.__class__.__name__
215+
)
216+
data = self._prepare_results_output(results, metric_name)
217+
with open(
218+
save_as_json,
219+
"w",
220+
) as f:
221+
json.dump(data, f)
222+
182223
return EvaluationResult(
183224
score=round(100 * ncorrect / ntotal, 2),
184225
results=results,
185226
)
186227

228+
@staticmethod
229+
def _prepare_results_output(
230+
results: list[tuple["dspy.Example", "dspy.Example", Any]], metric_name: str
231+
):
232+
return [
233+
(
234+
merge_dicts(example, prediction) | {metric_name: score}
235+
if prediction_is_dictlike(prediction)
236+
else dict(example) | {"prediction": prediction, metric_name: score}
237+
)
238+
for example, prediction, score in results
239+
]
240+
187241
def _construct_result_table(
188242
self, results: list[tuple["dspy.Example", "dspy.Example", Any]], metric_name: str
189243
) -> "pd.DataFrame":
@@ -200,14 +254,7 @@ def _construct_result_table(
200254
"""
201255
import pandas as pd
202256

203-
data = [
204-
(
205-
merge_dicts(example, prediction) | {"correct": score}
206-
if prediction_is_dictlike(prediction)
207-
else dict(example) | {"prediction": prediction, "correct": score}
208-
)
209-
for example, prediction, score in results
210-
]
257+
data = self._prepare_results_output(results, metric_name)
211258

212259
# Truncate every cell in the DataFrame (DataFrame.applymap was renamed to DataFrame.map in Pandas 2.1.0)
213260
result_df = pd.DataFrame(data)

0 commit comments

Comments
 (0)