Skip to content

Commit c0461ef

Browse files
committed
refactor [llm]
1 parent 5ee3e98 commit c0461ef

File tree

1 file changed

+131
-96
lines changed

1 file changed

+131
-96
lines changed

hf_model_evaluation/scripts/evaluation_manager.py

Lines changed: 131 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"""
2121

2222
import argparse
23+
import json
2324
import os
2425
import re
2526
from typing import Any, Dict, List, Optional, Tuple
@@ -130,6 +131,20 @@ def normalize_model_name(name: str) -> tuple[set[str], str]:
130131
return tokens, normalized
131132

132133

134+
def parse_numeric_cell(cell: Optional[str]) -> Optional[float]:
135+
"""Parse a table cell into a float if possible."""
136+
if not cell:
137+
return None
138+
139+
try:
140+
value_str = cell.replace("%", "").replace(",", "").strip()
141+
if not value_str:
142+
return None
143+
return float(value_str)
144+
except (AttributeError, ValueError):
145+
return None
146+
147+
133148
def find_main_model_column(header: List[str], model_name: str) -> Optional[int]:
134149
"""
135150
Identify the column index that corresponds to the main model.
@@ -333,42 +348,31 @@ def extract_metrics_from_table(
333348

334349
# If we identified a specific column, use it; otherwise use first numeric value
335350
if target_column is not None and target_column < len(row):
336-
try:
337-
value_str = row[target_column].replace("%", "").replace(",", "").strip()
338-
if value_str:
339-
value = float(value_str)
340-
metrics.append({
341-
"name": benchmark_name,
342-
"type": benchmark_name.lower().replace(" ", "_"),
343-
"value": value
344-
})
345-
except (ValueError, IndexError):
346-
pass
351+
value = parse_numeric_cell(row[target_column])
352+
if value is not None:
353+
metrics.append({
354+
"name": benchmark_name,
355+
"type": benchmark_name.lower().replace(" ", "_"),
356+
"value": value
357+
})
347358
else:
348359
# Extract numeric values from remaining columns (original behavior)
349360
for i, cell in enumerate(row[1:], start=1):
350-
try:
351-
# Remove common suffixes and convert to float
352-
value_str = cell.replace("%", "").replace(",", "").strip()
353-
if not value_str:
354-
continue
355-
356-
value = float(value_str)
357-
358-
# Determine metric name
359-
metric_name = benchmark_name
360-
if len(header) > i and header[i].lower() not in ["score", "value", "result"]:
361-
metric_name = f"{benchmark_name} ({header[i]})"
362-
363-
metrics.append({
364-
"name": metric_name,
365-
"type": benchmark_name.lower().replace(" ", "_"),
366-
"value": value
367-
})
368-
break # Only take first numeric value per row
369-
except (ValueError, IndexError):
361+
value = parse_numeric_cell(cell)
362+
if value is None:
370363
continue
371364

365+
metric_name = benchmark_name
366+
if len(header) > i and header[i].lower() not in ["score", "value", "result"]:
367+
metric_name = f"{benchmark_name} ({header[i]})"
368+
369+
metrics.append({
370+
"name": metric_name,
371+
"type": benchmark_name.lower().replace(" ", "_"),
372+
"value": value
373+
})
374+
break # Only take first numeric value per row
375+
372376
elif table_format == "transposed":
373377
# Models are in rows (first column), benchmarks are in columns (header)
374378
# Find the row that matches the target model
@@ -397,20 +401,13 @@ def extract_metrics_from_table(
397401
if not benchmark_name or i >= len(target_row):
398402
continue
399403

400-
try:
401-
value_str = target_row[i].replace("%", "").replace(",", "").strip()
402-
if not value_str:
403-
continue
404-
405-
value = float(value_str)
406-
404+
value = parse_numeric_cell(target_row[i])
405+
if value is not None:
407406
metrics.append({
408407
"name": benchmark_name,
409408
"type": benchmark_name.lower().replace(" ", "_").replace("-", "_"),
410409
"value": value
411410
})
412-
except (ValueError, AttributeError):
413-
continue
414411

415412
else: # table_format == "columns"
416413
# Benchmarks are in columns
@@ -424,20 +421,13 @@ def extract_metrics_from_table(
424421
if not benchmark_name or i >= len(data_row):
425422
continue
426423

427-
try:
428-
value_str = data_row[i].replace("%", "").replace(",", "").strip()
429-
if not value_str:
430-
continue
431-
432-
value = float(value_str)
433-
424+
value = parse_numeric_cell(data_row[i])
425+
if value is not None:
434426
metrics.append({
435427
"name": benchmark_name,
436428
"type": benchmark_name.lower().replace(" ", "_"),
437429
"value": value
438430
})
439-
except ValueError:
440-
continue
441431

442432
return metrics
443433

@@ -598,6 +588,35 @@ def extract_tables_with_parser(markdown_content: str) -> List[Dict[str, Any]]:
598588
return tables
599589

600590

591+
def format_model_index(repo_id: str, results: List[Dict[str, Any]], output_format: str = "yaml") -> str:
592+
"""Return model-index payload as YAML or JSON for easy consumption."""
593+
payload = {
594+
"model-index": [
595+
{
596+
"name": repo_id.split("/")[-1] if "/" in repo_id else repo_id,
597+
"results": results,
598+
}
599+
]
600+
}
601+
602+
if output_format == "json":
603+
return json.dumps(payload, indent=2)
604+
return yaml.dump(payload, sort_keys=False)
605+
606+
607+
def build_extract_command(repo_id: str, table_number: int, model_name_override: Optional[str]) -> str:
608+
"""Construct the suggested extract-readme command."""
609+
parts = [
610+
"python scripts/evaluation_manager.py extract-readme",
611+
f'--repo-id "{repo_id}"',
612+
f"--table {table_number}",
613+
]
614+
if model_name_override:
615+
parts.append(f'--model-name-override "{model_name_override}"')
616+
parts.append("--dry-run")
617+
return " \\\n ".join(parts)
618+
619+
601620
def detect_table_format(table: Dict[str, Any], repo_id: str) -> Dict[str, Any]:
602621
"""Analyze a table to detect its format and identify model columns."""
603622
headers = table.get("headers", [])
@@ -662,7 +681,7 @@ def detect_table_format(table: Dict[str, Any], repo_id: str) -> Dict[str, Any]:
662681
}
663682

664683

665-
def inspect_tables(repo_id: str) -> None:
684+
def inspect_tables(repo_id: str, output_format: str = "text") -> None:
666685
"""Inspect and display all evaluation tables in a model's README."""
667686
try:
668687
card = ModelCard.load(repo_id, token=HF_TOKEN)
@@ -678,76 +697,80 @@ def inspect_tables(repo_id: str) -> None:
678697
print(f"No tables found in README for {repo_id}")
679698
return
680699

681-
print(f"\n{'='*70}")
682-
print(f"Tables found in README for: {repo_id}")
683-
print(f"{'='*70}")
700+
summary: Dict[str, Any] = {"repo_id": repo_id, "tables": []}
701+
702+
if output_format == "text":
703+
print(f"\n{'='*70}")
704+
print(f"Tables found in README for: {repo_id}")
705+
print(f"{'='*70}")
684706

685707
eval_table_count = 0
686-
for table in tables:
708+
for idx, table in enumerate(tables, start=1):
687709
analysis = detect_table_format(table, repo_id)
688710

689711
if analysis["format"] == "unknown" and not analysis.get("sample_rows"):
690712
continue
691713

692714
eval_table_count += 1
693-
print(f"\n## Table {eval_table_count}")
715+
716+
override_value = None
717+
if analysis["format"] == "comparison":
718+
exact = next((c for c in analysis.get("model_columns", []) if c["is_exact_match"]), None)
719+
if exact:
720+
override_value = exact["header"]
721+
else:
722+
partial = next((c for c in analysis.get("model_columns", []) if c["is_partial_match"]), None)
723+
override_value = partial["header"] if partial else None
724+
725+
suggested_command = build_extract_command(repo_id, idx, override_value)
726+
727+
table_summary = {
728+
"table_number": idx,
729+
"format": analysis["format"],
730+
"row_count": analysis["row_count"],
731+
"columns": analysis["columns"],
732+
"model_columns": analysis.get("model_columns", []),
733+
"sample_rows": analysis.get("sample_rows", []),
734+
"suggested_command": suggested_command,
735+
}
736+
summary["tables"].append(table_summary)
737+
738+
if output_format == "json":
739+
continue
740+
741+
print(f"\n## Table {idx}")
694742
print(f" Format: {analysis['format']}")
695743
print(f" Rows: {analysis['row_count']}")
696744

697745
print(f"\n Columns ({len(analysis['columns'])}):")
698746
for col_info in analysis.get("model_columns", []):
699-
idx = col_info["index"]
747+
col_idx = col_info["index"]
700748
header = col_info["header"]
701749
if col_info["is_exact_match"]:
702-
print(f" [{idx}] {header} ✓ EXACT MATCH")
750+
print(f" [{col_idx}] {header} ✓ EXACT MATCH")
703751
elif col_info["is_partial_match"]:
704-
print(f" [{idx}] {header} ~ partial match")
752+
print(f" [{col_idx}] {header} ~ partial match")
705753
else:
706-
print(f" [{idx}] {header}")
754+
print(f" [{col_idx}] {header}")
707755

708756
if analysis.get("sample_rows"):
709757
print(f"\n Sample rows (first column):")
710758
for row_val in analysis["sample_rows"][:5]:
711759
print(f" - {row_val}")
712760

713-
# Build suggested command
714-
cmd_parts = [
715-
"python scripts/evaluation_manager.py extract-readme",
716-
f'--repo-id "{repo_id}"',
717-
f"--table {eval_table_count}"
718-
]
761+
if override_value and not any(c["is_exact_match"] for c in analysis.get("model_columns", [])):
762+
print(f"\n ⚠ No exact match. Best candidate: {override_value}")
719763

720-
override_value = None
721-
if analysis["format"] == "comparison":
722-
exact = next((c for c in analysis.get("model_columns", []) if c["is_exact_match"]), None)
723-
if exact:
724-
print(f"\n ✓ Column match: {exact['header']}")
725-
else:
726-
partial = next((c for c in analysis.get("model_columns", []) if c["is_partial_match"]), None)
727-
if partial:
728-
override_value = partial["header"]
729-
print(f"\n ⚠ No exact match. Best candidate: {partial['header']}")
730-
elif analysis.get("model_columns"):
731-
print(f"\n ⚠ Could not identify model column. Options:")
732-
for col_info in analysis.get("model_columns", []):
733-
print(f' "{col_info["header"]}"')
734-
override_value = analysis["model_columns"][0]["header"]
735-
736-
if override_value:
737-
cmd_parts.append(f'--model-name-override "{override_value}"')
738-
739-
cmd_parts.append("--dry-run")
740-
741-
print(f"\n Suggested command:")
742-
print(f" {cmd_parts[0]} \\")
743-
for part in cmd_parts[1:-1]:
744-
print(f" {part} \\")
745-
print(f" {cmd_parts[-1]}")
746-
747-
if eval_table_count == 0:
764+
print(f"\n Suggested command:\n {suggested_command}")
765+
766+
if eval_table_count == 0 and output_format == "text":
748767
print("\nNo evaluation tables detected.")
749768

750-
print(f"\n{'='*70}\n")
769+
if output_format == "json":
770+
print(json.dumps(summary, indent=2))
771+
772+
if output_format == "text":
773+
print(f"\n{'='*70}\n")
751774

752775
except Exception as e:
753776
print(f"Error inspecting tables: {e}")
@@ -1065,6 +1088,12 @@ def main():
10651088
extract_parser.add_argument("--dataset-type", type=str, default="benchmark", help="Dataset type")
10661089
extract_parser.add_argument("--create-pr", action="store_true", help="Create PR instead of direct push")
10671090
extract_parser.add_argument("--dry-run", action="store_true", help="Preview YAML without updating")
1091+
extract_parser.add_argument(
1092+
"--output-format",
1093+
choices=["yaml", "json"],
1094+
default="yaml",
1095+
help="Output format for --dry-run"
1096+
)
10681097

10691098
# Import from AA command
10701099
aa_parser = subparsers.add_parser(
@@ -1104,6 +1133,12 @@ def main():
11041133
"""
11051134
)
11061135
inspect_parser.add_argument("--repo-id", type=str, required=True, help="HF repository ID")
1136+
inspect_parser.add_argument(
1137+
"--output-format",
1138+
choices=["text", "json"],
1139+
default="text",
1140+
help="Choose machine-readable JSON for LLM workflows"
1141+
)
11071142

11081143
args = parser.parse_args()
11091144

@@ -1128,7 +1163,7 @@ def main():
11281163

11291164
if args.dry_run:
11301165
print("\nPreview of extracted evaluations:")
1131-
print(yaml.dump({"model-index": [{"name": args.repo_id.split("/")[-1], "results": results}]}, sort_keys=False))
1166+
print(format_model_index(args.repo_id, results, args.output_format))
11321167
else:
11331168
update_model_card_with_evaluations(
11341169
repo_id=args.repo_id,
@@ -1162,7 +1197,7 @@ def main():
11621197
validate_model_index(args.repo_id)
11631198

11641199
elif args.command == "inspect-tables":
1165-
inspect_tables(args.repo_id)
1200+
inspect_tables(args.repo_id, output_format=args.output_format)
11661201

11671202

11681203
if __name__ == "__main__":

0 commit comments

Comments
 (0)