Code/rag.py
2025-05-19 06:21:30 +02:00

59 lines
2.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import faiss, json, requests, readline, numpy as np
from rich import print
from FlagEmbedding import BGEM3FlagModel
# --- chargements -------------------------------------------------------------
idx = faiss.read_index("corpus.idx")
meta = json.load(open("corpus.meta.json"))
model = BGEM3FlagModel("BAAI/bge-m3", device="cpu") # même quà lindexation
# simple aide mémoire pour retrouver rapidement un passage
def fetch_passage(i):
m = meta[i]
return f"[{m['file']} · part {m['part']}] {m['text']}"
def ask_llm(prompt):
r = requests.post("http://127.0.0.1:11434/api/generate", json={
"model": "mistral7b-fast",
"prompt": prompt,
"stream": False,
"options": {"temperature":0.2, "num_predict":512}
}, timeout=300)
return r.json()["response"]
# --- boucle interactive ------------------------------------------------------
while True:
try:
q = input("❓ > ").strip()
if not q: continue
except (KeyboardInterrupt, EOFError):
print("\nBye."); break
# embeddings & recherche FAISS (top-k=4)
# remplace ces deux lignes (32-34)
# q_emb = model.encode([q], normalize_embeddings=True)
# D, I = idx.search(q_emb.astype("float32"), 4)
emb = model.encode([q]) # ndarray (1, 1024)
if isinstance(emb, dict): # selon la version de FlagEmbedding
emb = emb.get("embedding") or emb.get("embeddings")
q_emb = emb[0] / np.linalg.norm(emb[0]) # L2 normalisation
D, I = idx.search(q_emb.astype("float32").reshape(1, -1), 4)
ctx_blocks = []
for rank, idx_id in enumerate(I[0]):
ctx_blocks.append(fetch_passage(idx_id))
context = "\n\n".join(ctx_blocks)
prompt = f"""<system>Réponds en français, précis et factuel.</system>
<context>{context}</context>
<user>{q}</user>"""
print("\n[bold]Réponse :[/]\n")
print(ask_llm(prompt))
print("\n[dim]--- contexte utilisé ---[/]")
print(context)