Update rag.py
This commit is contained in:
parent
852b81ba93
commit
c56a46545f
134
rag.py
134
rag.py
@ -1,107 +1,91 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Mini RAG interactif :
|
||||
• Recherche sémantique FAISS sur le corpus.idx / corpus.meta.json
|
||||
• Contexte (top‑k=4 passages) envoyé à Mistral‑7B via Ollama.
|
||||
|
||||
Robuste aux différentes sorties de BGEM3FlagModel.encode.
|
||||
RAG interactif robuste.
|
||||
• Recharge les passages à partir des fiches (même découpe que l'index) pour disposer du texte.
|
||||
• Recherche FAISS top‑k=4 et génération via mistral7b-fast.
|
||||
"""
|
||||
import json
|
||||
import readline
|
||||
import os, json, readline, re
|
||||
from pathlib import Path
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import requests
|
||||
import faiss, numpy as np, requests
|
||||
from FlagEmbedding import BGEM3FlagModel
|
||||
from rich import print
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chargements initiaux
|
||||
# ---------------------------------------------------------------------------
|
||||
IDX_FILE = Path("corpus.idx")
|
||||
META_FILE = Path("corpus.meta.json")
|
||||
ROOT = Path("Fiches") # dossier des fiches sur l'hôte
|
||||
CHUNK, OVERLAP = 800, 100 # identiques à l'indexation
|
||||
K = 4 # nb de passages remis au modèle
|
||||
|
||||
if not IDX_FILE.exists() or not META_FILE.exists():
|
||||
raise SystemExit("[bold red]Erreur :[/] index absent. Lancez d'abord index.py !")
|
||||
# --- découpe ---------------------------------------------------------------
|
||||
|
||||
index = faiss.read_index(str(IDX_FILE))
|
||||
meta = json.loads(META_FILE.read_text())
|
||||
def split(text: str):
|
||||
sents = re.split(r"(?<=[.!?]) +", text)
|
||||
buf, out = [], []
|
||||
for s in sents:
|
||||
buf.append(s)
|
||||
if len(" ".join(buf).split()) > CHUNK: # approx 1 mot = 1 token
|
||||
out.append(" ".join(buf))
|
||||
buf = buf[-OVERLAP:]
|
||||
if buf:
|
||||
out.append(" ".join(buf))
|
||||
return out
|
||||
|
||||
# --- charger docs + meta dans le même ordre que l'index --------------------
|
||||
|
||||
docs, meta = [], []
|
||||
for fp in ROOT.rglob("*.md"):
|
||||
for i, chunk in enumerate(split(fp.read_text(encoding="utf-8"))):
|
||||
docs.append(chunk)
|
||||
meta.append({"file": fp.name, "part": i})
|
||||
|
||||
print(f"[dim]Chargé {len(docs)} passages depuis {ROOT}.[/]")
|
||||
|
||||
# --- FAISS index existant ---------------------------------------------------
|
||||
|
||||
idx = faiss.read_index("corpus.idx")
|
||||
model = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Utilitaires
|
||||
# ---------------------------------------------------------------------------
|
||||
# --- boucle Q/A -------------------------------------------------------------
|
||||
|
||||
def _normalize(x: np.ndarray) -> np.ndarray:
|
||||
"""L2‑normalize each row (tokens=float32)."""
|
||||
return x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-12)
|
||||
|
||||
|
||||
def embed(texts):
|
||||
"""Encode list[str] → ndarray (n, dim), quelle que soit la sortie lib."""
|
||||
out = model.encode(texts)
|
||||
# Possible shapes :
|
||||
# • ndarray
|
||||
# • dict {"embedding": ndarray} ou {"embeddings": ndarray}
|
||||
# • dict {"sentence_embeds": [...]} etc.
|
||||
if isinstance(out, np.ndarray):
|
||||
arr = out
|
||||
elif isinstance(out, dict):
|
||||
# pick the first ndarray-like value
|
||||
for v in out.values():
|
||||
if isinstance(v, (list, tuple)) or hasattr(v, "shape"):
|
||||
arr = np.asarray(v)
|
||||
break
|
||||
else:
|
||||
raise TypeError("encode() dict sans clé embedding !")
|
||||
else: # list[list[float]] etc.
|
||||
arr = np.asarray(out)
|
||||
return _normalize(arr.astype("float32"))
|
||||
|
||||
|
||||
def fetch_passage(i: int) -> str:
|
||||
def fetch_passage(i: int):
|
||||
m = meta[i]
|
||||
return f"[{m['file']} · part {m['part']}] {m['text']}"
|
||||
return f"[{m['file']} · part {m['part']}] {docs[i][:200]}…"
|
||||
|
||||
|
||||
def ask_llm(prompt: str) -> str:
|
||||
r = requests.post(
|
||||
"http://127.0.0.1:11434/api/generate",
|
||||
json={
|
||||
def ask_llm(prompt: str):
|
||||
r = requests.post("http://127.0.0.1:11434/api/generate", json={
|
||||
"model": "mistral7b-fast",
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.2, "num_predict": 512},
|
||||
},
|
||||
timeout=300,
|
||||
)
|
||||
"options": {"temperature": 0.2, "num_predict": 512}
|
||||
}, timeout=300)
|
||||
return r.json()["response"]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Boucle interactive
|
||||
# ---------------------------------------------------------------------------
|
||||
print("[bold green]RAG prêt.[/] Posez vos questions ! (Ctrl‑D pour sortir)")
|
||||
print("RAG prêt. Posez vos questions ! (Ctrl‑D pour sortir)")
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
q = input("❓ > ").strip()
|
||||
if not q:
|
||||
continue
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\n[dim]Bye.[/]")
|
||||
break
|
||||
print("\nBye."); break
|
||||
|
||||
q_emb = embed([q]) # (1, dim)
|
||||
D, I = index.search(q_emb, 4)
|
||||
ctx = "\n\n".join(fetch_passage(int(idx)) for idx in I[0])
|
||||
emb = model.encode([q])
|
||||
if isinstance(emb, dict):
|
||||
emb = next(v for v in emb.values() if isinstance(v, np.ndarray))
|
||||
q_emb = emb[0] / np.linalg.norm(emb[0])
|
||||
|
||||
prompt = (
|
||||
"<system>Réponds en français, précis et factuel.</system>\n"
|
||||
f"<context>{ctx}</context>\n"
|
||||
f"<user>{q}</user>"
|
||||
)
|
||||
D, I = idx.search(q_emb.astype("float32").reshape(1, -1), K)
|
||||
|
||||
print("\n[bold]Réponse :[/]\n")
|
||||
context = "\n\n".join(docs[int(idx_id)] for idx_id in I[0])
|
||||
prompt = f"""<system>Réponds en français, précis et factuel.</system>\n<context>{context}</context>\n<user>{q}</user>"""
|
||||
|
||||
print("\n[bold]Réponse :[/]")
|
||||
print(ask_llm(prompt))
|
||||
|
||||
# petite trace des sources
|
||||
print("\n[dim]--- contexte utilisé ---[/]")
|
||||
print(ctx)
|
||||
for idx_id in I[0]:
|
||||
print(fetch_passage(int(idx_id)))
|
||||
except Exception as e:
|
||||
print("[red]Erreur :", e)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user