Update rag_md.py
This commit is contained in:
parent
952f0dd92d
commit
54c6a309e6
117
rag_md.py
117
rag_md.py
@ -5,8 +5,8 @@ import requests
|
|||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# 1. Charger les fichiers Markdown et enrichir le contexte
|
|
||||||
def collect_markdown_files(root_dir):
|
def collect_markdown_files(root_dir):
|
||||||
|
"""Parcourt récursivement le répertoire pour charger les fichiers .md"""
|
||||||
texts, sources, raw_contents = [], [], []
|
texts, sources, raw_contents = [], [], []
|
||||||
for root, dirs, files in os.walk(root_dir):
|
for root, dirs, files in os.walk(root_dir):
|
||||||
for f in files:
|
for f in files:
|
||||||
@ -25,67 +25,92 @@ def collect_markdown_files(root_dir):
|
|||||||
print(f"Erreur lecture {full_path}: {e}")
|
print(f"Erreur lecture {full_path}: {e}")
|
||||||
return texts, sources, raw_contents
|
return texts, sources, raw_contents
|
||||||
|
|
||||||
# 2. Initialisation
|
def build_faiss_index(texts, model):
|
||||||
ROOT_DIR = "Corpus"
|
"""Crée l'index FAISS avec les embeddings"""
|
||||||
print("🔍 Chargement des fichiers markdown...")
|
print("📦 Génération des embeddings...")
|
||||||
documents, paths, raw_contents = collect_markdown_files(ROOT_DIR)
|
embeddings = model.encode(texts, show_progress_bar=True)
|
||||||
print(f"📄 {len(documents)} fichiers chargés.")
|
dim = embeddings.shape[1]
|
||||||
|
index = faiss.IndexFlatL2(dim)
|
||||||
|
index.add(np.array(embeddings))
|
||||||
|
return index, embeddings
|
||||||
|
|
||||||
print("📦 Génération des embeddings...")
|
def search_hybrid(query, embeddings, texts, paths, raw_contents, model, root_dir, k=5):
|
||||||
model = SentenceTransformer("all-MiniLM-L6-v2")
|
"""Effectue une recherche hybride : vecteurs + mots-clés"""
|
||||||
embeddings = model.encode(documents, show_progress_bar=True)
|
print("🔗 Recherche vectorielle...")
|
||||||
|
query_vector = model.encode([query])
|
||||||
|
_, faiss_indices = index.search(np.array(query_vector), k)
|
||||||
|
|
||||||
# 3. Indexation FAISS
|
vector_results = [(texts[i], paths[i]) for i in faiss_indices[0]]
|
||||||
dim = embeddings.shape[1]
|
|
||||||
index = faiss.IndexFlatL2(dim)
|
|
||||||
index.add(np.array(embeddings))
|
|
||||||
|
|
||||||
# 4. Boucle de questions
|
print("🔍 Recherche par mot-clé améliorée...")
|
||||||
while True:
|
query_lower = query.lower()
|
||||||
query = input("\n🔎 Pose ta question : ").strip()
|
keywords = set(re.findall(r'\w+', query_lower))
|
||||||
if not query:
|
|
||||||
break
|
|
||||||
|
|
||||||
print("\n🔗 Recherche vectorielle...")
|
|
||||||
query_embedding = model.encode([query])
|
|
||||||
_, faiss_indices = index.search(np.array(query_embedding), k=5)
|
|
||||||
|
|
||||||
vector_results = [(documents[i], paths[i]) for i in faiss_indices[0]]
|
|
||||||
|
|
||||||
print("🔍 Recherche par mot-clé...")
|
|
||||||
keyword_hits = []
|
keyword_hits = []
|
||||||
keywords = re.findall(r'\w+', query.lower())
|
|
||||||
for i, (path, content) in enumerate(zip(paths, raw_contents)):
|
for i, (path, content) in enumerate(zip(paths, raw_contents)):
|
||||||
combined = f"{path.lower()} {content.lower()}"
|
haystack = f"{path} {content}".lower()
|
||||||
if all(kw in combined for kw in keywords):
|
match_count = sum(1 for kw in keywords if kw in haystack)
|
||||||
keyword_hits.append((documents[i], paths[i]))
|
if match_count >= 2 or 'isg' in haystack:
|
||||||
|
keyword_hits.append((texts[i], paths[i], match_count))
|
||||||
|
|
||||||
# 5. Fusionner résultats (vector d'abord, puis keyword)
|
keyword_hits.sort(key=lambda x: -x[2])
|
||||||
all_results = vector_results + keyword_hits
|
keyword_results = [(doc, path) for doc, path, _ in keyword_hits[:5]]
|
||||||
seen_paths = set()
|
|
||||||
|
combined = vector_results + keyword_results
|
||||||
|
seen = set()
|
||||||
unique_results = []
|
unique_results = []
|
||||||
for doc, p in all_results:
|
for doc, path in combined:
|
||||||
if p not in seen_paths:
|
if path not in seen:
|
||||||
unique_results.append((doc, p))
|
unique_results.append((doc, path))
|
||||||
seen_paths.add(p)
|
seen.add(path)
|
||||||
|
|
||||||
top_contexts = [doc for doc, _ in unique_results[:3]]
|
top_contexts = [doc for doc, _ in unique_results[:3]]
|
||||||
top_sources = [os.path.relpath(p, ROOT_DIR) for _, p in unique_results[:3]]
|
top_sources = [os.path.relpath(p, root_dir) for _, p in unique_results[:3]]
|
||||||
contexte = "\n\n".join(top_contexts)
|
return top_contexts, top_sources
|
||||||
|
|
||||||
|
def ask_ollama(prompt, model_name="llama3-8b-fast:latest"):
|
||||||
|
"""Appelle le modèle Ollama"""
|
||||||
|
response = requests.post(
|
||||||
|
"http://localhost:11434/api/generate",
|
||||||
|
json={"model": model_name, "prompt": prompt, "stream": False}
|
||||||
|
)
|
||||||
|
return response.json()["response"]
|
||||||
|
|
||||||
|
# === Main ===
|
||||||
|
ROOT_DIR = "Corpus"
|
||||||
|
MODEL_NAME = "all-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
print("🔍 Chargement des fichiers markdown...")
|
||||||
|
texts, paths, raw_contents = collect_markdown_files(ROOT_DIR)
|
||||||
|
print(f"📄 {len(texts)} fichiers chargés.")
|
||||||
|
|
||||||
|
print("🧠 Chargement du modèle d'embedding...")
|
||||||
|
model = SentenceTransformer(MODEL_NAME)
|
||||||
|
|
||||||
|
index, embeddings = build_faiss_index(texts, model)
|
||||||
|
|
||||||
|
# Boucle utilisateur
|
||||||
|
while True:
|
||||||
|
query = input("\n🔎 Pose ta question (ou Entrée pour quitter) : ").strip()
|
||||||
|
if not query:
|
||||||
|
print("👋 Fin du programme.")
|
||||||
|
break
|
||||||
|
|
||||||
|
top_contexts, top_sources = search_hybrid(
|
||||||
|
query, embeddings, texts, paths, raw_contents, model, ROOT_DIR, k=10
|
||||||
|
)
|
||||||
|
|
||||||
|
context = "\n\n".join(top_contexts)
|
||||||
fichiers_utilisés = "\n".join(f"- {src}" for src in top_sources)
|
fichiers_utilisés = "\n".join(f"- {src}" for src in top_sources)
|
||||||
|
|
||||||
# 6. Préparer le prompt
|
|
||||||
prompt = (
|
prompt = (
|
||||||
f"Contexte :\n{contexte}\n\n"
|
f"Contexte :\n{context}\n\n"
|
||||||
f"Question : {query}\n"
|
f"Question : {query}\n"
|
||||||
f"Réponds clairement et cite les éléments importants si besoin."
|
f"Réponds clairement, cite les seuils ou données si disponibles."
|
||||||
)
|
)
|
||||||
|
|
||||||
print("\n🧠 Appel au modèle Ollama...\n")
|
print("\n🧠 Appel au modèle Ollama...\n")
|
||||||
res = requests.post(
|
reponse = ask_ollama(prompt)
|
||||||
"http://localhost:11434/api/generate",
|
|
||||||
json={"model": "llama3-8b-fast:latest", "prompt": prompt, "stream": False}
|
|
||||||
)
|
|
||||||
|
|
||||||
print("📘 Fichiers utilisés :\n", fichiers_utilisés)
|
print("📘 Fichiers utilisés :\n", fichiers_utilisés)
|
||||||
print("\n🧠 Réponse :\n", res.json()["response"])
|
print("\n🧠 Réponse :\n", reponse)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user