Update rag.py
This commit is contained in:
parent
c56a46545f
commit
be9c3709db
92
rag.py
92
rag.py
@ -1,65 +1,67 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
RAG interactif robuste.
|
RAG interactif – version alignée sur l'index
|
||||||
• Recharge les passages à partir des fiches (même découpe que l'index) pour disposer du texte.
|
-------------------------------------------
|
||||||
• Recherche FAISS top‑k=4 et génération via mistral7b-fast.
|
• Utilise corpus.idx + corpus.meta.json pour connaître l'ordre exact des passages.
|
||||||
|
• Recharge **uniquement** les textes correspondants en gardant cet ordre – ainsi, plus
|
||||||
|
d'erreur d'index out‑of‑range quelle que soit la découpe.
|
||||||
|
• Recherche FAISS (top‑k=4) + génération via mistral7b-fast (Ollama).
|
||||||
"""
|
"""
|
||||||
import os, json, readline, re
|
|
||||||
|
import json, readline, re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
import faiss, numpy as np, requests
|
import faiss, numpy as np, requests
|
||||||
from FlagEmbedding import BGEM3FlagModel
|
from FlagEmbedding import BGEM3FlagModel
|
||||||
from rich import print
|
from rich import print
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
ROOT = Path("Fiches") # dossier racine des fiches
|
||||||
ROOT = Path("Fiches") # dossier des fiches sur l'hôte
|
K = 4 # nombre de passages remis au LLM
|
||||||
CHUNK, OVERLAP = 800, 100 # identiques à l'indexation
|
|
||||||
K = 4 # nb de passages remis au modèle
|
|
||||||
|
|
||||||
# --- découpe ---------------------------------------------------------------
|
# ------------------ utilitaires de découpe identiques à l'index -------------
|
||||||
|
CHUNK, OVERLAP = 800, 100 # garder cohérent avec index.py
|
||||||
|
|
||||||
def split(text: str):
|
def split(text: str):
|
||||||
sents = re.split(r"(?<=[.!?]) +", text)
|
sents = re.split(r"(?<=[.!?]) +", text)
|
||||||
buf, out = [], []
|
buf, out = [], []
|
||||||
for s in sents:
|
for s in sents:
|
||||||
buf.append(s)
|
buf.append(s)
|
||||||
if len(" ".join(buf).split()) > CHUNK: # approx 1 mot = 1 token
|
if len(" ".join(buf).split()) > CHUNK:
|
||||||
out.append(" ".join(buf))
|
out.append(" ".join(buf))
|
||||||
buf = buf[-OVERLAP:]
|
buf = buf[-OVERLAP:]
|
||||||
if buf:
|
if buf:
|
||||||
out.append(" ".join(buf))
|
out.append(" ".join(buf))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
# --- charger docs + meta dans le même ordre que l'index --------------------
|
# ------------------- charger meta et reconstruire passages ------------------
|
||||||
|
meta_path = Path("corpus.meta.json")
|
||||||
|
if not meta_path.exists():
|
||||||
|
raise SystemExit("corpus.meta.json introuvable – lancez d'abord index.py")
|
||||||
|
meta = json.load(meta_path.open())
|
||||||
|
|
||||||
docs, meta = [], []
|
# mapping (file, part) -> chunk text
|
||||||
for fp in ROOT.rglob("*.md"):
|
cache: dict[tuple[str, int], str] = {}
|
||||||
for i, chunk in enumerate(split(fp.read_text(encoding="utf-8"))):
|
for fp in sorted(ROOT.rglob("*")):
|
||||||
docs.append(chunk)
|
if fp.suffix.lower() not in {".md", ".markdown", ".txt"}:
|
||||||
meta.append({"file": fp.name, "part": i})
|
continue
|
||||||
|
chunks = split(fp.read_text(encoding="utf-8"))
|
||||||
|
for i, ch in enumerate(chunks):
|
||||||
|
cache[(fp.name, i)] = ch
|
||||||
|
|
||||||
print(f"[dim]Chargé {len(docs)} passages depuis {ROOT}.[/]")
|
# reconstruire docs dans le même ordre que l'index ---------------------------
|
||||||
|
docs = []
|
||||||
|
for m in meta:
|
||||||
|
key = (m["file"], m["part"])
|
||||||
|
docs.append(cache.get(key, "[passage manquant]"))
|
||||||
|
|
||||||
# --- FAISS index existant ---------------------------------------------------
|
print(f"[dim]Passages rechargés : {len(docs)} (ordre conforme à l'index).[/]")
|
||||||
|
|
||||||
|
# ---------------- FAISS + modèle embeddings --------------------------------
|
||||||
idx = faiss.read_index("corpus.idx")
|
idx = faiss.read_index("corpus.idx")
|
||||||
model = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
|
model = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
|
||||||
|
|
||||||
# --- boucle Q/A -------------------------------------------------------------
|
# ---------------- boucle interactive ---------------------------------------
|
||||||
|
|
||||||
def fetch_passage(i: int):
|
|
||||||
m = meta[i]
|
|
||||||
return f"[{m['file']} · part {m['part']}] {docs[i][:200]}…"
|
|
||||||
|
|
||||||
def ask_llm(prompt: str):
|
|
||||||
r = requests.post("http://127.0.0.1:11434/api/generate", json={
|
|
||||||
"model": "mistral7b-fast",
|
|
||||||
"prompt": prompt,
|
|
||||||
"stream": False,
|
|
||||||
"options": {"temperature": 0.2, "num_predict": 512}
|
|
||||||
}, timeout=300)
|
|
||||||
return r.json()["response"]
|
|
||||||
|
|
||||||
print("RAG prêt. Posez vos questions ! (Ctrl‑D pour sortir)")
|
print("RAG prêt. Posez vos questions ! (Ctrl‑D pour sortir)")
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
@ -72,20 +74,34 @@ try:
|
|||||||
|
|
||||||
emb = model.encode([q])
|
emb = model.encode([q])
|
||||||
if isinstance(emb, dict):
|
if isinstance(emb, dict):
|
||||||
|
# récupère le 1er ndarray trouvé
|
||||||
emb = next(v for v in emb.values() if isinstance(v, np.ndarray))
|
emb = next(v for v in emb.values() if isinstance(v, np.ndarray))
|
||||||
q_emb = emb[0] / np.linalg.norm(emb[0])
|
q_emb = emb[0] / np.linalg.norm(emb[0])
|
||||||
|
|
||||||
D, I = idx.search(q_emb.astype("float32").reshape(1, -1), K)
|
D, I = idx.search(q_emb.astype("float32").reshape(1, -1), K)
|
||||||
|
hits = I[0]
|
||||||
|
|
||||||
context = "\n\n".join(docs[int(idx_id)] for idx_id in I[0])
|
context = "\n\n".join(docs[int(i)] for i in hits)
|
||||||
prompt = f"""<system>Réponds en français, précis et factuel.</system>\n<context>{context}</context>\n<user>{q}</user>"""
|
prompt = (
|
||||||
|
"<system>Réponds en français, précis et factuel.</system>\n"
|
||||||
|
f"<context>{context}</context>\n<user>{q}</user>"
|
||||||
|
)
|
||||||
|
|
||||||
|
def ask_llm(p):
|
||||||
|
r = requests.post("http://127.0.0.1:11434/api/generate", json={
|
||||||
|
"model": "mistral7b-fast",
|
||||||
|
"prompt": p,
|
||||||
|
"stream": False,
|
||||||
|
"options": {"temperature": 0.2, "num_predict": 512}
|
||||||
|
}, timeout=300)
|
||||||
|
return r.json()["response"]
|
||||||
|
|
||||||
print("\n[bold]Réponse :[/]")
|
print("\n[bold]Réponse :[/]")
|
||||||
print(ask_llm(prompt))
|
print(ask_llm(prompt))
|
||||||
|
|
||||||
# petite trace des sources
|
|
||||||
print("\n[dim]--- contexte utilisé ---[/]")
|
print("\n[dim]--- contexte utilisé ---[/]")
|
||||||
for idx_id in I[0]:
|
for rank, idx_id in enumerate(hits, 1):
|
||||||
print(fetch_passage(int(idx_id)))
|
m = meta[int(idx_id)]
|
||||||
|
print(f"[{rank}] {m['file']} · part {m['part']} → {docs[int(idx_id)][:120]}…")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("[red]Erreur :", e)
|
print("[red]Erreur :", e)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user