Update rag.py
This commit is contained in:
parent
b206375c6c
commit
a569c71ad4
57
rag.py
57
rag.py
@ -1,30 +1,49 @@
|
|||||||
# rag.py
|
#!/usr/bin/env python3
|
||||||
import faiss, json, requests, numpy as np
|
import faiss, json, requests, readline, numpy as np
|
||||||
from sentence_transformers import SentenceTransformer
|
from rich import print
|
||||||
|
from FlagEmbedding import BGEM3FlagModel
|
||||||
|
|
||||||
INDEX = faiss.read_index("corpus.idx")
|
# --- chargements -------------------------------------------------------------
|
||||||
META = json.load(open("corpus.meta.json"))
|
idx = faiss.read_index("corpus.idx")
|
||||||
EMBMOD = SentenceTransformer("WhereIsAI/bge-base-fr", device="cpu")
|
meta = json.load(open("corpus.meta.json"))
|
||||||
|
model = BGEM3FlagModel("BAAI/bge-m3", device="cpu") # même qu’à l’indexation
|
||||||
|
|
||||||
|
# simple aide mémoire pour retrouver rapidement un passage
|
||||||
|
def fetch_passage(i):
|
||||||
|
m = meta[i]
|
||||||
|
return f"[{m['file']} · part {m['part']}] {m['text']}"
|
||||||
|
|
||||||
def ask_llm(prompt):
|
def ask_llm(prompt):
|
||||||
r = requests.post("http://127.0.0.1:11434/api/generate", json={
|
r = requests.post("http://127.0.0.1:11434/api/generate", json={
|
||||||
"model": "mistral7b-fast",
|
"model": "mistral7b-fast",
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"options": {"temperature": 0.2, "num_predict": 512}
|
"options": {"temperature":0.2, "num_predict":512}
|
||||||
}, timeout=300)
|
}, timeout=300)
|
||||||
return r.json()["response"]
|
return r.json()["response"]
|
||||||
|
|
||||||
def query(q, k=4):
|
# --- boucle interactive ------------------------------------------------------
|
||||||
v = EMBMOD.encode([q], normalize_embeddings=True)
|
while True:
|
||||||
D, I = INDEX.search(v.astype("float32"), k)
|
try:
|
||||||
ctx = "\n\n".join(f"[{i}] {docs[I[0][i]]}" for i in range(k))
|
q = input("❓ > ").strip()
|
||||||
prompt = f"""<system>Tu réponds de façon concise en français.</system>
|
if not q: continue
|
||||||
<context>{ctx}</context>
|
except (KeyboardInterrupt, EOFError):
|
||||||
<user>{q}</user>"""
|
print("\nBye."); break
|
||||||
return ask_llm(prompt)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
# embeddings & recherche FAISS (top-k=4)
|
||||||
while True:
|
q_emb = model.encode([q], normalize_embeddings=True)
|
||||||
q = input("Question › ")
|
D, I = idx.search(q_emb.astype("float32"), 4)
|
||||||
print(query(q))
|
|
||||||
|
ctx_blocks = []
|
||||||
|
for rank, idx_id in enumerate(I[0]):
|
||||||
|
ctx_blocks.append(fetch_passage(idx_id))
|
||||||
|
context = "\n\n".join(ctx_blocks)
|
||||||
|
|
||||||
|
prompt = f"""<system>Réponds en français, précis et factuel.</system>
|
||||||
|
<context>{context}</context>
|
||||||
|
<user>{q}</user>"""
|
||||||
|
|
||||||
|
print("\n[bold]Réponse :[/]\n")
|
||||||
|
print(ask_llm(prompt))
|
||||||
|
print("\n[dim]--- contexte utilisé ---[/]")
|
||||||
|
print(context)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user