Skip to content

Commit 08313e5

Browse files
committed
move collect
1 parent 7f152d4 commit 08313e5

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

machine-learning/export/immich_model_exporter/parse_eval_data.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ def collapsed_table(language: str, df: pl.DataFrame) -> str:
8080
(pl.col("image_retrieval_recall@1") + pl.col("image_retrieval_recall@5") + pl.col("image_retrieval_recall@10"))
8181
* (100 / 3)
8282
).round(2)
83-
).collect()
84-
eval_df.write_parquet("model_info.parquet")
83+
)
8584

8685
pareto_front = eval_df.join_where(
8786
eval_df.select("language", "peak_rss", "exec_time_ms", "recall").rename(
@@ -104,9 +103,12 @@ def collapsed_table(language: str, df: pl.DataFrame) -> str:
104103
)
105104
eval_df = eval_df.join(pareto_front, on=["pretrained_model", "language"], how="left")
106105
eval_df = eval_df.with_columns(is_pareto=pl.col("recall_other").is_null())
107-
eval_df = eval_df.drop("peak_rss_other", "exec_time_ms_other", "recall_other", "language_other").unique(
108-
subset=["pretrained_model", "language"]
106+
eval_df = (
107+
eval_df.drop("peak_rss_other", "exec_time_ms_other", "recall_other", "language_other")
108+
.unique(subset=["pretrained_model", "language"])
109+
.collect()
109110
)
111+
eval_df.write_parquet("model_info.parquet")
110112

111113
eval_df = eval_df.filter(pl.col("recall") >= 20)
112114
eval_df = eval_df.select(

0 commit comments

Comments
 (0)