312 lines
11 KiB
Python
312 lines
11 KiB
Python
from typing import List, Optional, Tuple, Dict, Set
|
|
import streamlit as st
|
|
import networkx as nx
|
|
from utils.translations import _
|
|
from utils.widgets import html_expander
|
|
|
|
from utils.graph_utils import (
|
|
extraire_chemins_depuis,
|
|
extraire_chemins_vers
|
|
)
|
|
|
|
from batch_ia.batch_utils import soumettre_batch, statut_utilisateur, nettoyage_post_telechargement
|
|
|
|
niveau_labels = {
|
|
0: "Produit final",
|
|
1: "Composant",
|
|
2: "Minerai",
|
|
10: "Opération",
|
|
11: "Pays d'opération",
|
|
12: "Acteur d'opération",
|
|
99: "Pays géographique"
|
|
}
|
|
|
|
inverse_niveau_labels = {v: k for k, v in niveau_labels.items()}
|
|
|
|
|
|
def preparer_graphe(
|
|
G: nx.DiGraph,
|
|
) -> Tuple[nx.DiGraph, Dict[str, int]]:
|
|
"""
|
|
Nettoie et prépare le graphe pour l'analyse.
|
|
|
|
Args:
|
|
G (nx.DiGraph): Le graphe NetworkX contenant les données des produits.
|
|
|
|
Returns:
|
|
Tuple[nx.DiGraph, Dict[str, int]]: Un tuple contenant :
|
|
- Le graphe NetworkX proprement configuré
|
|
- Un dictionnaire des niveaux associés aux nœuds
|
|
"""
|
|
niveaux_temp = {
|
|
node: int(str(attrs.get("niveau")).strip('"'))
|
|
for node, attrs in G.nodes(data=True)
|
|
if attrs.get("niveau") and str(attrs.get("niveau")).strip('"').isdigit()
|
|
}
|
|
G.remove_nodes_from([n for n in G.nodes() if n not in niveaux_temp])
|
|
G.remove_nodes_from(
|
|
[n for n in G.nodes() if niveaux_temp.get(n) == 10 and 'Reserves' in n])
|
|
return G, niveaux_temp
|
|
|
|
|
|
def selectionner_minerais(
|
|
G: nx.DiGraph,
|
|
) -> Optional[List[str]]:
|
|
"""
|
|
Interface pour sélectionner les minerais si nécessaire.
|
|
|
|
Args:
|
|
G (nx.DiGraph): Le graphe NetworkX contenant les données des produits.
|
|
|
|
Returns:
|
|
Optional[List[str]]: La liste des minerais si une sélection a été effectuée,
|
|
- None sinon
|
|
"""
|
|
minerais_selection = None
|
|
|
|
st.markdown(f"## {str(_('pages.ia_nalyse.select_minerals'))}")
|
|
# Tous les nœuds de niveau 2 (minerai)
|
|
minerais_nodes = sorted([
|
|
n for n, d in G.nodes(data=True)
|
|
if d.get("niveau") and int(str(d.get("niveau")).strip('"')) == 2
|
|
])
|
|
|
|
minerais_selection = st.multiselect(
|
|
str(_("pages.ia_nalyse.filter_by_minerals")),
|
|
minerais_nodes,
|
|
key="analyse_minerais"
|
|
)
|
|
|
|
return minerais_selection
|
|
|
|
|
|
def selectionner_noeuds(
|
|
G: nx.DiGraph,
|
|
niveaux_temp: Dict[str, int],
|
|
niveau_depart: int,
|
|
) -> Tuple[Optional[List[str]], List[str]]:
|
|
"""
|
|
Interface pour sélectionner les nœuds spécifiques de départ et d'arrivée.
|
|
|
|
Args:
|
|
G (nx.DiGraph): Le graphe NetworkX contenant les données des produits.
|
|
niveaux_temp (Dict[str, int]): Dictionnaire associant chaque nœud à son niveau.
|
|
niveau_depart (int): Le niveau de départ sélectionné.
|
|
|
|
Returns:
|
|
Tuple[Optional[List[str]], List[str]]: Un tuple contenant :
|
|
- La liste des nœuds de départ si une sélection a été effectuée,
|
|
- None sinon
|
|
- La liste des nœuds d'arrivée
|
|
"""
|
|
st.markdown("---")
|
|
st.markdown(f"## {str(_('pages.ia_nalyse.fine_selection'))}")
|
|
|
|
depart_nodes = [n for n in G.nodes() if niveaux_temp.get(n) == niveau_depart]
|
|
noeuds_arrivee = [n for n in G.nodes() if niveaux_temp.get(n) == 99]
|
|
|
|
noeuds_depart = st.multiselect(str(_("pages.ia_nalyse.filter_start_nodes")),
|
|
sorted(depart_nodes),
|
|
key="analyse_noeuds_depart")
|
|
|
|
noeuds_depart = noeuds_depart if noeuds_depart else None
|
|
|
|
return noeuds_depart, noeuds_arrivee
|
|
|
|
def extraire_niveaux(
|
|
G: nx.DiGraph,
|
|
) -> Dict[str, int]:
|
|
"""
|
|
Extrait les niveaux des nœuds du graphe.
|
|
|
|
Args:
|
|
G (nx.DiGraph): Le graphe NetworkX contenant les données des produits.
|
|
|
|
Returns:
|
|
Dict[str, int]: Un dictionnaire associant chaque nœud à son niveau.
|
|
"""
|
|
niveaux = {}
|
|
for node, attrs in G.nodes(data=True):
|
|
niveau_str = attrs.get("niveau")
|
|
if niveau_str:
|
|
niveaux[node] = int(str(niveau_str).strip('"'))
|
|
return niveaux
|
|
|
|
def extraire_chemins_selon_criteres(
|
|
G: nx.DiGraph,
|
|
niveaux: Dict[str, int],
|
|
niveau_depart: int,
|
|
noeuds_depart: Optional[List[str]],
|
|
noeuds_arrivee: Optional[List[str]],
|
|
minerais: Optional[List[str]],
|
|
) -> List[Tuple[str, ...]]:
|
|
"""
|
|
Extrait les chemins selon les critères spécifiés.
|
|
|
|
Args:
|
|
G (nx.DiGraph): Le graphe NetworkX contenant les données des produits.
|
|
niveaux (Dict[str, int]): Dictionnaire associant chaque nœud à son niveau.
|
|
niveau_depart (int): Le niveau de départ sélectionné.
|
|
noeuds_depart (Optional[List[str]]): Les nœuds de départ si sélectionnés.
|
|
noeuds_arrivee (Optional[List[str]]): Les nœuds d'arrivée si sélectionnés.
|
|
minerais (Optional[List[str]]): Les minerais si sélectionnés.
|
|
|
|
Returns:
|
|
List[Tuple[str, ...]]: Liste des chemins valides selon les critères spécifiés.
|
|
"""
|
|
chemins = []
|
|
if noeuds_depart and noeuds_arrivee:
|
|
for nd in noeuds_depart:
|
|
for na in noeuds_arrivee:
|
|
tous_chemins = extraire_chemins_depuis(G, nd)
|
|
chemins.extend([chemin for chemin in tous_chemins if na in chemin])
|
|
elif noeuds_depart:
|
|
for nd in noeuds_depart:
|
|
chemins.extend(extraire_chemins_depuis(G, nd))
|
|
elif noeuds_arrivee:
|
|
for na in noeuds_arrivee:
|
|
chemins.extend(extraire_chemins_vers(G, na, niveau_depart))
|
|
else:
|
|
sources_depart = [n for n in G.nodes() if niveaux.get(n) == niveau_depart]
|
|
for nd in sources_depart:
|
|
chemins.extend(extraire_chemins_depuis(G, nd))
|
|
|
|
if minerais:
|
|
chemins = [chemin for chemin in chemins if any(n in minerais for n in chemin)]
|
|
|
|
return chemins
|
|
|
|
def exporter_graphe_filtre(
|
|
G: nx.DiGraph,
|
|
liens_chemins: Set[Tuple[str, str]],
|
|
) -> nx.DiGraph|None:
|
|
"""
|
|
Gère l'export du graphe filtré au format DOT.
|
|
|
|
Args:
|
|
G (nx.DiGraph): Le graphe NetworkX contenant les données des produits.
|
|
liens_chemins (Set[Tuple[str, str]]): Ensemble des paires de nœuds liés.
|
|
|
|
Returns:
|
|
nx.DiGraph: le graphe exporté
|
|
- Sinon aucun résultat (None)
|
|
"""
|
|
from utils.persistance import get_champ_statut
|
|
if get_champ_statut("login") == "" or not liens_chemins:
|
|
return
|
|
|
|
G_export = nx.DiGraph()
|
|
for u, v in liens_chemins:
|
|
G_export.add_node(u, **G.nodes[u])
|
|
G_export.add_node(v, **G.nodes[v])
|
|
data = G.get_edge_data(u, v)
|
|
if isinstance(data, dict) and all(isinstance(k, int) for k in data):
|
|
G_export.add_edge(u, v, **data[0])
|
|
elif isinstance(data, dict):
|
|
G_export.add_edge(u, v, **data)
|
|
else:
|
|
G_export.add_edge(u, v)
|
|
|
|
return(G_export)
|
|
|
|
def extraire_liens_filtres(
|
|
chemins: List[Tuple[str, ...]],
|
|
niveaux: Dict[str, int],
|
|
niveau_depart: int,
|
|
niveau_arrivee: int,
|
|
niveaux_speciaux: List[int]
|
|
) -> Set[Tuple[str, str]]:
|
|
"""
|
|
Extrait les liens des chemins en respectant les niveaux.
|
|
|
|
Args:
|
|
chemins (List[Tuple[str, ...]]): Liste initiale des chemins validés.
|
|
niveaux (Dict[str, int]): Dictionnaire associant chaque nœud à son niveau.
|
|
niveau_depart (int): Le niveau de départ sélectionné.
|
|
niveau_arrivee (int): Le niveau d'arrivée sélectionné.
|
|
niveaux_speciaux (List[int]): Les niveaux spéciaux à inclure dans l'extraction.
|
|
|
|
Returns:
|
|
Set[Tuple[str, str]]: Ensemble des paires de nœuds liés après filtrage.
|
|
"""
|
|
liens = set()
|
|
for chemin in chemins:
|
|
for i in range(len(chemin) - 1):
|
|
u, v = chemin[i], chemin[i + 1]
|
|
niveau_u = niveaux.get(u, 999)
|
|
niveau_v = niveaux.get(v, 999)
|
|
if (
|
|
(niveau_depart <= niveau_u <= niveau_arrivee or niveau_u in niveaux_speciaux)
|
|
and (niveau_depart <= niveau_v <= niveau_arrivee or niveau_v in niveaux_speciaux)
|
|
):
|
|
liens.add((u, v))
|
|
return liens
|
|
|
|
def interface_ia_nalyse(
|
|
G_temp: nx.DiGraph,
|
|
) -> None:
|
|
"""
|
|
Fonction principale qui s'occupe de la création du graphe pour analyse.
|
|
|
|
Args:
|
|
G_temp (nx.DiGraph): Le graphe NetworkX contenant les données des produits.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
st.markdown(f"# {str(_('pages.ia_nalyse.title'))}")
|
|
html_expander(f"{str(_('pages.ia_nalyse.help'))}", content="\n".join(_("pages.ia_nalyse.help_content")), open_by_default=False, details_class="details_introduction")
|
|
st.markdown("---")
|
|
|
|
from utils.persistance import get_champ_statut
|
|
resultat = statut_utilisateur(get_champ_statut("login"))
|
|
if resultat:
|
|
st.info(resultat["message"])
|
|
|
|
if resultat and resultat["statut"] is None:
|
|
# Préparation du graphe
|
|
G_temp, niveaux_temp = preparer_graphe(G_temp)
|
|
|
|
# Sélection des niveaux
|
|
niveau_depart = 0
|
|
niveau_arrivee = 99
|
|
|
|
# Sélection fine des noeuds
|
|
noeuds_depart, noeuds_arrivee = selectionner_noeuds(G_temp, niveaux_temp, niveau_depart)
|
|
|
|
# Sélection des minerais si nécessaire
|
|
minerais = selectionner_minerais(G_temp)
|
|
|
|
# Étape 1 : Extraction des niveaux des nœuds
|
|
niveaux = extraire_niveaux(G_temp)
|
|
|
|
# Étape 2 : Extraction des chemins selon les critères
|
|
chemins = extraire_chemins_selon_criteres(G_temp, niveaux, niveau_depart, noeuds_depart, noeuds_arrivee, minerais)
|
|
|
|
niveaux_speciaux = [1000, 1001, 1002, 1010, 1011, 1012]
|
|
# Extraction des liens sans filtrage
|
|
liens_chemins = extraire_liens_filtres(chemins, niveaux, niveau_depart, niveau_arrivee, niveaux_speciaux)
|
|
|
|
if liens_chemins:
|
|
G_final = exporter_graphe_filtre(G_temp, liens_chemins)
|
|
if st.button(str(_("pages.ia_nalyse.submit_request")), icon=":material/send:"):
|
|
soumettre_batch(get_champ_statut("login"), G_final)
|
|
st.rerun()
|
|
else:
|
|
st.info(str(_("pages.ia_nalyse.empty_graph")))
|
|
|
|
elif resultat and resultat["statut"] == "terminé" and resultat["telechargement"]:
|
|
if not st.session_state.get("telechargement_confirme"):
|
|
st.download_button(str(_("buttons.download")), resultat["telechargement"], file_name="analyse.zip", icon=":material/download:")
|
|
if st.button(str(_("pages.ia_nalyse.confirm_download")), icon=":material/task_alt:"):
|
|
nettoyage_post_telechargement(get_champ_statut("login"))
|
|
st.session_state["telechargement_confirme"] = True
|
|
st.rerun()
|
|
else:
|
|
st.success("Résultat supprimé. Vous pouvez relancer une nouvelle analyse.")
|
|
if st.button(str(_("buttons.refresh")), icon=":material/refresh:"):
|
|
st.rerun()
|
|
else:
|
|
if st.button(str(_("buttons.refresh")), icon=":material/refresh:"):
|
|
st.rerun()
|