-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_mlx_rerank.py
More file actions
91 lines (68 loc) · 2.79 KB
/
test_mlx_rerank.py
File metadata and controls
91 lines (68 loc) · 2.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""Test MLX-native cross-encoder reranking with Qwen3-Reranker.
Qwen3-Reranker is a generative model that acts as a cross-encoder:
- Input: query + document formatted via the tokenizer's chat template
- Output: compares logit probabilities of "yes" vs "no" tokens
- Score: sigmoid of (yes_logit - no_logit)
"""
import math
import time
import mlx.core as mx
from mlx_lm import load
MODEL_ID = "mlx-community/Qwen3-Reranker-4B-mxfp8"
SYSTEM_PROMPT = (
"Judge whether the Document is relevant to the Query. "
'Answer only "yes" or "no".'
)
QUERY = "fraud audit"
DOCUMENTS = [
"CMS audit and inspect books and records",
"The contractor shall perform annual financial audits for fraud detection and compliance verification",
"Chocolate chip cookies were invented in 1938 by Ruth Graves Wakefield",
"The weather forecast for tomorrow calls for partly cloudy skies",
"Federal agencies must conduct fraud risk assessments per OMB Circular A-123",
]
def build_input_ids(tokenizer, query: str, document: str) -> list[int]:
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"<Query>: {query}\n<Document>: {document}"},
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
return tokenizer.encode(text)
def main():
print(f"Loading model: {MODEL_ID}")
t0 = time.time()
model, tokenizer = load(MODEL_ID)
print(f"Model loaded in {time.time() - t0:.1f}s")
yes_id = tokenizer.encode("yes", add_special_tokens=False)[-1]
no_id = tokenizer.encode("no", add_special_tokens=False)[-1]
print(f"Token IDs — yes: {yes_id}, no: {no_id}")
# Show what the template produces for the first doc
sample_ids = build_input_ids(tokenizer, QUERY, DOCUMENTS[0])
print(f"\nSample prompt tokens ({len(sample_ids)} tokens):")
print(tokenizer.decode(sample_ids))
print("---")
print(f"\nQuery: '{QUERY}'\n{'='*70}")
results = []
for doc in DOCUMENTS:
input_ids = build_input_ids(tokenizer, QUERY, doc)
input_arr = mx.array([input_ids])
logits = model(input_arr)
mx.eval(logits)
last_logits = logits[0, -1, :]
yes_logit = last_logits[yes_id].item()
no_logit = last_logits[no_id].item()
score = 1.0 / (1.0 + math.exp(-(yes_logit - no_logit)))
results.append((score, yes_logit, no_logit, doc))
results.sort(key=lambda x: x[0], reverse=True)
print(f"\n{'Score':>8} {'Yes':>10} {'No':>10} Document")
print("-" * 70)
for score, yes_l, no_l, doc in results:
label = doc[:55] + "..." if len(doc) > 55 else doc
print(f"{score:8.4f} {yes_l:10.2f} {no_l:10.2f} {label}")
if __name__ == "__main__":
main()