Import de private_gpt et amléiorations de l'analyse IA
This commit is contained in:
parent
c4fffb829c
commit
95ede9c6f1
171
IA/get_regeneration_plan.py
Normal file
171
IA/get_regeneration_plan.py
Normal file
@ -0,0 +1,171 @@
|
||||
from datetime import datetime
|
||||
from collections import defaultdict, deque
|
||||
import os
|
||||
import sys
|
||||
from datetime import timezone
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
# À adapter dans ton environnement
|
||||
from config import GITEA_URL, ORGANISATION, DEPOT_FICHES, ENV
|
||||
from utils.gitea import recuperer_date_dernier_commit
|
||||
from IA.make_config import MAKE # MAKE doit être importé depuis un fichier de config
|
||||
|
||||
def get_mtime(path):
|
||||
try:
|
||||
return datetime.fromtimestamp(os.path.getmtime(path), tz=timezone.utc)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
def get_commit_time(path_relative):
|
||||
commits_url = f"{GITEA_URL}/repos/{ORGANISATION}/{DEPOT_FICHES}/commits?path={path_relative.replace("Fiches", "Documents")}&sha={ENV}"
|
||||
return recuperer_date_dernier_commit(commits_url)
|
||||
|
||||
def resolve_path_from_where(where_str):
|
||||
parts = where_str.split(".")
|
||||
current = MAKE
|
||||
path_stack = []
|
||||
|
||||
for part in parts:
|
||||
if isinstance(current, dict) and part in current:
|
||||
path_stack.append((part, current))
|
||||
current = current[part]
|
||||
else:
|
||||
return None
|
||||
|
||||
if not isinstance(current, str):
|
||||
return None
|
||||
|
||||
for i in range(len(path_stack) - 1, -1, -1):
|
||||
key, context = path_stack[i]
|
||||
if "directory" in context:
|
||||
directory = context["directory"]
|
||||
if "fiches" in where_str:
|
||||
return os.path.join("Fiches", directory, current)
|
||||
else:
|
||||
return os.path.join(directory, current)
|
||||
|
||||
return None
|
||||
|
||||
def identifier_type_fiche(path):
|
||||
for type_fiche, data in MAKE["fiches"].items():
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
directory = data.get("directory", "")
|
||||
prefix = data.get("prefix", "")
|
||||
base = os.path.join("Fiches", directory, prefix)
|
||||
if path.startswith(base):
|
||||
return type_fiche, data
|
||||
raise ValueError("Type de fiche non reconnu")
|
||||
|
||||
def doit_regenerer(fichier, doc_deps, fiche_data=None):
|
||||
mtime_fichier = get_mtime(fichier)
|
||||
|
||||
if fiche_data:
|
||||
gitea_dep = fiche_data.get("depends_on", {}).get("gitea", {})
|
||||
if gitea_dep.get("compare") == "file2commit":
|
||||
commit_time = get_commit_time(fichier)
|
||||
if commit_time and mtime_fichier and commit_time > mtime_fichier:
|
||||
return True
|
||||
|
||||
for _, dep in doc_deps.items():
|
||||
if isinstance(dep, dict) and "where" in dep:
|
||||
source_path = resolve_path_from_where(dep["where"])
|
||||
if source_path:
|
||||
if dep["compare"] == "file2file":
|
||||
mtime_source = get_mtime(source_path)
|
||||
elif dep["compare"] == "file2commit":
|
||||
mtime_source = get_commit_time(source_path)
|
||||
else:
|
||||
continue
|
||||
if mtime_source and mtime_fichier and mtime_source > mtime_fichier:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_regeneration_plan(fiche_path):
|
||||
def build_dependency_graph_complete(path, graph=None, visited=None):
|
||||
if graph is None:
|
||||
graph = defaultdict(set)
|
||||
if visited is None:
|
||||
visited = set()
|
||||
if path in visited:
|
||||
return graph
|
||||
visited.add(path)
|
||||
|
||||
try:
|
||||
_, fiche_data = identifier_type_fiche(path)
|
||||
except ValueError:
|
||||
return graph
|
||||
|
||||
depends = fiche_data.get("depends_on", {})
|
||||
doc_deps = depends.get("document", {})
|
||||
|
||||
for _, dep_info in doc_deps.items():
|
||||
if isinstance(dep_info, dict) and "where" in dep_info:
|
||||
dep_path = resolve_path_from_where(dep_info["where"])
|
||||
if dep_path:
|
||||
graph[path].add(dep_path)
|
||||
build_dependency_graph_complete(dep_path, graph, visited)
|
||||
|
||||
if "file" in fiche_data:
|
||||
for fichier in fiche_data["file"].values():
|
||||
dir_fiche = fiche_data.get("directory", "")
|
||||
fichier_path = os.path.join("Fiches", dir_fiche, fichier)
|
||||
if fichier_path not in graph:
|
||||
graph[fichier_path] = set()
|
||||
|
||||
return graph
|
||||
|
||||
def topological_sort(graph):
|
||||
in_degree = defaultdict(int)
|
||||
for node in graph:
|
||||
for dep in graph[node]:
|
||||
in_degree[dep] += 1
|
||||
queue = deque([node for node in graph if in_degree[node] == 0])
|
||||
result = []
|
||||
|
||||
while queue:
|
||||
node = queue.popleft()
|
||||
result.append(node)
|
||||
for dep in graph[node]:
|
||||
in_degree[dep] -= 1
|
||||
if in_degree[dep] == 0:
|
||||
queue.append(dep)
|
||||
|
||||
all_nodes = set(graph.keys()).union(*graph.values())
|
||||
for node in all_nodes:
|
||||
if node not in result:
|
||||
result.append(node)
|
||||
|
||||
return result[::-1]
|
||||
|
||||
graph = build_dependency_graph_complete(fiche_path)
|
||||
sorted_fiches = topological_sort(graph)
|
||||
if fiche_path not in sorted_fiches:
|
||||
sorted_fiches.append(fiche_path)
|
||||
|
||||
|
||||
to_regen = []
|
||||
regen_flags = {}
|
||||
|
||||
for fiche in sorted_fiches:
|
||||
print(f"=> {fiche}")
|
||||
try:
|
||||
_, fiche_data = identifier_type_fiche(fiche)
|
||||
except ValueError:
|
||||
fiche_data = None
|
||||
depends = fiche_data.get("depends_on", {}) if fiche_data else {}
|
||||
doc_deps = depends.get("document", {}) if depends else {}
|
||||
|
||||
doit = doit_regenerer(fiche, doc_deps, fiche_data)
|
||||
if any(regen_flags.get(dep, False) for dep in graph.get(fiche, [])):
|
||||
doit = True
|
||||
|
||||
regen_flags[fiche] = doit
|
||||
if doit:
|
||||
to_regen.append(fiche)
|
||||
|
||||
return to_regen
|
||||
|
||||
plan = get_regeneration_plan("Fiches/Minerai/Fiche minerai antimoine.md")
|
||||
print(plan)
|
||||
141
IA/make_config.py
Normal file
141
IA/make_config.py
Normal file
@ -0,0 +1,141 @@
|
||||
from utils.gitea import recuperer_date_dernier_commit
|
||||
#
|
||||
# from config import GITEA_URL, GITEA_TOKEN, ORGANISATION, DEPOT_FICHES, DEPOT_CODE, ENV, ENV_CODE, DOT_FILE
|
||||
#
|
||||
#def recuperer_date_dernier_commit(url):
|
||||
# headers = {"Authorization": f"token {GITEA_TOKEN}"}
|
||||
# try:
|
||||
# response = requests.get(url, headers=headers, timeout=10)
|
||||
# response.raise_for_status()
|
||||
# commits = response.json()
|
||||
# if commits:
|
||||
# return parser.isoparse(commits[0]["commit"]["author"]["date"])
|
||||
# except Exception as e:
|
||||
# logging.error(f"Erreur récupération commit schema : {e}")
|
||||
# return None
|
||||
#
|
||||
# path_relative = f"Documents/{dossier_choisi}/{fiche_choisie}"
|
||||
# commits_url = f"{GITEA_URL}/repos/{ORGANISATION}/{DEPOT_FICHES}/commits?path={path_relative}&sha={ENV}"
|
||||
#
|
||||
# local_mtime = datetime.fromtimestamp(os.path.getmtime(path_relative), tz=timezone.utc)
|
||||
# remote_mtime = recuperer_date_dernier_commit(commit_url)
|
||||
|
||||
MAKE = {
|
||||
"assets": {
|
||||
"directory": "assets",
|
||||
"seuils": {
|
||||
"depends_on": "None",
|
||||
},
|
||||
"file": {
|
||||
"seuils": "config.yaml"
|
||||
}
|
||||
},
|
||||
"fiches": {
|
||||
"directory": "Fiches",
|
||||
"criticites": {
|
||||
"directory": "Criticités",
|
||||
"préfix": "Fiche technique ",
|
||||
"depends_on": {
|
||||
"gitea": {
|
||||
"compare": "file2commit"
|
||||
},
|
||||
"document": {
|
||||
"seuils": {
|
||||
"where": "assets.file.seuils",
|
||||
"compare": "file2file"
|
||||
}
|
||||
}
|
||||
},
|
||||
"file": {
|
||||
"ihh": "Fiche technique IHH.md",
|
||||
"isg": "Fiche technique ISG.md",
|
||||
"ivc": "Fiche technique IVC.md",
|
||||
"ics": "Fiche technique ICS.md"
|
||||
}
|
||||
},
|
||||
"assemblage": {
|
||||
"directory": "Assemblage",
|
||||
"prefix": "Fiche assemblage ",
|
||||
"depends_on": {
|
||||
"gitea": {
|
||||
"compare": "file2commit"
|
||||
},
|
||||
"document": {
|
||||
"ihh": {
|
||||
"where": "fiches.criticites.file.ihh",
|
||||
"compare": "file2file"
|
||||
},
|
||||
"isg": {
|
||||
"where": "fiches.criticites.file.isg",
|
||||
"compare": "file2file"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"fabrication": {
|
||||
"directory": "Fabrication",
|
||||
"prefix": "Fiche fabrication ",
|
||||
"depends_on": {
|
||||
"gitea": {
|
||||
"compare": "file2commit"
|
||||
},
|
||||
"document": {
|
||||
"ihh": {
|
||||
"where": "fiches.criticites.file.ihh",
|
||||
"compare": "file2file"
|
||||
},
|
||||
"isg": {
|
||||
"where": "fiches.criticites.file.isg",
|
||||
"compare": "file2file"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"connexe": {
|
||||
"directory": "Connexe",
|
||||
"prefix": "Fiche assemblage ",
|
||||
"depends_on": {
|
||||
"gitea": {
|
||||
"compare": "file2commit"
|
||||
},
|
||||
"document": {
|
||||
"ihh": {
|
||||
"where": "fiches.criticites.file.ihh",
|
||||
"compare": "file2file"
|
||||
},
|
||||
"isg": {
|
||||
"where": "fiches.criticites.file.isg",
|
||||
"compare": "file2file"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"minerai": {
|
||||
"directory": "Minerai",
|
||||
"prefix": "Fiche minerai ",
|
||||
"depends_on": {
|
||||
"gitea": {
|
||||
"compare": "file2commit"
|
||||
},
|
||||
"document": {
|
||||
"ihh": {
|
||||
"where": "fiches.criticites.file.ihh",
|
||||
"compare": "file2file"
|
||||
},
|
||||
"isg": {
|
||||
"where": "fiches.criticites.file.isg",
|
||||
"compare": "file2file"
|
||||
},
|
||||
"ics": {
|
||||
"where": "fiches.criticites.file.ics",
|
||||
"compare": "file2file"
|
||||
},
|
||||
"ivc": {
|
||||
"where": "fiches.criticites.file.ivc",
|
||||
"compare": "file2file"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,5 +1,5 @@
|
||||
version: 1.1
|
||||
date: 2025-05-06
|
||||
date: 2025-05-27
|
||||
|
||||
seuils:
|
||||
IVC: # Indice de vulnérabilité concurrentielle
|
||||
|
||||
@ -1139,7 +1139,7 @@ def generate_operations_section(data, results, config):
|
||||
if total_share > 0:
|
||||
isg_combined = isg_weighted_sum / total_share
|
||||
color, suffix = determine_threshold_color(isg_combined, "ISG", config.get('thresholds'))
|
||||
template.append(f"\n**ISG combiné: {isg_combined:.0f} - {color} ({suffix})**\n")
|
||||
template.append(f"\n**ISG combiné: {isg_combined:.0f} - {color} ({suffix})**\n\n")
|
||||
|
||||
# IHH
|
||||
ihh_file = find_corpus_file("matrice-des-risques-liés-à-la-fabrication/indice-de-herfindahl-hirschmann", f"Fabrication/Fiche fabrication {component_slug}")
|
||||
@ -1647,7 +1647,7 @@ def generate_report(data, results, config):
|
||||
|
||||
return "\n".join(report), file_names
|
||||
|
||||
def generate_text(input_file, full_prompt, system_message, temperature = "0.1"):
|
||||
def generate_text(input_file, full_prompt, system_message, temperature = "0.1", use_context = True):
|
||||
"""Génère du texte avec l'API PrivateGPT"""
|
||||
try:
|
||||
|
||||
@ -1657,7 +1657,7 @@ def generate_text(input_file, full_prompt, system_message, temperature = "0.1"):
|
||||
{"role": "system", "content": system_message},
|
||||
{"role": "user", "content": full_prompt}
|
||||
],
|
||||
"use_context": True, # Active la recherche RAG dans les documents ingérés
|
||||
"use_context": use_context, # Active la recherche RAG dans les documents ingérés
|
||||
"temperature": temperature, # Température réduite pour plus de cohérence
|
||||
"stream": False
|
||||
}
|
||||
@ -1733,6 +1733,8 @@ def ia_analyse(file_names):
|
||||
Votre analyse doit être rigoureuse, accessible, pertinente pour la prise de décision stratégique, et conforme à la méthodologie définie ci-dessous :
|
||||
|
||||
{PROMPT_METHODOLOGIE}
|
||||
|
||||
/no_think
|
||||
"""
|
||||
|
||||
reponse[produit_final] = f"\n**{produit_final}**\n\n" + generate_text(file, full_prompt, system_message).split("</think>")[-1].strip()
|
||||
@ -1758,6 +1760,8 @@ def ia_analyse(file_names):
|
||||
- Être fluide, agréable à lire, avec un ton sobre et professionnel.
|
||||
|
||||
Répondez uniquement avec l'introduction rédigée. Ne fournissez aucune autre explication complémentaire.
|
||||
|
||||
/no_think
|
||||
"""
|
||||
|
||||
|
||||
@ -1794,6 +1798,8 @@ def ia_analyse(file_names):
|
||||
- Inviter de manière dynamique le COMEX à passer immédiatement à l'action.
|
||||
|
||||
Votre rédaction doit être fluide, professionnelle, claire et immédiatement exploitable par des dirigeants. Ne fournissez aucune explication supplémentaire. Ne répondez que par la conclusion demandée.
|
||||
|
||||
/no_think
|
||||
"""
|
||||
|
||||
conclusion = generate_text("", full_prompt, system_message, "0.7").split("</think>")[-1].strip()
|
||||
@ -1811,36 +1817,47 @@ def ia_analyse(file_names):
|
||||
"\n\n## Méthodologie\n\n" + \
|
||||
PROMPT_METHODOLOGIE
|
||||
|
||||
fichier_a_reviser = Path(TEMPLATE_PATH.name.replace(".md", " - analyse à relire.md"))
|
||||
write_report(analyse, TEMP_SECTIONS / fichier_a_reviser)
|
||||
ingest_document(TEMP_SECTIONS / fichier_a_reviser)
|
||||
# fichier_a_reviser = Path(TEMPLATE_PATH.name.replace(".md", " - analyse à relire.md"))
|
||||
# write_report(analyse, TEMP_SECTIONS / fichier_a_reviser)
|
||||
# ingest_document(TEMP_SECTIONS / fichier_a_reviser)
|
||||
|
||||
full_prompt = f"""
|
||||
Le fichier à réviser est {fichier_a_reviser}. Suivre scrupuleusement les consignes.
|
||||
full_prompt = """
|
||||
Suivre scrupuleusement les consignes.
|
||||
"""
|
||||
|
||||
system_message = f"""
|
||||
Vous êtes un réviseur professionnel expert en écriture stratégique, maîtrisant parfaitement la langue française et habitué à réviser des textes destinés à des dirigeants de haut niveau (COMEX).
|
||||
|
||||
Votre unique tâche est d'améliorer la qualité rédactionnelle du texte dans le fichier {fichier_a_reviser}, sans modifier ni sa structure, ni son sens initial, ni ajouter d’informations nouvelles. Cette révision doit :
|
||||
Votre tâche unique est d'améliorer strictement la qualité rédactionnelle du texte suivant, sans modifier en aucune manière :
|
||||
- la structure existante (sections, titres, sous-titres),
|
||||
- l'ordre des paragraphes et des idées,
|
||||
- le sens précis du contenu original,
|
||||
- sans ajouter aucune information nouvelle.
|
||||
|
||||
Votre révision doit impérativement respecter les points suivants :
|
||||
- Éliminer toutes répétitions ou redondances et varier systématiquement les tournures entre les paragraphes.
|
||||
- Rendre chaque phrase claire, directe et concise. Si une phrase est trop longue, scindez-la en plusieurs phrases courtes.
|
||||
- Scinder les paragraphes en 2 ou 3 parties cohérentes et bien enchaînées avec des termes de coordinations, d'implication, …
|
||||
- Remplacer systématiquement les acronymes par les expressions suivantes :
|
||||
- Rendre chaque phrase claire, directe et concise. Si une phrase est trop longue, scindez-la clairement en plusieurs phrases courtes.
|
||||
- Structurer chaque paragraphe en 2 à 3 parties cohérentes, reliées entre elles par des termes logiques (coordination, implication, opposition, etc.) et séparées par des retours à la ligne.
|
||||
- Remplacer systématiquement les acronymes par ces expressions précises :
|
||||
- ICS → « capacité à substituer un minerai »
|
||||
- IHH → « concentration géographique ou industrielle »
|
||||
- ISG → « stabilité géopolitique »
|
||||
- IVC → « concurrence intersectorielle pour les minerais »
|
||||
|
||||
Votre texte final doit être fluide, agréable à lire, parfaitement adapté à un COMEX, avec un ton professionnel et sobre.
|
||||
Votre texte final doit être parfaitement fluide, agréable à lire, adapté à un COMEX, avec un ton professionnel et sobre.
|
||||
|
||||
Répondez uniquement avec le texte révisé, sans autre commentaire.
|
||||
**Important : Ne répondez strictement que par le texte révisé ci-dessous, sans aucun commentaire ou explication supplémentaire.**
|
||||
|
||||
Voici le texte à réviser précisément :
|
||||
|
||||
{analyse}
|
||||
|
||||
/no_think
|
||||
"""
|
||||
corps = generate_text(fichier_a_reviser, full_prompt, system_message, "0.6").split("</think>")[-1].strip()
|
||||
revision = generate_text("", full_prompt, system_message, "0.1", False).split("</think>")[-1].strip()
|
||||
print("Relecture")
|
||||
|
||||
return analyse
|
||||
return revision
|
||||
|
||||
def write_report(report, fichier):
|
||||
"""Écrit le rapport généré dans le fichier spécifié."""
|
||||
|
||||
@ -1 +1,7 @@
|
||||
{}
|
||||
{
|
||||
"stephan": {
|
||||
"status": "en cours",
|
||||
"timestamp": 1748358233.0347543,
|
||||
"position": 0
|
||||
}
|
||||
}
|
||||
16
pgpt/.docker/router.yml
Normal file
16
pgpt/.docker/router.yml
Normal file
@ -0,0 +1,16 @@
|
||||
http:
|
||||
services:
|
||||
ollama:
|
||||
loadBalancer:
|
||||
healthCheck:
|
||||
interval: 5s
|
||||
path: /
|
||||
servers:
|
||||
- url: http://ollama-cpu:11434
|
||||
- url: http://ollama-cuda:11434
|
||||
- url: http://host.docker.internal:11434
|
||||
|
||||
routers:
|
||||
ollama-router:
|
||||
rule: "PathPrefix(`/`)"
|
||||
service: ollama
|
||||
51
pgpt/Dockerfile.ollama
Normal file
51
pgpt/Dockerfile.ollama
Normal file
@ -0,0 +1,51 @@
|
||||
FROM python:3.11.6-slim-bookworm AS base
|
||||
|
||||
# Install poetry
|
||||
RUN pip install pipx
|
||||
RUN python3 -m pipx ensurepath
|
||||
RUN pipx install poetry==1.8.3
|
||||
ENV PATH="/root/.local/bin:$PATH"
|
||||
ENV PATH=".venv/bin/:$PATH"
|
||||
|
||||
# https://python-poetry.org/docs/configuration/#virtualenvsin-project
|
||||
ENV POETRY_VIRTUALENVS_IN_PROJECT=true
|
||||
|
||||
FROM base AS dependencies
|
||||
WORKDIR /home/worker/app
|
||||
COPY pyproject.toml poetry.lock ./
|
||||
|
||||
ARG POETRY_EXTRAS="ui vector-stores-qdrant llms-ollama embeddings-ollama"
|
||||
RUN poetry install --no-root --extras "${POETRY_EXTRAS}"
|
||||
|
||||
FROM base AS app
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PORT=8080
|
||||
ENV APP_ENV=prod
|
||||
ENV PYTHONPATH="$PYTHONPATH:/home/worker/app/private_gpt/"
|
||||
EXPOSE 8080
|
||||
|
||||
# Prepare a non-root user
|
||||
# More info about how to configure UIDs and GIDs in Docker:
|
||||
# https://github.com/systemd/systemd/blob/main/docs/UIDS-GIDS.md
|
||||
|
||||
# Define the User ID (UID) for the non-root user
|
||||
# UID 100 is chosen to avoid conflicts with existing system users
|
||||
ARG UID=100
|
||||
|
||||
# Define the Group ID (GID) for the non-root user
|
||||
# GID 65534 is often used for the 'nogroup' or 'nobody' group
|
||||
ARG GID=65534
|
||||
|
||||
RUN adduser --system --gid ${GID} --uid ${UID} --home /home/worker worker
|
||||
WORKDIR /home/worker/app
|
||||
|
||||
RUN chown worker /home/worker/app
|
||||
RUN mkdir local_data && chown worker local_data
|
||||
RUN mkdir models && chown worker models
|
||||
COPY --chown=worker --from=dependencies /home/worker/app/.venv/ .venv
|
||||
COPY --chown=worker private_gpt/ private_gpt
|
||||
COPY --chown=worker *.yaml .
|
||||
COPY --chown=worker scripts/ scripts
|
||||
|
||||
USER worker
|
||||
ENTRYPOINT python -m private_gpt
|
||||
121
pgpt/docker-compose.yaml
Normal file
121
pgpt/docker-compose.yaml
Normal file
@ -0,0 +1,121 @@
|
||||
services:
|
||||
#-----------------------------------
|
||||
#---- Private-GPT services ---------
|
||||
#-----------------------------------
|
||||
|
||||
# Private-GPT service for the Ollama CPU and GPU modes
|
||||
# This service builds from an external Dockerfile and runs the Ollama mode.
|
||||
private-gpt-ollama:
|
||||
image: ${PGPT_IMAGE:-zylonai/private-gpt}:${PGPT_TAG:-0.6.2}-ollama # x-release-please-version
|
||||
user: root
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.ollama
|
||||
volumes:
|
||||
- /home/fabnum/fabnum-dev/Fiches:/home/worker/app/local_data/Fiches:Z
|
||||
ports:
|
||||
- "127.0.0.1:8001:8001"
|
||||
environment:
|
||||
PORT: 8001
|
||||
PGPT_PROFILES: docker
|
||||
PGPT_MODE: ollama
|
||||
PGPT_EMBED_MODE: ollama
|
||||
PGPT_OLLAMA_API_BASE: http://ollama:11434
|
||||
HF_TOKEN: ${HF_TOKEN:-}
|
||||
profiles:
|
||||
- ""
|
||||
- ollama-cpu
|
||||
- ollama-cuda
|
||||
- ollama-api
|
||||
depends_on:
|
||||
ollama:
|
||||
condition: service_healthy
|
||||
|
||||
# Private-GPT service for the local mode
|
||||
# This service builds from a local Dockerfile and runs the application in local mode.
|
||||
private-gpt-llamacpp-cpu:
|
||||
image: ${PGPT_IMAGE:-zylonai/private-gpt}:${PGPT_TAG:-0.6.2}-llamacpp-cpu # x-release-please-version
|
||||
user: root
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.llamacpp-cpu
|
||||
volumes:
|
||||
- ./local_data/:/home/worker/app/local_data
|
||||
- ./models/:/home/worker/app/models
|
||||
entrypoint: sh -c ".venv/bin/python scripts/setup && .venv/bin/python -m private_gpt"
|
||||
ports:
|
||||
- "127.0.0.1:8001:8001"
|
||||
environment:
|
||||
PORT: 8001
|
||||
PGPT_PROFILES: local
|
||||
HF_TOKEN: ${HF_TOKEN:-}
|
||||
profiles:
|
||||
- llamacpp-cpu
|
||||
|
||||
#-----------------------------------
|
||||
#---- Ollama services --------------
|
||||
#-----------------------------------
|
||||
|
||||
# Traefik reverse proxy for the Ollama service
|
||||
# This will route requests to the Ollama service based on the profile.
|
||||
ollama:
|
||||
image: traefik:v2.10
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD",
|
||||
"sh",
|
||||
"-c",
|
||||
"wget -q --spider http://ollama:11434 || exit 1",
|
||||
]
|
||||
interval: 10s
|
||||
retries: 3
|
||||
start_period: 5s
|
||||
timeout: 5s
|
||||
ports:
|
||||
- "127.0.0.1:8080:8080"
|
||||
command:
|
||||
- "--providers.file.filename=/etc/router.yml"
|
||||
- "--log.level=ERROR"
|
||||
- "--api.insecure=true"
|
||||
- "--providers.docker=true"
|
||||
- "--providers.docker.exposedbydefault=false"
|
||||
- "--entrypoints.web.address=:11434"
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||
- ./.docker/router.yml:/etc/router.yml:ro
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
profiles:
|
||||
- ""
|
||||
- ollama-cpu
|
||||
- ollama-cuda
|
||||
- ollama-api
|
||||
|
||||
# Ollama service for the CPU mode
|
||||
ollama-cpu:
|
||||
image: ollama/ollama:latest
|
||||
ports:
|
||||
- "127.0.0.1:11434:11434"
|
||||
volumes:
|
||||
- ./models:/root/.ollama:Z
|
||||
profiles:
|
||||
- ""
|
||||
- ollama-cpu
|
||||
|
||||
# Ollama service for the CUDA mode
|
||||
ollama-cuda:
|
||||
image: ollama/ollama:latest
|
||||
ports:
|
||||
- "11434:11434"
|
||||
volumes:
|
||||
- ./models:/root/.ollama
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
profiles:
|
||||
- ollama-cuda
|
||||
27
pgpt/private_gpt/__init__.py
Normal file
27
pgpt/private_gpt/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""private-gpt."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
# Set to 'DEBUG' to have extensive logging turned on, even for libraries
|
||||
ROOT_LOG_LEVEL = "INFO"
|
||||
|
||||
PRETTY_LOG_FORMAT = (
|
||||
"%(asctime)s.%(msecs)03d [%(levelname)-8s] %(name)+25s - %(message)s"
|
||||
)
|
||||
logging.basicConfig(level=ROOT_LOG_LEVEL, format=PRETTY_LOG_FORMAT, datefmt="%H:%M:%S")
|
||||
logging.captureWarnings(True)
|
||||
|
||||
# Disable gradio analytics
|
||||
# This is done this way because gradio does not solely rely on what values are
|
||||
# passed to gr.Blocks(enable_analytics=...) but also on the environment
|
||||
# variable GRADIO_ANALYTICS_ENABLED. `gradio.strings` actually reads this env
|
||||
# directly, so to fully disable gradio analytics we need to set this env var.
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
# Disable chromaDB telemetry
|
||||
# It is already disabled, see PR#1144
|
||||
# os.environ["ANONYMIZED_TELEMETRY"] = "False"
|
||||
|
||||
# adding tiktoken cache path within repo to be able to run in offline environment.
|
||||
os.environ["TIKTOKEN_CACHE_DIR"] = "tiktoken_cache"
|
||||
11
pgpt/private_gpt/__main__.py
Normal file
11
pgpt/private_gpt/__main__.py
Normal file
@ -0,0 +1,11 @@
|
||||
# start a fastapi server with uvicorn
|
||||
|
||||
import uvicorn
|
||||
|
||||
from private_gpt.main import app
|
||||
from private_gpt.settings.settings import settings
|
||||
|
||||
# Set log_config=None to do not use the uvicorn logging configuration, and
|
||||
# use ours instead. For reference, see below:
|
||||
# https://github.com/tiangolo/fastapi/discussions/7457#discussioncomment-5141108
|
||||
uvicorn.run(app, host="0.0.0.0", port=settings().server.port, log_config=None)
|
||||
0
pgpt/private_gpt/components/__init__.py
Normal file
0
pgpt/private_gpt/components/__init__.py
Normal file
0
pgpt/private_gpt/components/embedding/__init__.py
Normal file
0
pgpt/private_gpt/components/embedding/__init__.py
Normal file
82
pgpt/private_gpt/components/embedding/custom/sagemaker.py
Normal file
82
pgpt/private_gpt/components/embedding/custom/sagemaker.py
Normal file
@ -0,0 +1,82 @@
|
||||
# mypy: ignore-errors
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import boto3
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
|
||||
class SagemakerEmbedding(BaseEmbedding):
|
||||
"""Sagemaker Embedding Endpoint.
|
||||
|
||||
To use, you must supply the endpoint name from your deployed
|
||||
Sagemaker embedding model & the region where it is deployed.
|
||||
|
||||
To authenticate, the AWS client uses the following methods to
|
||||
automatically load credentials:
|
||||
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||
|
||||
If a specific credential profile should be used, you must pass
|
||||
the name of the profile from the ~/.aws/credentials file that is to be used.
|
||||
|
||||
Make sure the credentials / roles used have the required policies to
|
||||
access the Sagemaker endpoint.
|
||||
See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
|
||||
"""
|
||||
|
||||
endpoint_name: str = Field(description="")
|
||||
|
||||
_boto_client: Any = boto3.client(
|
||||
"sagemaker-runtime",
|
||||
) # TODO make it an optional field
|
||||
|
||||
_async_not_implemented_warned: bool = PrivateAttr(default=False)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "SagemakerEmbedding"
|
||||
|
||||
def _async_not_implemented_warn_once(self) -> None:
|
||||
if not self._async_not_implemented_warned:
|
||||
print("Async embedding not available, falling back to sync method.")
|
||||
self._async_not_implemented_warned = True
|
||||
|
||||
def _embed(self, sentences: list[str]) -> list[list[float]]:
|
||||
request_params = {
|
||||
"inputs": sentences,
|
||||
}
|
||||
|
||||
resp = self._boto_client.invoke_endpoint(
|
||||
EndpointName=self.endpoint_name,
|
||||
Body=json.dumps(request_params),
|
||||
ContentType="application/json",
|
||||
)
|
||||
|
||||
response_body = resp["Body"]
|
||||
response_str = response_body.read().decode("utf-8")
|
||||
response_json = json.loads(response_str)
|
||||
|
||||
return response_json["vectors"]
|
||||
|
||||
def _get_query_embedding(self, query: str) -> list[float]:
|
||||
"""Get query embedding."""
|
||||
return self._embed([query])[0]
|
||||
|
||||
async def _aget_query_embedding(self, query: str) -> list[float]:
|
||||
# Warn the user that sync is being used
|
||||
self._async_not_implemented_warn_once()
|
||||
return self._get_query_embedding(query)
|
||||
|
||||
async def _aget_text_embedding(self, text: str) -> list[float]:
|
||||
# Warn the user that sync is being used
|
||||
self._async_not_implemented_warn_once()
|
||||
return self._get_text_embedding(text)
|
||||
|
||||
def _get_text_embedding(self, text: str) -> list[float]:
|
||||
"""Get text embedding."""
|
||||
return self._embed([text])[0]
|
||||
|
||||
def _get_text_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Get text embeddings."""
|
||||
return self._embed(texts)
|
||||
167
pgpt/private_gpt/components/embedding/embedding_component.py
Normal file
167
pgpt/private_gpt/components/embedding/embedding_component.py
Normal file
@ -0,0 +1,167 @@
|
||||
import logging
|
||||
|
||||
from injector import inject, singleton
|
||||
from llama_index.core.embeddings import BaseEmbedding, MockEmbedding
|
||||
|
||||
from private_gpt.paths import models_cache_path
|
||||
from private_gpt.settings.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@singleton
|
||||
class EmbeddingComponent:
|
||||
embedding_model: BaseEmbedding
|
||||
|
||||
@inject
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
embedding_mode = settings.embedding.mode
|
||||
logger.info("Initializing the embedding model in mode=%s", embedding_mode)
|
||||
match embedding_mode:
|
||||
case "huggingface":
|
||||
try:
|
||||
from llama_index.embeddings.huggingface import ( # type: ignore
|
||||
HuggingFaceEmbedding,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Local dependencies not found, install with `poetry install --extras embeddings-huggingface`"
|
||||
) from e
|
||||
|
||||
self.embedding_model = HuggingFaceEmbedding(
|
||||
model_name=settings.huggingface.embedding_hf_model_name,
|
||||
cache_folder=str(models_cache_path),
|
||||
trust_remote_code=settings.huggingface.trust_remote_code,
|
||||
)
|
||||
case "sagemaker":
|
||||
try:
|
||||
from private_gpt.components.embedding.custom.sagemaker import (
|
||||
SagemakerEmbedding,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Sagemaker dependencies not found, install with `poetry install --extras embeddings-sagemaker`"
|
||||
) from e
|
||||
|
||||
self.embedding_model = SagemakerEmbedding(
|
||||
endpoint_name=settings.sagemaker.embedding_endpoint_name,
|
||||
)
|
||||
case "openai":
|
||||
try:
|
||||
from llama_index.embeddings.openai import ( # type: ignore
|
||||
OpenAIEmbedding,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"OpenAI dependencies not found, install with `poetry install --extras embeddings-openai`"
|
||||
) from e
|
||||
|
||||
api_base = (
|
||||
settings.openai.embedding_api_base or settings.openai.api_base
|
||||
)
|
||||
api_key = settings.openai.embedding_api_key or settings.openai.api_key
|
||||
model = settings.openai.embedding_model
|
||||
|
||||
self.embedding_model = OpenAIEmbedding(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
)
|
||||
case "ollama":
|
||||
try:
|
||||
from llama_index.embeddings.ollama import ( # type: ignore
|
||||
OllamaEmbedding,
|
||||
)
|
||||
from ollama import Client # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Local dependencies not found, install with `poetry install --extras embeddings-ollama`"
|
||||
) from e
|
||||
|
||||
ollama_settings = settings.ollama
|
||||
|
||||
# Calculate embedding model. If not provided tag, it will be use latest
|
||||
model_name = (
|
||||
ollama_settings.embedding_model + ":latest"
|
||||
if ":" not in ollama_settings.embedding_model
|
||||
else ollama_settings.embedding_model
|
||||
)
|
||||
|
||||
self.embedding_model = OllamaEmbedding(
|
||||
model_name=model_name,
|
||||
base_url=ollama_settings.embedding_api_base,
|
||||
)
|
||||
|
||||
if ollama_settings.autopull_models:
|
||||
if ollama_settings.autopull_models:
|
||||
from private_gpt.utils.ollama import (
|
||||
check_connection,
|
||||
pull_model,
|
||||
)
|
||||
|
||||
# TODO: Reuse llama-index client when llama-index is updated
|
||||
client = Client(
|
||||
host=ollama_settings.embedding_api_base,
|
||||
timeout=ollama_settings.request_timeout,
|
||||
)
|
||||
|
||||
if not check_connection(client):
|
||||
raise ValueError(
|
||||
f"Failed to connect to Ollama, "
|
||||
f"check if Ollama server is running on {ollama_settings.api_base}"
|
||||
)
|
||||
pull_model(client, model_name)
|
||||
|
||||
case "azopenai":
|
||||
try:
|
||||
from llama_index.embeddings.azure_openai import ( # type: ignore
|
||||
AzureOpenAIEmbedding,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Azure OpenAI dependencies not found, install with `poetry install --extras embeddings-azopenai`"
|
||||
) from e
|
||||
|
||||
azopenai_settings = settings.azopenai
|
||||
self.embedding_model = AzureOpenAIEmbedding(
|
||||
model=azopenai_settings.embedding_model,
|
||||
deployment_name=azopenai_settings.embedding_deployment_name,
|
||||
api_key=azopenai_settings.api_key,
|
||||
azure_endpoint=azopenai_settings.azure_endpoint,
|
||||
api_version=azopenai_settings.api_version,
|
||||
)
|
||||
case "gemini":
|
||||
try:
|
||||
from llama_index.embeddings.gemini import ( # type: ignore
|
||||
GeminiEmbedding,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Gemini dependencies not found, install with `poetry install --extras embeddings-gemini`"
|
||||
) from e
|
||||
|
||||
self.embedding_model = GeminiEmbedding(
|
||||
api_key=settings.gemini.api_key,
|
||||
model_name=settings.gemini.embedding_model,
|
||||
)
|
||||
case "mistralai":
|
||||
try:
|
||||
from llama_index.embeddings.mistralai import ( # type: ignore
|
||||
MistralAIEmbedding,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Mistral dependencies not found, install with `poetry install --extras embeddings-mistral`"
|
||||
) from e
|
||||
|
||||
api_key = settings.openai.api_key
|
||||
model = settings.openai.embedding_model
|
||||
|
||||
self.embedding_model = MistralAIEmbedding(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
)
|
||||
case "mock":
|
||||
# Not a random number, is the dimensionality used by
|
||||
# the default embedding model
|
||||
self.embedding_model = MockEmbedding(384)
|
||||
0
pgpt/private_gpt/components/ingest/__init__.py
Normal file
0
pgpt/private_gpt/components/ingest/__init__.py
Normal file
517
pgpt/private_gpt/components/ingest/ingest_component.py
Normal file
517
pgpt/private_gpt/components/ingest/ingest_component.py
Normal file
@ -0,0 +1,517 @@
|
||||
import abc
|
||||
import itertools
|
||||
import logging
|
||||
import multiprocessing
|
||||
import multiprocessing.pool
|
||||
import os
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.data_structs import IndexDict
|
||||
from llama_index.core.embeddings.utils import EmbedType
|
||||
from llama_index.core.indices import VectorStoreIndex, load_index_from_storage
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.ingestion import run_transformations
|
||||
from llama_index.core.schema import BaseNode, Document, TransformComponent
|
||||
from llama_index.core.storage import StorageContext
|
||||
|
||||
from private_gpt.components.ingest.ingest_helper import IngestionHelper
|
||||
from private_gpt.paths import local_data_path
|
||||
from private_gpt.settings.settings import Settings
|
||||
from private_gpt.utils.eta import eta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseIngestComponent(abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
storage_context: StorageContext,
|
||||
embed_model: EmbedType,
|
||||
transformations: list[TransformComponent],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
logger.debug("Initializing base ingest component type=%s", type(self).__name__)
|
||||
self.storage_context = storage_context
|
||||
self.embed_model = embed_model
|
||||
self.transformations = transformations
|
||||
|
||||
@abc.abstractmethod
|
||||
def ingest(self, file_name: str, file_data: Path) -> list[Document]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self, doc_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class BaseIngestComponentWithIndex(BaseIngestComponent, abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
storage_context: StorageContext,
|
||||
embed_model: EmbedType,
|
||||
transformations: list[TransformComponent],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(storage_context, embed_model, transformations, *args, **kwargs)
|
||||
|
||||
self.show_progress = True
|
||||
self._index_thread_lock = (
|
||||
threading.Lock()
|
||||
) # Thread lock! Not Multiprocessing lock
|
||||
self._index = self._initialize_index()
|
||||
|
||||
def _initialize_index(self) -> BaseIndex[IndexDict]:
|
||||
"""Initialize the index from the storage context."""
|
||||
try:
|
||||
# Load the index with store_nodes_override=True to be able to delete them
|
||||
index = load_index_from_storage(
|
||||
storage_context=self.storage_context,
|
||||
store_nodes_override=True, # Force store nodes in index and document stores
|
||||
show_progress=self.show_progress,
|
||||
embed_model=self.embed_model,
|
||||
transformations=self.transformations,
|
||||
)
|
||||
except ValueError:
|
||||
# There are no index in the storage context, creating a new one
|
||||
logger.info("Creating a new vector store index")
|
||||
index = VectorStoreIndex.from_documents(
|
||||
[],
|
||||
storage_context=self.storage_context,
|
||||
store_nodes_override=True, # Force store nodes in index and document stores
|
||||
show_progress=self.show_progress,
|
||||
embed_model=self.embed_model,
|
||||
transformations=self.transformations,
|
||||
)
|
||||
index.storage_context.persist(persist_dir=local_data_path)
|
||||
return index
|
||||
|
||||
def _save_index(self) -> None:
|
||||
self._index.storage_context.persist(persist_dir=local_data_path)
|
||||
|
||||
def delete(self, doc_id: str) -> None:
|
||||
with self._index_thread_lock:
|
||||
# Delete the document from the index
|
||||
self._index.delete_ref_doc(doc_id, delete_from_docstore=True)
|
||||
|
||||
# Save the index
|
||||
self._save_index()
|
||||
|
||||
|
||||
class SimpleIngestComponent(BaseIngestComponentWithIndex):
|
||||
def __init__(
|
||||
self,
|
||||
storage_context: StorageContext,
|
||||
embed_model: EmbedType,
|
||||
transformations: list[TransformComponent],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(storage_context, embed_model, transformations, *args, **kwargs)
|
||||
|
||||
def ingest(self, file_name: str, file_data: Path) -> list[Document]:
|
||||
logger.info("Ingesting file_name=%s", file_name)
|
||||
documents = IngestionHelper.transform_file_into_documents(file_name, file_data)
|
||||
logger.info(
|
||||
"Transformed file=%s into count=%s documents", file_name, len(documents)
|
||||
)
|
||||
logger.debug("Saving the documents in the index and doc store")
|
||||
return self._save_docs(documents)
|
||||
|
||||
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
|
||||
saved_documents = []
|
||||
for file_name, file_data in files:
|
||||
documents = IngestionHelper.transform_file_into_documents(
|
||||
file_name, file_data
|
||||
)
|
||||
saved_documents.extend(self._save_docs(documents))
|
||||
return saved_documents
|
||||
|
||||
def _save_docs(self, documents: list[Document]) -> list[Document]:
|
||||
logger.debug("Transforming count=%s documents into nodes", len(documents))
|
||||
with self._index_thread_lock:
|
||||
for document in documents:
|
||||
self._index.insert(document, show_progress=True)
|
||||
logger.debug("Persisting the index and nodes")
|
||||
# persist the index and nodes
|
||||
self._save_index()
|
||||
logger.debug("Persisted the index and nodes")
|
||||
return documents
|
||||
|
||||
|
||||
class BatchIngestComponent(BaseIngestComponentWithIndex):
|
||||
"""Parallelize the file reading and parsing on multiple CPU core.
|
||||
|
||||
This also makes the embeddings to be computed in batches (on GPU or CPU).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_context: StorageContext,
|
||||
embed_model: EmbedType,
|
||||
transformations: list[TransformComponent],
|
||||
count_workers: int,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(storage_context, embed_model, transformations, *args, **kwargs)
|
||||
# Make an efficient use of the CPU and GPU, the embedding
|
||||
# must be in the transformations
|
||||
assert (
|
||||
len(self.transformations) >= 2
|
||||
), "Embeddings must be in the transformations"
|
||||
assert count_workers > 0, "count_workers must be > 0"
|
||||
self.count_workers = count_workers
|
||||
|
||||
self._file_to_documents_work_pool = multiprocessing.Pool(
|
||||
processes=self.count_workers
|
||||
)
|
||||
|
||||
def ingest(self, file_name: str, file_data: Path) -> list[Document]:
|
||||
logger.info("Ingesting file_name=%s", file_name)
|
||||
documents = IngestionHelper.transform_file_into_documents(file_name, file_data)
|
||||
logger.info(
|
||||
"Transformed file=%s into count=%s documents", file_name, len(documents)
|
||||
)
|
||||
logger.debug("Saving the documents in the index and doc store")
|
||||
return self._save_docs(documents)
|
||||
|
||||
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
|
||||
documents = list(
|
||||
itertools.chain.from_iterable(
|
||||
self._file_to_documents_work_pool.starmap(
|
||||
IngestionHelper.transform_file_into_documents, files
|
||||
)
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"Transformed count=%s files into count=%s documents",
|
||||
len(files),
|
||||
len(documents),
|
||||
)
|
||||
return self._save_docs(documents)
|
||||
|
||||
def _save_docs(self, documents: list[Document]) -> list[Document]:
|
||||
logger.debug("Transforming count=%s documents into nodes", len(documents))
|
||||
nodes = run_transformations(
|
||||
documents, # type: ignore[arg-type]
|
||||
self.transformations,
|
||||
show_progress=self.show_progress,
|
||||
)
|
||||
# Locking the index to avoid concurrent writes
|
||||
with self._index_thread_lock:
|
||||
logger.info("Inserting count=%s nodes in the index", len(nodes))
|
||||
self._index.insert_nodes(nodes, show_progress=True)
|
||||
for document in documents:
|
||||
self._index.docstore.set_document_hash(
|
||||
document.get_doc_id(), document.hash
|
||||
)
|
||||
logger.debug("Persisting the index and nodes")
|
||||
# persist the index and nodes
|
||||
self._save_index()
|
||||
logger.debug("Persisted the index and nodes")
|
||||
return documents
|
||||
|
||||
|
||||
class ParallelizedIngestComponent(BaseIngestComponentWithIndex):
|
||||
"""Parallelize the file ingestion (file reading, embeddings, and index insertion).
|
||||
|
||||
This use the CPU and GPU in parallel (both running at the same time), and
|
||||
reduce the memory pressure by not loading all the files in memory at the same time.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_context: StorageContext,
|
||||
embed_model: EmbedType,
|
||||
transformations: list[TransformComponent],
|
||||
count_workers: int,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(storage_context, embed_model, transformations, *args, **kwargs)
|
||||
# To make an efficient use of the CPU and GPU, the embeddings
|
||||
# must be in the transformations (to be computed in batches)
|
||||
assert (
|
||||
len(self.transformations) >= 2
|
||||
), "Embeddings must be in the transformations"
|
||||
assert count_workers > 0, "count_workers must be > 0"
|
||||
self.count_workers = count_workers
|
||||
# We are doing our own multiprocessing
|
||||
# To do not collide with the multiprocessing of huggingface, we disable it
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
self._ingest_work_pool = multiprocessing.pool.ThreadPool(
|
||||
processes=self.count_workers
|
||||
)
|
||||
|
||||
self._file_to_documents_work_pool = multiprocessing.Pool(
|
||||
processes=self.count_workers
|
||||
)
|
||||
|
||||
def ingest(self, file_name: str, file_data: Path) -> list[Document]:
|
||||
logger.info("Ingesting file_name=%s", file_name)
|
||||
# Running in a single (1) process to release the current
|
||||
# thread, and take a dedicated CPU core for computation
|
||||
documents = self._file_to_documents_work_pool.apply(
|
||||
IngestionHelper.transform_file_into_documents, (file_name, file_data)
|
||||
)
|
||||
logger.info(
|
||||
"Transformed file=%s into count=%s documents", file_name, len(documents)
|
||||
)
|
||||
logger.debug("Saving the documents in the index and doc store")
|
||||
return self._save_docs(documents)
|
||||
|
||||
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
|
||||
# Lightweight threads, used for parallelize the
|
||||
# underlying IO calls made in the ingestion
|
||||
|
||||
documents = list(
|
||||
itertools.chain.from_iterable(
|
||||
self._ingest_work_pool.starmap(self.ingest, files)
|
||||
)
|
||||
)
|
||||
return documents
|
||||
|
||||
def _save_docs(self, documents: list[Document]) -> list[Document]:
|
||||
logger.debug("Transforming count=%s documents into nodes", len(documents))
|
||||
nodes = run_transformations(
|
||||
documents, # type: ignore[arg-type]
|
||||
self.transformations,
|
||||
show_progress=self.show_progress,
|
||||
)
|
||||
# Locking the index to avoid concurrent writes
|
||||
with self._index_thread_lock:
|
||||
logger.info("Inserting count=%s nodes in the index", len(nodes))
|
||||
self._index.insert_nodes(nodes, show_progress=True)
|
||||
for document in documents:
|
||||
self._index.docstore.set_document_hash(
|
||||
document.get_doc_id(), document.hash
|
||||
)
|
||||
logger.debug("Persisting the index and nodes")
|
||||
# persist the index and nodes
|
||||
self._save_index()
|
||||
logger.debug("Persisted the index and nodes")
|
||||
return documents
|
||||
|
||||
def __del__(self) -> None:
|
||||
# We need to do the appropriate cleanup of the multiprocessing pools
|
||||
# when the object is deleted. Using root logger to avoid
|
||||
# the logger to be deleted before the pool
|
||||
logging.debug("Closing the ingest work pool")
|
||||
self._ingest_work_pool.close()
|
||||
self._ingest_work_pool.join()
|
||||
self._ingest_work_pool.terminate()
|
||||
logging.debug("Closing the file to documents work pool")
|
||||
self._file_to_documents_work_pool.close()
|
||||
self._file_to_documents_work_pool.join()
|
||||
self._file_to_documents_work_pool.terminate()
|
||||
|
||||
|
||||
class PipelineIngestComponent(BaseIngestComponentWithIndex):
|
||||
"""Pipeline ingestion - keeping the embedding worker pool as busy as possible.
|
||||
|
||||
This class implements a threaded ingestion pipeline, which comprises two threads
|
||||
and two queues. The primary thread is responsible for reading and parsing files
|
||||
into documents. These documents are then placed into a queue, which is
|
||||
distributed to a pool of worker processes for embedding computation. After
|
||||
embedding, the documents are transferred to another queue where they are
|
||||
accumulated until a threshold is reached. Upon reaching this threshold, the
|
||||
accumulated documents are flushed to the document store, index, and vector
|
||||
store.
|
||||
|
||||
Exception handling ensures robustness against erroneous files. However, in the
|
||||
pipelined design, one error can lead to the discarding of multiple files. Any
|
||||
discarded files will be reported.
|
||||
"""
|
||||
|
||||
NODE_FLUSH_COUNT = 5000 # Save the index every # nodes.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_context: StorageContext,
|
||||
embed_model: EmbedType,
|
||||
transformations: list[TransformComponent],
|
||||
count_workers: int,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(storage_context, embed_model, transformations, *args, **kwargs)
|
||||
self.count_workers = count_workers
|
||||
assert (
|
||||
len(self.transformations) >= 2
|
||||
), "Embeddings must be in the transformations"
|
||||
assert count_workers > 0, "count_workers must be > 0"
|
||||
self.count_workers = count_workers
|
||||
# We are doing our own multiprocessing
|
||||
# To do not collide with the multiprocessing of huggingface, we disable it
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# doc_q stores parsed files as Document chunks.
|
||||
# Using a shallow queue causes the filesystem parser to block
|
||||
# when it reaches capacity. This ensures it doesn't outpace the
|
||||
# computationally intensive embeddings phase, avoiding unnecessary
|
||||
# memory consumption. The semaphore is used to bound the async worker
|
||||
# embedding computations to cause the doc Q to fill and block.
|
||||
self.doc_semaphore = multiprocessing.Semaphore(
|
||||
self.count_workers
|
||||
) # limit the doc queue to # items.
|
||||
self.doc_q: Queue[tuple[str, str | None, list[Document] | None]] = Queue(20)
|
||||
# node_q stores documents parsed into nodes (embeddings).
|
||||
# Larger queue size so we don't block the embedding workers during a slow
|
||||
# index update.
|
||||
self.node_q: Queue[
|
||||
tuple[str, str | None, list[Document] | None, list[BaseNode] | None]
|
||||
] = Queue(40)
|
||||
threading.Thread(target=self._doc_to_node, daemon=True).start()
|
||||
threading.Thread(target=self._write_nodes, daemon=True).start()
|
||||
|
||||
def _doc_to_node(self) -> None:
|
||||
# Parse documents into nodes
|
||||
with multiprocessing.pool.ThreadPool(processes=self.count_workers) as pool:
|
||||
while True:
|
||||
try:
|
||||
cmd, file_name, documents = self.doc_q.get(
|
||||
block=True
|
||||
) # Documents for a file
|
||||
if cmd == "process":
|
||||
# Push CPU/GPU embedding work to the worker pool
|
||||
# Acquire semaphore to control access to worker pool
|
||||
self.doc_semaphore.acquire()
|
||||
pool.apply_async(
|
||||
self._doc_to_node_worker, (file_name, documents)
|
||||
)
|
||||
elif cmd == "quit":
|
||||
break
|
||||
finally:
|
||||
if cmd != "process":
|
||||
self.doc_q.task_done() # unblock Q joins
|
||||
|
||||
def _doc_to_node_worker(self, file_name: str, documents: list[Document]) -> None:
|
||||
# CPU/GPU intensive work in its own process
|
||||
try:
|
||||
nodes = run_transformations(
|
||||
documents, # type: ignore[arg-type]
|
||||
self.transformations,
|
||||
show_progress=self.show_progress,
|
||||
)
|
||||
self.node_q.put(("process", file_name, documents, list(nodes)))
|
||||
finally:
|
||||
self.doc_semaphore.release()
|
||||
self.doc_q.task_done() # unblock Q joins
|
||||
|
||||
def _save_docs(
|
||||
self, files: list[str], documents: list[Document], nodes: list[BaseNode]
|
||||
) -> None:
|
||||
try:
|
||||
logger.info(
|
||||
f"Saving {len(files)} files ({len(documents)} documents / {len(nodes)} nodes)"
|
||||
)
|
||||
self._index.insert_nodes(nodes)
|
||||
for document in documents:
|
||||
self._index.docstore.set_document_hash(
|
||||
document.get_doc_id(), document.hash
|
||||
)
|
||||
self._save_index()
|
||||
except Exception:
|
||||
# Tell the user so they can investigate these files
|
||||
logger.exception(f"Processing files {files}")
|
||||
finally:
|
||||
# Clearing work, even on exception, maintains a clean state.
|
||||
nodes.clear()
|
||||
documents.clear()
|
||||
files.clear()
|
||||
|
||||
def _write_nodes(self) -> None:
|
||||
# Save nodes to index. I/O intensive.
|
||||
node_stack: list[BaseNode] = []
|
||||
doc_stack: list[Document] = []
|
||||
file_stack: list[str] = []
|
||||
while True:
|
||||
try:
|
||||
cmd, file_name, documents, nodes = self.node_q.get(block=True)
|
||||
if cmd in ("flush", "quit"):
|
||||
if file_stack:
|
||||
self._save_docs(file_stack, doc_stack, node_stack)
|
||||
if cmd == "quit":
|
||||
break
|
||||
elif cmd == "process":
|
||||
node_stack.extend(nodes) # type: ignore[arg-type]
|
||||
doc_stack.extend(documents) # type: ignore[arg-type]
|
||||
file_stack.append(file_name) # type: ignore[arg-type]
|
||||
# Constant saving is heavy on I/O - accumulate to a threshold
|
||||
if len(node_stack) >= self.NODE_FLUSH_COUNT:
|
||||
self._save_docs(file_stack, doc_stack, node_stack)
|
||||
finally:
|
||||
self.node_q.task_done()
|
||||
|
||||
def _flush(self) -> None:
|
||||
self.doc_q.put(("flush", None, None))
|
||||
self.doc_q.join()
|
||||
self.node_q.put(("flush", None, None, None))
|
||||
self.node_q.join()
|
||||
|
||||
def ingest(self, file_name: str, file_data: Path) -> list[Document]:
|
||||
documents = IngestionHelper.transform_file_into_documents(file_name, file_data)
|
||||
self.doc_q.put(("process", file_name, documents))
|
||||
self._flush()
|
||||
return documents
|
||||
|
||||
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
|
||||
docs = []
|
||||
for file_name, file_data in eta(files):
|
||||
try:
|
||||
documents = IngestionHelper.transform_file_into_documents(
|
||||
file_name, file_data
|
||||
)
|
||||
self.doc_q.put(("process", file_name, documents))
|
||||
docs.extend(documents)
|
||||
except Exception:
|
||||
logger.exception(f"Skipping {file_data.name}")
|
||||
self._flush()
|
||||
return docs
|
||||
|
||||
|
||||
def get_ingestion_component(
|
||||
storage_context: StorageContext,
|
||||
embed_model: EmbedType,
|
||||
transformations: list[TransformComponent],
|
||||
settings: Settings,
|
||||
) -> BaseIngestComponent:
|
||||
"""Get the ingestion component for the given configuration."""
|
||||
ingest_mode = settings.embedding.ingest_mode
|
||||
if ingest_mode == "batch":
|
||||
return BatchIngestComponent(
|
||||
storage_context=storage_context,
|
||||
embed_model=embed_model,
|
||||
transformations=transformations,
|
||||
count_workers=settings.embedding.count_workers,
|
||||
)
|
||||
elif ingest_mode == "parallel":
|
||||
return ParallelizedIngestComponent(
|
||||
storage_context=storage_context,
|
||||
embed_model=embed_model,
|
||||
transformations=transformations,
|
||||
count_workers=settings.embedding.count_workers,
|
||||
)
|
||||
elif ingest_mode == "pipeline":
|
||||
return PipelineIngestComponent(
|
||||
storage_context=storage_context,
|
||||
embed_model=embed_model,
|
||||
transformations=transformations,
|
||||
count_workers=settings.embedding.count_workers,
|
||||
)
|
||||
else:
|
||||
return SimpleIngestComponent(
|
||||
storage_context=storage_context,
|
||||
embed_model=embed_model,
|
||||
transformations=transformations,
|
||||
)
|
||||
111
pgpt/private_gpt/components/ingest/ingest_helper.py
Normal file
111
pgpt/private_gpt/components/ingest/ingest_helper.py
Normal file
@ -0,0 +1,111 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from llama_index.core.readers import StringIterableReader
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.core.readers.json import JSONReader
|
||||
from llama_index.core.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Inspired by the `llama_index.core.readers.file.base` module
|
||||
def _try_loading_included_file_formats() -> dict[str, type[BaseReader]]:
|
||||
try:
|
||||
from llama_index.readers.file.docs import ( # type: ignore
|
||||
DocxReader,
|
||||
HWPReader,
|
||||
PDFReader,
|
||||
)
|
||||
from llama_index.readers.file.epub import EpubReader # type: ignore
|
||||
from llama_index.readers.file.image import ImageReader # type: ignore
|
||||
from llama_index.readers.file.ipynb import IPYNBReader # type: ignore
|
||||
from llama_index.readers.file.markdown import MarkdownReader # type: ignore
|
||||
from llama_index.readers.file.mbox import MboxReader # type: ignore
|
||||
from llama_index.readers.file.slides import PptxReader # type: ignore
|
||||
from llama_index.readers.file.tabular import PandasCSVReader # type: ignore
|
||||
from llama_index.readers.file.video_audio import ( # type: ignore
|
||||
VideoAudioReader,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError("`llama-index-readers-file` package not found") from e
|
||||
|
||||
default_file_reader_cls: dict[str, type[BaseReader]] = {
|
||||
".hwp": HWPReader,
|
||||
".pdf": PDFReader,
|
||||
".docx": DocxReader,
|
||||
".pptx": PptxReader,
|
||||
".ppt": PptxReader,
|
||||
".pptm": PptxReader,
|
||||
".jpg": ImageReader,
|
||||
".png": ImageReader,
|
||||
".jpeg": ImageReader,
|
||||
".mp3": VideoAudioReader,
|
||||
".mp4": VideoAudioReader,
|
||||
".csv": PandasCSVReader,
|
||||
".epub": EpubReader,
|
||||
".md": MarkdownReader,
|
||||
".mbox": MboxReader,
|
||||
".ipynb": IPYNBReader,
|
||||
}
|
||||
return default_file_reader_cls
|
||||
|
||||
|
||||
# Patching the default file reader to support other file types
|
||||
FILE_READER_CLS = _try_loading_included_file_formats()
|
||||
FILE_READER_CLS.update(
|
||||
{
|
||||
".json": JSONReader,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class IngestionHelper:
|
||||
"""Helper class to transform a file into a list of documents.
|
||||
|
||||
This class should be used to transform a file into a list of documents.
|
||||
These methods are thread-safe (and multiprocessing-safe).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def transform_file_into_documents(
|
||||
file_name: str, file_data: Path
|
||||
) -> list[Document]:
|
||||
documents = IngestionHelper._load_file_to_documents(file_name, file_data)
|
||||
for document in documents:
|
||||
document.metadata["file_name"] = file_name
|
||||
IngestionHelper._exclude_metadata(documents)
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def _load_file_to_documents(file_name: str, file_data: Path) -> list[Document]:
|
||||
logger.debug("Transforming file_name=%s into documents", file_name)
|
||||
extension = Path(file_name).suffix
|
||||
reader_cls = FILE_READER_CLS.get(extension)
|
||||
if reader_cls is None:
|
||||
logger.debug(
|
||||
"No reader found for extension=%s, using default string reader",
|
||||
extension,
|
||||
)
|
||||
# Read as a plain text
|
||||
string_reader = StringIterableReader()
|
||||
return string_reader.load_data([file_data.read_text()])
|
||||
|
||||
logger.debug("Specific reader found for extension=%s", extension)
|
||||
documents = reader_cls().load_data(file_data)
|
||||
|
||||
# Sanitize NUL bytes in text which can't be stored in Postgres
|
||||
for i in range(len(documents)):
|
||||
documents[i].text = documents[i].text.replace("\u0000", "")
|
||||
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def _exclude_metadata(documents: list[Document]) -> None:
|
||||
logger.debug("Excluding metadata from count=%s documents", len(documents))
|
||||
for document in documents:
|
||||
document.metadata["doc_id"] = document.doc_id
|
||||
# We don't want the Embeddings search to receive this metadata
|
||||
document.excluded_embed_metadata_keys = ["doc_id"]
|
||||
# We don't want the LLM to receive these metadata in the context
|
||||
document.excluded_llm_metadata_keys = ["file_name", "doc_id", "page_label"]
|
||||
1
pgpt/private_gpt/components/llm/__init__.py
Normal file
1
pgpt/private_gpt/components/llm/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""LLM implementations."""
|
||||
0
pgpt/private_gpt/components/llm/custom/__init__.py
Normal file
0
pgpt/private_gpt/components/llm/custom/__init__.py
Normal file
276
pgpt/private_gpt/components/llm/custom/sagemaker.py
Normal file
276
pgpt/private_gpt/components/llm/custom/sagemaker.py
Normal file
@ -0,0 +1,276 @@
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import boto3 # type: ignore
|
||||
from llama_index.core.base.llms.generic_utils import (
|
||||
completion_response_to_chat_response,
|
||||
stream_completion_response_to_chat_response,
|
||||
)
|
||||
from llama_index.core.bridge.pydantic import Field
|
||||
from llama_index.core.llms import (
|
||||
CompletionResponse,
|
||||
CustomLLM,
|
||||
LLMMetadata,
|
||||
)
|
||||
from llama_index.core.llms.callbacks import (
|
||||
llm_chat_callback,
|
||||
llm_completion_callback,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from llama_index.callbacks import CallbackManager
|
||||
from llama_index.llms import (
|
||||
ChatMessage,
|
||||
ChatResponse,
|
||||
ChatResponseGen,
|
||||
CompletionResponseGen,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LineIterator:
|
||||
r"""A helper class for parsing the byte stream input from TGI container.
|
||||
|
||||
The output of the model will be in the following format:
|
||||
```
|
||||
b'data:{"token": {"text": " a"}}\n\n'
|
||||
b'data:{"token": {"text": " challenging"}}\n\n'
|
||||
b'data:{"token": {"text": " problem"
|
||||
b'}}'
|
||||
...
|
||||
```
|
||||
|
||||
While usually each PayloadPart event from the event stream will contain a byte array
|
||||
with a full json, this is not guaranteed and some of the json objects may be split
|
||||
across PayloadPart events. For example:
|
||||
```
|
||||
{'PayloadPart': {'Bytes': b'{"outputs": '}}
|
||||
{'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
|
||||
```
|
||||
|
||||
|
||||
This class accounts for this by concatenating bytes written via the 'write' function
|
||||
and then exposing a method which will return lines (ending with a '\n' character)
|
||||
within the buffer via the 'scan_lines' function. It maintains the position of the
|
||||
last read position to ensure that previous bytes are not exposed again. It will
|
||||
also save any pending lines that doe not end with a '\n' to make sure truncations
|
||||
are concatinated
|
||||
"""
|
||||
|
||||
def __init__(self, stream: Any) -> None:
|
||||
"""Line iterator initializer."""
|
||||
self.byte_iterator = iter(stream)
|
||||
self.buffer = io.BytesIO()
|
||||
self.read_pos = 0
|
||||
|
||||
def __iter__(self) -> Any:
|
||||
"""Self iterator."""
|
||||
return self
|
||||
|
||||
def __next__(self) -> Any:
|
||||
"""Next element from iterator."""
|
||||
while True:
|
||||
self.buffer.seek(self.read_pos)
|
||||
line = self.buffer.readline()
|
||||
if line and line[-1] == ord("\n"):
|
||||
self.read_pos += len(line)
|
||||
return line[:-1]
|
||||
try:
|
||||
chunk = next(self.byte_iterator)
|
||||
except StopIteration:
|
||||
if self.read_pos < self.buffer.getbuffer().nbytes:
|
||||
continue
|
||||
raise
|
||||
if "PayloadPart" not in chunk:
|
||||
logger.warning("Unknown event type=%s", chunk)
|
||||
continue
|
||||
self.buffer.seek(0, io.SEEK_END)
|
||||
self.buffer.write(chunk["PayloadPart"]["Bytes"])
|
||||
|
||||
|
||||
class SagemakerLLM(CustomLLM):
|
||||
"""Sagemaker Inference Endpoint models.
|
||||
|
||||
To use, you must supply the endpoint name from your deployed
|
||||
Sagemaker model & the region where it is deployed.
|
||||
|
||||
To authenticate, the AWS client uses the following methods to
|
||||
automatically load credentials:
|
||||
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||
|
||||
If a specific credential profile should be used, you must pass
|
||||
the name of the profile from the ~/.aws/credentials file that is to be used.
|
||||
|
||||
Make sure the credentials / roles used have the required policies to
|
||||
access the Sagemaker endpoint.
|
||||
See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
|
||||
"""
|
||||
|
||||
endpoint_name: str = Field(description="")
|
||||
temperature: float = Field(description="The temperature to use for sampling.")
|
||||
max_new_tokens: int = Field(description="The maximum number of tokens to generate.")
|
||||
context_window: int = Field(
|
||||
description="The maximum number of context tokens for the model."
|
||||
)
|
||||
messages_to_prompt: Any = Field(
|
||||
description="The function to convert messages to a prompt.", exclude=True
|
||||
)
|
||||
completion_to_prompt: Any = Field(
|
||||
description="The function to convert a completion to a prompt.", exclude=True
|
||||
)
|
||||
generate_kwargs: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Kwargs used for generation."
|
||||
)
|
||||
model_kwargs: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Kwargs used for model initialization."
|
||||
)
|
||||
verbose: bool = Field(description="Whether to print verbose output.")
|
||||
|
||||
_boto_client: Any = boto3.client(
|
||||
"sagemaker-runtime",
|
||||
) # TODO make it an optional field
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint_name: str | None = "",
|
||||
temperature: float = 0.1,
|
||||
max_new_tokens: int = 512, # to review defaults
|
||||
context_window: int = 2048, # to review defaults
|
||||
messages_to_prompt: Any = None,
|
||||
completion_to_prompt: Any = None,
|
||||
callback_manager: CallbackManager | None = None,
|
||||
generate_kwargs: dict[str, Any] | None = None,
|
||||
model_kwargs: dict[str, Any] | None = None,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
"""SagemakerLLM initializer."""
|
||||
model_kwargs = model_kwargs or {}
|
||||
model_kwargs.update({"n_ctx": context_window, "verbose": verbose})
|
||||
|
||||
messages_to_prompt = messages_to_prompt or {}
|
||||
completion_to_prompt = completion_to_prompt or {}
|
||||
|
||||
generate_kwargs = generate_kwargs or {}
|
||||
generate_kwargs.update(
|
||||
{"temperature": temperature, "max_tokens": max_new_tokens}
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
endpoint_name=endpoint_name,
|
||||
temperature=temperature,
|
||||
context_window=context_window,
|
||||
max_new_tokens=max_new_tokens,
|
||||
messages_to_prompt=messages_to_prompt,
|
||||
completion_to_prompt=completion_to_prompt,
|
||||
callback_manager=callback_manager,
|
||||
generate_kwargs=generate_kwargs,
|
||||
model_kwargs=model_kwargs,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@property
|
||||
def inference_params(self):
|
||||
# TODO expose the rest of params
|
||||
return {
|
||||
"do_sample": True,
|
||||
"top_p": 0.7,
|
||||
"temperature": self.temperature,
|
||||
"top_k": 50,
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
}
|
||||
|
||||
@property
|
||||
def metadata(self) -> LLMMetadata:
|
||||
"""Get LLM metadata."""
|
||||
return LLMMetadata(
|
||||
context_window=self.context_window,
|
||||
num_output=self.max_new_tokens,
|
||||
model_name="Sagemaker LLama 2",
|
||||
)
|
||||
|
||||
@llm_completion_callback()
|
||||
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
|
||||
self.generate_kwargs.update({"stream": False})
|
||||
|
||||
is_formatted = kwargs.pop("formatted", False)
|
||||
if not is_formatted:
|
||||
prompt = self.completion_to_prompt(prompt)
|
||||
|
||||
request_params = {
|
||||
"inputs": prompt,
|
||||
"stream": False,
|
||||
"parameters": self.inference_params,
|
||||
}
|
||||
|
||||
resp = self._boto_client.invoke_endpoint(
|
||||
EndpointName=self.endpoint_name,
|
||||
Body=json.dumps(request_params),
|
||||
ContentType="application/json",
|
||||
)
|
||||
|
||||
response_body = resp["Body"]
|
||||
response_str = response_body.read().decode("utf-8")
|
||||
response_dict = json.loads(response_str)
|
||||
|
||||
return CompletionResponse(
|
||||
text=response_dict[0]["generated_text"][len(prompt) :], raw=resp
|
||||
)
|
||||
|
||||
@llm_completion_callback()
|
||||
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
|
||||
def get_stream():
|
||||
text = ""
|
||||
|
||||
request_params = {
|
||||
"inputs": prompt,
|
||||
"stream": True,
|
||||
"parameters": self.inference_params,
|
||||
}
|
||||
resp = self._boto_client.invoke_endpoint_with_response_stream(
|
||||
EndpointName=self.endpoint_name,
|
||||
Body=json.dumps(request_params),
|
||||
ContentType="application/json",
|
||||
)
|
||||
|
||||
event_stream = resp["Body"]
|
||||
start_json = b"{"
|
||||
stop_token = "<|endoftext|>"
|
||||
first_token = True
|
||||
|
||||
for line in LineIterator(event_stream):
|
||||
if line != b"" and start_json in line:
|
||||
data = json.loads(line[line.find(start_json) :].decode("utf-8"))
|
||||
special = data["token"]["special"]
|
||||
stop = data["token"]["text"] == stop_token
|
||||
if not special and not stop:
|
||||
delta = data["token"]["text"]
|
||||
# trim the leading space for the first token if present
|
||||
if first_token:
|
||||
delta = delta.lstrip()
|
||||
first_token = False
|
||||
text += delta
|
||||
yield CompletionResponse(delta=delta, text=text, raw=data)
|
||||
|
||||
return get_stream()
|
||||
|
||||
@llm_chat_callback()
|
||||
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
|
||||
prompt = self.messages_to_prompt(messages)
|
||||
completion_response = self.complete(prompt, formatted=True, **kwargs)
|
||||
return completion_response_to_chat_response(completion_response)
|
||||
|
||||
@llm_chat_callback()
|
||||
def stream_chat(
|
||||
self, messages: Sequence[ChatMessage], **kwargs: Any
|
||||
) -> ChatResponseGen:
|
||||
prompt = self.messages_to_prompt(messages)
|
||||
completion_response = self.stream_complete(prompt, formatted=True, **kwargs)
|
||||
return stream_completion_response_to_chat_response(completion_response)
|
||||
225
pgpt/private_gpt/components/llm/llm_component.py
Normal file
225
pgpt/private_gpt/components/llm/llm_component.py
Normal file
@ -0,0 +1,225 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from injector import inject, singleton
|
||||
from llama_index.core.llms import LLM, MockLLM
|
||||
from llama_index.core.settings import Settings as LlamaIndexSettings
|
||||
from llama_index.core.utils import set_global_tokenizer
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
from private_gpt.components.llm.prompt_helper import get_prompt_style
|
||||
from private_gpt.paths import models_cache_path, models_path
|
||||
from private_gpt.settings.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@singleton
|
||||
class LLMComponent:
|
||||
llm: LLM
|
||||
|
||||
@inject
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
llm_mode = settings.llm.mode
|
||||
if settings.llm.tokenizer and settings.llm.mode != "mock":
|
||||
# Try to download the tokenizer. If it fails, the LLM will still work
|
||||
# using the default one, which is less accurate.
|
||||
try:
|
||||
set_global_tokenizer(
|
||||
AutoTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path=settings.llm.tokenizer,
|
||||
cache_dir=str(models_cache_path),
|
||||
token=settings.huggingface.access_token,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to download tokenizer {settings.llm.tokenizer}: {e!s}"
|
||||
f"Please follow the instructions in the documentation to download it if needed: "
|
||||
f"https://docs.privategpt.dev/installation/getting-started/troubleshooting#tokenizer-setup."
|
||||
f"Falling back to default tokenizer."
|
||||
)
|
||||
|
||||
logger.info("Initializing the LLM in mode=%s", llm_mode)
|
||||
match settings.llm.mode:
|
||||
case "llamacpp":
|
||||
try:
|
||||
from llama_index.llms.llama_cpp import LlamaCPP # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Local dependencies not found, install with `poetry install --extras llms-llama-cpp`"
|
||||
) from e
|
||||
|
||||
prompt_style = get_prompt_style(settings.llm.prompt_style)
|
||||
settings_kwargs = {
|
||||
"tfs_z": settings.llamacpp.tfs_z, # ollama and llama-cpp
|
||||
"top_k": settings.llamacpp.top_k, # ollama and llama-cpp
|
||||
"top_p": settings.llamacpp.top_p, # ollama and llama-cpp
|
||||
"repeat_penalty": settings.llamacpp.repeat_penalty, # ollama llama-cpp
|
||||
"n_gpu_layers": -1,
|
||||
"offload_kqv": True,
|
||||
}
|
||||
self.llm = LlamaCPP(
|
||||
model_path=str(models_path / settings.llamacpp.llm_hf_model_file),
|
||||
temperature=settings.llm.temperature,
|
||||
max_new_tokens=settings.llm.max_new_tokens,
|
||||
context_window=settings.llm.context_window,
|
||||
generate_kwargs={},
|
||||
callback_manager=LlamaIndexSettings.callback_manager,
|
||||
# All to GPU
|
||||
model_kwargs=settings_kwargs,
|
||||
# transform inputs into Llama2 format
|
||||
messages_to_prompt=prompt_style.messages_to_prompt,
|
||||
completion_to_prompt=prompt_style.completion_to_prompt,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
case "sagemaker":
|
||||
try:
|
||||
from private_gpt.components.llm.custom.sagemaker import SagemakerLLM
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Sagemaker dependencies not found, install with `poetry install --extras llms-sagemaker`"
|
||||
) from e
|
||||
|
||||
self.llm = SagemakerLLM(
|
||||
endpoint_name=settings.sagemaker.llm_endpoint_name,
|
||||
max_new_tokens=settings.llm.max_new_tokens,
|
||||
context_window=settings.llm.context_window,
|
||||
)
|
||||
case "openai":
|
||||
try:
|
||||
from llama_index.llms.openai import OpenAI # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"OpenAI dependencies not found, install with `poetry install --extras llms-openai`"
|
||||
) from e
|
||||
|
||||
openai_settings = settings.openai
|
||||
self.llm = OpenAI(
|
||||
api_base=openai_settings.api_base,
|
||||
api_key=openai_settings.api_key,
|
||||
model=openai_settings.model,
|
||||
)
|
||||
case "openailike":
|
||||
try:
|
||||
from llama_index.llms.openai_like import OpenAILike # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"OpenAILike dependencies not found, install with `poetry install --extras llms-openai-like`"
|
||||
) from e
|
||||
prompt_style = get_prompt_style(settings.llm.prompt_style)
|
||||
openai_settings = settings.openai
|
||||
self.llm = OpenAILike(
|
||||
api_base=openai_settings.api_base,
|
||||
api_key=openai_settings.api_key,
|
||||
model=openai_settings.model,
|
||||
is_chat_model=True,
|
||||
max_tokens=settings.llm.max_new_tokens,
|
||||
api_version="",
|
||||
temperature=settings.llm.temperature,
|
||||
context_window=settings.llm.context_window,
|
||||
messages_to_prompt=prompt_style.messages_to_prompt,
|
||||
completion_to_prompt=prompt_style.completion_to_prompt,
|
||||
tokenizer=settings.llm.tokenizer,
|
||||
timeout=openai_settings.request_timeout,
|
||||
reuse_client=False,
|
||||
)
|
||||
case "ollama":
|
||||
try:
|
||||
from llama_index.llms.ollama import Ollama # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Ollama dependencies not found, install with `poetry install --extras llms-ollama`"
|
||||
) from e
|
||||
|
||||
ollama_settings = settings.ollama
|
||||
|
||||
settings_kwargs = {
|
||||
"tfs_z": ollama_settings.tfs_z, # ollama and llama-cpp
|
||||
"num_predict": ollama_settings.num_predict, # ollama only
|
||||
"top_k": ollama_settings.top_k, # ollama and llama-cpp
|
||||
"top_p": ollama_settings.top_p, # ollama and llama-cpp
|
||||
"repeat_last_n": ollama_settings.repeat_last_n, # ollama
|
||||
"repeat_penalty": ollama_settings.repeat_penalty, # ollama llama-cpp
|
||||
}
|
||||
|
||||
# calculate llm model. If not provided tag, it will be use latest
|
||||
model_name = (
|
||||
ollama_settings.llm_model + ":latest"
|
||||
if ":" not in ollama_settings.llm_model
|
||||
else ollama_settings.llm_model
|
||||
)
|
||||
|
||||
llm = Ollama(
|
||||
model=model_name,
|
||||
base_url=ollama_settings.api_base,
|
||||
temperature=settings.llm.temperature,
|
||||
context_window=settings.llm.context_window,
|
||||
additional_kwargs=settings_kwargs,
|
||||
request_timeout=ollama_settings.request_timeout,
|
||||
)
|
||||
|
||||
if ollama_settings.autopull_models:
|
||||
from private_gpt.utils.ollama import check_connection, pull_model
|
||||
|
||||
if not check_connection(llm.client):
|
||||
raise ValueError(
|
||||
f"Failed to connect to Ollama, "
|
||||
f"check if Ollama server is running on {ollama_settings.api_base}"
|
||||
)
|
||||
pull_model(llm.client, model_name)
|
||||
|
||||
if (
|
||||
ollama_settings.keep_alive
|
||||
!= ollama_settings.model_fields["keep_alive"].default
|
||||
):
|
||||
# Modify Ollama methods to use the "keep_alive" field.
|
||||
def add_keep_alive(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
kwargs["keep_alive"] = ollama_settings.keep_alive
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
Ollama.chat = add_keep_alive(Ollama.chat) # type: ignore
|
||||
Ollama.stream_chat = add_keep_alive(Ollama.stream_chat) # type: ignore
|
||||
Ollama.complete = add_keep_alive(Ollama.complete) # type: ignore
|
||||
Ollama.stream_complete = add_keep_alive(Ollama.stream_complete) # type: ignore
|
||||
|
||||
self.llm = llm
|
||||
|
||||
case "azopenai":
|
||||
try:
|
||||
from llama_index.llms.azure_openai import ( # type: ignore
|
||||
AzureOpenAI,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Azure OpenAI dependencies not found, install with `poetry install --extras llms-azopenai`"
|
||||
) from e
|
||||
|
||||
azopenai_settings = settings.azopenai
|
||||
self.llm = AzureOpenAI(
|
||||
model=azopenai_settings.llm_model,
|
||||
deployment_name=azopenai_settings.llm_deployment_name,
|
||||
api_key=azopenai_settings.api_key,
|
||||
azure_endpoint=azopenai_settings.azure_endpoint,
|
||||
api_version=azopenai_settings.api_version,
|
||||
)
|
||||
case "gemini":
|
||||
try:
|
||||
from llama_index.llms.gemini import ( # type: ignore
|
||||
Gemini,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Google Gemini dependencies not found, install with `poetry install --extras llms-gemini`"
|
||||
) from e
|
||||
gemini_settings = settings.gemini
|
||||
self.llm = Gemini(
|
||||
model_name=gemini_settings.model, api_key=gemini_settings.api_key
|
||||
)
|
||||
case "mock":
|
||||
self.llm = MockLLM()
|
||||
310
pgpt/private_gpt/components/llm/prompt_helper.py
Normal file
310
pgpt/private_gpt/components/llm/prompt_helper.py
Normal file
@ -0,0 +1,310 @@
|
||||
import abc
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from llama_index.core.llms import ChatMessage, MessageRole
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbstractPromptStyle(abc.ABC):
|
||||
"""Abstract class for prompt styles.
|
||||
|
||||
This class is used to format a series of messages into a prompt that can be
|
||||
understood by the models. A series of messages represents the interaction(s)
|
||||
between a user and an assistant. This series of messages can be considered as a
|
||||
session between a user X and an assistant Y.This session holds, through the
|
||||
messages, the state of the conversation. This session, to be understood by the
|
||||
model, needs to be formatted into a prompt (i.e. a string that the models
|
||||
can understand). Prompts can be formatted in different ways,
|
||||
depending on the model.
|
||||
|
||||
The implementations of this class represent the different ways to format a
|
||||
series of messages into a prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
logger.debug("Initializing prompt_style=%s", self.__class__.__name__)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _completion_to_prompt(self, completion: str) -> str:
|
||||
pass
|
||||
|
||||
def messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||
prompt = self._messages_to_prompt(messages)
|
||||
logger.debug("Got for messages='%s' the prompt='%s'", messages, prompt)
|
||||
return prompt
|
||||
|
||||
def completion_to_prompt(self, prompt: str) -> str:
|
||||
completion = prompt # Fix: Llama-index parameter has to be named as prompt
|
||||
prompt = self._completion_to_prompt(completion)
|
||||
logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt)
|
||||
return prompt
|
||||
|
||||
|
||||
class DefaultPromptStyle(AbstractPromptStyle):
|
||||
"""Default prompt style that uses the defaults from llama_utils.
|
||||
|
||||
It basically passes None to the LLM, indicating it should use
|
||||
the default functions.
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Hacky way to override the functions
|
||||
# Override the functions to be None, and pass None to the LLM.
|
||||
self.messages_to_prompt = None # type: ignore[method-assign, assignment]
|
||||
self.completion_to_prompt = None # type: ignore[method-assign, assignment]
|
||||
|
||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||
return ""
|
||||
|
||||
def _completion_to_prompt(self, completion: str) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
class Llama2PromptStyle(AbstractPromptStyle):
|
||||
"""Simple prompt style that uses llama 2 prompt style.
|
||||
|
||||
Inspired by llama_index/legacy/llms/llama_utils.py
|
||||
|
||||
It transforms the sequence of messages into a prompt that should look like:
|
||||
```text
|
||||
<s> [INST] <<SYS>> your system prompt here. <</SYS>>
|
||||
|
||||
user message here [/INST] assistant (model) response here </s>
|
||||
```
|
||||
"""
|
||||
|
||||
BOS, EOS = "<s>", "</s>"
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
DEFAULT_SYSTEM_PROMPT = """\
|
||||
You are a helpful, respectful and honest assistant. \
|
||||
Always answer as helpfully as possible and follow ALL given instructions. \
|
||||
Do not speculate or make up information. \
|
||||
Do not reference any given instructions or context. \
|
||||
"""
|
||||
|
||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||
string_messages: list[str] = []
|
||||
if messages[0].role == MessageRole.SYSTEM:
|
||||
# pull out the system message (if it exists in messages)
|
||||
system_message_str = messages[0].content or ""
|
||||
messages = messages[1:]
|
||||
else:
|
||||
system_message_str = self.DEFAULT_SYSTEM_PROMPT
|
||||
|
||||
system_message_str = f"{self.B_SYS} {system_message_str.strip()} {self.E_SYS}"
|
||||
|
||||
for i in range(0, len(messages), 2):
|
||||
# first message should always be a user
|
||||
user_message = messages[i]
|
||||
assert user_message.role == MessageRole.USER
|
||||
|
||||
if i == 0:
|
||||
# make sure system prompt is included at the start
|
||||
str_message = f"{self.BOS} {self.B_INST} {system_message_str} "
|
||||
else:
|
||||
# end previous user-assistant interaction
|
||||
string_messages[-1] += f" {self.EOS}"
|
||||
# no need to include system prompt
|
||||
str_message = f"{self.BOS} {self.B_INST} "
|
||||
|
||||
# include user message content
|
||||
str_message += f"{user_message.content} {self.E_INST}"
|
||||
|
||||
if len(messages) > (i + 1):
|
||||
# if assistant message exists, add to str_message
|
||||
assistant_message = messages[i + 1]
|
||||
assert assistant_message.role == MessageRole.ASSISTANT
|
||||
str_message += f" {assistant_message.content}"
|
||||
|
||||
string_messages.append(str_message)
|
||||
|
||||
return "".join(string_messages)
|
||||
|
||||
def _completion_to_prompt(self, completion: str) -> str:
|
||||
system_prompt_str = self.DEFAULT_SYSTEM_PROMPT
|
||||
|
||||
return (
|
||||
f"{self.BOS} {self.B_INST} {self.B_SYS} {system_prompt_str.strip()} {self.E_SYS} "
|
||||
f"{completion.strip()} {self.E_INST}"
|
||||
)
|
||||
|
||||
|
||||
class Llama3PromptStyle(AbstractPromptStyle):
|
||||
r"""Template for Meta's Llama 3.1.
|
||||
|
||||
The format follows this structure:
|
||||
<|begin_of_text|>
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
[System message content]<|eot_id|>
|
||||
<|start_header_id|>user<|end_header_id|>
|
||||
|
||||
[User message content]<|eot_id|>
|
||||
<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
[Assistant message content]<|eot_id|>
|
||||
...
|
||||
(Repeat for each message, including possible 'ipython' role)
|
||||
"""
|
||||
|
||||
BOS, EOS = "<|begin_of_text|>", "<|end_of_text|>"
|
||||
B_INST, E_INST = "<|start_header_id|>", "<|end_header_id|>"
|
||||
EOT = "<|eot_id|>"
|
||||
B_SYS, E_SYS = "<|start_header_id|>system<|end_header_id|>", "<|eot_id|>"
|
||||
ASSISTANT_INST = "<|start_header_id|>assistant<|end_header_id|>"
|
||||
DEFAULT_SYSTEM_PROMPT = """\
|
||||
You are a helpful, respectful and honest assistant. \
|
||||
Always answer as helpfully as possible and follow ALL given instructions. \
|
||||
Do not speculate or make up information. \
|
||||
Do not reference any given instructions or context. \
|
||||
"""
|
||||
|
||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||
prompt = ""
|
||||
has_system_message = False
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
if not message or message.content is None:
|
||||
continue
|
||||
if message.role == MessageRole.SYSTEM:
|
||||
prompt += f"{self.B_SYS}\n\n{message.content.strip()}{self.E_SYS}"
|
||||
has_system_message = True
|
||||
else:
|
||||
role_header = f"{self.B_INST}{message.role.value}{self.E_INST}"
|
||||
prompt += f"{role_header}\n\n{message.content.strip()}{self.EOT}"
|
||||
|
||||
# Add assistant header if the last message is not from the assistant
|
||||
if i == len(messages) - 1 and message.role != MessageRole.ASSISTANT:
|
||||
prompt += f"{self.ASSISTANT_INST}\n\n"
|
||||
|
||||
# Add default system prompt if no system message was provided
|
||||
if not has_system_message:
|
||||
prompt = (
|
||||
f"{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}" + prompt
|
||||
)
|
||||
|
||||
# TODO: Implement tool handling logic
|
||||
|
||||
return prompt
|
||||
|
||||
def _completion_to_prompt(self, completion: str) -> str:
|
||||
return (
|
||||
f"{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}"
|
||||
f"{self.B_INST}user{self.E_INST}\n\n{completion.strip()}{self.EOT}"
|
||||
f"{self.ASSISTANT_INST}\n\n"
|
||||
)
|
||||
|
||||
|
||||
class TagPromptStyle(AbstractPromptStyle):
|
||||
"""Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`.
|
||||
|
||||
It transforms the sequence of messages into a prompt that should look like:
|
||||
```text
|
||||
<|system|>: your system prompt here.
|
||||
<|user|>: user message here
|
||||
(possibly with context and question)
|
||||
<|assistant|>: assistant (model) response here.
|
||||
```
|
||||
|
||||
FIXME: should we add surrounding `<s>` and `</s>` tags, like in llama2?
|
||||
"""
|
||||
|
||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||
"""Format message to prompt with `<|ROLE|>: MSG` style."""
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
role = message.role
|
||||
content = message.content or ""
|
||||
message_from_user = f"<|{role.lower()}|>: {content.strip()}"
|
||||
message_from_user += "\n"
|
||||
prompt += message_from_user
|
||||
# we are missing the last <|assistant|> tag that will trigger a completion
|
||||
prompt += "<|assistant|>: "
|
||||
return prompt
|
||||
|
||||
def _completion_to_prompt(self, completion: str) -> str:
|
||||
return self._messages_to_prompt(
|
||||
[ChatMessage(content=completion, role=MessageRole.USER)]
|
||||
)
|
||||
|
||||
|
||||
class MistralPromptStyle(AbstractPromptStyle):
|
||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||
inst_buffer = []
|
||||
text = ""
|
||||
for message in messages:
|
||||
if message.role == MessageRole.SYSTEM or message.role == MessageRole.USER:
|
||||
inst_buffer.append(str(message.content).strip())
|
||||
elif message.role == MessageRole.ASSISTANT:
|
||||
text += "<s>[INST] " + "\n".join(inst_buffer) + " [/INST]"
|
||||
text += " " + str(message.content).strip() + "</s>"
|
||||
inst_buffer.clear()
|
||||
else:
|
||||
raise ValueError(f"Unknown message role {message.role}")
|
||||
|
||||
if len(inst_buffer) > 0:
|
||||
text += "<s>[INST] " + "\n".join(inst_buffer) + " [/INST]"
|
||||
|
||||
return text
|
||||
|
||||
def _completion_to_prompt(self, completion: str) -> str:
|
||||
return self._messages_to_prompt(
|
||||
[ChatMessage(content=completion, role=MessageRole.USER)]
|
||||
)
|
||||
|
||||
|
||||
class ChatMLPromptStyle(AbstractPromptStyle):
|
||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||
prompt = "<|im_start|>system\n"
|
||||
for message in messages:
|
||||
role = message.role
|
||||
content = message.content or ""
|
||||
if role.lower() == "system":
|
||||
message_from_user = f"{content.strip()}"
|
||||
prompt += message_from_user
|
||||
elif role.lower() == "user":
|
||||
prompt += "<|im_end|>\n<|im_start|>user\n"
|
||||
message_from_user = f"{content.strip()}<|im_end|>\n"
|
||||
prompt += message_from_user
|
||||
prompt += "<|im_start|>assistant\n"
|
||||
return prompt
|
||||
|
||||
def _completion_to_prompt(self, completion: str) -> str:
|
||||
return self._messages_to_prompt(
|
||||
[ChatMessage(content=completion, role=MessageRole.USER)]
|
||||
)
|
||||
|
||||
|
||||
def get_prompt_style(
|
||||
prompt_style: (
|
||||
Literal["default", "llama2", "llama3", "tag", "mistral", "chatml"] | None
|
||||
)
|
||||
) -> AbstractPromptStyle:
|
||||
"""Get the prompt style to use from the given string.
|
||||
|
||||
:param prompt_style: The prompt style to use.
|
||||
:return: The prompt style to use.
|
||||
"""
|
||||
if prompt_style is None or prompt_style == "default":
|
||||
return DefaultPromptStyle()
|
||||
elif prompt_style == "llama2":
|
||||
return Llama2PromptStyle()
|
||||
elif prompt_style == "llama3":
|
||||
return Llama3PromptStyle()
|
||||
elif prompt_style == "tag":
|
||||
return TagPromptStyle()
|
||||
elif prompt_style == "mistral":
|
||||
return MistralPromptStyle()
|
||||
elif prompt_style == "chatml":
|
||||
return ChatMLPromptStyle()
|
||||
raise ValueError(f"Unknown prompt_style='{prompt_style}'")
|
||||
0
pgpt/private_gpt/components/node_store/__init__.py
Normal file
0
pgpt/private_gpt/components/node_store/__init__.py
Normal file
@ -0,0 +1,68 @@
|
||||
import logging
|
||||
|
||||
from injector import inject, singleton
|
||||
from llama_index.core.storage.docstore import BaseDocumentStore, SimpleDocumentStore
|
||||
from llama_index.core.storage.index_store import SimpleIndexStore
|
||||
from llama_index.core.storage.index_store.types import BaseIndexStore
|
||||
|
||||
from private_gpt.paths import local_data_path
|
||||
from private_gpt.settings.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@singleton
|
||||
class NodeStoreComponent:
|
||||
index_store: BaseIndexStore
|
||||
doc_store: BaseDocumentStore
|
||||
|
||||
@inject
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
match settings.nodestore.database:
|
||||
case "simple":
|
||||
try:
|
||||
self.index_store = SimpleIndexStore.from_persist_dir(
|
||||
persist_dir=str(local_data_path)
|
||||
)
|
||||
except FileNotFoundError:
|
||||
logger.debug("Local index store not found, creating a new one")
|
||||
self.index_store = SimpleIndexStore()
|
||||
|
||||
try:
|
||||
self.doc_store = SimpleDocumentStore.from_persist_dir(
|
||||
persist_dir=str(local_data_path)
|
||||
)
|
||||
except FileNotFoundError:
|
||||
logger.debug("Local document store not found, creating a new one")
|
||||
self.doc_store = SimpleDocumentStore()
|
||||
|
||||
case "postgres":
|
||||
try:
|
||||
from llama_index.storage.docstore.postgres import ( # type: ignore
|
||||
PostgresDocumentStore,
|
||||
)
|
||||
from llama_index.storage.index_store.postgres import ( # type: ignore
|
||||
PostgresIndexStore,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Postgres dependencies not found, install with `poetry install --extras storage-nodestore-postgres`"
|
||||
) from None
|
||||
|
||||
if settings.postgres is None:
|
||||
raise ValueError("Postgres index/doc store settings not found.")
|
||||
|
||||
self.index_store = PostgresIndexStore.from_params(
|
||||
**settings.postgres.model_dump(exclude_none=True)
|
||||
)
|
||||
|
||||
self.doc_store = PostgresDocumentStore.from_params(
|
||||
**settings.postgres.model_dump(exclude_none=True)
|
||||
)
|
||||
|
||||
case _:
|
||||
# Should be unreachable
|
||||
# The settings validator should have caught this
|
||||
raise ValueError(
|
||||
f"Database {settings.nodestore.database} not supported"
|
||||
)
|
||||
106
pgpt/private_gpt/components/vector_store/batched_chroma.py
Normal file
106
pgpt/private_gpt/components/vector_store/batched_chroma.py
Normal file
@ -0,0 +1,106 @@
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from llama_index.core.schema import BaseNode, MetadataMode
|
||||
from llama_index.core.vector_stores.utils import node_to_metadata_dict
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore # type: ignore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
|
||||
|
||||
def chunk_list(
|
||||
lst: Sequence[BaseNode], max_chunk_size: int
|
||||
) -> Generator[Sequence[BaseNode], None, None]:
|
||||
"""Yield successive max_chunk_size-sized chunks from lst.
|
||||
|
||||
Args:
|
||||
lst (List[BaseNode]): list of nodes with embeddings
|
||||
max_chunk_size (int): max chunk size
|
||||
|
||||
Yields:
|
||||
Generator[List[BaseNode], None, None]: list of nodes with embeddings
|
||||
"""
|
||||
for i in range(0, len(lst), max_chunk_size):
|
||||
yield lst[i : i + max_chunk_size]
|
||||
|
||||
|
||||
class BatchedChromaVectorStore(ChromaVectorStore): # type: ignore
|
||||
"""Chroma vector store, batching additions to avoid reaching the max batch limit.
|
||||
|
||||
In this vector store, embeddings are stored within a ChromaDB collection.
|
||||
|
||||
During query time, the index uses ChromaDB to query for the top
|
||||
k most similar nodes.
|
||||
|
||||
Args:
|
||||
chroma_client (from chromadb.api.API):
|
||||
API instance
|
||||
chroma_collection (chromadb.api.models.Collection.Collection):
|
||||
ChromaDB collection instance
|
||||
|
||||
"""
|
||||
|
||||
chroma_client: Any | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chroma_client: Any,
|
||||
chroma_collection: Any,
|
||||
host: str | None = None,
|
||||
port: str | None = None,
|
||||
ssl: bool = False,
|
||||
headers: dict[str, str] | None = None,
|
||||
collection_kwargs: dict[Any, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
chroma_collection=chroma_collection,
|
||||
host=host,
|
||||
port=port,
|
||||
ssl=ssl,
|
||||
headers=headers,
|
||||
collection_kwargs=collection_kwargs or {},
|
||||
)
|
||||
self.chroma_client = chroma_client
|
||||
|
||||
def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]:
|
||||
"""Add nodes to index, batching the insertion to avoid issues.
|
||||
|
||||
Args:
|
||||
nodes: List[BaseNode]: list of nodes with embeddings
|
||||
add_kwargs: _
|
||||
"""
|
||||
if not self.chroma_client:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
if not self._collection:
|
||||
raise ValueError("Collection not initialized")
|
||||
|
||||
max_chunk_size = self.chroma_client.max_batch_size
|
||||
node_chunks = chunk_list(nodes, max_chunk_size)
|
||||
|
||||
all_ids = []
|
||||
for node_chunk in node_chunks:
|
||||
embeddings: list[Sequence[float]] = []
|
||||
metadatas: list[Mapping[str, Any]] = []
|
||||
ids = []
|
||||
documents = []
|
||||
for node in node_chunk:
|
||||
embeddings.append(node.get_embedding())
|
||||
metadatas.append(
|
||||
node_to_metadata_dict(
|
||||
node, remove_text=True, flat_metadata=self.flat_metadata
|
||||
)
|
||||
)
|
||||
ids.append(node.node_id)
|
||||
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
|
||||
|
||||
self._collection.add(
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
)
|
||||
all_ids.extend(ids)
|
||||
|
||||
return all_ids
|
||||
@ -0,0 +1,217 @@
|
||||
import logging
|
||||
import typing
|
||||
|
||||
from injector import inject, singleton
|
||||
from llama_index.core.indices.vector_store import VectorIndexRetriever, VectorStoreIndex
|
||||
from llama_index.core.vector_stores.types import (
|
||||
BasePydanticVectorStore,
|
||||
FilterCondition,
|
||||
MetadataFilter,
|
||||
MetadataFilters,
|
||||
)
|
||||
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.paths import local_data_path
|
||||
from private_gpt.settings.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _doc_id_metadata_filter(
|
||||
context_filter: ContextFilter | None,
|
||||
) -> MetadataFilters:
|
||||
filters = MetadataFilters(filters=[], condition=FilterCondition.OR)
|
||||
|
||||
if context_filter is not None and context_filter.docs_ids is not None:
|
||||
for doc_id in context_filter.docs_ids:
|
||||
filters.filters.append(MetadataFilter(key="doc_id", value=doc_id))
|
||||
|
||||
return filters
|
||||
|
||||
|
||||
@singleton
|
||||
class VectorStoreComponent:
|
||||
settings: Settings
|
||||
vector_store: BasePydanticVectorStore
|
||||
|
||||
@inject
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self.settings = settings
|
||||
match settings.vectorstore.database:
|
||||
case "postgres":
|
||||
try:
|
||||
from llama_index.vector_stores.postgres import ( # type: ignore
|
||||
PGVectorStore,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Postgres dependencies not found, install with `poetry install --extras vector-stores-postgres`"
|
||||
) from e
|
||||
|
||||
if settings.postgres is None:
|
||||
raise ValueError(
|
||||
"Postgres settings not found. Please provide settings."
|
||||
)
|
||||
|
||||
self.vector_store = typing.cast(
|
||||
BasePydanticVectorStore,
|
||||
PGVectorStore.from_params(
|
||||
**settings.postgres.model_dump(exclude_none=True),
|
||||
table_name="embeddings",
|
||||
embed_dim=settings.embedding.embed_dim,
|
||||
),
|
||||
)
|
||||
|
||||
case "chroma":
|
||||
try:
|
||||
import chromadb # type: ignore
|
||||
from chromadb.config import ( # type: ignore
|
||||
Settings as ChromaSettings,
|
||||
)
|
||||
|
||||
from private_gpt.components.vector_store.batched_chroma import (
|
||||
BatchedChromaVectorStore,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ChromaDB dependencies not found, install with `poetry install --extras vector-stores-chroma`"
|
||||
) from e
|
||||
|
||||
chroma_settings = ChromaSettings(anonymized_telemetry=False)
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=str((local_data_path / "chroma_db").absolute()),
|
||||
settings=chroma_settings,
|
||||
)
|
||||
chroma_collection = chroma_client.get_or_create_collection(
|
||||
"make_this_parameterizable_per_api_call"
|
||||
) # TODO
|
||||
|
||||
self.vector_store = typing.cast(
|
||||
BasePydanticVectorStore,
|
||||
BatchedChromaVectorStore(
|
||||
chroma_client=chroma_client, chroma_collection=chroma_collection
|
||||
),
|
||||
)
|
||||
|
||||
case "qdrant":
|
||||
try:
|
||||
from llama_index.vector_stores.qdrant import ( # type: ignore
|
||||
QdrantVectorStore,
|
||||
)
|
||||
from qdrant_client import QdrantClient # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Qdrant dependencies not found, install with `poetry install --extras vector-stores-qdrant`"
|
||||
) from e
|
||||
|
||||
if settings.qdrant is None:
|
||||
logger.info(
|
||||
"Qdrant config not found. Using default settings."
|
||||
"Trying to connect to Qdrant at localhost:6333."
|
||||
)
|
||||
client = QdrantClient()
|
||||
else:
|
||||
client = QdrantClient(
|
||||
**settings.qdrant.model_dump(exclude_none=True)
|
||||
)
|
||||
self.vector_store = typing.cast(
|
||||
BasePydanticVectorStore,
|
||||
QdrantVectorStore(
|
||||
client=client,
|
||||
collection_name="make_this_parameterizable_per_api_call",
|
||||
), # TODO
|
||||
)
|
||||
|
||||
case "milvus":
|
||||
try:
|
||||
from llama_index.vector_stores.milvus import ( # type: ignore
|
||||
MilvusVectorStore,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Milvus dependencies not found, install with `poetry install --extras vector-stores-milvus`"
|
||||
) from e
|
||||
|
||||
if settings.milvus is None:
|
||||
logger.info(
|
||||
"Milvus config not found. Using default settings.\n"
|
||||
"Trying to connect to Milvus at local_data/private_gpt/milvus/milvus_local.db "
|
||||
"with collection 'make_this_parameterizable_per_api_call'."
|
||||
)
|
||||
|
||||
self.vector_store = typing.cast(
|
||||
BasePydanticVectorStore,
|
||||
MilvusVectorStore(
|
||||
dim=settings.embedding.embed_dim,
|
||||
collection_name="make_this_parameterizable_per_api_call",
|
||||
overwrite=True,
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
self.vector_store = typing.cast(
|
||||
BasePydanticVectorStore,
|
||||
MilvusVectorStore(
|
||||
dim=settings.embedding.embed_dim,
|
||||
uri=settings.milvus.uri,
|
||||
token=settings.milvus.token,
|
||||
collection_name=settings.milvus.collection_name,
|
||||
overwrite=settings.milvus.overwrite,
|
||||
),
|
||||
)
|
||||
|
||||
case "clickhouse":
|
||||
try:
|
||||
from clickhouse_connect import ( # type: ignore
|
||||
get_client,
|
||||
)
|
||||
from llama_index.vector_stores.clickhouse import ( # type: ignore
|
||||
ClickHouseVectorStore,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ClickHouse dependencies not found, install with `poetry install --extras vector-stores-clickhouse`"
|
||||
) from e
|
||||
|
||||
if settings.clickhouse is None:
|
||||
raise ValueError(
|
||||
"ClickHouse settings not found. Please provide settings."
|
||||
)
|
||||
|
||||
clickhouse_client = get_client(
|
||||
host=settings.clickhouse.host,
|
||||
port=settings.clickhouse.port,
|
||||
username=settings.clickhouse.username,
|
||||
password=settings.clickhouse.password,
|
||||
)
|
||||
self.vector_store = ClickHouseVectorStore(
|
||||
clickhouse_client=clickhouse_client
|
||||
)
|
||||
case _:
|
||||
# Should be unreachable
|
||||
# The settings validator should have caught this
|
||||
raise ValueError(
|
||||
f"Vectorstore database {settings.vectorstore.database} not supported"
|
||||
)
|
||||
|
||||
def get_retriever(
|
||||
self,
|
||||
index: VectorStoreIndex,
|
||||
context_filter: ContextFilter | None = None,
|
||||
similarity_top_k: int = 2,
|
||||
) -> VectorIndexRetriever:
|
||||
# This way we support qdrant (using doc_ids) and the rest (using filters)
|
||||
return VectorIndexRetriever(
|
||||
index=index,
|
||||
similarity_top_k=similarity_top_k,
|
||||
doc_ids=context_filter.docs_ids if context_filter else None,
|
||||
filters=(
|
||||
_doc_id_metadata_filter(context_filter)
|
||||
if self.settings.vectorstore.database != "qdrant"
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
if hasattr(self.vector_store.client, "close"):
|
||||
self.vector_store.client.close()
|
||||
3
pgpt/private_gpt/constants.py
Normal file
3
pgpt/private_gpt/constants.py
Normal file
@ -0,0 +1,3 @@
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT_PATH: Path = Path(__file__).parents[1]
|
||||
19
pgpt/private_gpt/di.py
Normal file
19
pgpt/private_gpt/di.py
Normal file
@ -0,0 +1,19 @@
|
||||
from injector import Injector
|
||||
|
||||
from private_gpt.settings.settings import Settings, unsafe_typed_settings
|
||||
|
||||
|
||||
def create_application_injector() -> Injector:
|
||||
_injector = Injector(auto_bind=True)
|
||||
_injector.binder.bind(Settings, to=unsafe_typed_settings)
|
||||
return _injector
|
||||
|
||||
|
||||
"""
|
||||
Global injector for the application.
|
||||
|
||||
Avoid using this reference, it will make your code harder to test.
|
||||
|
||||
Instead, use the `request.state.injector` reference, which is bound to every request
|
||||
"""
|
||||
global_injector: Injector = create_application_injector()
|
||||
69
pgpt/private_gpt/launcher.py
Normal file
69
pgpt/private_gpt/launcher.py
Normal file
@ -0,0 +1,69 @@
|
||||
"""FastAPI app creation, logger configuration and main API routes."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import Depends, FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from injector import Injector
|
||||
from llama_index.core.callbacks import CallbackManager
|
||||
from llama_index.core.callbacks.global_handlers import create_global_handler
|
||||
from llama_index.core.settings import Settings as LlamaIndexSettings
|
||||
|
||||
from private_gpt.server.chat.chat_router import chat_router
|
||||
from private_gpt.server.chunks.chunks_router import chunks_router
|
||||
from private_gpt.server.completions.completions_router import completions_router
|
||||
from private_gpt.server.embeddings.embeddings_router import embeddings_router
|
||||
from private_gpt.server.health.health_router import health_router
|
||||
from private_gpt.server.ingest.ingest_router import ingest_router
|
||||
from private_gpt.server.recipes.summarize.summarize_router import summarize_router
|
||||
from private_gpt.settings.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_app(root_injector: Injector) -> FastAPI:
|
||||
|
||||
# Start the API
|
||||
async def bind_injector_to_request(request: Request) -> None:
|
||||
request.state.injector = root_injector
|
||||
|
||||
app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
|
||||
|
||||
app.include_router(completions_router)
|
||||
app.include_router(chat_router)
|
||||
app.include_router(chunks_router)
|
||||
app.include_router(ingest_router)
|
||||
app.include_router(summarize_router)
|
||||
app.include_router(embeddings_router)
|
||||
app.include_router(health_router)
|
||||
|
||||
# Add LlamaIndex simple observability
|
||||
global_handler = create_global_handler("simple")
|
||||
if global_handler:
|
||||
LlamaIndexSettings.callback_manager = CallbackManager([global_handler])
|
||||
|
||||
settings = root_injector.get(Settings)
|
||||
if settings.server.cors.enabled:
|
||||
logger.debug("Setting up CORS middleware")
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_credentials=settings.server.cors.allow_credentials,
|
||||
allow_origins=settings.server.cors.allow_origins,
|
||||
allow_origin_regex=settings.server.cors.allow_origin_regex,
|
||||
allow_methods=settings.server.cors.allow_methods,
|
||||
allow_headers=settings.server.cors.allow_headers,
|
||||
)
|
||||
|
||||
if settings.ui.enabled:
|
||||
logger.debug("Importing the UI module")
|
||||
try:
|
||||
from private_gpt.ui.ui import PrivateGptUi
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"UI dependencies not found, install with `poetry install --extras ui`"
|
||||
) from e
|
||||
|
||||
ui = root_injector.get(PrivateGptUi)
|
||||
ui.mount_in_app(app, settings.ui.path)
|
||||
|
||||
return app
|
||||
6
pgpt/private_gpt/main.py
Normal file
6
pgpt/private_gpt/main.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""FastAPI app creation, logger configuration and main API routes."""
|
||||
|
||||
from private_gpt.di import global_injector
|
||||
from private_gpt.launcher import create_app
|
||||
|
||||
app = create_app(global_injector)
|
||||
1
pgpt/private_gpt/open_ai/__init__.py
Normal file
1
pgpt/private_gpt/open_ai/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""OpenAI compatibility utilities."""
|
||||
1
pgpt/private_gpt/open_ai/extensions/__init__.py
Normal file
1
pgpt/private_gpt/open_ai/extensions/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""OpenAI API extensions."""
|
||||
7
pgpt/private_gpt/open_ai/extensions/context_filter.py
Normal file
7
pgpt/private_gpt/open_ai/extensions/context_filter.py
Normal file
@ -0,0 +1,7 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ContextFilter(BaseModel):
|
||||
docs_ids: list[str] | None = Field(
|
||||
examples=[["c202d5e6-7b69-4869-81cc-dd574ee8ee11"]]
|
||||
)
|
||||
122
pgpt/private_gpt/open_ai/openai_models.py
Normal file
122
pgpt/private_gpt/open_ai/openai_models.py
Normal file
@ -0,0 +1,122 @@
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
from typing import Literal
|
||||
|
||||
from llama_index.core.llms import ChatResponse, CompletionResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from private_gpt.server.chunks.chunks_service import Chunk
|
||||
|
||||
|
||||
class OpenAIDelta(BaseModel):
|
||||
"""A piece of completion that needs to be concatenated to get the full message."""
|
||||
|
||||
content: str | None
|
||||
|
||||
|
||||
class OpenAIMessage(BaseModel):
|
||||
"""Inference result, with the source of the message.
|
||||
|
||||
Role could be the assistant or system
|
||||
(providing a default response, not AI generated).
|
||||
"""
|
||||
|
||||
role: Literal["assistant", "system", "user"] = Field(default="user")
|
||||
content: str | None
|
||||
|
||||
|
||||
class OpenAIChoice(BaseModel):
|
||||
"""Response from AI.
|
||||
|
||||
Either the delta or the message will be present, but never both.
|
||||
Sources used will be returned in case context retrieval was enabled.
|
||||
"""
|
||||
|
||||
finish_reason: str | None = Field(examples=["stop"])
|
||||
delta: OpenAIDelta | None = None
|
||||
message: OpenAIMessage | None = None
|
||||
sources: list[Chunk] | None = None
|
||||
index: int = 0
|
||||
|
||||
|
||||
class OpenAICompletion(BaseModel):
|
||||
"""Clone of OpenAI Completion model.
|
||||
|
||||
For more information see: https://platform.openai.com/docs/api-reference/chat/object
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: Literal["completion", "completion.chunk"] = Field(default="completion")
|
||||
created: int = Field(..., examples=[1623340000])
|
||||
model: Literal["private-gpt"]
|
||||
choices: list[OpenAIChoice]
|
||||
|
||||
@classmethod
|
||||
def from_text(
|
||||
cls,
|
||||
text: str | None,
|
||||
finish_reason: str | None = None,
|
||||
sources: list[Chunk] | None = None,
|
||||
) -> "OpenAICompletion":
|
||||
return OpenAICompletion(
|
||||
id=str(uuid.uuid4()),
|
||||
object="completion",
|
||||
created=int(time.time()),
|
||||
model="private-gpt",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=OpenAIMessage(role="assistant", content=text),
|
||||
finish_reason=finish_reason,
|
||||
sources=sources,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def json_from_delta(
|
||||
cls,
|
||||
*,
|
||||
text: str | None,
|
||||
finish_reason: str | None = None,
|
||||
sources: list[Chunk] | None = None,
|
||||
) -> str:
|
||||
chunk = OpenAICompletion(
|
||||
id=str(uuid.uuid4()),
|
||||
object="completion.chunk",
|
||||
created=int(time.time()),
|
||||
model="private-gpt",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
delta=OpenAIDelta(content=text),
|
||||
finish_reason=finish_reason,
|
||||
sources=sources,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return chunk.model_dump_json()
|
||||
|
||||
|
||||
def to_openai_response(
|
||||
response: str | ChatResponse, sources: list[Chunk] | None = None
|
||||
) -> OpenAICompletion:
|
||||
if isinstance(response, ChatResponse):
|
||||
return OpenAICompletion.from_text(response.delta, finish_reason="stop")
|
||||
else:
|
||||
return OpenAICompletion.from_text(
|
||||
response, finish_reason="stop", sources=sources
|
||||
)
|
||||
|
||||
|
||||
def to_openai_sse_stream(
|
||||
response_generator: Iterator[str | CompletionResponse | ChatResponse],
|
||||
sources: list[Chunk] | None = None,
|
||||
) -> Iterator[str]:
|
||||
for response in response_generator:
|
||||
if isinstance(response, CompletionResponse | ChatResponse):
|
||||
yield f"data: {OpenAICompletion.json_from_delta(text=response.delta)}\n\n"
|
||||
else:
|
||||
yield f"data: {OpenAICompletion.json_from_delta(text=response, sources=sources)}\n\n"
|
||||
yield f"data: {OpenAICompletion.json_from_delta(text='', finish_reason='stop')}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
18
pgpt/private_gpt/paths.py
Normal file
18
pgpt/private_gpt/paths.py
Normal file
@ -0,0 +1,18 @@
|
||||
from pathlib import Path
|
||||
|
||||
from private_gpt.constants import PROJECT_ROOT_PATH
|
||||
from private_gpt.settings.settings import settings
|
||||
|
||||
|
||||
def _absolute_or_from_project_root(path: str) -> Path:
|
||||
if path.startswith("/"):
|
||||
return Path(path)
|
||||
return PROJECT_ROOT_PATH / path
|
||||
|
||||
|
||||
models_path: Path = PROJECT_ROOT_PATH / "models"
|
||||
models_cache_path: Path = models_path / "cache"
|
||||
docs_path: Path = PROJECT_ROOT_PATH / "docs"
|
||||
local_data_path: Path = _absolute_or_from_project_root(
|
||||
settings().data.local_data_folder
|
||||
)
|
||||
1
pgpt/private_gpt/server/__init__.py
Normal file
1
pgpt/private_gpt/server/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""private-gpt server."""
|
||||
0
pgpt/private_gpt/server/chat/__init__.py
Normal file
0
pgpt/private_gpt/server/chat/__init__.py
Normal file
115
pgpt/private_gpt/server/chat/chat_router.py
Normal file
115
pgpt/private_gpt/server/chat/chat_router.py
Normal file
@ -0,0 +1,115 @@
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from llama_index.core.llms import ChatMessage, MessageRole
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.open_ai.openai_models import (
|
||||
OpenAICompletion,
|
||||
OpenAIMessage,
|
||||
to_openai_response,
|
||||
to_openai_sse_stream,
|
||||
)
|
||||
from private_gpt.server.chat.chat_service import ChatService
|
||||
from private_gpt.server.utils.auth import authenticated
|
||||
|
||||
chat_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
|
||||
|
||||
|
||||
class ChatBody(BaseModel):
|
||||
messages: list[OpenAIMessage]
|
||||
use_context: bool = False
|
||||
context_filter: ContextFilter | None = None
|
||||
include_sources: bool = True
|
||||
stream: bool = False
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a rapper. Always answer with a rap.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "How do you fry an egg?",
|
||||
},
|
||||
],
|
||||
"stream": False,
|
||||
"use_context": True,
|
||||
"include_sources": True,
|
||||
"context_filter": {
|
||||
"docs_ids": ["c202d5e6-7b69-4869-81cc-dd574ee8ee11"]
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@chat_router.post(
|
||||
"/chat/completions",
|
||||
response_model=None,
|
||||
responses={200: {"model": OpenAICompletion}},
|
||||
tags=["Contextual Completions"],
|
||||
openapi_extra={
|
||||
"x-fern-streaming": {
|
||||
"stream-condition": "stream",
|
||||
"response": {"$ref": "#/components/schemas/OpenAICompletion"},
|
||||
"response-stream": {"$ref": "#/components/schemas/OpenAICompletion"},
|
||||
}
|
||||
},
|
||||
)
|
||||
def chat_completion(
|
||||
request: Request, body: ChatBody
|
||||
) -> OpenAICompletion | StreamingResponse:
|
||||
"""Given a list of messages comprising a conversation, return a response.
|
||||
|
||||
Optionally include an initial `role: system` message to influence the way
|
||||
the LLM answers.
|
||||
|
||||
If `use_context` is set to `true`, the model will use context coming
|
||||
from the ingested documents to create the response. The documents being used can
|
||||
be filtered using the `context_filter` and passing the document IDs to be used.
|
||||
Ingested documents IDs can be found using `/ingest/list` endpoint. If you want
|
||||
all ingested documents to be used, remove `context_filter` altogether.
|
||||
|
||||
When using `'include_sources': true`, the API will return the source Chunks used
|
||||
to create the response, which come from the context provided.
|
||||
|
||||
When using `'stream': true`, the API will return data chunks following [OpenAI's
|
||||
streaming model](https://platform.openai.com/docs/api-reference/chat/streaming):
|
||||
```
|
||||
{"id":"12345","object":"completion.chunk","created":1694268190,
|
||||
"model":"private-gpt","choices":[{"index":0,"delta":{"content":"Hello"},
|
||||
"finish_reason":null}]}
|
||||
```
|
||||
"""
|
||||
service = request.state.injector.get(ChatService)
|
||||
all_messages = [
|
||||
ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages
|
||||
]
|
||||
if body.stream:
|
||||
completion_gen = service.stream_chat(
|
||||
messages=all_messages,
|
||||
use_context=body.use_context,
|
||||
context_filter=body.context_filter,
|
||||
)
|
||||
return StreamingResponse(
|
||||
to_openai_sse_stream(
|
||||
completion_gen.response,
|
||||
completion_gen.sources if body.include_sources else None,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
completion = service.chat(
|
||||
messages=all_messages,
|
||||
use_context=body.use_context,
|
||||
context_filter=body.context_filter,
|
||||
)
|
||||
return to_openai_response(
|
||||
completion.response, completion.sources if body.include_sources else None
|
||||
)
|
||||
217
pgpt/private_gpt/server/chat/chat_service.py
Normal file
217
pgpt/private_gpt/server/chat/chat_service.py
Normal file
@ -0,0 +1,217 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from injector import inject, singleton
|
||||
from llama_index.core.chat_engine import ContextChatEngine, SimpleChatEngine
|
||||
from llama_index.core.chat_engine.types import (
|
||||
BaseChatEngine,
|
||||
)
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor
|
||||
from llama_index.core.llms import ChatMessage, MessageRole
|
||||
from llama_index.core.postprocessor import (
|
||||
SentenceTransformerRerank,
|
||||
SimilarityPostprocessor,
|
||||
)
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.core.types import TokenGen
|
||||
from pydantic import BaseModel
|
||||
|
||||
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
|
||||
from private_gpt.components.llm.llm_component import LLMComponent
|
||||
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
|
||||
from private_gpt.components.vector_store.vector_store_component import (
|
||||
VectorStoreComponent,
|
||||
)
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.server.chunks.chunks_service import Chunk
|
||||
from private_gpt.settings.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
|
||||
|
||||
class Completion(BaseModel):
|
||||
response: str
|
||||
sources: list[Chunk] | None = None
|
||||
|
||||
|
||||
class CompletionGen(BaseModel):
|
||||
response: TokenGen
|
||||
sources: list[Chunk] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatEngineInput:
|
||||
system_message: ChatMessage | None = None
|
||||
last_message: ChatMessage | None = None
|
||||
chat_history: list[ChatMessage] | None = None
|
||||
|
||||
@classmethod
|
||||
def from_messages(cls, messages: list[ChatMessage]) -> "ChatEngineInput":
|
||||
# Detect if there is a system message, extract the last message and chat history
|
||||
system_message = (
|
||||
messages[0]
|
||||
if len(messages) > 0 and messages[0].role == MessageRole.SYSTEM
|
||||
else None
|
||||
)
|
||||
last_message = (
|
||||
messages[-1]
|
||||
if len(messages) > 0 and messages[-1].role == MessageRole.USER
|
||||
else None
|
||||
)
|
||||
# Remove from messages list the system message and last message,
|
||||
# if they exist. The rest is the chat history.
|
||||
if system_message:
|
||||
messages.pop(0)
|
||||
if last_message:
|
||||
messages.pop(-1)
|
||||
chat_history = messages if len(messages) > 0 else None
|
||||
|
||||
return cls(
|
||||
system_message=system_message,
|
||||
last_message=last_message,
|
||||
chat_history=chat_history,
|
||||
)
|
||||
|
||||
|
||||
@singleton
|
||||
class ChatService:
|
||||
settings: Settings
|
||||
|
||||
@inject
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
llm_component: LLMComponent,
|
||||
vector_store_component: VectorStoreComponent,
|
||||
embedding_component: EmbeddingComponent,
|
||||
node_store_component: NodeStoreComponent,
|
||||
) -> None:
|
||||
self.settings = settings
|
||||
self.llm_component = llm_component
|
||||
self.embedding_component = embedding_component
|
||||
self.vector_store_component = vector_store_component
|
||||
self.storage_context = StorageContext.from_defaults(
|
||||
vector_store=vector_store_component.vector_store,
|
||||
docstore=node_store_component.doc_store,
|
||||
index_store=node_store_component.index_store,
|
||||
)
|
||||
self.index = VectorStoreIndex.from_vector_store(
|
||||
vector_store_component.vector_store,
|
||||
storage_context=self.storage_context,
|
||||
llm=llm_component.llm,
|
||||
embed_model=embedding_component.embedding_model,
|
||||
show_progress=True,
|
||||
)
|
||||
|
||||
def _chat_engine(
|
||||
self,
|
||||
system_prompt: str | None = None,
|
||||
use_context: bool = False,
|
||||
context_filter: ContextFilter | None = None,
|
||||
) -> BaseChatEngine:
|
||||
settings = self.settings
|
||||
if use_context:
|
||||
vector_index_retriever = self.vector_store_component.get_retriever(
|
||||
index=self.index,
|
||||
context_filter=context_filter,
|
||||
similarity_top_k=self.settings.rag.similarity_top_k,
|
||||
)
|
||||
node_postprocessors: list[BaseNodePostprocessor] = [
|
||||
MetadataReplacementPostProcessor(target_metadata_key="window"),
|
||||
]
|
||||
if settings.rag.similarity_value:
|
||||
node_postprocessors.append(
|
||||
SimilarityPostprocessor(
|
||||
similarity_cutoff=settings.rag.similarity_value
|
||||
)
|
||||
)
|
||||
|
||||
if settings.rag.rerank.enabled:
|
||||
rerank_postprocessor = SentenceTransformerRerank(
|
||||
model=settings.rag.rerank.model, top_n=settings.rag.rerank.top_n
|
||||
)
|
||||
node_postprocessors.append(rerank_postprocessor)
|
||||
|
||||
return ContextChatEngine.from_defaults(
|
||||
system_prompt=system_prompt,
|
||||
retriever=vector_index_retriever,
|
||||
llm=self.llm_component.llm, # Takes no effect at the moment
|
||||
node_postprocessors=node_postprocessors,
|
||||
)
|
||||
else:
|
||||
return SimpleChatEngine.from_defaults(
|
||||
system_prompt=system_prompt,
|
||||
llm=self.llm_component.llm,
|
||||
)
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
use_context: bool = False,
|
||||
context_filter: ContextFilter | None = None,
|
||||
) -> CompletionGen:
|
||||
chat_engine_input = ChatEngineInput.from_messages(messages)
|
||||
last_message = (
|
||||
chat_engine_input.last_message.content
|
||||
if chat_engine_input.last_message
|
||||
else None
|
||||
)
|
||||
system_prompt = (
|
||||
chat_engine_input.system_message.content
|
||||
if chat_engine_input.system_message
|
||||
else None
|
||||
)
|
||||
chat_history = (
|
||||
chat_engine_input.chat_history if chat_engine_input.chat_history else None
|
||||
)
|
||||
|
||||
chat_engine = self._chat_engine(
|
||||
system_prompt=system_prompt,
|
||||
use_context=use_context,
|
||||
context_filter=context_filter,
|
||||
)
|
||||
streaming_response = chat_engine.stream_chat(
|
||||
message=last_message if last_message is not None else "",
|
||||
chat_history=chat_history,
|
||||
)
|
||||
sources = [Chunk.from_node(node) for node in streaming_response.source_nodes]
|
||||
completion_gen = CompletionGen(
|
||||
response=streaming_response.response_gen, sources=sources
|
||||
)
|
||||
return completion_gen
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
use_context: bool = False,
|
||||
context_filter: ContextFilter | None = None,
|
||||
) -> Completion:
|
||||
chat_engine_input = ChatEngineInput.from_messages(messages)
|
||||
last_message = (
|
||||
chat_engine_input.last_message.content
|
||||
if chat_engine_input.last_message
|
||||
else None
|
||||
)
|
||||
system_prompt = (
|
||||
chat_engine_input.system_message.content
|
||||
if chat_engine_input.system_message
|
||||
else None
|
||||
)
|
||||
chat_history = (
|
||||
chat_engine_input.chat_history if chat_engine_input.chat_history else None
|
||||
)
|
||||
|
||||
chat_engine = self._chat_engine(
|
||||
system_prompt=system_prompt,
|
||||
use_context=use_context,
|
||||
context_filter=context_filter,
|
||||
)
|
||||
wrapped_response = chat_engine.chat(
|
||||
message=last_message if last_message is not None else "",
|
||||
chat_history=chat_history,
|
||||
)
|
||||
sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes]
|
||||
completion = Completion(response=wrapped_response.response, sources=sources)
|
||||
return completion
|
||||
0
pgpt/private_gpt/server/chunks/__init__.py
Normal file
0
pgpt/private_gpt/server/chunks/__init__.py
Normal file
55
pgpt/private_gpt/server/chunks/chunks_router.py
Normal file
55
pgpt/private_gpt/server/chunks/chunks_router.py
Normal file
@ -0,0 +1,55 @@
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
|
||||
from private_gpt.server.utils.auth import authenticated
|
||||
|
||||
chunks_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
|
||||
|
||||
|
||||
class ChunksBody(BaseModel):
|
||||
text: str = Field(examples=["Q3 2023 sales"])
|
||||
context_filter: ContextFilter | None = None
|
||||
limit: int = 10
|
||||
prev_next_chunks: int = Field(default=0, examples=[2])
|
||||
|
||||
|
||||
class ChunksResponse(BaseModel):
|
||||
object: Literal["list"]
|
||||
model: Literal["private-gpt"]
|
||||
data: list[Chunk]
|
||||
|
||||
|
||||
@chunks_router.post("/chunks", tags=["Context Chunks"])
|
||||
def chunks_retrieval(request: Request, body: ChunksBody) -> ChunksResponse:
|
||||
"""Given a `text`, returns the most relevant chunks from the ingested documents.
|
||||
|
||||
The returned information can be used to generate prompts that can be
|
||||
passed to `/completions` or `/chat/completions` APIs. Note: it is usually a very
|
||||
fast API, because only the Embeddings model is involved, not the LLM. The
|
||||
returned information contains the relevant chunk `text` together with the source
|
||||
`document` it is coming from. It also contains a score that can be used to
|
||||
compare different results.
|
||||
|
||||
The max number of chunks to be returned is set using the `limit` param.
|
||||
|
||||
Previous and next chunks (pieces of text that appear right before or after in the
|
||||
document) can be fetched by using the `prev_next_chunks` field.
|
||||
|
||||
The documents being used can be filtered using the `context_filter` and passing
|
||||
the document IDs to be used. Ingested documents IDs can be found using
|
||||
`/ingest/list` endpoint. If you want all ingested documents to be used,
|
||||
remove `context_filter` altogether.
|
||||
"""
|
||||
service = request.state.injector.get(ChunksService)
|
||||
results = service.retrieve_relevant(
|
||||
body.text, body.context_filter, body.limit, body.prev_next_chunks
|
||||
)
|
||||
return ChunksResponse(
|
||||
object="list",
|
||||
model="private-gpt",
|
||||
data=results,
|
||||
)
|
||||
125
pgpt/private_gpt/server/chunks/chunks_service.py
Normal file
125
pgpt/private_gpt/server/chunks/chunks_service.py
Normal file
@ -0,0 +1,125 @@
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
from injector import inject, singleton
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from llama_index.core.storage import StorageContext
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
|
||||
from private_gpt.components.llm.llm_component import LLMComponent
|
||||
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
|
||||
from private_gpt.components.vector_store.vector_store_component import (
|
||||
VectorStoreComponent,
|
||||
)
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.server.ingest.model import IngestedDoc
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core.schema import RelatedNodeInfo
|
||||
|
||||
|
||||
class Chunk(BaseModel):
|
||||
object: Literal["context.chunk"]
|
||||
score: float = Field(examples=[0.023])
|
||||
document: IngestedDoc
|
||||
text: str = Field(examples=["Outbound sales increased 20%, driven by new leads."])
|
||||
previous_texts: list[str] | None = Field(
|
||||
default=None,
|
||||
examples=[["SALES REPORT 2023", "Inbound didn't show major changes."]],
|
||||
)
|
||||
next_texts: list[str] | None = Field(
|
||||
default=None,
|
||||
examples=[
|
||||
[
|
||||
"New leads came from Google Ads campaign.",
|
||||
"The campaign was run by the Marketing Department",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_node(cls: type["Chunk"], node: NodeWithScore) -> "Chunk":
|
||||
doc_id = node.node.ref_doc_id if node.node.ref_doc_id is not None else "-"
|
||||
return cls(
|
||||
object="context.chunk",
|
||||
score=node.score or 0.0,
|
||||
document=IngestedDoc(
|
||||
object="ingest.document",
|
||||
doc_id=doc_id,
|
||||
doc_metadata=node.metadata,
|
||||
),
|
||||
text=node.get_content(),
|
||||
)
|
||||
|
||||
|
||||
@singleton
|
||||
class ChunksService:
|
||||
@inject
|
||||
def __init__(
|
||||
self,
|
||||
llm_component: LLMComponent,
|
||||
vector_store_component: VectorStoreComponent,
|
||||
embedding_component: EmbeddingComponent,
|
||||
node_store_component: NodeStoreComponent,
|
||||
) -> None:
|
||||
self.vector_store_component = vector_store_component
|
||||
self.llm_component = llm_component
|
||||
self.embedding_component = embedding_component
|
||||
self.storage_context = StorageContext.from_defaults(
|
||||
vector_store=vector_store_component.vector_store,
|
||||
docstore=node_store_component.doc_store,
|
||||
index_store=node_store_component.index_store,
|
||||
)
|
||||
|
||||
def _get_sibling_nodes_text(
|
||||
self, node_with_score: NodeWithScore, related_number: int, forward: bool = True
|
||||
) -> list[str]:
|
||||
explored_nodes_texts = []
|
||||
current_node = node_with_score.node
|
||||
for _ in range(related_number):
|
||||
explored_node_info: RelatedNodeInfo | None = (
|
||||
current_node.next_node if forward else current_node.prev_node
|
||||
)
|
||||
if explored_node_info is None:
|
||||
break
|
||||
|
||||
explored_node = self.storage_context.docstore.get_node(
|
||||
explored_node_info.node_id
|
||||
)
|
||||
|
||||
explored_nodes_texts.append(explored_node.get_content())
|
||||
current_node = explored_node
|
||||
|
||||
return explored_nodes_texts
|
||||
|
||||
def retrieve_relevant(
|
||||
self,
|
||||
text: str,
|
||||
context_filter: ContextFilter | None = None,
|
||||
limit: int = 10,
|
||||
prev_next_chunks: int = 0,
|
||||
) -> list[Chunk]:
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
self.vector_store_component.vector_store,
|
||||
storage_context=self.storage_context,
|
||||
llm=self.llm_component.llm,
|
||||
embed_model=self.embedding_component.embedding_model,
|
||||
show_progress=True,
|
||||
)
|
||||
vector_index_retriever = self.vector_store_component.get_retriever(
|
||||
index=index, context_filter=context_filter, similarity_top_k=limit
|
||||
)
|
||||
nodes = vector_index_retriever.retrieve(text)
|
||||
nodes.sort(key=lambda n: n.score or 0.0, reverse=True)
|
||||
|
||||
retrieved_nodes = []
|
||||
for node in nodes:
|
||||
chunk = Chunk.from_node(node)
|
||||
chunk.previous_texts = self._get_sibling_nodes_text(
|
||||
node, prev_next_chunks, False
|
||||
)
|
||||
chunk.next_texts = self._get_sibling_nodes_text(node, prev_next_chunks)
|
||||
retrieved_nodes.append(chunk)
|
||||
|
||||
return retrieved_nodes
|
||||
1
pgpt/private_gpt/server/completions/__init__.py
Normal file
1
pgpt/private_gpt/server/completions/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Deprecated Openai compatibility endpoint."""
|
||||
92
pgpt/private_gpt/server/completions/completions_router.py
Normal file
92
pgpt/private_gpt/server/completions/completions_router.py
Normal file
@ -0,0 +1,92 @@
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.open_ai.openai_models import (
|
||||
OpenAICompletion,
|
||||
OpenAIMessage,
|
||||
)
|
||||
from private_gpt.server.chat.chat_router import ChatBody, chat_completion
|
||||
from private_gpt.server.utils.auth import authenticated
|
||||
|
||||
completions_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
|
||||
|
||||
|
||||
class CompletionsBody(BaseModel):
|
||||
prompt: str
|
||||
system_prompt: str | None = None
|
||||
use_context: bool = False
|
||||
context_filter: ContextFilter | None = None
|
||||
include_sources: bool = True
|
||||
stream: bool = False
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"prompt": "How do you fry an egg?",
|
||||
"system_prompt": "You are a rapper. Always answer with a rap.",
|
||||
"stream": False,
|
||||
"use_context": False,
|
||||
"include_sources": False,
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@completions_router.post(
|
||||
"/completions",
|
||||
response_model=None,
|
||||
summary="Completion",
|
||||
responses={200: {"model": OpenAICompletion}},
|
||||
tags=["Contextual Completions"],
|
||||
openapi_extra={
|
||||
"x-fern-streaming": {
|
||||
"stream-condition": "stream",
|
||||
"response": {"$ref": "#/components/schemas/OpenAICompletion"},
|
||||
"response-stream": {"$ref": "#/components/schemas/OpenAICompletion"},
|
||||
}
|
||||
},
|
||||
)
|
||||
def prompt_completion(
|
||||
request: Request, body: CompletionsBody
|
||||
) -> OpenAICompletion | StreamingResponse:
|
||||
"""We recommend most users use our Chat completions API.
|
||||
|
||||
Given a prompt, the model will return one predicted completion.
|
||||
|
||||
Optionally include a `system_prompt` to influence the way the LLM answers.
|
||||
|
||||
If `use_context`
|
||||
is set to `true`, the model will use context coming from the ingested documents
|
||||
to create the response. The documents being used can be filtered using the
|
||||
`context_filter` and passing the document IDs to be used. Ingested documents IDs
|
||||
can be found using `/ingest/list` endpoint. If you want all ingested documents to
|
||||
be used, remove `context_filter` altogether.
|
||||
|
||||
When using `'include_sources': true`, the API will return the source Chunks used
|
||||
to create the response, which come from the context provided.
|
||||
|
||||
When using `'stream': true`, the API will return data chunks following [OpenAI's
|
||||
streaming model](https://platform.openai.com/docs/api-reference/chat/streaming):
|
||||
```
|
||||
{"id":"12345","object":"completion.chunk","created":1694268190,
|
||||
"model":"private-gpt","choices":[{"index":0,"delta":{"content":"Hello"},
|
||||
"finish_reason":null}]}
|
||||
```
|
||||
"""
|
||||
messages = [OpenAIMessage(content=body.prompt, role="user")]
|
||||
# If system prompt is passed, create a fake message with the system prompt.
|
||||
if body.system_prompt:
|
||||
messages.insert(0, OpenAIMessage(content=body.system_prompt, role="system"))
|
||||
|
||||
chat_body = ChatBody(
|
||||
messages=messages,
|
||||
use_context=body.use_context,
|
||||
stream=body.stream,
|
||||
include_sources=body.include_sources,
|
||||
context_filter=body.context_filter,
|
||||
)
|
||||
return chat_completion(request, chat_body)
|
||||
0
pgpt/private_gpt/server/embeddings/__init__.py
Normal file
0
pgpt/private_gpt/server/embeddings/__init__.py
Normal file
35
pgpt/private_gpt/server/embeddings/embeddings_router.py
Normal file
35
pgpt/private_gpt/server/embeddings/embeddings_router.py
Normal file
@ -0,0 +1,35 @@
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from private_gpt.server.embeddings.embeddings_service import (
|
||||
Embedding,
|
||||
EmbeddingsService,
|
||||
)
|
||||
from private_gpt.server.utils.auth import authenticated
|
||||
|
||||
embeddings_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
|
||||
|
||||
|
||||
class EmbeddingsBody(BaseModel):
|
||||
input: str | list[str]
|
||||
|
||||
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
object: Literal["list"]
|
||||
model: Literal["private-gpt"]
|
||||
data: list[Embedding]
|
||||
|
||||
|
||||
@embeddings_router.post("/embeddings", tags=["Embeddings"])
|
||||
def embeddings_generation(request: Request, body: EmbeddingsBody) -> EmbeddingsResponse:
|
||||
"""Get a vector representation of a given input.
|
||||
|
||||
That vector representation can be easily consumed
|
||||
by machine learning models and algorithms.
|
||||
"""
|
||||
service = request.state.injector.get(EmbeddingsService)
|
||||
input_texts = body.input if isinstance(body.input, list) else [body.input]
|
||||
embeddings = service.texts_embeddings(input_texts)
|
||||
return EmbeddingsResponse(object="list", model="private-gpt", data=embeddings)
|
||||
30
pgpt/private_gpt/server/embeddings/embeddings_service.py
Normal file
30
pgpt/private_gpt/server/embeddings/embeddings_service.py
Normal file
@ -0,0 +1,30 @@
|
||||
from typing import Literal
|
||||
|
||||
from injector import inject, singleton
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
|
||||
|
||||
|
||||
class Embedding(BaseModel):
|
||||
index: int
|
||||
object: Literal["embedding"]
|
||||
embedding: list[float] = Field(examples=[[0.0023064255, -0.009327292]])
|
||||
|
||||
|
||||
@singleton
|
||||
class EmbeddingsService:
|
||||
@inject
|
||||
def __init__(self, embedding_component: EmbeddingComponent) -> None:
|
||||
self.embedding_model = embedding_component.embedding_model
|
||||
|
||||
def texts_embeddings(self, texts: list[str]) -> list[Embedding]:
|
||||
texts_embeddings = self.embedding_model.get_text_embedding_batch(texts)
|
||||
return [
|
||||
Embedding(
|
||||
index=texts_embeddings.index(embedding),
|
||||
object="embedding",
|
||||
embedding=embedding,
|
||||
)
|
||||
for embedding in texts_embeddings
|
||||
]
|
||||
0
pgpt/private_gpt/server/health/__init__.py
Normal file
0
pgpt/private_gpt/server/health/__init__.py
Normal file
17
pgpt/private_gpt/server/health/health_router.py
Normal file
17
pgpt/private_gpt/server/health/health_router.py
Normal file
@ -0,0 +1,17 @@
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Not authentication or authorization required to get the health status.
|
||||
health_router = APIRouter()
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: Literal["ok"] = Field(default="ok")
|
||||
|
||||
|
||||
@health_router.get("/health", tags=["Health"])
|
||||
def health() -> HealthResponse:
|
||||
"""Return ok if the system is up."""
|
||||
return HealthResponse(status="ok")
|
||||
0
pgpt/private_gpt/server/ingest/__init__.py
Normal file
0
pgpt/private_gpt/server/ingest/__init__.py
Normal file
104
pgpt/private_gpt/server/ingest/ingest_router.py
Normal file
104
pgpt/private_gpt/server/ingest/ingest_router.py
Normal file
@ -0,0 +1,104 @@
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from private_gpt.server.ingest.ingest_service import IngestService
|
||||
from private_gpt.server.ingest.model import IngestedDoc
|
||||
from private_gpt.server.utils.auth import authenticated
|
||||
|
||||
ingest_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
|
||||
|
||||
|
||||
class IngestTextBody(BaseModel):
|
||||
file_name: str = Field(examples=["Avatar: The Last Airbender"])
|
||||
text: str = Field(
|
||||
examples=[
|
||||
"Avatar is set in an Asian and Arctic-inspired world in which some "
|
||||
"people can telekinetically manipulate one of the four elements—water, "
|
||||
"earth, fire or air—through practices known as 'bending', inspired by "
|
||||
"Chinese martial arts."
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class IngestResponse(BaseModel):
|
||||
object: Literal["list"]
|
||||
model: Literal["private-gpt"]
|
||||
data: list[IngestedDoc]
|
||||
|
||||
|
||||
@ingest_router.post("/ingest", tags=["Ingestion"], deprecated=True)
|
||||
def ingest(request: Request, file: UploadFile) -> IngestResponse:
|
||||
"""Ingests and processes a file.
|
||||
|
||||
Deprecated. Use ingest/file instead.
|
||||
"""
|
||||
return ingest_file(request, file)
|
||||
|
||||
|
||||
@ingest_router.post("/ingest/file", tags=["Ingestion"])
|
||||
def ingest_file(request: Request, file: UploadFile) -> IngestResponse:
|
||||
"""Ingests and processes a file, storing its chunks to be used as context.
|
||||
|
||||
The context obtained from files is later used in
|
||||
`/chat/completions`, `/completions`, and `/chunks` APIs.
|
||||
|
||||
Most common document
|
||||
formats are supported, but you may be prompted to install an extra dependency to
|
||||
manage a specific file type.
|
||||
|
||||
A file can generate different Documents (for example a PDF generates one Document
|
||||
per page). All Documents IDs are returned in the response, together with the
|
||||
extracted Metadata (which is later used to improve context retrieval). Those IDs
|
||||
can be used to filter the context used to create responses in
|
||||
`/chat/completions`, `/completions`, and `/chunks` APIs.
|
||||
"""
|
||||
service = request.state.injector.get(IngestService)
|
||||
if file.filename is None:
|
||||
raise HTTPException(400, "No file name provided")
|
||||
ingested_documents = service.ingest_bin_data(file.filename, file.file)
|
||||
return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
|
||||
|
||||
|
||||
@ingest_router.post("/ingest/text", tags=["Ingestion"])
|
||||
def ingest_text(request: Request, body: IngestTextBody) -> IngestResponse:
|
||||
"""Ingests and processes a text, storing its chunks to be used as context.
|
||||
|
||||
The context obtained from files is later used in
|
||||
`/chat/completions`, `/completions`, and `/chunks` APIs.
|
||||
|
||||
A Document will be generated with the given text. The Document
|
||||
ID is returned in the response, together with the
|
||||
extracted Metadata (which is later used to improve context retrieval). That ID
|
||||
can be used to filter the context used to create responses in
|
||||
`/chat/completions`, `/completions`, and `/chunks` APIs.
|
||||
"""
|
||||
service = request.state.injector.get(IngestService)
|
||||
if len(body.file_name) == 0:
|
||||
raise HTTPException(400, "No file name provided")
|
||||
ingested_documents = service.ingest_text(body.file_name, body.text)
|
||||
return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
|
||||
|
||||
|
||||
@ingest_router.get("/ingest/list", tags=["Ingestion"])
|
||||
def list_ingested(request: Request) -> IngestResponse:
|
||||
"""Lists already ingested Documents including their Document ID and metadata.
|
||||
|
||||
Those IDs can be used to filter the context used to create responses
|
||||
in `/chat/completions`, `/completions`, and `/chunks` APIs.
|
||||
"""
|
||||
service = request.state.injector.get(IngestService)
|
||||
ingested_documents = service.list_ingested()
|
||||
return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
|
||||
|
||||
|
||||
@ingest_router.delete("/ingest/{doc_id}", tags=["Ingestion"])
|
||||
def delete_ingested(request: Request, doc_id: str) -> None:
|
||||
"""Delete the specified ingested Document.
|
||||
|
||||
The `doc_id` can be obtained from the `GET /ingest/list` endpoint.
|
||||
The document will be effectively deleted from your storage context.
|
||||
"""
|
||||
service = request.state.injector.get(IngestService)
|
||||
service.delete(doc_id)
|
||||
125
pgpt/private_gpt/server/ingest/ingest_service.py
Normal file
125
pgpt/private_gpt/server/ingest/ingest_service.py
Normal file
@ -0,0 +1,125 @@
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, AnyStr, BinaryIO
|
||||
|
||||
from injector import inject, singleton
|
||||
from llama_index.core.node_parser import SentenceWindowNodeParser
|
||||
from llama_index.core.storage import StorageContext
|
||||
|
||||
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
|
||||
from private_gpt.components.ingest.ingest_component import get_ingestion_component
|
||||
from private_gpt.components.llm.llm_component import LLMComponent
|
||||
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
|
||||
from private_gpt.components.vector_store.vector_store_component import (
|
||||
VectorStoreComponent,
|
||||
)
|
||||
from private_gpt.server.ingest.model import IngestedDoc
|
||||
from private_gpt.settings.settings import settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core.storage.docstore.types import RefDocInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@singleton
|
||||
class IngestService:
|
||||
@inject
|
||||
def __init__(
|
||||
self,
|
||||
llm_component: LLMComponent,
|
||||
vector_store_component: VectorStoreComponent,
|
||||
embedding_component: EmbeddingComponent,
|
||||
node_store_component: NodeStoreComponent,
|
||||
) -> None:
|
||||
self.llm_service = llm_component
|
||||
self.storage_context = StorageContext.from_defaults(
|
||||
vector_store=vector_store_component.vector_store,
|
||||
docstore=node_store_component.doc_store,
|
||||
index_store=node_store_component.index_store,
|
||||
)
|
||||
node_parser = SentenceWindowNodeParser.from_defaults()
|
||||
|
||||
self.ingest_component = get_ingestion_component(
|
||||
self.storage_context,
|
||||
embed_model=embedding_component.embedding_model,
|
||||
transformations=[node_parser, embedding_component.embedding_model],
|
||||
settings=settings(),
|
||||
)
|
||||
|
||||
def _ingest_data(self, file_name: str, file_data: AnyStr) -> list[IngestedDoc]:
|
||||
logger.debug("Got file data of size=%s to ingest", len(file_data))
|
||||
# llama-index mainly supports reading from files, so
|
||||
# we have to create a tmp file to read for it to work
|
||||
# delete=False to avoid a Windows 11 permission error.
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
||||
try:
|
||||
path_to_tmp = Path(tmp.name)
|
||||
if isinstance(file_data, bytes):
|
||||
path_to_tmp.write_bytes(file_data)
|
||||
else:
|
||||
path_to_tmp.write_text(str(file_data))
|
||||
return self.ingest_file(file_name, path_to_tmp)
|
||||
finally:
|
||||
tmp.close()
|
||||
path_to_tmp.unlink()
|
||||
|
||||
def ingest_file(self, file_name: str, file_data: Path) -> list[IngestedDoc]:
|
||||
logger.info("Ingesting file_name=%s", file_name)
|
||||
documents = self.ingest_component.ingest(file_name, file_data)
|
||||
logger.info("Finished ingestion file_name=%s", file_name)
|
||||
return [IngestedDoc.from_document(document) for document in documents]
|
||||
|
||||
def ingest_text(self, file_name: str, text: str) -> list[IngestedDoc]:
|
||||
logger.debug("Ingesting text data with file_name=%s", file_name)
|
||||
return self._ingest_data(file_name, text)
|
||||
|
||||
def ingest_bin_data(
|
||||
self, file_name: str, raw_file_data: BinaryIO
|
||||
) -> list[IngestedDoc]:
|
||||
logger.debug("Ingesting binary data with file_name=%s", file_name)
|
||||
file_data = raw_file_data.read()
|
||||
return self._ingest_data(file_name, file_data)
|
||||
|
||||
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[IngestedDoc]:
|
||||
logger.info("Ingesting file_names=%s", [f[0] for f in files])
|
||||
documents = self.ingest_component.bulk_ingest(files)
|
||||
logger.info("Finished ingestion file_name=%s", [f[0] for f in files])
|
||||
return [IngestedDoc.from_document(document) for document in documents]
|
||||
|
||||
def list_ingested(self) -> list[IngestedDoc]:
|
||||
ingested_docs: list[IngestedDoc] = []
|
||||
try:
|
||||
docstore = self.storage_context.docstore
|
||||
ref_docs: dict[str, RefDocInfo] | None = docstore.get_all_ref_doc_info()
|
||||
|
||||
if not ref_docs:
|
||||
return ingested_docs
|
||||
|
||||
for doc_id, ref_doc_info in ref_docs.items():
|
||||
doc_metadata = None
|
||||
if ref_doc_info is not None and ref_doc_info.metadata is not None:
|
||||
doc_metadata = IngestedDoc.curate_metadata(ref_doc_info.metadata)
|
||||
ingested_docs.append(
|
||||
IngestedDoc(
|
||||
object="ingest.document",
|
||||
doc_id=doc_id,
|
||||
doc_metadata=doc_metadata,
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning("Got an exception when getting list of docs", exc_info=True)
|
||||
pass
|
||||
logger.debug("Found count=%s ingested documents", len(ingested_docs))
|
||||
return ingested_docs
|
||||
|
||||
def delete(self, doc_id: str) -> None:
|
||||
"""Delete an ingested document.
|
||||
|
||||
:raises ValueError: if the document does not exist
|
||||
"""
|
||||
logger.info(
|
||||
"Deleting the ingested document=%s in the doc and index store", doc_id
|
||||
)
|
||||
self.ingest_component.delete(doc_id)
|
||||
45
pgpt/private_gpt/server/ingest/ingest_watcher.py
Normal file
45
pgpt/private_gpt/server/ingest/ingest_watcher.py
Normal file
@ -0,0 +1,45 @@
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from watchdog.events import (
|
||||
FileCreatedEvent,
|
||||
FileModifiedEvent,
|
||||
FileSystemEvent,
|
||||
FileSystemEventHandler,
|
||||
)
|
||||
from watchdog.observers import Observer
|
||||
|
||||
|
||||
class IngestWatcher:
|
||||
def __init__(
|
||||
self, watch_path: Path, on_file_changed: Callable[[Path], None]
|
||||
) -> None:
|
||||
self.watch_path = watch_path
|
||||
self.on_file_changed = on_file_changed
|
||||
|
||||
class Handler(FileSystemEventHandler):
|
||||
def on_modified(self, event: FileSystemEvent) -> None:
|
||||
if isinstance(event, FileModifiedEvent):
|
||||
on_file_changed(Path(event.src_path))
|
||||
|
||||
def on_created(self, event: FileSystemEvent) -> None:
|
||||
if isinstance(event, FileCreatedEvent):
|
||||
on_file_changed(Path(event.src_path))
|
||||
|
||||
event_handler = Handler()
|
||||
observer: Any = Observer()
|
||||
self._observer = observer
|
||||
self._observer.schedule(event_handler, str(watch_path), recursive=True)
|
||||
|
||||
def start(self) -> None:
|
||||
self._observer.start()
|
||||
while self._observer.is_alive():
|
||||
try:
|
||||
self._observer.join(1)
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
||||
def stop(self) -> None:
|
||||
self._observer.stop()
|
||||
self._observer.join()
|
||||
32
pgpt/private_gpt/server/ingest/model.py
Normal file
32
pgpt/private_gpt/server/ingest/model.py
Normal file
@ -0,0 +1,32 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from llama_index.core.schema import Document
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class IngestedDoc(BaseModel):
|
||||
object: Literal["ingest.document"]
|
||||
doc_id: str = Field(examples=["c202d5e6-7b69-4869-81cc-dd574ee8ee11"])
|
||||
doc_metadata: dict[str, Any] | None = Field(
|
||||
examples=[
|
||||
{
|
||||
"page_label": "2",
|
||||
"file_name": "Sales Report Q3 2023.pdf",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def curate_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Remove unwanted metadata keys."""
|
||||
for key in ["doc_id", "window", "original_text"]:
|
||||
metadata.pop(key, None)
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def from_document(document: Document) -> "IngestedDoc":
|
||||
return IngestedDoc(
|
||||
object="ingest.document",
|
||||
doc_id=document.doc_id,
|
||||
doc_metadata=IngestedDoc.curate_metadata(document.metadata),
|
||||
)
|
||||
@ -0,0 +1,86 @@
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.open_ai.openai_models import (
|
||||
to_openai_sse_stream,
|
||||
)
|
||||
from private_gpt.server.recipes.summarize.summarize_service import SummarizeService
|
||||
from private_gpt.server.utils.auth import authenticated
|
||||
|
||||
summarize_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
|
||||
|
||||
|
||||
class SummarizeBody(BaseModel):
|
||||
text: str | None = None
|
||||
use_context: bool = False
|
||||
context_filter: ContextFilter | None = None
|
||||
prompt: str | None = None
|
||||
instructions: str | None = None
|
||||
stream: bool = False
|
||||
|
||||
|
||||
class SummarizeResponse(BaseModel):
|
||||
summary: str
|
||||
|
||||
|
||||
@summarize_router.post(
|
||||
"/summarize",
|
||||
response_model=None,
|
||||
summary="Summarize",
|
||||
responses={200: {"model": SummarizeResponse}},
|
||||
tags=["Recipes"],
|
||||
)
|
||||
def summarize(
|
||||
request: Request, body: SummarizeBody
|
||||
) -> SummarizeResponse | StreamingResponse:
|
||||
"""Given a text, the model will return a summary.
|
||||
|
||||
Optionally include `instructions` to influence the way the summary is generated.
|
||||
|
||||
If `use_context`
|
||||
is set to `true`, the model will also use the content coming from the ingested
|
||||
documents in the summary. The documents being used can
|
||||
be filtered by their metadata using the `context_filter`.
|
||||
Ingested documents metadata can be found using `/ingest/list` endpoint.
|
||||
If you want all ingested documents to be used, remove `context_filter` altogether.
|
||||
|
||||
If `prompt` is set, it will be used as the prompt for the summarization,
|
||||
otherwise the default prompt will be used.
|
||||
|
||||
When using `'stream': true`, the API will return data chunks following [OpenAI's
|
||||
streaming model](https://platform.openai.com/docs/api-reference/chat/streaming):
|
||||
```
|
||||
{"id":"12345","object":"completion.chunk","created":1694268190,
|
||||
"model":"private-gpt","choices":[{"index":0,"delta":{"content":"Hello"},
|
||||
"finish_reason":null}]}
|
||||
```
|
||||
"""
|
||||
service: SummarizeService = request.state.injector.get(SummarizeService)
|
||||
|
||||
if body.stream:
|
||||
completion_gen = service.stream_summarize(
|
||||
text=body.text,
|
||||
instructions=body.instructions,
|
||||
use_context=body.use_context,
|
||||
context_filter=body.context_filter,
|
||||
prompt=body.prompt,
|
||||
)
|
||||
return StreamingResponse(
|
||||
to_openai_sse_stream(
|
||||
response_generator=completion_gen,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
completion = service.summarize(
|
||||
text=body.text,
|
||||
instructions=body.instructions,
|
||||
use_context=body.use_context,
|
||||
context_filter=body.context_filter,
|
||||
prompt=body.prompt,
|
||||
)
|
||||
return SummarizeResponse(
|
||||
summary=completion,
|
||||
)
|
||||
182
pgpt/private_gpt/server/recipes/summarize/summarize_service.py
Normal file
182
pgpt/private_gpt/server/recipes/summarize/summarize_service.py
Normal file
@ -0,0 +1,182 @@
|
||||
from itertools import chain
|
||||
|
||||
from injector import inject, singleton
|
||||
from llama_index.core import (
|
||||
Document,
|
||||
StorageContext,
|
||||
SummaryIndex,
|
||||
)
|
||||
from llama_index.core.base.response.schema import Response, StreamingResponse
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.response_synthesizers import ResponseMode
|
||||
from llama_index.core.storage.docstore.types import RefDocInfo
|
||||
from llama_index.core.types import TokenGen
|
||||
|
||||
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
|
||||
from private_gpt.components.llm.llm_component import LLMComponent
|
||||
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
|
||||
from private_gpt.components.vector_store.vector_store_component import (
|
||||
VectorStoreComponent,
|
||||
)
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.settings.settings import Settings
|
||||
|
||||
#######
|
||||
# Modification par SPC
|
||||
#
|
||||
# Forcer à utiliser le prompt de settings.yaml et non un prompt par défaut
|
||||
#
|
||||
# DEFAULT_SUMMARIZE_PROMPT = (
|
||||
# "Provide a comprehensive summary of the provided context information. "
|
||||
# "The summary should cover all the key points and main ideas presented in "
|
||||
# "the original text, while also condensing the information into a concise "
|
||||
# "and easy-to-understand format. Please ensure that the summary includes "
|
||||
# "relevant details and examples that support the main ideas, while avoiding "
|
||||
# "any unnecessary information or repetition."
|
||||
# )
|
||||
|
||||
from private_gpt.settings.settings import settings
|
||||
DEFAULT_SUMMARIZE_PROMPT = settings().ui.default_summarization_system_prompt
|
||||
#
|
||||
# Fin modification par SPC
|
||||
#######
|
||||
|
||||
@singleton
|
||||
class SummarizeService:
|
||||
@inject
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
llm_component: LLMComponent,
|
||||
node_store_component: NodeStoreComponent,
|
||||
vector_store_component: VectorStoreComponent,
|
||||
embedding_component: EmbeddingComponent,
|
||||
) -> None:
|
||||
self.settings = settings
|
||||
self.llm_component = llm_component
|
||||
self.node_store_component = node_store_component
|
||||
self.vector_store_component = vector_store_component
|
||||
self.embedding_component = embedding_component
|
||||
self.storage_context = StorageContext.from_defaults(
|
||||
vector_store=vector_store_component.vector_store,
|
||||
docstore=node_store_component.doc_store,
|
||||
index_store=node_store_component.index_store,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _filter_ref_docs(
|
||||
ref_docs: dict[str, RefDocInfo], context_filter: ContextFilter | None
|
||||
) -> list[RefDocInfo]:
|
||||
if context_filter is None or not context_filter.docs_ids:
|
||||
return list(ref_docs.values())
|
||||
|
||||
return [
|
||||
ref_doc
|
||||
for doc_id, ref_doc in ref_docs.items()
|
||||
if doc_id in context_filter.docs_ids
|
||||
]
|
||||
|
||||
def _summarize(
|
||||
self,
|
||||
use_context: bool = False,
|
||||
stream: bool = False,
|
||||
text: str | None = None,
|
||||
instructions: str | None = None,
|
||||
context_filter: ContextFilter | None = None,
|
||||
prompt: str | None = None,
|
||||
) -> str | TokenGen:
|
||||
|
||||
nodes_to_summarize = []
|
||||
|
||||
# Add text to summarize
|
||||
if text:
|
||||
text_documents = [Document(text=text)]
|
||||
nodes_to_summarize += (
|
||||
SentenceSplitter.from_defaults().get_nodes_from_documents(
|
||||
text_documents
|
||||
)
|
||||
)
|
||||
|
||||
# Add context documents to summarize
|
||||
if use_context:
|
||||
# 1. Recover all ref docs
|
||||
ref_docs: dict[str, RefDocInfo] | None = (
|
||||
self.storage_context.docstore.get_all_ref_doc_info()
|
||||
)
|
||||
if ref_docs is None:
|
||||
raise ValueError("No documents have been ingested yet.")
|
||||
|
||||
# 2. Filter documents based on context_filter (if provided)
|
||||
filtered_ref_docs = self._filter_ref_docs(ref_docs, context_filter)
|
||||
|
||||
# 3. Get all nodes from the filtered documents
|
||||
filtered_node_ids = chain.from_iterable(
|
||||
[ref_doc.node_ids for ref_doc in filtered_ref_docs]
|
||||
)
|
||||
filtered_nodes = self.storage_context.docstore.get_nodes(
|
||||
node_ids=list(filtered_node_ids),
|
||||
)
|
||||
|
||||
nodes_to_summarize += filtered_nodes
|
||||
|
||||
# Create a SummaryIndex to summarize the nodes
|
||||
summary_index = SummaryIndex(
|
||||
nodes=nodes_to_summarize,
|
||||
storage_context=StorageContext.from_defaults(), # In memory SummaryIndex
|
||||
show_progress=True,
|
||||
)
|
||||
|
||||
# Make a tree summarization query
|
||||
# above the set of all candidate nodes
|
||||
query_engine = summary_index.as_query_engine(
|
||||
llm=self.llm_component.llm,
|
||||
response_mode=ResponseMode.TREE_SUMMARIZE,
|
||||
streaming=stream,
|
||||
use_async=self.settings.summarize.use_async,
|
||||
)
|
||||
|
||||
prompt = prompt or DEFAULT_SUMMARIZE_PROMPT
|
||||
|
||||
summarize_query = prompt + "\n" + (instructions or "")
|
||||
|
||||
response = query_engine.query(summarize_query)
|
||||
if isinstance(response, Response):
|
||||
return response.response or ""
|
||||
elif isinstance(response, StreamingResponse):
|
||||
return response.response_gen
|
||||
else:
|
||||
raise TypeError(f"The result is not of a supported type: {type(response)}")
|
||||
|
||||
def summarize(
|
||||
self,
|
||||
use_context: bool = False,
|
||||
text: str | None = None,
|
||||
instructions: str | None = None,
|
||||
context_filter: ContextFilter | None = None,
|
||||
prompt: str | None = None,
|
||||
) -> str:
|
||||
return self._summarize(
|
||||
use_context=use_context,
|
||||
stream=False,
|
||||
text=text,
|
||||
instructions=instructions,
|
||||
context_filter=context_filter,
|
||||
prompt=prompt,
|
||||
) # type: ignore
|
||||
|
||||
def stream_summarize(
|
||||
self,
|
||||
use_context: bool = False,
|
||||
text: str | None = None,
|
||||
instructions: str | None = None,
|
||||
context_filter: ContextFilter | None = None,
|
||||
prompt: str | None = None,
|
||||
) -> TokenGen:
|
||||
return self._summarize(
|
||||
use_context=use_context,
|
||||
stream=True,
|
||||
text=text,
|
||||
instructions=instructions,
|
||||
context_filter=context_filter,
|
||||
prompt=prompt,
|
||||
) # type: ignore
|
||||
0
pgpt/private_gpt/server/utils/__init__.py
Normal file
0
pgpt/private_gpt/server/utils/__init__.py
Normal file
69
pgpt/private_gpt/server/utils/auth.py
Normal file
69
pgpt/private_gpt/server/utils/auth.py
Normal file
@ -0,0 +1,69 @@
|
||||
"""Authentication mechanism for the API.
|
||||
|
||||
Define a simple mechanism to authenticate requests.
|
||||
More complex authentication mechanisms can be defined here, and be placed in the
|
||||
`authenticated` method (being a 'bean' injected in fastapi routers).
|
||||
|
||||
Authorization can also be made after the authentication, and depends on
|
||||
the authentication. Authorization should not be implemented in this file.
|
||||
|
||||
Authorization can be done by following fastapi's guides:
|
||||
* https://fastapi.tiangolo.com/advanced/security/oauth2-scopes/
|
||||
* https://fastapi.tiangolo.com/tutorial/security/
|
||||
* https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-in-path-operation-decorators/
|
||||
"""
|
||||
|
||||
# mypy: ignore-errors
|
||||
# Disabled mypy error: All conditional function variants must have identical signatures
|
||||
# We are changing the implementation of the authenticated method, based on
|
||||
# the config. If the auth is not enabled, we are not defining the complex method
|
||||
# with its dependencies.
|
||||
import logging
|
||||
import secrets
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, Header, HTTPException
|
||||
|
||||
from private_gpt.settings.settings import settings
|
||||
|
||||
# 401 signify that the request requires authentication.
|
||||
# 403 signify that the authenticated user is not authorized to perform the operation.
|
||||
NOT_AUTHENTICATED = HTTPException(
|
||||
status_code=401,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": 'Basic realm="All the API", charset="UTF-8"'},
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _simple_authentication(authorization: Annotated[str, Header()] = "") -> bool:
|
||||
"""Check if the request is authenticated."""
|
||||
if not secrets.compare_digest(authorization, settings().server.auth.secret):
|
||||
# If the "Authorization" header is not the expected one, raise an exception.
|
||||
raise NOT_AUTHENTICATED
|
||||
return True
|
||||
|
||||
|
||||
if not settings().server.auth.enabled:
|
||||
logger.debug(
|
||||
"Defining a dummy authentication mechanism for fastapi, always authenticating requests"
|
||||
)
|
||||
|
||||
# Define a dummy authentication method that always returns True.
|
||||
def authenticated() -> bool:
|
||||
"""Check if the request is authenticated."""
|
||||
return True
|
||||
|
||||
else:
|
||||
logger.info("Defining the given authentication mechanism for the API")
|
||||
|
||||
# Method to be used as a dependency to check if the request is authenticated.
|
||||
def authenticated(
|
||||
_simple_authentication: Annotated[bool, Depends(_simple_authentication)]
|
||||
) -> bool:
|
||||
"""Check if the request is authenticated."""
|
||||
assert settings().server.auth.enabled
|
||||
if not _simple_authentication:
|
||||
raise NOT_AUTHENTICATED
|
||||
return True
|
||||
1
pgpt/private_gpt/settings/__init__.py
Normal file
1
pgpt/private_gpt/settings/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Settings."""
|
||||
638
pgpt/private_gpt/settings/settings.py
Normal file
638
pgpt/private_gpt/settings/settings.py
Normal file
@ -0,0 +1,638 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from private_gpt.settings.settings_loader import load_active_settings
|
||||
|
||||
|
||||
class CorsSettings(BaseModel):
|
||||
"""CORS configuration.
|
||||
|
||||
For more details on the CORS configuration, see:
|
||||
# * https://fastapi.tiangolo.com/tutorial/cors/
|
||||
# * https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
|
||||
"""
|
||||
|
||||
enabled: bool = Field(
|
||||
description="Flag indicating if CORS headers are set or not."
|
||||
"If set to True, the CORS headers will be set to allow all origins, methods and headers.",
|
||||
default=False,
|
||||
)
|
||||
allow_credentials: bool = Field(
|
||||
description="Indicate that cookies should be supported for cross-origin requests",
|
||||
default=False,
|
||||
)
|
||||
allow_origins: list[str] = Field(
|
||||
description="A list of origins that should be permitted to make cross-origin requests.",
|
||||
default=[],
|
||||
)
|
||||
allow_origin_regex: list[str] = Field(
|
||||
description="A regex string to match against origins that should be permitted to make cross-origin requests.",
|
||||
default=None,
|
||||
)
|
||||
allow_methods: list[str] = Field(
|
||||
description="A list of HTTP methods that should be allowed for cross-origin requests.",
|
||||
default=[
|
||||
"GET",
|
||||
],
|
||||
)
|
||||
allow_headers: list[str] = Field(
|
||||
description="A list of HTTP request headers that should be supported for cross-origin requests.",
|
||||
default=[],
|
||||
)
|
||||
|
||||
|
||||
class AuthSettings(BaseModel):
|
||||
"""Authentication configuration.
|
||||
|
||||
The implementation of the authentication strategy must
|
||||
"""
|
||||
|
||||
enabled: bool = Field(
|
||||
description="Flag indicating if authentication is enabled or not.",
|
||||
default=False,
|
||||
)
|
||||
secret: str = Field(
|
||||
description="The secret to be used for authentication. "
|
||||
"It can be any non-blank string. For HTTP basic authentication, "
|
||||
"this value should be the whole 'Authorization' header that is expected"
|
||||
)
|
||||
|
||||
|
||||
class IngestionSettings(BaseModel):
|
||||
"""Ingestion configuration.
|
||||
|
||||
This configuration is used to control the ingestion of data into the system
|
||||
using non-server methods. This is useful for local development and testing;
|
||||
or to ingest in bulk from a folder.
|
||||
|
||||
Please note that this configuration is not secure and should be used in
|
||||
a controlled environment only (setting right permissions, etc.).
|
||||
"""
|
||||
|
||||
enabled: bool = Field(
|
||||
description="Flag indicating if local ingestion is enabled or not.",
|
||||
default=False,
|
||||
)
|
||||
allow_ingest_from: list[str] = Field(
|
||||
description="A list of folders that should be permitted to make ingest requests.",
|
||||
default=[],
|
||||
)
|
||||
|
||||
|
||||
class ServerSettings(BaseModel):
|
||||
env_name: str = Field(
|
||||
description="Name of the environment (prod, staging, local...)"
|
||||
)
|
||||
port: int = Field(description="Port of PrivateGPT FastAPI server, defaults to 8001")
|
||||
cors: CorsSettings = Field(
|
||||
description="CORS configuration", default=CorsSettings(enabled=False)
|
||||
)
|
||||
auth: AuthSettings = Field(
|
||||
description="Authentication configuration",
|
||||
default_factory=lambda: AuthSettings(enabled=False, secret="secret-key"),
|
||||
)
|
||||
|
||||
|
||||
class DataSettings(BaseModel):
|
||||
local_ingestion: IngestionSettings = Field(
|
||||
description="Ingestion configuration",
|
||||
default_factory=lambda: IngestionSettings(allow_ingest_from=["*"]),
|
||||
)
|
||||
local_data_folder: str = Field(
|
||||
description="Path to local storage."
|
||||
"It will be treated as an absolute path if it starts with /"
|
||||
)
|
||||
|
||||
|
||||
class LLMSettings(BaseModel):
|
||||
mode: Literal[
|
||||
"llamacpp",
|
||||
"openai",
|
||||
"openailike",
|
||||
"azopenai",
|
||||
"sagemaker",
|
||||
"mock",
|
||||
"ollama",
|
||||
"gemini",
|
||||
]
|
||||
max_new_tokens: int = Field(
|
||||
256,
|
||||
description="The maximum number of token that the LLM is authorized to generate in one completion.",
|
||||
)
|
||||
context_window: int = Field(
|
||||
3900,
|
||||
description="The maximum number of context tokens for the model.",
|
||||
)
|
||||
tokenizer: str = Field(
|
||||
None,
|
||||
description="The model id of a predefined tokenizer hosted inside a model repo on "
|
||||
"huggingface.co. Valid model ids can be located at the root-level, like "
|
||||
"`bert-base-uncased`, or namespaced under a user or organization name, "
|
||||
"like `HuggingFaceH4/zephyr-7b-beta`. If not set, will load a tokenizer matching "
|
||||
"gpt-3.5-turbo LLM.",
|
||||
)
|
||||
temperature: float = Field(
|
||||
0.1,
|
||||
description="The temperature of the model. Increasing the temperature will make the model answer more creatively. A value of 0.1 would be more factual.",
|
||||
)
|
||||
prompt_style: Literal["default", "llama2", "llama3", "tag", "mistral", "chatml"] = (
|
||||
Field(
|
||||
"llama2",
|
||||
description=(
|
||||
"The prompt style to use for the chat engine. "
|
||||
"If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n"
|
||||
"If `llama2` - use the llama2 prompt style from the llama_index. Based on `<s>`, `[INST]` and `<<SYS>>`.\n"
|
||||
"If `llama3` - use the llama3 prompt style from the llama_index."
|
||||
"If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n"
|
||||
"If `mistral` - use the `mistral prompt style. It shoudl look like <s>[INST] {System Prompt} [/INST]</s>[INST] { UserInstructions } [/INST]"
|
||||
"`llama2` is the historic behaviour. `default` might work better with your custom models."
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class VectorstoreSettings(BaseModel):
|
||||
database: Literal["chroma", "qdrant", "postgres", "clickhouse", "milvus"]
|
||||
|
||||
|
||||
class NodeStoreSettings(BaseModel):
|
||||
database: Literal["simple", "postgres"]
|
||||
|
||||
|
||||
class LlamaCPPSettings(BaseModel):
|
||||
llm_hf_repo_id: str
|
||||
llm_hf_model_file: str
|
||||
tfs_z: float = Field(
|
||||
1.0,
|
||||
description="Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting.",
|
||||
)
|
||||
top_k: int = Field(
|
||||
40,
|
||||
description="Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)",
|
||||
)
|
||||
top_p: float = Field(
|
||||
0.9,
|
||||
description="Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)",
|
||||
)
|
||||
repeat_penalty: float = Field(
|
||||
1.1,
|
||||
description="Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)",
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceSettings(BaseModel):
|
||||
embedding_hf_model_name: str = Field(
|
||||
description="Name of the HuggingFace model to use for embeddings"
|
||||
)
|
||||
access_token: str = Field(
|
||||
None,
|
||||
description="Huggingface access token, required to download some models",
|
||||
)
|
||||
trust_remote_code: bool = Field(
|
||||
False,
|
||||
description="If set to True, the code from the remote model will be trusted and executed.",
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingSettings(BaseModel):
|
||||
mode: Literal[
|
||||
"huggingface",
|
||||
"openai",
|
||||
"azopenai",
|
||||
"sagemaker",
|
||||
"ollama",
|
||||
"mock",
|
||||
"gemini",
|
||||
"mistralai",
|
||||
]
|
||||
ingest_mode: Literal["simple", "batch", "parallel", "pipeline"] = Field(
|
||||
"simple",
|
||||
description=(
|
||||
"The ingest mode to use for the embedding engine:\n"
|
||||
"If `simple` - ingest files sequentially and one by one. It is the historic behaviour.\n"
|
||||
"If `batch` - if multiple files, parse all the files in parallel, "
|
||||
"and send them in batch to the embedding model.\n"
|
||||
"In `pipeline` - The Embedding engine is kept as busy as possible\n"
|
||||
"If `parallel` - parse the files in parallel using multiple cores, and embedd them in parallel.\n"
|
||||
"`parallel` is the fastest mode for local setup, as it parallelize IO RW in the index.\n"
|
||||
"For modes that leverage parallelization, you can specify the number of "
|
||||
"workers to use with `count_workers`.\n"
|
||||
),
|
||||
)
|
||||
count_workers: int = Field(
|
||||
2,
|
||||
description=(
|
||||
"The number of workers to use for file ingestion.\n"
|
||||
"In `batch` mode, this is the number of workers used to parse the files.\n"
|
||||
"In `parallel` mode, this is the number of workers used to parse the files and embed them.\n"
|
||||
"In `pipeline` mode, this is the number of workers that can perform embeddings.\n"
|
||||
"This is only used if `ingest_mode` is not `simple`.\n"
|
||||
"Do not go too high with this number, as it might cause memory issues. (especially in `parallel` mode)\n"
|
||||
"Do not set it higher than your number of threads of your CPU."
|
||||
),
|
||||
)
|
||||
embed_dim: int = Field(
|
||||
384,
|
||||
description="The dimension of the embeddings stored in the Postgres database",
|
||||
)
|
||||
|
||||
|
||||
class SagemakerSettings(BaseModel):
|
||||
llm_endpoint_name: str
|
||||
embedding_endpoint_name: str
|
||||
|
||||
|
||||
class OpenAISettings(BaseModel):
|
||||
api_base: str = Field(
|
||||
None,
|
||||
description="Base URL of OpenAI API. Example: 'https://api.openai.com/v1'.",
|
||||
)
|
||||
api_key: str
|
||||
model: str = Field(
|
||||
"gpt-3.5-turbo",
|
||||
description="OpenAI Model to use. Example: 'gpt-4'.",
|
||||
)
|
||||
request_timeout: float = Field(
|
||||
120.0,
|
||||
description="Time elapsed until openailike server times out the request. Default is 120s. Format is float. ",
|
||||
)
|
||||
embedding_api_base: str = Field(
|
||||
None,
|
||||
description="Base URL of OpenAI API. Example: 'https://api.openai.com/v1'.",
|
||||
)
|
||||
embedding_api_key: str
|
||||
embedding_model: str = Field(
|
||||
"text-embedding-ada-002",
|
||||
description="OpenAI embedding Model to use. Example: 'text-embedding-3-large'.",
|
||||
)
|
||||
|
||||
|
||||
class GeminiSettings(BaseModel):
|
||||
api_key: str
|
||||
model: str = Field(
|
||||
"models/gemini-pro",
|
||||
description="Google Model to use. Example: 'models/gemini-pro'.",
|
||||
)
|
||||
embedding_model: str = Field(
|
||||
"models/embedding-001",
|
||||
description="Google Embedding Model to use. Example: 'models/embedding-001'.",
|
||||
)
|
||||
|
||||
|
||||
class OllamaSettings(BaseModel):
|
||||
api_base: str = Field(
|
||||
"http://localhost:11434",
|
||||
description="Base URL of Ollama API. Example: 'https://localhost:11434'.",
|
||||
)
|
||||
embedding_api_base: str = Field(
|
||||
"http://localhost:11434",
|
||||
description="Base URL of Ollama embedding API. Example: 'https://localhost:11434'.",
|
||||
)
|
||||
llm_model: str = Field(
|
||||
None,
|
||||
description="Model to use. Example: 'llama2-uncensored'.",
|
||||
)
|
||||
embedding_model: str = Field(
|
||||
None,
|
||||
description="Model to use. Example: 'nomic-embed-text'.",
|
||||
)
|
||||
keep_alive: str = Field(
|
||||
"5m",
|
||||
description="Time the model will stay loaded in memory after a request. examples: 5m, 5h, '-1' ",
|
||||
)
|
||||
tfs_z: float = Field(
|
||||
1.0,
|
||||
description="Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting.",
|
||||
)
|
||||
num_predict: int = Field(
|
||||
None,
|
||||
description="Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context)",
|
||||
)
|
||||
top_k: int = Field(
|
||||
40,
|
||||
description="Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)",
|
||||
)
|
||||
top_p: float = Field(
|
||||
0.9,
|
||||
description="Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)",
|
||||
)
|
||||
repeat_last_n: int = Field(
|
||||
64,
|
||||
description="Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)",
|
||||
)
|
||||
repeat_penalty: float = Field(
|
||||
1.1,
|
||||
description="Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)",
|
||||
)
|
||||
request_timeout: float = Field(
|
||||
120.0,
|
||||
description="Time elapsed until ollama times out the request. Default is 120s. Format is float. ",
|
||||
)
|
||||
autopull_models: bool = Field(
|
||||
False,
|
||||
description="If set to True, the Ollama will automatically pull the models from the API base.",
|
||||
)
|
||||
|
||||
|
||||
class AzureOpenAISettings(BaseModel):
|
||||
api_key: str
|
||||
azure_endpoint: str
|
||||
api_version: str = Field(
|
||||
"2023_05_15",
|
||||
description="The API version to use for this operation. This follows the YYYY-MM-DD format.",
|
||||
)
|
||||
embedding_deployment_name: str
|
||||
embedding_model: str = Field(
|
||||
"text-embedding-ada-002",
|
||||
description="OpenAI Model to use. Example: 'text-embedding-ada-002'.",
|
||||
)
|
||||
llm_deployment_name: str
|
||||
llm_model: str = Field(
|
||||
"gpt-35-turbo",
|
||||
description="OpenAI Model to use. Example: 'gpt-4'.",
|
||||
)
|
||||
|
||||
|
||||
class UISettings(BaseModel):
|
||||
enabled: bool
|
||||
path: str
|
||||
default_mode: Literal["RAG", "Search", "Basic", "Summarize"] = Field(
|
||||
"RAG",
|
||||
description="The default mode.",
|
||||
)
|
||||
default_chat_system_prompt: str = Field(
|
||||
None,
|
||||
description="The default system prompt to use for the chat mode.",
|
||||
)
|
||||
default_query_system_prompt: str = Field(
|
||||
None, description="The default system prompt to use for the query mode."
|
||||
)
|
||||
default_summarization_system_prompt: str = Field(
|
||||
None,
|
||||
description="The default system prompt to use for the summarization mode.",
|
||||
)
|
||||
delete_file_button_enabled: bool = Field(
|
||||
True, description="If the button to delete a file is enabled or not."
|
||||
)
|
||||
delete_all_files_button_enabled: bool = Field(
|
||||
False, description="If the button to delete all files is enabled or not."
|
||||
)
|
||||
|
||||
|
||||
class RerankSettings(BaseModel):
|
||||
enabled: bool = Field(
|
||||
False,
|
||||
description="This value controls whether a reranker should be included in the RAG pipeline.",
|
||||
)
|
||||
model: str = Field(
|
||||
"cross-encoder/ms-marco-MiniLM-L-2-v2",
|
||||
description="Rerank model to use. Limited to SentenceTransformer cross-encoder models.",
|
||||
)
|
||||
top_n: int = Field(
|
||||
2,
|
||||
description="This value controls the number of documents returned by the RAG pipeline.",
|
||||
)
|
||||
|
||||
|
||||
class RagSettings(BaseModel):
|
||||
similarity_top_k: int = Field(
|
||||
2,
|
||||
description="This value controls the number of documents returned by the RAG pipeline or considered for reranking if enabled.",
|
||||
)
|
||||
similarity_value: float = Field(
|
||||
None,
|
||||
description="If set, any documents retrieved from the RAG must meet a certain match score. Acceptable values are between 0 and 1.",
|
||||
)
|
||||
rerank: RerankSettings
|
||||
|
||||
|
||||
class SummarizeSettings(BaseModel):
|
||||
use_async: bool = Field(
|
||||
True,
|
||||
description="If set to True, the summarization will be done asynchronously.",
|
||||
)
|
||||
|
||||
|
||||
class ClickHouseSettings(BaseModel):
|
||||
host: str = Field(
|
||||
"localhost",
|
||||
description="The server hosting the ClickHouse database",
|
||||
)
|
||||
port: int = Field(
|
||||
8443,
|
||||
description="The port on which the ClickHouse database is accessible",
|
||||
)
|
||||
username: str = Field(
|
||||
"default",
|
||||
description="The username to use to connect to the ClickHouse database",
|
||||
)
|
||||
password: str = Field(
|
||||
"",
|
||||
description="The password to use to connect to the ClickHouse database",
|
||||
)
|
||||
database: str = Field(
|
||||
"__default__",
|
||||
description="The default database to use for connections",
|
||||
)
|
||||
secure: bool | str = Field(
|
||||
False,
|
||||
description="Use https/TLS for secure connection to the server",
|
||||
)
|
||||
interface: str | None = Field(
|
||||
None,
|
||||
description="Must be either 'http' or 'https'. Determines the protocol to use for the connection",
|
||||
)
|
||||
settings: dict[str, Any] | None = Field(
|
||||
None,
|
||||
description="Specific ClickHouse server settings to be used with the session",
|
||||
)
|
||||
connect_timeout: int | None = Field(
|
||||
None,
|
||||
description="Timeout in seconds for establishing a connection",
|
||||
)
|
||||
send_receive_timeout: int | None = Field(
|
||||
None,
|
||||
description="Read timeout in seconds for http connection",
|
||||
)
|
||||
verify: bool | None = Field(
|
||||
None,
|
||||
description="Verify the server certificate in secure/https mode",
|
||||
)
|
||||
ca_cert: str | None = Field(
|
||||
None,
|
||||
description="Path to Certificate Authority root certificate (.pem format)",
|
||||
)
|
||||
client_cert: str | None = Field(
|
||||
None,
|
||||
description="Path to TLS Client certificate (.pem format)",
|
||||
)
|
||||
client_cert_key: str | None = Field(
|
||||
None,
|
||||
description="Path to the private key for the TLS Client certificate",
|
||||
)
|
||||
http_proxy: str | None = Field(
|
||||
None,
|
||||
description="HTTP proxy address",
|
||||
)
|
||||
https_proxy: str | None = Field(
|
||||
None,
|
||||
description="HTTPS proxy address",
|
||||
)
|
||||
server_host_name: str | None = Field(
|
||||
None,
|
||||
description="Server host name to be checked against the TLS certificate",
|
||||
)
|
||||
|
||||
|
||||
class PostgresSettings(BaseModel):
|
||||
host: str = Field(
|
||||
"localhost",
|
||||
description="The server hosting the Postgres database",
|
||||
)
|
||||
port: int = Field(
|
||||
5432,
|
||||
description="The port on which the Postgres database is accessible",
|
||||
)
|
||||
user: str = Field(
|
||||
"postgres",
|
||||
description="The user to use to connect to the Postgres database",
|
||||
)
|
||||
password: str = Field(
|
||||
"postgres",
|
||||
description="The password to use to connect to the Postgres database",
|
||||
)
|
||||
database: str = Field(
|
||||
"postgres",
|
||||
description="The database to use to connect to the Postgres database",
|
||||
)
|
||||
schema_name: str = Field(
|
||||
"public",
|
||||
description="The name of the schema in the Postgres database to use",
|
||||
)
|
||||
|
||||
|
||||
class QdrantSettings(BaseModel):
|
||||
location: str | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"If `:memory:` - use in-memory Qdrant instance.\n"
|
||||
"If `str` - use it as a `url` parameter.\n"
|
||||
),
|
||||
)
|
||||
url: str | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Either host or str of 'Optional[scheme], host, Optional[port], Optional[prefix]'."
|
||||
),
|
||||
)
|
||||
port: int | None = Field(6333, description="Port of the REST API interface.")
|
||||
grpc_port: int | None = Field(6334, description="Port of the gRPC interface.")
|
||||
prefer_grpc: bool | None = Field(
|
||||
False,
|
||||
description="If `true` - use gRPC interface whenever possible in custom methods.",
|
||||
)
|
||||
https: bool | None = Field(
|
||||
None,
|
||||
description="If `true` - use HTTPS(SSL) protocol.",
|
||||
)
|
||||
api_key: str | None = Field(
|
||||
None,
|
||||
description="API key for authentication in Qdrant Cloud.",
|
||||
)
|
||||
prefix: str | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Prefix to add to the REST URL path."
|
||||
"Example: `service/v1` will result in "
|
||||
"'http://localhost:6333/service/v1/{qdrant-endpoint}' for REST API."
|
||||
),
|
||||
)
|
||||
timeout: float | None = Field(
|
||||
None,
|
||||
description="Timeout for REST and gRPC API requests.",
|
||||
)
|
||||
host: str | None = Field(
|
||||
None,
|
||||
description="Host name of Qdrant service. If url and host are None, set to 'localhost'.",
|
||||
)
|
||||
path: str | None = Field(None, description="Persistence path for QdrantLocal.")
|
||||
force_disable_check_same_thread: bool | None = Field(
|
||||
True,
|
||||
description=(
|
||||
"For QdrantLocal, force disable check_same_thread. Default: `True`"
|
||||
"Only use this if you can guarantee that you can resolve the thread safety outside QdrantClient."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class MilvusSettings(BaseModel):
|
||||
uri: str = Field(
|
||||
"local_data/private_gpt/milvus/milvus_local.db",
|
||||
description="The URI of the Milvus instance. For example: 'local_data/private_gpt/milvus/milvus_local.db' for Milvus Lite.",
|
||||
)
|
||||
token: str = Field(
|
||||
"",
|
||||
description=(
|
||||
"A valid access token to access the specified Milvus instance. "
|
||||
"This can be used as a recommended alternative to setting user and password separately. "
|
||||
),
|
||||
)
|
||||
collection_name: str = Field(
|
||||
"make_this_parameterizable_per_api_call",
|
||||
description="The name of the collection in Milvus. Default is 'make_this_parameterizable_per_api_call'.",
|
||||
)
|
||||
overwrite: bool = Field(
|
||||
True, description="Overwrite the previous collection schema if it exists."
|
||||
)
|
||||
|
||||
|
||||
class Settings(BaseModel):
|
||||
server: ServerSettings
|
||||
data: DataSettings
|
||||
ui: UISettings
|
||||
llm: LLMSettings
|
||||
embedding: EmbeddingSettings
|
||||
llamacpp: LlamaCPPSettings
|
||||
huggingface: HuggingFaceSettings
|
||||
sagemaker: SagemakerSettings
|
||||
openai: OpenAISettings
|
||||
gemini: GeminiSettings
|
||||
ollama: OllamaSettings
|
||||
azopenai: AzureOpenAISettings
|
||||
vectorstore: VectorstoreSettings
|
||||
nodestore: NodeStoreSettings
|
||||
rag: RagSettings
|
||||
summarize: SummarizeSettings
|
||||
qdrant: QdrantSettings | None = None
|
||||
postgres: PostgresSettings | None = None
|
||||
clickhouse: ClickHouseSettings | None = None
|
||||
milvus: MilvusSettings | None = None
|
||||
|
||||
|
||||
"""
|
||||
This is visible just for DI or testing purposes.
|
||||
|
||||
Use dependency injection or `settings()` method instead.
|
||||
"""
|
||||
unsafe_settings = load_active_settings()
|
||||
|
||||
"""
|
||||
This is visible just for DI or testing purposes.
|
||||
|
||||
Use dependency injection or `settings()` method instead.
|
||||
"""
|
||||
unsafe_typed_settings = Settings(**unsafe_settings)
|
||||
|
||||
|
||||
def settings() -> Settings:
|
||||
"""Get the current loaded settings from the DI container.
|
||||
|
||||
This method exists to keep compatibility with the existing code,
|
||||
that require global access to the settings.
|
||||
|
||||
For regular components use dependency injection instead.
|
||||
"""
|
||||
from private_gpt.di import global_injector
|
||||
|
||||
return global_injector.get(Settings)
|
||||
57
pgpt/private_gpt/settings/settings_loader.py
Normal file
57
pgpt/private_gpt/settings/settings_loader.py
Normal file
@ -0,0 +1,57 @@
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic.v1.utils import deep_update, unique_list
|
||||
|
||||
from private_gpt.constants import PROJECT_ROOT_PATH
|
||||
from private_gpt.settings.yaml import load_yaml_with_envvars
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_settings_folder = os.environ.get("PGPT_SETTINGS_FOLDER", PROJECT_ROOT_PATH)
|
||||
|
||||
# if running in unittest, use the test profile
|
||||
_test_profile = ["test"] if "tests.fixtures" in sys.modules else []
|
||||
|
||||
active_profiles: list[str] = unique_list(
|
||||
["default"]
|
||||
+ [
|
||||
item.strip()
|
||||
for item in os.environ.get("PGPT_PROFILES", "").split(",")
|
||||
if item.strip()
|
||||
]
|
||||
+ _test_profile
|
||||
)
|
||||
|
||||
|
||||
def merge_settings(settings: Iterable[dict[str, Any]]) -> dict[str, Any]:
|
||||
return functools.reduce(deep_update, settings, {})
|
||||
|
||||
|
||||
def load_settings_from_profile(profile: str) -> dict[str, Any]:
|
||||
if profile == "default":
|
||||
profile_file_name = "settings.yaml"
|
||||
else:
|
||||
profile_file_name = f"settings-{profile}.yaml"
|
||||
|
||||
path = Path(_settings_folder) / profile_file_name
|
||||
with Path(path).open("r") as f:
|
||||
config = load_yaml_with_envvars(f)
|
||||
if not isinstance(config, dict):
|
||||
raise TypeError(f"Config file has no top-level mapping: {path}")
|
||||
return config
|
||||
|
||||
|
||||
def load_active_settings() -> dict[str, Any]:
|
||||
"""Load active profiles and merge them."""
|
||||
logger.info("Starting application with profiles=%s", active_profiles)
|
||||
loaded_profiles = [
|
||||
load_settings_from_profile(profile) for profile in active_profiles
|
||||
]
|
||||
merged: dict[str, Any] = merge_settings(loaded_profiles)
|
||||
return merged
|
||||
41
pgpt/private_gpt/settings/yaml.py
Normal file
41
pgpt/private_gpt/settings/yaml.py
Normal file
@ -0,0 +1,41 @@
|
||||
import os
|
||||
import re
|
||||
import typing
|
||||
from typing import Any, TextIO
|
||||
|
||||
from yaml import SafeLoader
|
||||
|
||||
_env_replace_matcher = re.compile(r"\$\{(\w|_)+:?.*}")
|
||||
|
||||
|
||||
@typing.no_type_check # pyaml does not have good hints, everything is Any
|
||||
def load_yaml_with_envvars(
|
||||
stream: TextIO, environ: dict[str, Any] = os.environ
|
||||
) -> dict[str, Any]:
|
||||
"""Load yaml file with environment variable expansion.
|
||||
|
||||
The pattern ${VAR} or ${VAR:default} will be replaced with
|
||||
the value of the environment variable.
|
||||
"""
|
||||
loader = SafeLoader(stream)
|
||||
|
||||
def load_env_var(_, node) -> str:
|
||||
"""Extract the matched value, expand env variable, and replace the match."""
|
||||
value = str(node.value).removeprefix("${").removesuffix("}")
|
||||
split = value.split(":", 1)
|
||||
env_var = split[0]
|
||||
value = environ.get(env_var)
|
||||
default = None if len(split) == 1 else split[1]
|
||||
if value is None and default is None:
|
||||
raise ValueError(
|
||||
f"Environment variable {env_var} is not set and not default was provided"
|
||||
)
|
||||
return value or default
|
||||
|
||||
loader.add_implicit_resolver("env_var_replacer", _env_replace_matcher, None)
|
||||
loader.add_constructor("env_var_replacer", load_env_var)
|
||||
|
||||
try:
|
||||
return loader.get_single_data()
|
||||
finally:
|
||||
loader.dispose()
|
||||
1
pgpt/private_gpt/ui/__init__.py
Normal file
1
pgpt/private_gpt/ui/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Gradio based UI."""
|
||||
BIN
pgpt/private_gpt/ui/avatar-bot.ico
Normal file
BIN
pgpt/private_gpt/ui/avatar-bot.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 15 KiB |
1
pgpt/private_gpt/ui/images.py
Normal file
1
pgpt/private_gpt/ui/images.py
Normal file
@ -0,0 +1 @@
|
||||
logo_svg = "data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iODYxIiBoZWlnaHQ9Ijk4IiB2aWV3Qm94PSIwIDAgODYxIDk4IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgo8cGF0aCBkPSJNNDguMTM0NSAwLjE1NzkxMUMzNi44Mjk5IDEuMDM2NTQgMjYuMTIwNSA1LjU1MzI4IDE3LjYyNTYgMTMuMDI1QzkuMTMwNDYgMjAuNDk2NyAzLjMxMTcgMzAuNTE2OSAxLjA0OTUyIDQxLjU3MDVDLTEuMjEyNzMgNTIuNjIzOCAwLjIwNDQxOSA2NC4xMDk0IDUuMDg2MiA3NC4yOTA1QzkuOTY4NjggODQuNDcxNiAxOC4wNTAzIDkyLjc5NDMgMjguMTA5OCA5OEwzMy43MDI2IDgyLjU5MDdMMzUuNDU0MiA3Ny43NjU2QzI5LjgzODcgNzQuMTY5MiAyNS41NDQ0IDY4Ljg2MDcgMjMuMjE0IDYyLjYzNDRDMjAuODgyMiA1Ni40MDg2IDIwLjYzOSA0OS41OTkxIDIyLjUyMDQgNDMuMjI0M0MyNC40MDI5IDM2Ljg0OTUgMjguMzA5NiAzMS4yNTI1IDMzLjY1NjEgMjcuMjcwNkMzOS4wMDIgMjMuMjg4MyA0NS41MDAzIDIxLjEzNSA1Mi4xNzg5IDIxLjEzM0M1OC44NTczIDIxLjEzMDMgNjUuMzU3MSAyMy4yNzgzIDcwLjcwNjUgMjcuMjU1OEM3Ni4wNTU0IDMxLjIzNCA3OS45NjY0IDM2LjgyNzcgODEuODU0MyA0My4yMDA2QzgzLjc0MjkgNDkuNTczNiA4My41MDYyIDU2LjM4MzYgODEuMTgwMSA2Mi42MTE3Qzc4Ljg1NDUgNjguODM5NiA3NC41NjUgNzQuMTUxNCA2OC45NTI5IDc3Ljc1MjhMNzAuNzA3NCA4Mi41OTA3TDc2LjMwMDIgOTcuOTk3MUM4Ni45Nzg4IDkyLjQ3MDUgOTUuNDA4OCA4My40NDE5IDEwMC4xNjMgNzIuNDQwNEMxMDQuOTE3IDYxLjQzOTQgMTA1LjcwNCA0OS4xNDE3IDEwMi4zODkgMzcuNjNDOTkuMDc0NiAyNi4xMTc5IDkxLjg2MjcgMTYuMDk5MyA4MS45NzQzIDkuMjcwNzlDNzIuMDg2MSAyLjQ0MTkxIDYwLjEyOTEgLTAuNzc3MDg2IDQ4LjEyODYgMC4xNTg5MzRMNDguMTM0NSAwLjE1NzkxMVoiIGZpbGw9IiMxRjFGMjkiLz4KPGcgY2xpcC1wYXRoPSJ1cmwoI2NsaXAwXzVfMTkpIj4KPHBhdGggZD0iTTIyMC43NzIgMTIuNzUyNEgyNTIuNjM5QzI2Ny4yNjMgMTIuNzUyNCAyNzcuNzM5IDIxLjk2NzUgMjc3LjczOSAzNS40MDUyQzI3Ny43MzkgNDYuNzg3IDI2OS44ODEgNTUuMzUwOCAyNTguMzE0IDU3LjQxMDdMMjc4LjgzIDg1LjM3OTRIMjYxLjM3TDI0Mi4wNTQgNTcuOTUzM0gyMzUuNTA2Vjg1LjM3OTRIMjIwLjc3NEwyMjAuNzcyIDEyLjc1MjRaTTIzNS41MDQgMjYuMzAyOFY0NC40MDdIMjUyLjYzMkMyNTguOTYyIDQ0LjQwNyAyNjIuOTk5IDQwLjgyOTggMjYyLjk5OSAzNS40MTAyQzI2Mi45OTkgMjkuODgwOSAyNTguOTYyIDI2LjMwMjggMjUyLjYzMiAyNi4zMDI4SDIzNS41MDRaIiBmaWxsPSIjMUYxRjI5Ii8+CjxwYXRoIGQ9Ik0yOTUuMTc2IDg1LjM4NDRWMTIuNzUyNEgzMDkuOTA5Vjg1LjM4NDRIMjk1LjE3NloiIGZpbGw9IiMxRjFGMjkiLz4KPHBhdGggZD0iTTM2My43OTUgNjUuNzYzTDM4NS42MiAxMi43NTI0SDQwMS40NDRMMzcxLjIxNSA4NS4zODQ0SDM1Ni40ODNMMzI2LjI1NCAxMi43NTI0SDM0Mi4wNzhMMzYzLjc5NSA2NS43NjNaIiBmaWxsPSIjMUYxRjI5Ii8+CjxwYXRoIGQ9Ik00NDguMzI3IDcyLjA1MDRINDE1LjY5OEw0MTAuMjQxIDg1LjM4NDRIMzk0LjQxOEw0MjQuNjQ3IDEyLjc1MjRINDM5LjM3OUw0NjkuNjA4IDg1LjM4NDRINDUzLjc4M0w0NDguMzI3IDcyLjA1MDRaTTQ0Mi43NjEgNTguNUw0MzIuMDY2IDMyLjM3NDhMNDIxLjI2MiA1OC41SDQ0Mi43NjFaIiBmaWxsPSIjMUYxRjI5Ii8+CjxwYXRoIGQ9Ik00NjUuMjIxIDEyLjc1MjRINTMwLjU5MlYyNi4zMDI4SDUwNS4yNzVWODUuMzg0NEg0OTAuNTM5VjI2LjMwMjhINDY1LjIyMVYxMi43NTI0WiIgZmlsbD0iIzFGMUYyOSIvPgo8cGF0aCBkPSJNNTk1LjE5MyAxMi43NTI0VjI2LjMwMjhINTYyLjEyOFY0MS4xNTUxSDU5NS4xOTNWNTQuNzA2NUg1NjIuMTI4VjcxLjgzNEg1OTUuMTkzVjg1LjM4NDRINTQ3LjM5NVYxMi43NTI0SDU5NS4xOTNaIiBmaWxsPSIjMUYxRjI5Ii8+CjxwYXRoIGQ9Ik0xNjcuMjAxIDU3LjQxNThIMTg2LjUzNkMxOTAuODg2IDU3LjQ2NjIgMTk1LjE2OCA1Ni4zMzQ4IDE5OC45MTggNTQuMTQzN0MyMDIuMTc5IDUyLjIxOTkgMjA0Ljg2OSA0OS40NzM2IDIwNi43MTYgNDYuMTgzNUMyMDguNTYyIDQyLjg5MzQgMjA5LjUgMzkuMTc2NiAyMDkuNDMzIDM1LjQxMDJDMjA5LjQzMyAyMS45Njc1IDE5OC45NTggMTIuNzU3NCAxODQuMzM0IDEyLjc1NzRIMTUyLjQ2OFY4NS4zODk0SDE2Ny4yMDFWNTcuNDIwN1Y1Ny40MTU4Wk0xNjcuMjAxIDI2LjMwNThIMTg0LjMyOUMxOTAuNjU4IDI2LjMwNTggMTk0LjY5NiAyOS44ODQgMTk0LjY5NiAzNS40MTMzQzE5NC42OTYgNDAuODMyOSAxOTAuNjU4IDQ0LjQwOTkgMTg0LjMyOSA0NC40MDk5SDE2Ny4yMDFWMjYuMzA1OFoiIGZpbGw9IiMxRjFGMjkiLz4KPHBhdGggZD0iTTc5NC44MzUgMTIuNzUyNEg4NjAuMjA2VjI2LjMwMjhIODM0Ljg4OVY4NS4zODQ0SDgyMC4xNTZWMjYuMzAyOEg3OTQuODM1VjEyLjc1MjRaIiBmaWxsPSIjMUYxRjI5Ii8+CjxwYXRoIGQ9Ik03NDEuOTA3IDU3LjQxNThINzYxLjI0MUM3NjUuNTkyIDU3LjQ2NjEgNzY5Ljg3NCA1Ni4zMzQ3IDc3My42MjQgNTQuMTQzN0M3NzYuODg0IDUyLjIxOTkgNzc5LjU3NSA0OS40NzM2IDc4MS40MjEgNDYuMTgzNUM3ODMuMjY4IDQyLjg5MzQgNzg0LjIwNiAzOS4xNzY2IDc4NC4xMzkgMzUuNDEwMkM3ODQuMTM5IDIxLjk2NzUgNzczLjY2NCAxMi43NTc0IDc1OS4wMzkgMTIuNzU3NEg3MjcuMTc1Vjg1LjM4OTRINzQxLjkwN1Y1Ny40MjA3VjU3LjQxNThaTTc0MS45MDcgMjYuMzA1OEg3NTkuMDM1Qzc2NS4zNjUgMjYuMzA1OCA3NjkuNDAzIDI5Ljg4NCA3NjkuNDAzIDM1LjQxMzNDNzY5LjQwMyA0MC44MzI5IDc2NS4zNjUgNDQuNDA5OSA3NTkuMDM1IDQ0LjQwOTlINzQxLjkwN1YyNi4zMDU4WiIgZmlsbD0iIzFGMUYyOSIvPgo8cGF0aCBkPSJNNjgxLjA2OSA0Ny4wMTE1VjU5LjAxMjVINjk1LjM3OVY3MS42NzE5QzY5Mi41MjYgNzMuNDM2OCA2ODguNTI0IDc0LjMzMTkgNjgzLjQ3NyA3NC4zMzE5QzY2Ni4wMDMgNzQuMzMxOSA2NTguMDQ1IDYxLjgxMjQgNjU4LjA0NSA1MC4xOEM2NTguMDQ1IDMzLjk2MDUgNjcxLjAwOCAyNS40NzMyIDY4My44MTIgMjUuNDczMkM2OTAuNDI1IDI1LjQ2MjggNjk2LjkwOSAyNy4yODA0IDcwMi41NDEgMzAuNzIyNkw3MDMuMTU3IDMxLjEyNTRMNzA1Ljk1OCAxOC4xODZMNzA1LjY2MyAxNy45OTc3QzcwMC4wNDYgMTQuNDAwNCA2OTEuMjkxIDEyLjI1OSA2ODIuMjUxIDEyLjI1OUM2NjMuMTk3IDEyLjI1OSA2NDIuOTQ5IDI1LjM5NjcgNjQyLjk0OSA0OS43NDVDNjQyLjk0OSA2MS4wODQ1IDY0Ny4yOTMgNzAuNzE3NCA2NTUuNTExIDc3LjYwMjlDNjYzLjIyNCA4My44MjQ1IDY3Mi44NzQgODcuMTg5IDY4Mi44MDkgODcuMTIwMUM2OTQuMzYzIDg3LjEyMDEgNzAzLjA2MSA4NC42NDk1IDcwOS40MDIgNzkuNTY5Mkw3MDkuNTg5IDc5LjQxODFWNDcuMDExNUg2ODEuMDY5WiIgZmlsbD0iIzFGMUYyOSIvPgo8L2c+CjxkZWZzPgo8Y2xpcFBhdGggaWQ9ImNsaXAwXzVfMTkiPgo8cmVjdCB3aWR0aD0iNzA3Ljc3OCIgaGVpZ2h0PSI3NC44NjExIiBmaWxsPSJ3aGl0ZSIgdHJhbnNmb3JtPSJ0cmFuc2xhdGUoMTUyLjQ0NCAxMi4yNSkiLz4KPC9jbGlwUGF0aD4KPC9kZWZzPgo8L3N2Zz4K"
|
||||
596
pgpt/private_gpt/ui/ui.py
Normal file
596
pgpt/private_gpt/ui/ui.py
Normal file
@ -0,0 +1,596 @@
|
||||
"""This file should be imported if and only if you want to run the UI locally."""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import gradio as gr # type: ignore
|
||||
from fastapi import FastAPI
|
||||
from gradio.themes.utils.colors import slate # type: ignore
|
||||
from injector import inject, singleton
|
||||
from llama_index.core.llms import ChatMessage, ChatResponse, MessageRole
|
||||
from llama_index.core.types import TokenGen
|
||||
from pydantic import BaseModel
|
||||
|
||||
from private_gpt.constants import PROJECT_ROOT_PATH
|
||||
from private_gpt.di import global_injector
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.server.chat.chat_service import ChatService, CompletionGen
|
||||
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
|
||||
from private_gpt.server.ingest.ingest_service import IngestService
|
||||
from private_gpt.server.recipes.summarize.summarize_service import SummarizeService
|
||||
from private_gpt.settings.settings import settings
|
||||
from private_gpt.ui.images import logo_svg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH)
|
||||
# Should be "private_gpt/ui/avatar-bot.ico"
|
||||
AVATAR_BOT = THIS_DIRECTORY_RELATIVE / "avatar-bot.ico"
|
||||
|
||||
UI_TAB_TITLE = "My Private GPT"
|
||||
|
||||
SOURCES_SEPARATOR = "<hr>Sources: \n"
|
||||
|
||||
|
||||
class Modes(str, Enum):
|
||||
RAG_MODE = "RAG"
|
||||
SEARCH_MODE = "Search"
|
||||
BASIC_CHAT_MODE = "Basic"
|
||||
SUMMARIZE_MODE = "Summarize"
|
||||
|
||||
|
||||
MODES: list[Modes] = [
|
||||
Modes.RAG_MODE,
|
||||
Modes.SEARCH_MODE,
|
||||
Modes.BASIC_CHAT_MODE,
|
||||
Modes.SUMMARIZE_MODE,
|
||||
]
|
||||
|
||||
|
||||
class Source(BaseModel):
|
||||
file: str
|
||||
page: str
|
||||
text: str
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
@staticmethod
|
||||
def curate_sources(sources: list[Chunk]) -> list["Source"]:
|
||||
curated_sources = []
|
||||
|
||||
for chunk in sources:
|
||||
doc_metadata = chunk.document.doc_metadata
|
||||
|
||||
file_name = doc_metadata.get("file_name", "-") if doc_metadata else "-"
|
||||
page_label = doc_metadata.get("page_label", "-") if doc_metadata else "-"
|
||||
|
||||
source = Source(file=file_name, page=page_label, text=chunk.text)
|
||||
curated_sources.append(source)
|
||||
curated_sources = list(
|
||||
dict.fromkeys(curated_sources).keys()
|
||||
) # Unique sources only
|
||||
|
||||
return curated_sources
|
||||
|
||||
|
||||
@singleton
|
||||
class PrivateGptUi:
|
||||
@inject
|
||||
def __init__(
|
||||
self,
|
||||
ingest_service: IngestService,
|
||||
chat_service: ChatService,
|
||||
chunks_service: ChunksService,
|
||||
summarizeService: SummarizeService,
|
||||
) -> None:
|
||||
self._ingest_service = ingest_service
|
||||
self._chat_service = chat_service
|
||||
self._chunks_service = chunks_service
|
||||
self._summarize_service = summarizeService
|
||||
|
||||
# Cache the UI blocks
|
||||
self._ui_block = None
|
||||
|
||||
self._selected_filename = None
|
||||
|
||||
# Initialize system prompt based on default mode
|
||||
default_mode_map = {mode.value: mode for mode in Modes}
|
||||
self._default_mode = default_mode_map.get(
|
||||
settings().ui.default_mode, Modes.RAG_MODE
|
||||
)
|
||||
self._system_prompt = self._get_default_system_prompt(self._default_mode)
|
||||
|
||||
def _chat(
|
||||
self, message: str, history: list[list[str]], mode: Modes, *_: Any
|
||||
) -> Any:
|
||||
def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
|
||||
full_response: str = ""
|
||||
stream = completion_gen.response
|
||||
for delta in stream:
|
||||
if isinstance(delta, str):
|
||||
full_response += str(delta)
|
||||
elif isinstance(delta, ChatResponse):
|
||||
full_response += delta.delta or ""
|
||||
yield full_response
|
||||
time.sleep(0.02)
|
||||
|
||||
if completion_gen.sources:
|
||||
full_response += SOURCES_SEPARATOR
|
||||
cur_sources = Source.curate_sources(completion_gen.sources)
|
||||
sources_text = "\n\n\n"
|
||||
used_files = set()
|
||||
for index, source in enumerate(cur_sources, start=1):
|
||||
if f"{source.file}-{source.page}" not in used_files:
|
||||
sources_text = (
|
||||
sources_text
|
||||
+ f"{index}. {source.file} (page {source.page}) \n\n"
|
||||
)
|
||||
used_files.add(f"{source.file}-{source.page}")
|
||||
sources_text += "<hr>\n\n"
|
||||
full_response += sources_text
|
||||
yield full_response
|
||||
|
||||
def yield_tokens(token_gen: TokenGen) -> Iterable[str]:
|
||||
full_response: str = ""
|
||||
for token in token_gen:
|
||||
full_response += str(token)
|
||||
yield full_response
|
||||
|
||||
def build_history() -> list[ChatMessage]:
|
||||
history_messages: list[ChatMessage] = []
|
||||
|
||||
for interaction in history:
|
||||
history_messages.append(
|
||||
ChatMessage(content=interaction[0], role=MessageRole.USER)
|
||||
)
|
||||
if len(interaction) > 1 and interaction[1] is not None:
|
||||
history_messages.append(
|
||||
ChatMessage(
|
||||
# Remove from history content the Sources information
|
||||
content=interaction[1].split(SOURCES_SEPARATOR)[0],
|
||||
role=MessageRole.ASSISTANT,
|
||||
)
|
||||
)
|
||||
|
||||
# max 20 messages to try to avoid context overflow
|
||||
return history_messages[:20]
|
||||
|
||||
new_message = ChatMessage(content=message, role=MessageRole.USER)
|
||||
all_messages = [*build_history(), new_message]
|
||||
# If a system prompt is set, add it as a system message
|
||||
if self._system_prompt:
|
||||
all_messages.insert(
|
||||
0,
|
||||
ChatMessage(
|
||||
content=self._system_prompt,
|
||||
role=MessageRole.SYSTEM,
|
||||
),
|
||||
)
|
||||
match mode:
|
||||
case Modes.RAG_MODE:
|
||||
# Use only the selected file for the query
|
||||
context_filter = None
|
||||
if self._selected_filename is not None:
|
||||
docs_ids = []
|
||||
for ingested_document in self._ingest_service.list_ingested():
|
||||
if (
|
||||
ingested_document.doc_metadata["file_name"]
|
||||
== self._selected_filename
|
||||
):
|
||||
docs_ids.append(ingested_document.doc_id)
|
||||
context_filter = ContextFilter(docs_ids=docs_ids)
|
||||
|
||||
query_stream = self._chat_service.stream_chat(
|
||||
messages=all_messages,
|
||||
use_context=True,
|
||||
context_filter=context_filter,
|
||||
)
|
||||
yield from yield_deltas(query_stream)
|
||||
case Modes.BASIC_CHAT_MODE:
|
||||
llm_stream = self._chat_service.stream_chat(
|
||||
messages=all_messages,
|
||||
use_context=False,
|
||||
)
|
||||
yield from yield_deltas(llm_stream)
|
||||
|
||||
case Modes.SEARCH_MODE:
|
||||
response = self._chunks_service.retrieve_relevant(
|
||||
text=message, limit=4, prev_next_chunks=0
|
||||
)
|
||||
|
||||
sources = Source.curate_sources(response)
|
||||
|
||||
yield "\n\n\n".join(
|
||||
f"{index}. **{source.file} "
|
||||
f"(page {source.page})**\n "
|
||||
f"{source.text}"
|
||||
for index, source in enumerate(sources, start=1)
|
||||
)
|
||||
case Modes.SUMMARIZE_MODE:
|
||||
# Summarize the given message, optionally using selected files
|
||||
context_filter = None
|
||||
if self._selected_filename:
|
||||
docs_ids = []
|
||||
for ingested_document in self._ingest_service.list_ingested():
|
||||
if (
|
||||
ingested_document.doc_metadata["file_name"]
|
||||
== self._selected_filename
|
||||
):
|
||||
docs_ids.append(ingested_document.doc_id)
|
||||
context_filter = ContextFilter(docs_ids=docs_ids)
|
||||
|
||||
summary_stream = self._summarize_service.stream_summarize(
|
||||
use_context=True,
|
||||
context_filter=context_filter,
|
||||
instructions=message,
|
||||
)
|
||||
yield from yield_tokens(summary_stream)
|
||||
|
||||
# On initialization and on mode change, this function set the system prompt
|
||||
# to the default prompt based on the mode (and user settings).
|
||||
@staticmethod
|
||||
def _get_default_system_prompt(mode: Modes) -> str:
|
||||
p = ""
|
||||
match mode:
|
||||
# For query chat mode, obtain default system prompt from settings
|
||||
case Modes.RAG_MODE:
|
||||
p = settings().ui.default_query_system_prompt
|
||||
# For chat mode, obtain default system prompt from settings
|
||||
case Modes.BASIC_CHAT_MODE:
|
||||
p = settings().ui.default_chat_system_prompt
|
||||
# For summarization mode, obtain default system prompt from settings
|
||||
case Modes.SUMMARIZE_MODE:
|
||||
p = settings().ui.default_summarization_system_prompt
|
||||
# For any other mode, clear the system prompt
|
||||
case _:
|
||||
p = ""
|
||||
return p
|
||||
|
||||
@staticmethod
|
||||
def _get_default_mode_explanation(mode: Modes) -> str:
|
||||
match mode:
|
||||
case Modes.RAG_MODE:
|
||||
return "Get contextualized answers from selected files."
|
||||
case Modes.SEARCH_MODE:
|
||||
return "Find relevant chunks of text in selected files."
|
||||
case Modes.BASIC_CHAT_MODE:
|
||||
return "Chat with the LLM using its training data. Files are ignored."
|
||||
case Modes.SUMMARIZE_MODE:
|
||||
#######
|
||||
# Modification par SPC
|
||||
#
|
||||
# return "Generate a summary of the selected files. Prompt to customize the result."
|
||||
return "Générer l'analyse selon les conditions présentées dans le prompt système."
|
||||
#
|
||||
# Fin modification par SPC
|
||||
#######
|
||||
case _:
|
||||
return ""
|
||||
|
||||
def _set_system_prompt(self, system_prompt_input: str) -> None:
|
||||
logger.info(f"Setting system prompt to: {system_prompt_input}")
|
||||
self._system_prompt = system_prompt_input
|
||||
|
||||
def _set_explanatation_mode(self, explanation_mode: str) -> None:
|
||||
self._explanation_mode = explanation_mode
|
||||
|
||||
def _set_current_mode(self, mode: Modes) -> Any:
|
||||
self.mode = mode
|
||||
self._set_system_prompt(self._get_default_system_prompt(mode))
|
||||
self._set_explanatation_mode(self._get_default_mode_explanation(mode))
|
||||
interactive = self._system_prompt is not None
|
||||
return [
|
||||
gr.update(placeholder=self._system_prompt, interactive=interactive),
|
||||
gr.update(value=self._explanation_mode),
|
||||
]
|
||||
|
||||
def _list_ingested_files(self) -> list[list[str]]:
|
||||
files = set()
|
||||
for ingested_document in self._ingest_service.list_ingested():
|
||||
if ingested_document.doc_metadata is None:
|
||||
# Skipping documents without metadata
|
||||
continue
|
||||
file_name = ingested_document.doc_metadata.get(
|
||||
"file_name", "[FILE NAME MISSING]"
|
||||
)
|
||||
files.add(file_name)
|
||||
return [[row] for row in files]
|
||||
|
||||
def _upload_file(self, files: list[str]) -> None:
|
||||
logger.debug("Loading count=%s files", len(files))
|
||||
paths = [Path(file) for file in files]
|
||||
|
||||
# remove all existing Documents with name identical to a new file upload:
|
||||
file_names = [path.name for path in paths]
|
||||
doc_ids_to_delete = []
|
||||
for ingested_document in self._ingest_service.list_ingested():
|
||||
if (
|
||||
ingested_document.doc_metadata
|
||||
and ingested_document.doc_metadata["file_name"] in file_names
|
||||
):
|
||||
doc_ids_to_delete.append(ingested_document.doc_id)
|
||||
if len(doc_ids_to_delete) > 0:
|
||||
logger.info(
|
||||
"Uploading file(s) which were already ingested: %s document(s) will be replaced.",
|
||||
len(doc_ids_to_delete),
|
||||
)
|
||||
for doc_id in doc_ids_to_delete:
|
||||
self._ingest_service.delete(doc_id)
|
||||
|
||||
self._ingest_service.bulk_ingest([(str(path.name), path) for path in paths])
|
||||
|
||||
def _delete_all_files(self) -> Any:
|
||||
ingested_files = self._ingest_service.list_ingested()
|
||||
logger.debug("Deleting count=%s files", len(ingested_files))
|
||||
for ingested_document in ingested_files:
|
||||
self._ingest_service.delete(ingested_document.doc_id)
|
||||
return [
|
||||
gr.List(self._list_ingested_files()),
|
||||
gr.components.Button(interactive=False),
|
||||
gr.components.Button(interactive=False),
|
||||
gr.components.Textbox("All files"),
|
||||
]
|
||||
|
||||
def _delete_selected_file(self) -> Any:
|
||||
logger.debug("Deleting selected %s", self._selected_filename)
|
||||
# Note: keep looping for pdf's (each page became a Document)
|
||||
for ingested_document in self._ingest_service.list_ingested():
|
||||
if (
|
||||
ingested_document.doc_metadata
|
||||
and ingested_document.doc_metadata["file_name"]
|
||||
== self._selected_filename
|
||||
):
|
||||
self._ingest_service.delete(ingested_document.doc_id)
|
||||
return [
|
||||
gr.List(self._list_ingested_files()),
|
||||
gr.components.Button(interactive=False),
|
||||
gr.components.Button(interactive=False),
|
||||
gr.components.Textbox("All files"),
|
||||
]
|
||||
|
||||
def _deselect_selected_file(self) -> Any:
|
||||
self._selected_filename = None
|
||||
return [
|
||||
gr.components.Button(interactive=False),
|
||||
gr.components.Button(interactive=False),
|
||||
gr.components.Textbox("All files"),
|
||||
]
|
||||
|
||||
def _selected_a_file(self, select_data: gr.SelectData) -> Any:
|
||||
self._selected_filename = select_data.value
|
||||
return [
|
||||
gr.components.Button(interactive=True),
|
||||
gr.components.Button(interactive=True),
|
||||
gr.components.Textbox(self._selected_filename),
|
||||
]
|
||||
|
||||
def _build_ui_blocks(self) -> gr.Blocks:
|
||||
logger.debug("Creating the UI blocks")
|
||||
with gr.Blocks(
|
||||
title=UI_TAB_TITLE,
|
||||
theme=gr.themes.Soft(primary_hue=slate),
|
||||
css=".logo { "
|
||||
"display:flex;"
|
||||
"background-color: #C7BAFF;"
|
||||
"height: 80px;"
|
||||
"border-radius: 8px;"
|
||||
"align-content: center;"
|
||||
"justify-content: center;"
|
||||
"align-items: center;"
|
||||
"}"
|
||||
".logo img { height: 25% }"
|
||||
".contain { display: flex !important; flex-direction: column !important; }"
|
||||
"#component-0, #component-3, #component-10, #component-8 { height: 100% !important; }"
|
||||
"#chatbot { flex-grow: 1 !important; overflow: auto !important;}"
|
||||
"#col { height: calc(100vh - 112px - 16px) !important; }"
|
||||
"hr { margin-top: 1em; margin-bottom: 1em; border: 0; border-top: 1px solid #FFF; }"
|
||||
".avatar-image { background-color: antiquewhite; border-radius: 2px; }"
|
||||
".footer { text-align: center; margin-top: 20px; font-size: 14px; display: flex; align-items: center; justify-content: center; }"
|
||||
".footer-zylon-link { display:flex; margin-left: 5px; text-decoration: auto; color: var(--body-text-color); }"
|
||||
".footer-zylon-link:hover { color: #C7BAFF; }"
|
||||
".footer-zylon-ico { height: 20px; margin-left: 5px; background-color: antiquewhite; border-radius: 2px; }",
|
||||
) as blocks:
|
||||
with gr.Row():
|
||||
gr.HTML(f"<div class='logo'/><img src={logo_svg} alt=PrivateGPT></div")
|
||||
|
||||
with gr.Row(equal_height=False):
|
||||
with gr.Column(scale=3):
|
||||
default_mode = self._default_mode
|
||||
mode = gr.Radio(
|
||||
[mode.value for mode in MODES],
|
||||
label="Mode",
|
||||
value=default_mode,
|
||||
)
|
||||
explanation_mode = gr.Textbox(
|
||||
placeholder=self._get_default_mode_explanation(default_mode),
|
||||
show_label=False,
|
||||
max_lines=3,
|
||||
interactive=False,
|
||||
)
|
||||
upload_button = gr.components.UploadButton(
|
||||
"Upload File(s)",
|
||||
type="filepath",
|
||||
file_count="multiple",
|
||||
size="sm",
|
||||
)
|
||||
ingested_dataset = gr.List(
|
||||
self._list_ingested_files,
|
||||
headers=["File name"],
|
||||
label="Ingested Files",
|
||||
height=235,
|
||||
interactive=False,
|
||||
render=False, # Rendered under the button
|
||||
)
|
||||
upload_button.upload(
|
||||
self._upload_file,
|
||||
inputs=upload_button,
|
||||
outputs=ingested_dataset,
|
||||
)
|
||||
ingested_dataset.change(
|
||||
self._list_ingested_files,
|
||||
outputs=ingested_dataset,
|
||||
)
|
||||
ingested_dataset.render()
|
||||
deselect_file_button = gr.components.Button(
|
||||
"De-select selected file", size="sm", interactive=False
|
||||
)
|
||||
selected_text = gr.components.Textbox(
|
||||
"All files", label="Selected for Query or Deletion", max_lines=1
|
||||
)
|
||||
delete_file_button = gr.components.Button(
|
||||
"🗑️ Delete selected file",
|
||||
size="sm",
|
||||
visible=settings().ui.delete_file_button_enabled,
|
||||
interactive=False,
|
||||
)
|
||||
delete_files_button = gr.components.Button(
|
||||
"⚠️ Delete ALL files",
|
||||
size="sm",
|
||||
visible=settings().ui.delete_all_files_button_enabled,
|
||||
)
|
||||
deselect_file_button.click(
|
||||
self._deselect_selected_file,
|
||||
outputs=[
|
||||
delete_file_button,
|
||||
deselect_file_button,
|
||||
selected_text,
|
||||
],
|
||||
)
|
||||
ingested_dataset.select(
|
||||
fn=self._selected_a_file,
|
||||
outputs=[
|
||||
delete_file_button,
|
||||
deselect_file_button,
|
||||
selected_text,
|
||||
],
|
||||
)
|
||||
delete_file_button.click(
|
||||
self._delete_selected_file,
|
||||
outputs=[
|
||||
ingested_dataset,
|
||||
delete_file_button,
|
||||
deselect_file_button,
|
||||
selected_text,
|
||||
],
|
||||
)
|
||||
delete_files_button.click(
|
||||
self._delete_all_files,
|
||||
outputs=[
|
||||
ingested_dataset,
|
||||
delete_file_button,
|
||||
deselect_file_button,
|
||||
selected_text,
|
||||
],
|
||||
)
|
||||
system_prompt_input = gr.Textbox(
|
||||
placeholder=self._system_prompt,
|
||||
label="System Prompt",
|
||||
lines=2,
|
||||
interactive=True,
|
||||
render=False,
|
||||
)
|
||||
# When mode changes, set default system prompt, and other stuffs
|
||||
mode.change(
|
||||
self._set_current_mode,
|
||||
inputs=mode,
|
||||
outputs=[system_prompt_input, explanation_mode],
|
||||
)
|
||||
# On blur, set system prompt to use in queries
|
||||
system_prompt_input.blur(
|
||||
self._set_system_prompt,
|
||||
inputs=system_prompt_input,
|
||||
)
|
||||
|
||||
def get_model_label() -> str | None:
|
||||
"""Get model label from llm mode setting YAML.
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid 'llm_mode' is encountered.
|
||||
|
||||
Returns:
|
||||
str: The corresponding model label.
|
||||
"""
|
||||
# Get model label from llm mode setting YAML
|
||||
# Labels: local, openai, openailike, sagemaker, mock, ollama
|
||||
config_settings = settings()
|
||||
if config_settings is None:
|
||||
raise ValueError("Settings are not configured.")
|
||||
|
||||
# Get llm_mode from settings
|
||||
llm_mode = config_settings.llm.mode
|
||||
|
||||
# Mapping of 'llm_mode' to corresponding model labels
|
||||
model_mapping = {
|
||||
"llamacpp": config_settings.llamacpp.llm_hf_model_file,
|
||||
"openai": config_settings.openai.model,
|
||||
"openailike": config_settings.openai.model,
|
||||
"azopenai": config_settings.azopenai.llm_model,
|
||||
"sagemaker": config_settings.sagemaker.llm_endpoint_name,
|
||||
"mock": llm_mode,
|
||||
"ollama": config_settings.ollama.llm_model,
|
||||
"gemini": config_settings.gemini.model,
|
||||
}
|
||||
|
||||
if llm_mode not in model_mapping:
|
||||
print(f"Invalid 'llm mode': {llm_mode}")
|
||||
return None
|
||||
|
||||
return model_mapping[llm_mode]
|
||||
|
||||
with gr.Column(scale=7, elem_id="col"):
|
||||
# Determine the model label based on the value of PGPT_PROFILES
|
||||
model_label = get_model_label()
|
||||
if model_label is not None:
|
||||
label_text = (
|
||||
f"LLM: {settings().llm.mode} | Model: {model_label}"
|
||||
)
|
||||
else:
|
||||
label_text = f"LLM: {settings().llm.mode}"
|
||||
|
||||
_ = gr.ChatInterface(
|
||||
self._chat,
|
||||
chatbot=gr.Chatbot(
|
||||
label=label_text,
|
||||
show_copy_button=True,
|
||||
elem_id="chatbot",
|
||||
render=False,
|
||||
avatar_images=(
|
||||
None,
|
||||
AVATAR_BOT,
|
||||
),
|
||||
),
|
||||
additional_inputs=[mode, upload_button, system_prompt_input],
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
avatar_byte = AVATAR_BOT.read_bytes()
|
||||
f_base64 = f"data:image/png;base64,{base64.b64encode(avatar_byte).decode('utf-8')}"
|
||||
gr.HTML(
|
||||
f"<div class='footer'><a class='footer-zylon-link' href='https://zylon.ai/'>Maintained by Zylon <img class='footer-zylon-ico' src='{f_base64}' alt=Zylon></a></div>"
|
||||
)
|
||||
|
||||
return blocks
|
||||
|
||||
def get_ui_blocks(self) -> gr.Blocks:
|
||||
if self._ui_block is None:
|
||||
self._ui_block = self._build_ui_blocks()
|
||||
return self._ui_block
|
||||
|
||||
def mount_in_app(self, app: FastAPI, path: str) -> None:
|
||||
blocks = self.get_ui_blocks()
|
||||
blocks.queue()
|
||||
logger.info("Mounting the gradio UI, at path=%s", path)
|
||||
gr.mount_gradio_app(app, blocks, path=path, favicon_path=AVATAR_BOT)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ui = global_injector.get(PrivateGptUi)
|
||||
_blocks = ui.get_ui_blocks()
|
||||
_blocks.queue()
|
||||
_blocks.launch(debug=False, show_api=False)
|
||||
1
pgpt/private_gpt/utils/__init__.py
Normal file
1
pgpt/private_gpt/utils/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""general utils."""
|
||||
122
pgpt/private_gpt/utils/eta.py
Normal file
122
pgpt/private_gpt/utils/eta.py
Normal file
@ -0,0 +1,122 @@
|
||||
import datetime
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def human_time(*args: Any, **kwargs: Any) -> str:
|
||||
def timedelta_total_seconds(timedelta: datetime.timedelta) -> float:
|
||||
return (
|
||||
timedelta.microseconds
|
||||
+ 0.0
|
||||
+ (timedelta.seconds + timedelta.days * 24 * 3600) * 10**6
|
||||
) / 10**6
|
||||
|
||||
secs = float(timedelta_total_seconds(datetime.timedelta(*args, **kwargs)))
|
||||
# We want (ms) precision below 2 seconds
|
||||
if secs < 2:
|
||||
return f"{secs * 1000}ms"
|
||||
units = [("y", 86400 * 365), ("d", 86400), ("h", 3600), ("m", 60), ("s", 1)]
|
||||
parts = []
|
||||
for unit, mul in units:
|
||||
if secs / mul >= 1 or mul == 1:
|
||||
if mul > 1:
|
||||
n = int(math.floor(secs / mul))
|
||||
secs -= n * mul
|
||||
else:
|
||||
# >2s we drop the (ms) component.
|
||||
n = int(secs)
|
||||
if n:
|
||||
parts.append(f"{n}{unit}")
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def eta(iterator: list[Any]) -> Any:
|
||||
"""Report an ETA after 30s and every 60s thereafter."""
|
||||
total = len(iterator)
|
||||
_eta = ETA(total)
|
||||
_eta.needReport(30)
|
||||
for processed, data in enumerate(iterator, start=1):
|
||||
yield data
|
||||
_eta.update(processed)
|
||||
if _eta.needReport(60):
|
||||
logger.info(f"{processed}/{total} - ETA {_eta.human_time()}")
|
||||
|
||||
|
||||
class ETA:
|
||||
"""Predict how long something will take to complete."""
|
||||
|
||||
def __init__(self, total: int):
|
||||
self.total: int = total # Total expected records.
|
||||
self.rate: float = 0.0 # per second
|
||||
self._timing_data: deque[tuple[float, int]] = deque(maxlen=100)
|
||||
self.secondsLeft: float = 0.0
|
||||
self.nexttime: float = 0.0
|
||||
|
||||
def human_time(self) -> str:
|
||||
if self._calc():
|
||||
return f"{human_time(seconds=self.secondsLeft)} @ {int(self.rate * 60)}/min"
|
||||
return "(computing)"
|
||||
|
||||
def update(self, count: int) -> None:
|
||||
# count should be in the range 0 to self.total
|
||||
assert count > 0
|
||||
assert count <= self.total
|
||||
self._timing_data.append((time.time(), count)) # (X,Y) for pearson
|
||||
|
||||
def needReport(self, whenSecs: int) -> bool:
|
||||
now = time.time()
|
||||
if now > self.nexttime:
|
||||
self.nexttime = now + whenSecs
|
||||
return True
|
||||
return False
|
||||
|
||||
def _calc(self) -> bool:
|
||||
# A sample before a prediction. Need two points to compute slope!
|
||||
if len(self._timing_data) < 3:
|
||||
return False
|
||||
|
||||
# http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient
|
||||
# Calculate means and standard deviations.
|
||||
samples = len(self._timing_data)
|
||||
# column wise sum of the timing tuples to compute their mean.
|
||||
mean_x, mean_y = (
|
||||
sum(i) / samples for i in zip(*self._timing_data, strict=False)
|
||||
)
|
||||
std_x = math.sqrt(
|
||||
sum(pow(i[0] - mean_x, 2) for i in self._timing_data) / (samples - 1)
|
||||
)
|
||||
std_y = math.sqrt(
|
||||
sum(pow(i[1] - mean_y, 2) for i in self._timing_data) / (samples - 1)
|
||||
)
|
||||
|
||||
# Calculate coefficient.
|
||||
sum_xy, sum_sq_v_x, sum_sq_v_y = 0.0, 0.0, 0
|
||||
for x, y in self._timing_data:
|
||||
x -= mean_x
|
||||
y -= mean_y
|
||||
sum_xy += x * y
|
||||
sum_sq_v_x += pow(x, 2)
|
||||
sum_sq_v_y += pow(y, 2)
|
||||
pearson_r = sum_xy / math.sqrt(sum_sq_v_x * sum_sq_v_y)
|
||||
|
||||
# Calculate regression line.
|
||||
# y = mx + b where m is the slope and b is the y-intercept.
|
||||
m = self.rate = pearson_r * (std_y / std_x)
|
||||
y = self.total
|
||||
b = mean_y - m * mean_x
|
||||
x = (y - b) / m
|
||||
|
||||
# Calculate fitted line (transformed/shifted regression line horizontally).
|
||||
fitted_b = self._timing_data[-1][1] - (m * self._timing_data[-1][0])
|
||||
fitted_x = (y - fitted_b) / m
|
||||
_, count = self._timing_data[-1] # adjust last data point progress count
|
||||
adjusted_x = ((fitted_x - x) * (count / self.total)) + x
|
||||
eta_epoch = adjusted_x
|
||||
|
||||
self.secondsLeft = max([eta_epoch - time.time(), 0])
|
||||
return True
|
||||
95
pgpt/private_gpt/utils/ollama.py
Normal file
95
pgpt/private_gpt/utils/ollama.py
Normal file
@ -0,0 +1,95 @@
|
||||
import logging
|
||||
from collections import deque
|
||||
from collections.abc import Iterator, Mapping
|
||||
from typing import Any
|
||||
|
||||
from httpx import ConnectError
|
||||
from tqdm import tqdm # type: ignore
|
||||
|
||||
from private_gpt.utils.retry import retry
|
||||
|
||||
try:
|
||||
from ollama import Client, ResponseError # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Ollama dependencies not found, install with `poetry install --extras llms-ollama or embeddings-ollama`"
|
||||
) from e
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MAX_RETRIES = 5
|
||||
_JITTER = (3.0, 10.0)
|
||||
|
||||
|
||||
@retry(
|
||||
is_async=False,
|
||||
exceptions=(ConnectError, ResponseError),
|
||||
tries=_MAX_RETRIES,
|
||||
jitter=_JITTER,
|
||||
logger=logger,
|
||||
)
|
||||
def check_connection(client: Client) -> bool:
|
||||
try:
|
||||
client.list()
|
||||
return True
|
||||
except (ConnectError, ResponseError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Ollama: {type(e).__name__}: {e!s}")
|
||||
return False
|
||||
|
||||
|
||||
def process_streaming(generator: Iterator[Mapping[str, Any]]) -> None:
|
||||
progress_bars = {}
|
||||
queue = deque() # type: ignore
|
||||
|
||||
def create_progress_bar(dgt: str, total: int) -> Any:
|
||||
return tqdm(
|
||||
total=total, desc=f"Pulling model {dgt[7:17]}...", unit="B", unit_scale=True
|
||||
)
|
||||
|
||||
current_digest = None
|
||||
|
||||
for chunk in generator:
|
||||
digest = chunk.get("digest")
|
||||
completed_size = chunk.get("completed", 0)
|
||||
total_size = chunk.get("total")
|
||||
|
||||
if digest and total_size is not None:
|
||||
if digest not in progress_bars and completed_size > 0:
|
||||
progress_bars[digest] = create_progress_bar(digest, total=total_size)
|
||||
if current_digest is None:
|
||||
current_digest = digest
|
||||
else:
|
||||
queue.append(digest)
|
||||
|
||||
if digest in progress_bars:
|
||||
progress_bar = progress_bars[digest]
|
||||
progress = completed_size - progress_bar.n
|
||||
if completed_size > 0 and total_size >= progress != progress_bar.n:
|
||||
if digest == current_digest:
|
||||
progress_bar.update(progress)
|
||||
if progress_bar.n >= total_size:
|
||||
progress_bar.close()
|
||||
current_digest = queue.popleft() if queue else None
|
||||
else:
|
||||
# Store progress for later update
|
||||
progress_bars[digest].total = total_size
|
||||
progress_bars[digest].n = completed_size
|
||||
|
||||
# Close any remaining progress bars at the end
|
||||
for progress_bar in progress_bars.values():
|
||||
progress_bar.close()
|
||||
|
||||
|
||||
def pull_model(client: Client, model_name: str, raise_error: bool = True) -> None:
|
||||
try:
|
||||
installed_models = [model["name"] for model in client.list().get("models", {})]
|
||||
if model_name not in installed_models:
|
||||
logger.info(f"Pulling model {model_name}. Please wait...")
|
||||
process_streaming(client.pull(model_name, stream=True))
|
||||
logger.info(f"Model {model_name} pulled successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to pull model {model_name}: {e!s}")
|
||||
if raise_error:
|
||||
raise e
|
||||
31
pgpt/private_gpt/utils/retry.py
Normal file
31
pgpt/private_gpt/utils/retry.py
Normal file
@ -0,0 +1,31 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from retry_async import retry as retry_untyped # type: ignore
|
||||
|
||||
retry_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def retry(
|
||||
exceptions: Any = Exception,
|
||||
*,
|
||||
is_async: bool = False,
|
||||
tries: int = -1,
|
||||
delay: float = 0,
|
||||
max_delay: float | None = None,
|
||||
backoff: float = 1,
|
||||
jitter: float | tuple[float, float] = 0,
|
||||
logger: logging.Logger = retry_logger,
|
||||
) -> Callable[..., Any]:
|
||||
wrapped = retry_untyped(
|
||||
exceptions=exceptions,
|
||||
is_async=is_async,
|
||||
tries=tries,
|
||||
delay=delay,
|
||||
max_delay=max_delay,
|
||||
backoff=backoff,
|
||||
jitter=jitter,
|
||||
logger=logger,
|
||||
)
|
||||
return wrapped # type: ignore
|
||||
5
pgpt/private_gpt/utils/typing.py
Normal file
5
pgpt/private_gpt/utils/typing.py
Normal file
@ -0,0 +1,5 @@
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
K = TypeVar("K")
|
||||
V = TypeVar("V")
|
||||
197
pgpt/pyproject.toml
Normal file
197
pgpt/pyproject.toml
Normal file
@ -0,0 +1,197 @@
|
||||
[tool.poetry]
|
||||
name = "private-gpt"
|
||||
version = "0.6.2"
|
||||
description = "Private GPT"
|
||||
authors = ["Zylon <hi@zylon.ai>"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.11,<3.12"
|
||||
# PrivateGPT
|
||||
fastapi = { extras = ["all"], version = "^0.115.0" }
|
||||
python-multipart = "^0.0.10"
|
||||
injector = "^0.22.0"
|
||||
pyyaml = "^6.0.2"
|
||||
watchdog = "^4.0.1"
|
||||
transformers = "^4.44.2"
|
||||
docx2txt = "^0.8"
|
||||
cryptography = "^3.1"
|
||||
# LlamaIndex core libs
|
||||
llama-index-core = ">=0.11.2,<0.12.0"
|
||||
llama-index-readers-file = "*"
|
||||
# Optional LlamaIndex integration libs
|
||||
llama-index-llms-llama-cpp = {version = "*", optional = true}
|
||||
llama-index-llms-openai = {version ="*", optional = true}
|
||||
llama-index-llms-openai-like = {version ="*", optional = true}
|
||||
llama-index-llms-ollama = {version ="*", optional = true}
|
||||
llama-index-llms-azure-openai = {version ="*", optional = true}
|
||||
llama-index-llms-gemini = {version ="*", optional = true}
|
||||
llama-index-embeddings-ollama = {version ="*", optional = true}
|
||||
llama-index-embeddings-huggingface = {version ="*", optional = true}
|
||||
llama-index-embeddings-openai = {version ="*", optional = true}
|
||||
llama-index-embeddings-azure-openai = {version ="*", optional = true}
|
||||
llama-index-embeddings-gemini = {version ="*", optional = true}
|
||||
llama-index-embeddings-mistralai = {version ="*", optional = true}
|
||||
llama-index-vector-stores-qdrant = {version ="*", optional = true}
|
||||
llama-index-vector-stores-milvus = {version ="*", optional = true}
|
||||
llama-index-vector-stores-chroma = {version ="*", optional = true}
|
||||
llama-index-vector-stores-postgres = {version ="*", optional = true}
|
||||
llama-index-vector-stores-clickhouse = {version ="*", optional = true}
|
||||
llama-index-storage-docstore-postgres = {version ="*", optional = true}
|
||||
llama-index-storage-index-store-postgres = {version ="*", optional = true}
|
||||
# Postgres
|
||||
psycopg2-binary = {version ="^2.9.9", optional = true}
|
||||
asyncpg = {version="^0.29.0", optional = true}
|
||||
|
||||
# ClickHouse
|
||||
clickhouse-connect = {version = "^0.7.19", optional = true}
|
||||
|
||||
# Optional Sagemaker dependency
|
||||
boto3 = {version ="^1.35.26", optional = true}
|
||||
|
||||
# Optional Reranker dependencies
|
||||
torch = {version ="^2.4.1", optional = true}
|
||||
sentence-transformers = {version ="^3.1.1", optional = true}
|
||||
|
||||
# Optional UI
|
||||
gradio = {version ="^4.44.0", optional = true}
|
||||
ffmpy = {version ="^0.4.0", optional = true}
|
||||
|
||||
# Optional HF Transformers
|
||||
einops = {version = "^0.8.0", optional = true}
|
||||
retry-async = "^0.1.4"
|
||||
|
||||
[tool.poetry.extras]
|
||||
ui = ["gradio", "ffmpy"]
|
||||
llms-llama-cpp = ["llama-index-llms-llama-cpp"]
|
||||
llms-openai = ["llama-index-llms-openai"]
|
||||
llms-openai-like = ["llama-index-llms-openai-like"]
|
||||
llms-ollama = ["llama-index-llms-ollama"]
|
||||
llms-sagemaker = ["boto3"]
|
||||
llms-azopenai = ["llama-index-llms-azure-openai"]
|
||||
llms-gemini = ["llama-index-llms-gemini"]
|
||||
embeddings-ollama = ["llama-index-embeddings-ollama"]
|
||||
embeddings-huggingface = ["llama-index-embeddings-huggingface", "einops"]
|
||||
embeddings-openai = ["llama-index-embeddings-openai"]
|
||||
embeddings-sagemaker = ["boto3"]
|
||||
embeddings-azopenai = ["llama-index-embeddings-azure-openai"]
|
||||
embeddings-gemini = ["llama-index-embeddings-gemini"]
|
||||
embeddings-mistral = ["llama-index-embeddings-mistralai"]
|
||||
vector-stores-qdrant = ["llama-index-vector-stores-qdrant"]
|
||||
vector-stores-clickhouse = ["llama-index-vector-stores-clickhouse", "clickhouse_connect"]
|
||||
vector-stores-chroma = ["llama-index-vector-stores-chroma"]
|
||||
vector-stores-postgres = ["llama-index-vector-stores-postgres"]
|
||||
vector-stores-milvus = ["llama-index-vector-stores-milvus"]
|
||||
storage-nodestore-postgres = ["llama-index-storage-docstore-postgres","llama-index-storage-index-store-postgres","psycopg2-binary","asyncpg"]
|
||||
rerank-sentence-transformers = ["torch", "sentence-transformers"]
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^24"
|
||||
mypy = "^1.11"
|
||||
pre-commit = "^3"
|
||||
pytest = "^8"
|
||||
pytest-cov = "^5"
|
||||
ruff = "^0"
|
||||
pytest-asyncio = "^0.24.0"
|
||||
types-pyyaml = "^6.0.12.20240917"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
# Packages configs
|
||||
|
||||
## coverage
|
||||
|
||||
[tool.coverage.run]
|
||||
branch = true
|
||||
|
||||
[tool.coverage.report]
|
||||
skip_empty = true
|
||||
precision = 2
|
||||
|
||||
## black
|
||||
|
||||
[tool.black]
|
||||
target-version = ['py311']
|
||||
|
||||
## ruff
|
||||
# Recommended ruff config for now, to be updated as we go along.
|
||||
[tool.ruff]
|
||||
target-version = 'py311'
|
||||
|
||||
# See all rules at https://beta.ruff.rs/docs/rules/
|
||||
lint.select = [
|
||||
"E", # pycodestyle
|
||||
"W", # pycodestyle
|
||||
"F", # Pyflakes
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
"D", # pydocstyle
|
||||
"I", # isort
|
||||
"SIM", # flake8-simplify
|
||||
"TCH", # flake8-type-checking
|
||||
"TID", # flake8-tidy-imports
|
||||
"Q", # flake8-quotes
|
||||
"UP", # pyupgrade
|
||||
"PT", # flake8-pytest-style
|
||||
"RUF", # Ruff-specific rules
|
||||
]
|
||||
|
||||
lint.ignore = [
|
||||
"E501", # "Line too long"
|
||||
# -> line length already regulated by black
|
||||
"PT011", # "pytest.raises() should specify expected exception"
|
||||
# -> would imply to update tests every time you update exception message
|
||||
"SIM102", # "Use a single `if` statement instead of nested `if` statements"
|
||||
# -> too restrictive,
|
||||
"D100",
|
||||
"D101",
|
||||
"D102",
|
||||
"D103",
|
||||
"D104",
|
||||
"D105",
|
||||
"D106",
|
||||
"D107"
|
||||
# -> "Missing docstring in public function too restrictive"
|
||||
]
|
||||
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
# Automatically disable rules that are incompatible with Google docstring convention
|
||||
convention = "google"
|
||||
|
||||
[tool.ruff.lint.pycodestyle]
|
||||
max-doc-length = 88
|
||||
|
||||
[tool.ruff.lint.flake8-tidy-imports]
|
||||
ban-relative-imports = "all"
|
||||
|
||||
[tool.ruff.lint.flake8-type-checking]
|
||||
strict = true
|
||||
runtime-evaluated-base-classes = ["pydantic.BaseModel"]
|
||||
# Pydantic needs to be able to evaluate types at runtime
|
||||
# see https://pypi.org/project/flake8-type-checking/ for flake8-type-checking documentation
|
||||
# see https://beta.ruff.rs/docs/settings/#flake8-type-checking-runtime-evaluated-base-classes for ruff documentation
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
# Allow missing docstrings for tests
|
||||
"tests/**/*.py" = ["D1"]
|
||||
|
||||
## mypy
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.11"
|
||||
strict = true
|
||||
check_untyped_defs = false
|
||||
explicit_package_bases = true
|
||||
warn_unused_ignores = false
|
||||
exclude = ["tests"]
|
||||
|
||||
[tool.mypy-llama-index]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
addopts = [
|
||||
"--import-mode=importlib",
|
||||
]
|
||||
42
pgpt/settings-docker.yaml
Normal file
42
pgpt/settings-docker.yaml
Normal file
@ -0,0 +1,42 @@
|
||||
server:
|
||||
env_name: ${APP_ENV:prod}
|
||||
port: ${PORT:8080}
|
||||
|
||||
llm:
|
||||
mode: ${PGPT_MODE:mock}
|
||||
|
||||
embedding:
|
||||
mode: ${PGPT_EMBED_MODE:mock}
|
||||
|
||||
llamacpp:
|
||||
llm_hf_repo_id: ${PGPT_HF_REPO_ID:lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF}
|
||||
llm_hf_model_file: ${PGPT_HF_MODEL_FILE:Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf}
|
||||
|
||||
huggingface:
|
||||
embedding_hf_model_name: ${PGPT_EMBEDDING_HF_MODEL_NAME:nomic-ai/nomic-embed-text-v1.5}
|
||||
|
||||
sagemaker:
|
||||
llm_endpoint_name: ${PGPT_SAGEMAKER_LLM_ENDPOINT_NAME:}
|
||||
embedding_endpoint_name: ${PGPT_SAGEMAKER_EMBEDDING_ENDPOINT_NAME:}
|
||||
|
||||
ollama:
|
||||
#llm_model: ${PGPT_OLLAMA_LLM_MODEL:llama3:8b-instruct-q4_K_M}
|
||||
#llm_model: llama3:8b-instruct-q4_K_M
|
||||
llm_model: qwen3:14b
|
||||
context_window: 5000
|
||||
# llm_model: gemma3:12b
|
||||
# llm_model: deepseek-r1:14b
|
||||
embedding_model: ${PGPT_OLLAMA_EMBEDDING_MODEL:mxbai-embed-large}
|
||||
api_base: ${PGPT_OLLAMA_API_BASE:http://ollama:11434}
|
||||
embedding_api_base: ${PGPT_OLLAMA_EMBEDDING_API_BASE:http://ollama:11434}
|
||||
tfs_z: ${PGPT_OLLAMA_TFS_Z:1.0}
|
||||
top_k: ${PGPT_OLLAMA_TOP_K:40}
|
||||
top_p: ${PGPT_OLLAMA_TOP_P:0.9}
|
||||
repeat_last_n: ${PGPT_OLLAMA_REPEAT_LAST_N:64}
|
||||
repeat_penalty: ${PGPT_OLLAMA_REPEAT_PENALTY:1.2}
|
||||
request_timeout: ${PGPT_OLLAMA_REQUEST_TIMEOUT:6000.0}
|
||||
autopull_models: ${PGPT_OLLAMA_AUTOPULL_MODELS:true}
|
||||
|
||||
ui:
|
||||
enabled: true
|
||||
path: /
|
||||
31
pgpt/settings-ollama.yaml
Normal file
31
pgpt/settings-ollama.yaml
Normal file
@ -0,0 +1,31 @@
|
||||
server:
|
||||
env_name: ${APP_ENV:ollama}
|
||||
|
||||
llm:
|
||||
mode: ollama
|
||||
max_new_tokens: 512
|
||||
context_window: 3900
|
||||
temperature: 0.1 #The temperature of the model. Increasing the temperature will make the model answer more creatively. A value of 0.1 would be more factual. (Default: 0.1)
|
||||
|
||||
embedding:
|
||||
mode: ollama
|
||||
|
||||
ollama:
|
||||
llm_model: qwen3:14b
|
||||
context_window: 5000
|
||||
embedding_model: mxbai-embed-large
|
||||
api_base: http://ollama:11434
|
||||
embedding_api_base: http://ollama:11434 # change if your embedding model runs on another ollama
|
||||
keep_alive: 5m
|
||||
tfs_z: 2.0 # Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting.
|
||||
top_k: 40 # Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)
|
||||
top_p: 0.9 # Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)
|
||||
repeat_last_n: -1 # Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)
|
||||
repeat_penalty: 1.5 # Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)
|
||||
request_timeout: 120.0 # Time elapsed until ollama times out the request. Default is 120s. Format is float.
|
||||
|
||||
vectorstore:
|
||||
database: qdrant
|
||||
|
||||
qdrant:
|
||||
path: local_data/private_gpt/qdrant
|
||||
158
pgpt/settings.yaml
Normal file
158
pgpt/settings.yaml
Normal file
@ -0,0 +1,158 @@
|
||||
# The default configuration file.
|
||||
# More information about configuration can be found in the documentation: https://docs.privategpt.dev/
|
||||
# Syntax in `private_pgt/settings/settings.py`
|
||||
server:
|
||||
env_name: ${APP_ENV:prod}
|
||||
port: ${PORT:8001}
|
||||
cors:
|
||||
enabled: true
|
||||
allow_origins: ["*"]
|
||||
allow_methods: ["*"]
|
||||
allow_headers: ["*"]
|
||||
auth:
|
||||
enabled: false
|
||||
# python -c 'import base64; print("Basic " + base64.b64encode("secret:key".encode()).decode())'
|
||||
# 'secret' is the username and 'key' is the password for basic auth by default
|
||||
# If the auth is enabled, this value must be set in the "Authorization" header of the request.
|
||||
secret: "Basic c2VjcmV0OmtleQ=="
|
||||
|
||||
#data:
|
||||
# local_ingestion:
|
||||
# enabled: ${LOCAL_INGESTION_ENABLED:false}
|
||||
# allow_ingest_from: ["*"]
|
||||
# local_data_folder: local_data/Corpus/private_gpt
|
||||
data:
|
||||
local_ingestion:
|
||||
enabled: true
|
||||
allow_ingest_from: ["*"]
|
||||
local_data_folder: local_data/private_gpt
|
||||
ui:
|
||||
enabled: true
|
||||
path: /
|
||||
# "RAG", "Search", "Basic", or "Summarize"
|
||||
default_mode: "RAG"
|
||||
default_chat_system_prompt: >
|
||||
Vous ne devez répondre aux questions qu'à partir des données du contexte.
|
||||
Si vous connaissez la réponse, mais qu'elle n'est pas basée sur le contexte,
|
||||
faites en suggestion et non une réponse. Annoncez le clairement.
|
||||
default_query_system_prompt: >
|
||||
Vous ne devez répondre aux questions qu'à partir des données du contexte.
|
||||
Si vous connaissez la réponse, mais qu'elle n'est pas basée sur le contexte,
|
||||
faites en suggestion et non une réponse. Annoncez le clairement.
|
||||
default_summarization_system_prompt: >
|
||||
Vous ne devez répondre aux questions qu'à partir des données du contexte.
|
||||
Si vous connaissez la réponse, mais qu'elle n'est pas basée sur le contexte,
|
||||
faites en suggestion et non une réponse. Annoncez le clairement.
|
||||
delete_file_button_enabled: true
|
||||
delete_all_files_button_enabled: true
|
||||
|
||||
llm:
|
||||
mode: llamacpp
|
||||
prompt_style: "llama3"
|
||||
# Should be matching the selected model
|
||||
max_new_tokens: 512
|
||||
context_window: 10000
|
||||
# Select your tokenizer. Llama-index tokenizer is the default.
|
||||
# tokenizer: meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
temperature: 0.1 # The temperature of the model. Increasing the temperature will make the model answer more creatively. A value of 0.1 would be more factual. (Default: 0.1)
|
||||
|
||||
rag:
|
||||
similarity_top_k: 15
|
||||
#This value controls how many "top" documents the RAG returns to use in the context.
|
||||
# similarity_value: 0.9
|
||||
#This value is disabled by default. If you enable this settings, the RAG will only use articles that meet a certain percentage score.
|
||||
rerank:
|
||||
enabled: false
|
||||
model: cross-encoder/ms-marco-MiniLM-L-2-v2
|
||||
top_n: 1
|
||||
|
||||
summarize:
|
||||
use_async: false
|
||||
|
||||
clickhouse:
|
||||
host: localhost
|
||||
port: 8443
|
||||
username: admin
|
||||
password: clickhouse
|
||||
database: embeddings
|
||||
|
||||
llamacpp:
|
||||
llm_hf_repo_id: lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF
|
||||
llm_hf_model_file: Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf
|
||||
tfs_z: 2.0 # Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting
|
||||
top_k: 10 # Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)
|
||||
top_p: 0.3 # Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)
|
||||
repeat_penalty: 1.1 # Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)
|
||||
|
||||
#embedding:
|
||||
# # Should be matching the value above in most cases
|
||||
# mode: huggingface
|
||||
# ingest_mode: simple
|
||||
# embed_dim: 768 # 768 is for nomic-ai/nomic-embed-text-v1.5
|
||||
|
||||
embedding:
|
||||
mode: ollama
|
||||
model: mxbai-embed-large
|
||||
ingest_mode: simple
|
||||
embed_dim: 1536
|
||||
|
||||
huggingface:
|
||||
embedding_hf_model_name: nomic-ai/nomic-embed-text-v1.5
|
||||
access_token: ${HF_TOKEN:}
|
||||
# Warning: Enabling this option will allow the model to download and execute code from the internet.
|
||||
# Nomic AI requires this option to be enabled to use the model, be aware if you are using a different model.
|
||||
trust_remote_code: true
|
||||
|
||||
nodestore:
|
||||
database: simple
|
||||
|
||||
milvus:
|
||||
uri: local_data/private_gpt/milvus/milvus_local.db
|
||||
collection_name: milvus_db
|
||||
overwrite: false
|
||||
|
||||
vectorstore:
|
||||
database: qdrant
|
||||
|
||||
qdrant:
|
||||
path: local_data/private_gpt/qdrant
|
||||
|
||||
postgres:
|
||||
host: localhost
|
||||
port: 5432
|
||||
database: postgres
|
||||
user: postgres
|
||||
password: postgres
|
||||
schema_name: private_gpt
|
||||
|
||||
sagemaker:
|
||||
llm_endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140
|
||||
embedding_endpoint_name: huggingface-pytorch-inference-2023-11-03-07-41-36-479
|
||||
|
||||
openai:
|
||||
api_key: ${OPENAI_API_KEY:}
|
||||
model: gpt-3.5-turbo
|
||||
embedding_api_key: ${OPENAI_API_KEY:}
|
||||
|
||||
ollama:
|
||||
llm_model: qwen3:14b
|
||||
embedding_model: mxbai-embed-large
|
||||
api_base: http://localhost:11434
|
||||
embedding_api_base: http://localhost:11434 # change if your embedding model runs on another ollama
|
||||
keep_alive: 5m
|
||||
request_timeout: 1200.0
|
||||
autopull_models: true
|
||||
|
||||
azopenai:
|
||||
api_key: ${AZ_OPENAI_API_KEY:}
|
||||
azure_endpoint: ${AZ_OPENAI_ENDPOINT:}
|
||||
embedding_deployment_name: ${AZ_OPENAI_EMBEDDING_DEPLOYMENT_NAME:}
|
||||
llm_deployment_name: ${AZ_OPENAI_LLM_DEPLOYMENT_NAME:}
|
||||
api_version: "2023-05-15"
|
||||
embedding_model: text-embedding-ada-002
|
||||
llm_model: gpt-35-turbo
|
||||
|
||||
gemini:
|
||||
api_key: ${GOOGLE_API_KEY:}
|
||||
model: models/gemini-pro
|
||||
embedding_model: models/embedding-001
|
||||
96
pgpt/start-private-gpt-custom.sh
Executable file
96
pgpt/start-private-gpt-custom.sh
Executable file
@ -0,0 +1,96 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script de démarrage pour Private-GPT avec Ollama personnalisé
|
||||
# Permet de lancer facilement private-gpt avec des modèles d'IA personnalisés
|
||||
|
||||
set -e
|
||||
|
||||
# Répertoire courant où se trouve le script
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# Couleurs pour les messages
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
RED='\033[0;31m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Afficher les modèles disponibles
|
||||
list_models() {
|
||||
echo -e "${BLUE}Modèles Ollama disponibles:${NC}"
|
||||
docker exec -it pgpt-ollama-cpu-1 ollama list
|
||||
}
|
||||
|
||||
# Télécharger un nouveau modèle
|
||||
download_model() {
|
||||
if [ -z "$1" ]; then
|
||||
echo -e "${RED}Erreur: Veuillez spécifier un nom de modèle${NC}"
|
||||
echo "Exemple: $0 download llama3"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo -e "${YELLOW}Téléchargement du modèle $1...${NC}"
|
||||
docker exec -it pgpt-ollama-cpu-1 ollama pull "$1"
|
||||
echo -e "${GREEN}Modèle $1 téléchargé avec succès${NC}"
|
||||
}
|
||||
|
||||
# Arrêter les services
|
||||
stop_services() {
|
||||
echo -e "${YELLOW}Arrêt des services Private-GPT...${NC}"
|
||||
docker compose -f docker-compose.yaml down
|
||||
echo -e "${GREEN}Services arrêtés${NC}"
|
||||
}
|
||||
|
||||
# Démarrer les services
|
||||
start_services() {
|
||||
local profile="$1"
|
||||
|
||||
if [ -z "$profile" ]; then
|
||||
profile="ollama-cpu"
|
||||
fi
|
||||
|
||||
echo -e "${YELLOW}Démarrage des services Private-GPT avec le profil $profile...${NC}"
|
||||
docker compose -f docker-compose.yaml --profile "$profile" up -d
|
||||
echo -e "${GREEN}Services démarrés avec succès${NC}"
|
||||
echo -e "${BLUE}Interface web disponible sur: ${GREEN}http://localhost:8001${NC}"
|
||||
}
|
||||
|
||||
# Afficher l'aide
|
||||
show_help() {
|
||||
echo -e "${BLUE}Script de démarrage pour Private-GPT avec Ollama${NC}"
|
||||
echo ""
|
||||
echo -e "Usage: $0 [commande]"
|
||||
echo ""
|
||||
echo "Commandes:"
|
||||
echo " start [profile] Démarrer les services (profile: ollama-cpu (défaut), ollama-cuda)"
|
||||
echo " stop Arrêter les services"
|
||||
echo " list Lister les modèles Ollama disponibles"
|
||||
echo " download <modèle> Télécharger un nouveau modèle Ollama"
|
||||
echo " help Afficher cette aide"
|
||||
echo ""
|
||||
echo "Exemples:"
|
||||
echo " $0 start Démarrer avec CPU (par défaut)"
|
||||
echo " $0 start ollama-cuda Démarrer avec GPU (CUDA)"
|
||||
echo " $0 download mistral Télécharger le modèle Mistral"
|
||||
echo ""
|
||||
}
|
||||
|
||||
# Traitement des commandes
|
||||
case "$1" in
|
||||
start)
|
||||
start_services "$2"
|
||||
;;
|
||||
stop)
|
||||
stop_services
|
||||
;;
|
||||
list)
|
||||
list_models
|
||||
;;
|
||||
download)
|
||||
download_model "$2"
|
||||
;;
|
||||
*)
|
||||
show_help
|
||||
;;
|
||||
esac
|
||||
1
pgpt/version.txt
Normal file
1
pgpt/version.txt
Normal file
@ -0,0 +1 @@
|
||||
0.6.2
|
||||
@ -1421,7 +1421,7 @@ digraph Hierarchie_Composants_Electroniques_Simplifiee {
|
||||
subgraph cluster_ProcedeDUV {
|
||||
label="ProcedeDUV";
|
||||
fillcolor="#ffd699";
|
||||
ProcedeDUV [fillcolor="#ffd699", label="Photolitographie DUV", niveau="1000"];
|
||||
ProcedeDUV [fillcolor="#ffd699", label="Procédé DUV", niveau="1000"];
|
||||
|
||||
// Relations sortantes
|
||||
ProcedeDUV -> Assemblage_ProcedeDUV [];
|
||||
@ -1476,7 +1476,7 @@ digraph Hierarchie_Composants_Electroniques_Simplifiee {
|
||||
subgraph cluster_ProcedeEUV {
|
||||
label="ProcedeEUV";
|
||||
fillcolor="#ffd699";
|
||||
ProcedeEUV [fillcolor="#ffd699", label="Photolitographie EUV", niveau="1000"];
|
||||
ProcedeEUV [fillcolor="#ffd699", label="Procédé EUV", niveau="1000"];
|
||||
|
||||
// Relations sortantes
|
||||
ProcedeEUV -> Assemblage_ProcedeEUV [];
|
||||
@ -4250,7 +4250,7 @@ digraph Hierarchie_Composants_Electroniques_Simplifiee {
|
||||
subgraph cluster_CreusetGraphite {
|
||||
label="CreusetGraphite";
|
||||
fillcolor="#ffd699";
|
||||
CreusetGraphite [fillcolor="#ffd699", label="Creuset en graphite - Pour fusion de métaux", niveau="1001"];
|
||||
CreusetGraphite [fillcolor="#ffd699", label="Creuset graphite", niveau="1001"];
|
||||
|
||||
// Relations sortantes
|
||||
CreusetGraphite -> Graphite [cout="0.5", delai="0.3", ics="0.44", technique="0.5"];
|
||||
@ -4259,7 +4259,7 @@ digraph Hierarchie_Composants_Electroniques_Simplifiee {
|
||||
subgraph cluster_CreusetQuartz {
|
||||
label="CreusetQuartz";
|
||||
fillcolor="#ffd699";
|
||||
CreusetQuartz [fillcolor="#ffd699", label="Creuset en quartz - Pour silicium monocristallin", niveau="1001"];
|
||||
CreusetQuartz [fillcolor="#ffd699", label="Creuset quartz", niveau="1001"];
|
||||
|
||||
// Relations sortantes
|
||||
CreusetQuartz -> Verre [];
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user