Code/rag.py
2025-05-19 06:22:52 +02:00

108 lines
3.3 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Mini RAG interactif :
• Recherche sémantique FAISS sur le corpus.idx / corpus.meta.json
• Contexte (topk=4 passages) envoyé à Mistral7B via Ollama.
Robuste aux différentes sorties de BGEM3FlagModel.encode.
"""
import json
import readline
from pathlib import Path
import faiss
import numpy as np
import requests
from FlagEmbedding import BGEM3FlagModel
from rich import print
# ---------------------------------------------------------------------------
# Chargements initiaux
# ---------------------------------------------------------------------------
IDX_FILE = Path("corpus.idx")
META_FILE = Path("corpus.meta.json")
if not IDX_FILE.exists() or not META_FILE.exists():
raise SystemExit("[bold red]Erreur :[/] index absent. Lancez d'abord index.py !")
index = faiss.read_index(str(IDX_FILE))
meta = json.loads(META_FILE.read_text())
model = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
# ---------------------------------------------------------------------------
# Utilitaires
# ---------------------------------------------------------------------------
def _normalize(x: np.ndarray) -> np.ndarray:
"""L2normalize each row (tokens=float32)."""
return x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-12)
def embed(texts):
"""Encode list[str] → ndarray (n, dim), quelle que soit la sortie lib."""
out = model.encode(texts)
# Possible shapes :
# • ndarray
# • dict {"embedding": ndarray} ou {"embeddings": ndarray}
# • dict {"sentence_embeds": [...]} etc.
if isinstance(out, np.ndarray):
arr = out
elif isinstance(out, dict):
# pick the first ndarray-like value
for v in out.values():
if isinstance(v, (list, tuple)) or hasattr(v, "shape"):
arr = np.asarray(v)
break
else:
raise TypeError("encode() dict sans clé embedding !")
else: # list[list[float]] etc.
arr = np.asarray(out)
return _normalize(arr.astype("float32"))
def fetch_passage(i: int) -> str:
m = meta[i]
return f"[{m['file']} · part {m['part']}] {m['text']}"
def ask_llm(prompt: str) -> str:
r = requests.post(
"http://127.0.0.1:11434/api/generate",
json={
"model": "mistral7b-fast",
"prompt": prompt,
"stream": False,
"options": {"temperature": 0.2, "num_predict": 512},
},
timeout=300,
)
return r.json()["response"]
# ---------------------------------------------------------------------------
# Boucle interactive
# ---------------------------------------------------------------------------
print("[bold green]RAG prêt.[/] Posez vos questions ! (CtrlD pour sortir)")
while True:
try:
q = input("❓ > ").strip()
if not q:
continue
except (EOFError, KeyboardInterrupt):
print("\n[dim]Bye.[/]")
break
q_emb = embed([q]) # (1, dim)
D, I = index.search(q_emb, 4)
ctx = "\n\n".join(fetch_passage(int(idx)) for idx in I[0])
prompt = (
"<system>Réponds en français, précis et factuel.</system>\n"
f"<context>{ctx}</context>\n"
f"<user>{q}</user>"
)
print("\n[bold]Réponse :[/]\n")
print(ask_llm(prompt))
print("\n[dim]--- contexte utilisé ---[/]")
print(ctx)