Skip to content

Commit 7ac7908

Browse files
authored
Merge pull request #1 from gusye1234/main
MERGE Master
2 parents 4e3fb7d + b33b2b8 commit 7ac7908

25 files changed

+17315
-917
lines changed

docs/ROADMAP.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
## Next Version
22

3-
- [ ] Add neo4j for better visualization
4-
- [x] Add DSpy for prompt-tuning
3+
- [ ] Add neo4j for better visualization @gusye1234
4+
- [ ] Add DSpy for prompt-tuning to make small models(Qwen2 7B, Llama 3.1 8B...) can extract entities. @NumberChiffre
55

66

77

examples/benchmarks/dspy_entity.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,19 +114,20 @@ async def run_benchmark(text: str):
114114
system_prompt_dspy = f"{system_prompt} Time: {time.time()}."
115115
lm = dspy.OpenAI(
116116
model="deepseek-chat",
117-
model_type="chat",
117+
model_type="chat",
118+
api_provider="openai",
118119
api_key=os.environ["DEEPSEEK_API_KEY"],
119120
base_url=os.environ["DEEPSEEK_BASE_URL"],
120-
system_prompt=system_prompt_dspy,
121+
system_prompt=system_prompt,
121122
temperature=1.0,
122-
top_p=1,
123-
max_tokens=4096
123+
max_tokens=8192
124124
)
125-
dspy.settings.configure(lm=lm)
125+
dspy.settings.configure(lm=lm, experimental=True)
126126
graph_storage_with_dspy, time_with_dspy = await benchmark_entity_extraction(text, system_prompt_dspy, use_dspy=True)
127127
print(f"Execution time with DSPy-AI: {time_with_dspy:.2f} seconds")
128128
print_extraction_results(graph_storage_with_dspy)
129129

130+
import pdb; pdb.set_trace()
130131
print("Running benchmark without DSPy-AI:")
131132
system_prompt_no_dspy = f"{system_prompt} Time: {time.time()}."
132133
graph_storage_without_dspy, time_without_dspy = await benchmark_entity_extraction(text, system_prompt_no_dspy, use_dspy=False)

examples/finetune_entity_relationship_dspy.ipynb

Lines changed: 14146 additions & 379 deletions
Large diffs are not rendered by default.

examples/generate_entity_relationship_dspy.ipynb

Lines changed: 2062 additions & 0 deletions
Large diffs are not rendered by default.

examples/graphml_visualize.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
import networkx as nx
22
import json
3+
import os
34
import webbrowser
45
import http.server
56
import socketserver
67
import threading
78

8-
# 读取GraphML文件并转换为JSON
9+
# load GraphML file and transfer to JSON
910
def graphml_to_json(graphml_file):
1011
G = nx.read_graphml(graphml_file)
1112
data = nx.node_link_data(G)
1213
return json.dumps(data)
1314

1415

