Update rag.py
This commit is contained in:
parent
b206375c6c
commit
a569c71ad4
57
rag.py
57
rag.py
@ -1,30 +1,49 @@
|
||||
# rag.py
|
||||
import faiss, json, requests, numpy as np
|
||||
from sentence_transformers import SentenceTransformer
|
||||
#!/usr/bin/env python3
|
||||
import faiss, json, requests, readline, numpy as np
|
||||
from rich import print
|
||||
from FlagEmbedding import BGEM3FlagModel
|
||||
|
||||
INDEX = faiss.read_index("corpus.idx")
|
||||
META = json.load(open("corpus.meta.json"))
|
||||
EMBMOD = SentenceTransformer("WhereIsAI/bge-base-fr", device="cpu")
|
||||
# --- chargements -------------------------------------------------------------
|
||||
idx = faiss.read_index("corpus.idx")
|
||||
meta = json.load(open("corpus.meta.json"))
|
||||
model = BGEM3FlagModel("BAAI/bge-m3", device="cpu") # même qu’à l’indexation
|
||||
|
||||
# simple aide mémoire pour retrouver rapidement un passage
|
||||
def fetch_passage(i):
|
||||
m = meta[i]
|
||||
return f"[{m['file']} · part {m['part']}] {m['text']}"
|
||||
|
||||
def ask_llm(prompt):
|
||||
r = requests.post("http://127.0.0.1:11434/api/generate", json={
|
||||
"model": "mistral7b-fast",
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.2, "num_predict": 512}
|
||||
"options": {"temperature":0.2, "num_predict":512}
|
||||
}, timeout=300)
|
||||
return r.json()["response"]
|
||||
|
||||
def query(q, k=4):
|
||||
v = EMBMOD.encode([q], normalize_embeddings=True)
|
||||
D, I = INDEX.search(v.astype("float32"), k)
|
||||
ctx = "\n\n".join(f"[{i}] {docs[I[0][i]]}" for i in range(k))
|
||||
prompt = f"""<system>Tu réponds de façon concise en français.</system>
|
||||
<context>{ctx}</context>
|
||||
<user>{q}</user>"""
|
||||
return ask_llm(prompt)
|
||||
# --- boucle interactive ------------------------------------------------------
|
||||
while True:
|
||||
try:
|
||||
q = input("❓ > ").strip()
|
||||
if not q: continue
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\nBye."); break
|
||||
|
||||
if __name__ == "__main__":
|
||||
while True:
|
||||
q = input("Question › ")
|
||||
print(query(q))
|
||||
# embeddings & recherche FAISS (top-k=4)
|
||||
q_emb = model.encode([q], normalize_embeddings=True)
|
||||
D, I = idx.search(q_emb.astype("float32"), 4)
|
||||
|
||||
ctx_blocks = []
|
||||
for rank, idx_id in enumerate(I[0]):
|
||||
ctx_blocks.append(fetch_passage(idx_id))
|
||||
context = "\n\n".join(ctx_blocks)
|
||||
|
||||
prompt = f"""<system>Réponds en français, précis et factuel.</system>
|
||||
<context>{context}</context>
|
||||
<user>{q}</user>"""
|
||||
|
||||
print("\n[bold]Réponse :[/]\n")
|
||||
print(ask_llm(prompt))
|
||||
print("\n[dim]--- contexte utilisé ---[/]")
|
||||
print(context)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user