From 54c6a309e68eff4cf27143ef35c1202b6aade9ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phan?= Date: Mon, 19 May 2025 14:22:39 +0200 Subject: [PATCH] Update rag_md.py --- rag_md.py | 117 +++++++++++++++++++++++++++++++++--------------------- 1 file changed, 71 insertions(+), 46 deletions(-) diff --git a/rag_md.py b/rag_md.py index 5dcaeb4..f52e70e 100644 --- a/rag_md.py +++ b/rag_md.py @@ -5,8 +5,8 @@ import requests from sentence_transformers import SentenceTransformer import re -# 1. Charger les fichiers Markdown et enrichir le contexte def collect_markdown_files(root_dir): + """Parcourt récursivement le répertoire pour charger les fichiers .md""" texts, sources, raw_contents = [], [], [] for root, dirs, files in os.walk(root_dir): for f in files: @@ -25,67 +25,92 @@ def collect_markdown_files(root_dir): print(f"Erreur lecture {full_path}: {e}") return texts, sources, raw_contents -# 2. Initialisation -ROOT_DIR = "Corpus" -print("🔍 Chargement des fichiers markdown...") -documents, paths, raw_contents = collect_markdown_files(ROOT_DIR) -print(f"📄 {len(documents)} fichiers chargés.") +def build_faiss_index(texts, model): + """Crée l'index FAISS avec les embeddings""" + print("📦 Génération des embeddings...") + embeddings = model.encode(texts, show_progress_bar=True) + dim = embeddings.shape[1] + index = faiss.IndexFlatL2(dim) + index.add(np.array(embeddings)) + return index, embeddings -print("📦 Génération des embeddings...") -model = SentenceTransformer("all-MiniLM-L6-v2") -embeddings = model.encode(documents, show_progress_bar=True) +def search_hybrid(query, embeddings, texts, paths, raw_contents, model, root_dir, k=5): + """Effectue une recherche hybride : vecteurs + mots-clés""" + print("🔗 Recherche vectorielle...") + query_vector = model.encode([query]) + _, faiss_indices = index.search(np.array(query_vector), k) -# 3. Indexation FAISS -dim = embeddings.shape[1] -index = faiss.IndexFlatL2(dim) -index.add(np.array(embeddings)) + vector_results = [(texts[i], paths[i]) for i in faiss_indices[0]] -# 4. Boucle de questions -while True: - query = input("\n🔎 Pose ta question : ").strip() - if not query: - break + print("🔍 Recherche par mot-clé améliorée...") + query_lower = query.lower() + keywords = set(re.findall(r'\w+', query_lower)) - print("\n🔗 Recherche vectorielle...") - query_embedding = model.encode([query]) - _, faiss_indices = index.search(np.array(query_embedding), k=5) - - vector_results = [(documents[i], paths[i]) for i in faiss_indices[0]] - - print("🔍 Recherche par mot-clé...") keyword_hits = [] - keywords = re.findall(r'\w+', query.lower()) for i, (path, content) in enumerate(zip(paths, raw_contents)): - combined = f"{path.lower()} {content.lower()}" - if all(kw in combined for kw in keywords): - keyword_hits.append((documents[i], paths[i])) + haystack = f"{path} {content}".lower() + match_count = sum(1 for kw in keywords if kw in haystack) + if match_count >= 2 or 'isg' in haystack: + keyword_hits.append((texts[i], paths[i], match_count)) - # 5. Fusionner résultats (vector d'abord, puis keyword) - all_results = vector_results + keyword_hits - seen_paths = set() + keyword_hits.sort(key=lambda x: -x[2]) + keyword_results = [(doc, path) for doc, path, _ in keyword_hits[:5]] + + combined = vector_results + keyword_results + seen = set() unique_results = [] - for doc, p in all_results: - if p not in seen_paths: - unique_results.append((doc, p)) - seen_paths.add(p) + for doc, path in combined: + if path not in seen: + unique_results.append((doc, path)) + seen.add(path) top_contexts = [doc for doc, _ in unique_results[:3]] - top_sources = [os.path.relpath(p, ROOT_DIR) for _, p in unique_results[:3]] - contexte = "\n\n".join(top_contexts) + top_sources = [os.path.relpath(p, root_dir) for _, p in unique_results[:3]] + return top_contexts, top_sources + +def ask_ollama(prompt, model_name="llama3-8b-fast:latest"): + """Appelle le modèle Ollama""" + response = requests.post( + "http://localhost:11434/api/generate", + json={"model": model_name, "prompt": prompt, "stream": False} + ) + return response.json()["response"] + +# === Main === +ROOT_DIR = "Corpus" +MODEL_NAME = "all-MiniLM-L6-v2" + +print("🔍 Chargement des fichiers markdown...") +texts, paths, raw_contents = collect_markdown_files(ROOT_DIR) +print(f"📄 {len(texts)} fichiers chargés.") + +print("🧠 Chargement du modèle d'embedding...") +model = SentenceTransformer(MODEL_NAME) + +index, embeddings = build_faiss_index(texts, model) + +# Boucle utilisateur +while True: + query = input("\n🔎 Pose ta question (ou Entrée pour quitter) : ").strip() + if not query: + print("👋 Fin du programme.") + break + + top_contexts, top_sources = search_hybrid( + query, embeddings, texts, paths, raw_contents, model, ROOT_DIR, k=10 + ) + + context = "\n\n".join(top_contexts) fichiers_utilisés = "\n".join(f"- {src}" for src in top_sources) - # 6. Préparer le prompt prompt = ( - f"Contexte :\n{contexte}\n\n" + f"Contexte :\n{context}\n\n" f"Question : {query}\n" - f"Réponds clairement et cite les éléments importants si besoin." + f"Réponds clairement, cite les seuils ou données si disponibles." ) print("\n🧠 Appel au modèle Ollama...\n") - res = requests.post( - "http://localhost:11434/api/generate", - json={"model": "llama3-8b-fast:latest", "prompt": prompt, "stream": False} - ) + reponse = ask_ollama(prompt) print("📘 Fichiers utilisés :\n", fichiers_utilisés) - print("\n🧠 Réponse :\n", res.json()["response"]) + print("\n🧠 Réponse :\n", reponse)