diff --git a/rag_md.py b/rag_md.py new file mode 100644 index 0000000..5b95144 --- /dev/null +++ b/rag_md.py @@ -0,0 +1,91 @@ +import os +import faiss +import numpy as np +import requests +from sentence_transformers import SentenceTransformer +import re + +# 1. Charger les fichiers Markdown et enrichir le contexte +def collect_markdown_files(root_dir): + texts, sources, raw_contents = [], [], [] + for root, dirs, files in os.walk(root_dir): + for f in files: + if f.endswith(".md"): + full_path = os.path.join(root, f) + rel_path = os.path.relpath(full_path, root_dir) + try: + with open(full_path, "r", encoding="utf-8") as file: + content = file.read().strip() + if content: + enriched = f"[Fichier : {rel_path}]\n\n{content}" + texts.append(enriched) + sources.append(full_path) + raw_contents.append(content) + except Exception as e: + print(f"Erreur lecture {full_path}: {e}") + return texts, sources, raw_contents + +# 2. Initialisation +ROOT_DIR = "mes_fiches" +print("🔍 Chargement des fichiers markdown...") +documents, paths, raw_contents = collect_markdown_files(ROOT_DIR) +print(f"📄 {len(documents)} fichiers chargés.") + +print("📦 Génération des embeddings...") +model = SentenceTransformer("all-MiniLM-L6-v2") +embeddings = model.encode(documents, show_progress_bar=True) + +# 3. Indexation FAISS +dim = embeddings.shape[1] +index = faiss.IndexFlatL2(dim) +index.add(np.array(embeddings)) + +# 4. Boucle de questions +while True: + query = input("\n🔎 Pose ta question : ").strip() + if not query: + break + + print("\n🔗 Recherche vectorielle...") + query_embedding = model.encode([query]) + _, faiss_indices = index.search(np.array(query_embedding), k=5) + + vector_results = [(documents[i], paths[i]) for i in faiss_indices[0]] + + print("🔍 Recherche par mot-clé...") + keyword_hits = [] + keywords = re.findall(r'\w+', query.lower()) + for i, (path, content) in enumerate(zip(paths, raw_contents)): + combined = f"{path.lower()} {content.lower()}" + if all(kw in combined for kw in keywords): + keyword_hits.append((documents[i], paths[i])) + + # 5. Fusionner résultats (vector d'abord, puis keyword) + all_results = vector_results + keyword_hits + seen_paths = set() + unique_results = [] + for doc, p in all_results: + if p not in seen_paths: + unique_results.append((doc, p)) + seen_paths.add(p) + + top_contexts = [doc for doc, _ in unique_results[:3]] + top_sources = [os.path.relpath(p, ROOT_DIR) for _, p in unique_results[:3]] + contexte = "\n\n".join(top_contexts) + fichiers_utilisés = "\n".join(f"- {src}" for src in top_sources) + + # 6. Préparer le prompt + prompt = ( + f"Contexte :\n{contexte}\n\n" + f"Question : {query}\n" + f"Réponds clairement et cite les éléments importants si besoin." + ) + + print("\n🧠 Appel au modèle Ollama...\n") + res = requests.post( + "http://localhost:11434/api/generate", + json={"model": "llama3", "prompt": prompt, "stream": False} + ) + + print("📘 Fichiers utilisés :\n", fichiers_utilisés) + print("\n🧠 Réponse :\n", res.json()["response"])