Update rag.py
This commit is contained in:
parent
86a902de9d
commit
4b16c2210e
173
rag.py
173
rag.py
@ -1,113 +1,124 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
RAG interactif – version alignée sur l'index
|
rag.py — recherche + génération (version robuste, chapitres)
|
||||||
-------------------------------------------
|
============================================================
|
||||||
• Utilise corpus.idx + corpus.meta.json pour connaître l'ordre exact des passages.
|
• Charge **un ou plusieurs** couples index/meta (FAISS + JSON). Par défaut :
|
||||||
• Recharge **uniquement** les textes correspondants en gardant cet ordre – ainsi, plus
|
chap.idx / chap.meta.json
|
||||||
d'erreur d'index out‑of‑range quelle que soit la découpe.
|
• Reconstitue les textes à partir des fichiers `path` indiqués dans la méta.
|
||||||
• Recherche FAISS (top‑k=4) + génération via mistral7b-fast (Ollama).
|
– Les passages sont déjà prêts (1 par fichier court, ou découpés par index.py).
|
||||||
|
• Recherche : embeddings BGE‑M3 (CPU) + FAISS (cosinus IP) sur tous les index.
|
||||||
|
– top‑k configurable (déf. 20 pour index détaillé, 5 pour index chapitres).
|
||||||
|
– trie ensuite les hits mettant en avant ceux contenant un mot‑clé fourni
|
||||||
|
(ex. « seuil » pour ICS).
|
||||||
|
• Génération : appelle Mistral‑7B (Ollama) avec temperature 0.1 et consigne :
|
||||||
|
« Réponds uniquement à partir du contexte. Si l’info manque : Je ne sais pas. »
|
||||||
|
|
||||||
|
Usage :
|
||||||
|
python rag.py [--k 25] [--kw seuil] [--model mistral7b-fast]
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
import json, readline, re
|
import argparse, json, re, sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
import faiss, numpy as np, requests
|
import faiss, numpy as np, requests
|
||||||
from FlagEmbedding import BGEM3FlagModel
|
from FlagEmbedding import BGEM3FlagModel
|
||||||
from rich import print
|
from rich import print
|
||||||
|
|
||||||
ROOT = Path("Fiches") # dossier racine des fiches
|
# ------------------------- CLI -------------------------------------------
|
||||||
K = 30 # nombre de passages remis au LLM
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument("--index", nargs="*", default=["chap.idx"],
|
||||||
|
help="Liste des fichiers FAISS à charger (déf. chap.idx)")
|
||||||
|
p.add_argument("--meta", nargs="*", default=["chap.meta.json"],
|
||||||
|
help="Liste des méta JSON assortis (même ordre que --index)")
|
||||||
|
p.add_argument("--k", type=int, default=15, help="top‑k cumulés (déf. 15)")
|
||||||
|
p.add_argument("--kw", default="seuil", help="mot‑clé boosté (déf. seuil)")
|
||||||
|
p.add_argument("--model", default="mistral7b-fast", help="modèle Ollama")
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
# ------------------ utilitaires de découpe identiques à l'index -------------
|
if len(args.index) != len(args.meta):
|
||||||
CHUNK, OVERLAP = 800, 100 # garder cohérent avec index.py
|
print("[red]Erreur : --index et --meta doivent avoir la même longueur.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
def split(text: str):
|
# ------------------------- charger indexes -------------------------------
|
||||||
sents = re.split(r"(?<=[.!?]) +", text)
|
indexes, metas, start_offset = [], [], []
|
||||||
buf, out = [], []
|
offset = 0
|
||||||
for s in sents:
|
for idx_f, meta_f in zip(args.index, args.meta):
|
||||||
buf.append(s)
|
idx = faiss.read_index(str(idx_f))
|
||||||
if len(" ".join(buf).split()) > CHUNK:
|
meta = json.load(open(meta_f))
|
||||||
out.append(" ".join(buf))
|
if idx.ntotal != len(meta):
|
||||||
buf = buf[-OVERLAP:]
|
print(f"[yellow]Avertissement : {idx_f} contient {idx.ntotal} vecteurs, meta {len(meta)} lignes.[/]")
|
||||||
if buf:
|
indexes.append(idx)
|
||||||
out.append(" ".join(buf))
|
metas.append(meta)
|
||||||
return out
|
start_offset.append(offset)
|
||||||
|
offset += idx.ntotal
|
||||||
|
|
||||||
# ------------------- charger meta et reconstruire passages ------------------
|
total_passages = offset
|
||||||
meta_path = Path("corpus.meta.json")
|
print(f"Passages chargés : {total_passages} (agrégat de {len(indexes)} index)")
|
||||||
if not meta_path.exists():
|
|
||||||
raise SystemExit("corpus.meta.json introuvable – lancez d'abord index.py")
|
|
||||||
meta = json.load(meta_path.open())
|
|
||||||
|
|
||||||
# mapping (file, part) -> chunk text
|
# ------------------------- cache texte -----------------------------------
|
||||||
cache: dict[tuple[str, int], str] = {}
|
DOCS: dict[int,str] = {}
|
||||||
for fp in sorted(ROOT.rglob("*")):
|
for base_offset, meta in zip(start_offset, metas):
|
||||||
if fp.suffix.lower() not in {".md", ".markdown", ".txt"}:
|
for i, m in enumerate(meta):
|
||||||
continue
|
DOCS[base_offset + i] = Path(m["path"]).read_text(encoding="utf-8")
|
||||||
chunks = split(fp.read_text(encoding="utf-8"))
|
print("[dim]Cache texte préchargé.[/]")
|
||||||
for i, ch in enumerate(chunks):
|
|
||||||
cache[(fp.name, i)] = ch
|
|
||||||
|
|
||||||
# reconstruire docs dans le même ordre que l'index ---------------------------
|
# ------------------------- modèle embeddings -----------------------------
|
||||||
docs = []
|
embedder = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
|
||||||
for m in meta:
|
|
||||||
# compatibilité avec l’ancien et le nouveau format
|
|
||||||
path = m.get("file") or m.get("path") # nouvelle clé : "path"
|
|
||||||
part = m["part"]
|
|
||||||
key = (Path(path).name, part) # on garde le nom court pour le cache
|
|
||||||
|
|
||||||
docs.append(cache.get(key, "[passage manquant]"))
|
# ------------------------- helpers ---------------------------------------
|
||||||
|
def encode_query(q: str):
|
||||||
|
emb = embedder.encode([q])
|
||||||
|
if isinstance(emb, dict):
|
||||||
|
emb = next(v for v in emb.values() if isinstance(v, np.ndarray))
|
||||||
|
v = emb[0]
|
||||||
|
return (v / np.linalg.norm(v)).astype("float32").reshape(1, -1)
|
||||||
|
|
||||||
print(f"[dim]Passages rechargés : {len(docs)} (ordre conforme à l'index).[/]")
|
def search_all(vec):
|
||||||
|
hits = []
|
||||||
|
for idx, off in zip(indexes, start_offset):
|
||||||
|
D, I = idx.search(vec, min(args.k, idx.ntotal))
|
||||||
|
hits.extend([off + int(i) for i in I[0]])
|
||||||
|
return hits
|
||||||
|
|
||||||
# ---------------- FAISS + modèle embeddings --------------------------------
|
# ------------------------- boucle interactive ----------------------------
|
||||||
idx = faiss.read_index("corpus.idx")
|
print("RAG prêt ! (Ctrl‑D pour quitter)")
|
||||||
model = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
|
|
||||||
|
|
||||||
# ---------------- boucle interactive ---------------------------------------
|
|
||||||
print("RAG prêt. Posez vos questions ! (Ctrl‑D pour sortir)")
|
|
||||||
try:
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
q = input("❓ > ").strip()
|
q = input("❓ > ").strip()
|
||||||
if not q:
|
|
||||||
continue
|
|
||||||
except (EOFError, KeyboardInterrupt):
|
except (EOFError, KeyboardInterrupt):
|
||||||
print("\nBye."); break
|
print("\nBye."); break
|
||||||
|
if not q: continue
|
||||||
|
|
||||||
emb = model.encode([q])
|
# correction rapide de typos courantes (substituabilité…)
|
||||||
if isinstance(emb, dict):
|
q_norm = re.sub(r"susbtitu[a-z]+", "substituabilité", q, flags=re.I)
|
||||||
# récupère le 1er ndarray trouvé
|
|
||||||
emb = next(v for v in emb.values() if isinstance(v, np.ndarray))
|
|
||||||
q_emb = emb[0] / np.linalg.norm(emb[0])
|
|
||||||
|
|
||||||
D, I = idx.search(q_emb.astype("float32").reshape(1, -1), K)
|
vec = encode_query(q_norm)
|
||||||
hits = I[0]
|
hits = search_all(vec)
|
||||||
# réordonne pour mettre en tête les passages contenant “Seuil”
|
|
||||||
hits = sorted(hits, key=lambda i: "Seuil" not in docs[int(i)])
|
# Boost lexical : passages contenant le mot‑clé args.kw d’abord
|
||||||
|
kw_lower = args.kw.lower()
|
||||||
|
hits.sort(key=lambda i: kw_lower not in DOCS[i].lower())
|
||||||
|
hits = hits[:args.k]
|
||||||
|
|
||||||
|
context = "\n\n".join(DOCS[i] for i in hits)
|
||||||
|
|
||||||
context = "\n\n".join(docs[int(i)] for i in hits[:K])
|
|
||||||
prompt = (
|
prompt = (
|
||||||
"<system>Réponds en français, de façon précise, et uniquement à partir du contexte fourni. Si l'information n'est pas dans le contexte, réponds : 'Je ne sais pas'.</system>\n"
|
"<system>Réponds en français, de façon précise et uniquement à partir du contexte. "
|
||||||
f"<context>{context}</context>\n<user>{q}</user>"
|
"Si l'information n'est pas dans le contexte, réponds : 'Je ne sais pas'.</system>\n"
|
||||||
|
f"<context>{context}</context>\n"
|
||||||
|
f"<user>{q}</user>"
|
||||||
)
|
)
|
||||||
|
|
||||||
def ask_llm(p):
|
|
||||||
r = requests.post("http://127.0.0.1:11434/api/generate", json={
|
r = requests.post("http://127.0.0.1:11434/api/generate", json={
|
||||||
"model": "mistral7b-fast",
|
"model": args.model,
|
||||||
"prompt": p,
|
"prompt": prompt,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"options": {"temperature": 0.0, "num_predict": 512}
|
"options": {"temperature": 0.1, "num_predict": 512}
|
||||||
}, timeout=300)
|
}, timeout=300)
|
||||||
return r.json()["response"]
|
answer = r.json().get("response", "(erreur API)")
|
||||||
|
|
||||||
print("\n[bold]Réponse :[/]")
|
print("\n[bold]Réponse :[/]\n", answer)
|
||||||
print(ask_llm(prompt))
|
print("\n[dim]--- contexte utilisé (top " + str(len(hits)) + ") ---[/]")
|
||||||
|
|
||||||
print("\n[dim]--- contexte utilisé ---[/]")
|
|
||||||
for rank, idx_id in enumerate(hits, 1):
|
for rank, idx_id in enumerate(hits, 1):
|
||||||
m = meta[int(idx_id)]
|
m = metas[0] # non utilisé ici, on affiche juste le nom
|
||||||
print(f"[{rank}] {Path(m.get('file') or m.get('path')).name} · part {m['part']} → …")
|
path = DOCS[idx_id].splitlines()[0][:60]
|
||||||
except Exception as e:
|
print(f"[{rank}] … {path}…")
|
||||||
print("[red]Erreur :", e)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user