15-
# 创建HTML文件
16+
# create HTML file
1617
def create_html(html_path):
1718
html_content = '''
1819
<!DOCTYPE html>
@@ -242,36 +243,40 @@ def create_json(json_data, json_path):
242243
f.write(json_data)
243244

244245

245-
# 启动简单的HTTP服务器
246-
def start_server():
246+
# start simple HTTP server
247+
def start_server(port):
247248
handler = http.server.SimpleHTTPRequestHandler
248-
with socketserver.TCPServer(("", 8000), handler) as httpd:
249-
print("Server started at http://localhost:8000")
249+
with socketserver.TCPServer(("", port), handler) as httpd:
250+
print(f"Server started at http://localhost:{port}")
250251
httpd.serve_forever()
251252

252-
# 主函数
253-
def visualize_graphml(graphml_file, html_path):
253+
# main function
254+
def visualize_graphml(graphml_file, html_path, port=8000):
254255
json_data = graphml_to_json(graphml_file)
255-
create_json(json_data, 'graph_json.js')
256+
html_dir = os.path.dirname(html_path)
257+
if not os.path.exists(html_dir):
258+
os.makedirs(html_dir)
259+
json_path = os.path.join(html_dir, 'graph_json.js')
260+
create_json(json_data, json_path)
256261
create_html(html_path)
257-
# 在后台启动服务器
258-
server_thread = threading.Thread(target=start_server)
262+
# start server in background
263+
server_thread = threading.Thread(target=start_server(port))
259264
server_thread.daemon = True
260265
server_thread.start()
261266

262-
# 打开默认浏览器
263-
webbrowser.open('http://localhost:8000/graph_visualization.html')
267+
# open default browser
268+
webbrowser.open(f'http://localhost:{port}/{html_path}')
264269

265270
print("Visualization is ready. Press Ctrl+C to exit.")
266271
try:
267-
# 保持主线程运行
272+
# keep main thread running
268273
while True:
269274
pass
270275
except KeyboardInterrupt:
271276
print("Shutting down...")
272277

273-
# 使用示例
278+
# usage
274279
if __name__ == "__main__":
275-
graphml_file = r"nano_graphrag_cache_azure_openai_TEST\graph_chunk_entity_relation.graphml" # 替换为您的GraphML文件路径
280+
graphml_file = r"nano_graphrag_cache_azure_openai_TEST\graph_chunk_entity_relation.graphml" # replace with your GraphML file path
276281
html_path = "graph_visualization.html"
277-
visualize_graphml(graphml_file, html_path)
282+
visualize_graphml(graphml_file, html_path, 11236)

examples/no_openai_key_at_all.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ async def local_embedding(texts: list[str]) -> np.ndarray:
3232
async def ollama_model_if_cache(
3333
prompt, system_prompt=None, history_messages=[], **kwargs
3434
) -> str:
35+
# remove kwargs that are not supported by ollama
36+
kwargs.pop("max_tokens", None)
37+
3538
ollama_client = ollama.AsyncClient()
3639
messages = []
3740
if system_prompt:
Lines changed: 28 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,43 @@
1-
2-
31
from nano_graphrag._utils import encode_string_by_tiktoken
42
from nano_graphrag.base import QueryParam
53
from nano_graphrag.graphrag import GraphRAG
4+
from nano_graphrag._op import chunking_by_seperators
65

76

8-
def chunking_by_specific_separators(
9-
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o",
7+
def chunking_by_token_size(
8+
tokens_list: list[list[int]], # nano-graphrag may pass a batch of docs' tokens
9+
doc_keys: list[str], # nano-graphrag may pass a batch of docs' key ids
10+
tiktoken_model, # a titoken model
11+
overlap_token_size=128,
12+
max_token_size=1024,
1013
):
11-
from langchain_text_splitters import RecursiveCharacterTextSplitter
12-
1314

14-
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=max_token_size,
15-
chunk_overlap=overlap_token_size,
16-
# length_function=lambda x: len(encode_string_by_tiktoken(x)),
17-
model_name=tiktoken_model,
18-
is_separator_regex=False,
19-
separators=[
20-
# Paragraph separators
21-
"\n\n",
22-
"\r\n\r\n",
23-
# Line breaks
24-
"\n",
25-
"\r\n",
26-
# Sentence ending punctuation
27-
"。", # Chinese period
28-
".", # Full-width dot
29-
".", # English period
30-
"!", # Chinese exclamation mark
31-
"!", # English exclamation mark
32-
"?", # Chinese question mark
33-
"?", # English question mark
34-
# Whitespace characters
35-
" ", # Space
36-
"\t", # Tab
37-
"\u3000", # Full-width space
38-
# Special characters
39-
"\u200b", # Zero-width space (used in some Asian languages)
40-
# Final fallback
41-
"",
42-
])
43-
texts = text_splitter.split_text(content)
44-
4515
results = []
46-
for index, chunk_content in enumerate(texts):
47-
48-
results.append(
49-
{
50-
# "tokens": None,
51-
"content": chunk_content.strip(),
52-
"chunk_order_index": index,
53-
}
54-
)
16+
for index, tokens in enumerate(tokens_list):
17+
chunk_token = []
18+
lengths = []
19+
for start in range(0, len(tokens), max_token_size - overlap_token_size):
20+
21+
chunk_token.append(tokens[start : start + max_token_size])
22+
lengths.append(min(max_token_size, len(tokens) - start))
23+
24+
chunk_token = tiktoken_model.decode_batch(chunk_token)
25+
for i, chunk in enumerate(chunk_token):
26+
27+
results.append(
28+
{
29+
"tokens": lengths[i],
30+
"content": chunk.strip(),
31+
"chunk_order_index": i,
32+
"full_doc_id": doc_keys[index],
33+
}
34+
)
35+
5536
return results
5637

5738

5839
WORKING_DIR = "./nano_graphrag_cache_local_embedding_TEST"
5940
rag = GraphRAG(
6041
working_dir=WORKING_DIR,
61-
chunk_func=chunking_by_specific_separators,
42+
chunk_func=chunking_by_seperators,
6243
)
63-
64-
with open("../tests/mock_data.txt", encoding="utf-8-sig") as f:
65-
FAKE_TEXT = f.read()
66-
67-
# rag.insert(FAKE_TEXT)
68-
print(rag.query("What the main theme of this story?", param=QueryParam(mode="local")))

examples/using_dspy_entity_extraction.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,14 @@ def query():
138138
"""
139139
lm = dspy.OpenAI(
140140
model="deepseek-chat",
141-
model_type="chat",
141+
model_type="chat",
142+
api_provider="openai",
142143
api_key=os.environ["DEEPSEEK_API_KEY"],
143144
base_url=os.environ["DEEPSEEK_BASE_URL"],
144145
system_prompt=system_prompt,
145146
temperature=1.0,
146-
top_p=1,
147-
max_tokens=4096
147+
max_tokens=8192
148148
)
149-
dspy.settings.configure(lm=lm)
149+
dspy.settings.configure(lm=lm, experimental=True)
150150
insert()
151151
query()
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import os
2+
import logging
3+
import ollama
4+
import numpy as np
5+
from openai import AsyncOpenAI
6+
from nano_graphrag import GraphRAG, QueryParam
7+
from nano_graphrag import GraphRAG, QueryParam
8+
from nano_graphrag.base import BaseKVStorage
9+
from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs
10+
11+
logging.basicConfig(level=logging.WARNING)
12+
logging.getLogger("nano-graphrag").setLevel(logging.INFO)
13+
14+
# Assumed llm model settings
15+
LLM_BASE_URL = "https://your.api.url"
16+
LLM_API_KEY = "your_api_key"
17+
MODEL = "your_model_name"
18+
19+
# Assumed embedding model settings
20+
EMBEDDING_MODEL = "nomic-embed-text"
21+
EMBEDDING_MODEL_DIM = 768
22+
EMBEDDING_MODEL_MAX_TOKENS = 8192
23+
24+
25+
async def llm_model_if_cache(
26+
prompt, system_prompt=None, history_messages=[], **kwargs
27+
) -> str:
28+
openai_async_client = AsyncOpenAI(
29+
api_key=LLM_API_KEY, base_url=LLM_BASE_URL
30+
)
31+
messages = []
32+
if system_prompt:
33+
messages.append({"role": "system", "content": system_prompt})
34+
35+
# Get the cached response if having-------------------
36+
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
37+
messages.extend(history_messages)
38+
messages.append({"role": "user", "content": prompt})
39+
if hashing_kv is not None:
40+
args_hash = compute_args_hash(MODEL, messages)
41+
if_cache_return = await hashing_kv.get_by_id(args_hash)
42+
if if_cache_return is not None:
43+
return if_cache_return["return"]
44+
# -----------------------------------------------------
45+
46+
response = await openai_async_client.chat.completions.create(
47+
model=MODEL, messages=messages, **kwargs
48+
)
49+
50+
# Cache the response if having-------------------
51+
if hashing_kv is not None:
52+
await hashing_kv.upsert(
53+
{args_hash: {"return": response.choices[0].message.content, "model": MODEL}}
54+
)
55+
# -----------------------------------------------------
56+
return response.choices[0].message.content
57+
58+
59+
def remove_if_exist(file):
60+
if os.path.exists(file):
61+
os.remove(file)
62+
63+
64+
WORKING_DIR = "./nano_graphrag_cache_llm_TEST"
65+
66+
67+
def query():
68+
rag = GraphRAG(
69+
working_dir=WORKING_DIR,
70+
best_model_func=llm_model_if_cache,
71+
cheap_model_func=llm_model_if_cache,
72+
embedding_func=ollama_embedding,
73+
)
74+
print(
75+
rag.query(
76+
"What are the top themes in this story?", param=QueryParam(mode="global")
77+
)
78+
)
79+
80+
81+
def insert():
82+
from time import time
83+
84+
with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
85+
FAKE_TEXT = f.read()
86+
87+
remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
88+
remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
89+
remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
90+
remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
91+
remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
92+
93+
rag = GraphRAG(
94+
working_dir=WORKING_DIR,
95+
enable_llm_cache=True,
96+
best_model_func=llm_model_if_cache,
97+
cheap_model_func=llm_model_if_cache,
98+
embedding_func=ollama_embedding,
99+
)
100+
start = time()
101+
rag.insert(FAKE_TEXT)
102+
print("indexing time:", time() - start)
103+
# rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True)
104+
# rag.insert(FAKE_TEXT[half_len:])
105+
106+
# We're using Ollama to generate embeddings for the BGE model
107+
@wrap_embedding_func_with_attrs(
108+
embedding_dim= EMBEDDING_MODEL_DIM,
109+
max_token_size= EMBEDDING_MODEL_MAX_TOKENS,
110+
)
111+
112+
async def ollama_embedding(texts :list[str]) -> np.ndarray:
113+
embed_text = []
114+
for text in texts:
115+
data = ollama.embeddings(model=EMBEDDING_MODEL, prompt=text)
116+
embed_text.append(data["embedding"])
117+
118+
return embed_text
119+
120+
if __name__ == "__main__":
121+
insert()
122+
query()

examples/using_ollama_as_llm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
async def ollama_model_if_cache(
1717
prompt, system_prompt=None, history_messages=[], **kwargs
1818
) -> str:
19+
# remove kwargs that are not supported by ollama
20+
kwargs.pop("max_tokens", None)
21+
1922
ollama_client = ollama.AsyncClient()
2023
messages = []
2124
if system_prompt:

0 commit comments

Comments
 (0)