Code/rag.py
2025-05-19 06:24:29 +02:00

92 lines
3.0 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
RAG interactif robuste.
• Recharge les passages à partir des fiches (même découpe que l'index) pour disposer du texte.
• Recherche FAISS topk=4 et génération via mistral7b-fast.
"""
import os, json, readline, re
from pathlib import Path
import faiss, numpy as np, requests
from FlagEmbedding import BGEM3FlagModel
from rich import print
# ---------------------------------------------------------------------------
ROOT = Path("Fiches") # dossier des fiches sur l'hôte
CHUNK, OVERLAP = 800, 100 # identiques à l'indexation
K = 4 # nb de passages remis au modèle
# --- découpe ---------------------------------------------------------------
def split(text: str):
sents = re.split(r"(?<=[.!?]) +", text)
buf, out = [], []
for s in sents:
buf.append(s)
if len(" ".join(buf).split()) > CHUNK: # approx 1 mot = 1 token
out.append(" ".join(buf))
buf = buf[-OVERLAP:]
if buf:
out.append(" ".join(buf))
return out
# --- charger docs + meta dans le même ordre que l'index --------------------
docs, meta = [], []
for fp in ROOT.rglob("*.md"):
for i, chunk in enumerate(split(fp.read_text(encoding="utf-8"))):
docs.append(chunk)
meta.append({"file": fp.name, "part": i})
print(f"[dim]Chargé {len(docs)} passages depuis {ROOT}.[/]")
# --- FAISS index existant ---------------------------------------------------
idx = faiss.read_index("corpus.idx")
model = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
# --- boucle Q/A -------------------------------------------------------------
def fetch_passage(i: int):
m = meta[i]
return f"[{m['file']} · part {m['part']}] {docs[i][:200]}"
def ask_llm(prompt: str):
r = requests.post("http://127.0.0.1:11434/api/generate", json={
"model": "mistral7b-fast",
"prompt": prompt,
"stream": False,
"options": {"temperature": 0.2, "num_predict": 512}
}, timeout=300)
return r.json()["response"]
print("RAG prêt. Posez vos questions ! (CtrlD pour sortir)")
try:
while True:
try:
q = input("❓ > ").strip()
if not q:
continue
except (EOFError, KeyboardInterrupt):
print("\nBye."); break
emb = model.encode([q])
if isinstance(emb, dict):
emb = next(v for v in emb.values() if isinstance(v, np.ndarray))
q_emb = emb[0] / np.linalg.norm(emb[0])
D, I = idx.search(q_emb.astype("float32").reshape(1, -1), K)
context = "\n\n".join(docs[int(idx_id)] for idx_id in I[0])
prompt = f"""<system>Réponds en français, précis et factuel.</system>\n<context>{context}</context>\n<user>{q}</user>"""
print("\n[bold]Réponse :[/]")
print(ask_llm(prompt))
# petite trace des sources
print("\n[dim]--- contexte utilisé ---[/]")
for idx_id in I[0]:
print(fetch_passage(int(idx_id)))
except Exception as e:
print("[red]Erreur :", e)