Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Web Demo v2 pour la base de données d'œuvres d'art - Version Sécurisée et Optimisée | |
| Interface multi-étapes avec matching basé sur prénom, date, ville et émotions | |
| Optimisé pour les performances avec caching et indexation | |
| Version sécurisée avec validation des entrées et gestion d'état propre | |
| """ | |
| import gradio as gr | |
| import os | |
| import sys | |
| import logging | |
| from logging.handlers import RotatingFileHandler | |
| import random | |
| import re | |
| import json | |
| import uuid | |
| import time | |
| from datetime import datetime | |
| from typing import List, Dict, Tuple, Optional, Any, Set | |
| from collections import Counter, defaultdict | |
| from functools import lru_cache | |
| from dataclasses import dataclass, field, asdict | |
| from pathlib import Path | |
| import pandas as pd | |
| # Configuration du logging principal | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="[%(asctime)s] %(levelname)s: %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Import pour la sauvegarde persistante sur HF Spaces | |
| try: | |
| from huggingface_hub import CommitScheduler | |
| HF_HUB_AVAILABLE = True | |
| except ImportError: | |
| HF_HUB_AVAILABLE = False | |
| logger.warning( | |
| "huggingface_hub non installé - Les logs ne seront pas sauvegardés dans un dataset HF" | |
| ) | |
| # Configuration du logging des sessions | |
| SESSION_LOG_FILE = "session_logs.jsonl" | |
| STATS_LOG_FILE = "statistics.json" | |
| # Configuration du dataset HF pour la persistance (modifiez ces valeurs) | |
| HF_DATASET_ID = os.environ.get( | |
| "HF_DATASET_ID", "ClickMons/art-matcher-logs" | |
| ) # Remplacez par votre dataset | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) # Token HF pour l'authentification | |
| LOGS_UPLOAD_INTERVAL = 5 # Upload toutes les 5 minutes | |
| # Créer un handler pour le fichier de logs des sessions (local) | |
| if not os.path.exists("logs"): | |
| os.makedirs("logs") | |
| session_file_handler = RotatingFileHandler( | |
| filename=os.path.join("logs", SESSION_LOG_FILE), | |
| maxBytes=10 * 1024 * 1024, # 10MB | |
| backupCount=5, | |
| encoding="utf-8", | |
| ) | |
| session_file_handler.setLevel(logging.INFO) | |
| session_logger = logging.getLogger("session_logger") | |
| session_logger.addHandler(session_file_handler) | |
| session_logger.setLevel(logging.INFO) | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src")) | |
| from art_pieces_db.database import Database | |
| from art_pieces_db.query import TargetProfile, WeightedLeximaxOptimizer, Optimizer | |
| from art_pieces_db.emotions import EmotionWheel | |
| from art_pieces_db.utils import str_to_date | |
| class ScoringWeights: | |
| """Centralise toutes les constantes de scoring pour éviter les magic numbers""" | |
| PRESELECTION_NAME_WEIGHT: float = 3.0 | |
| PRESELECTION_DATE_WEIGHT: float = 1.0 | |
| PRESELECTION_PLACE_WEIGHT: float = 2.0 | |
| PRESELECTION_EMOTION_WEIGHT: float = 0.0 | |
| MIN_PRESELECTION_COUNT: int = 20 | |
| MAX_IMAGES_PER_SELECTION: int = 3 # nombre d'images par sélection | |
| TOTAL_ROUNDS: int = 3 # nombre de rounds avant la recommandation finale | |
| class SessionState: | |
| """Gère l'état de session""" | |
| firstname: str = "" | |
| birthday: str = "" | |
| city: str = "" | |
| current_round: int = 0 | |
| selected_images: List[str] = field(default_factory=list) | |
| current_image_ids: List[str] = field(default_factory=list) | |
| preselected_pieces: Optional[pd.DataFrame] = None | |
| # Propriétés pour le tracking | |
| session_id: str = field( | |
| default_factory=lambda: str(uuid.uuid4()) | |
| ) # ID unique de session | |
| session_start_time: float = field(default_factory=time.time) | |
| recommendation_type: str = "" # "name_date_place" ou "emotions" | |
| final_artwork: str = "" | |
| def reset(self): | |
| """Réinitialise l'état de session""" | |
| self.firstname = "" | |
| self.birthday = "" | |
| self.city = "" | |
| self.current_round = 0 | |
| self.selected_images = [] | |
| self.current_image_ids = [] | |
| self.preselected_pieces = None | |
| self.session_id = str(uuid.uuid4()) # Nouveau ID de session | |
| self.session_start_time = time.time() | |
| self.recommendation_type = "" | |
| self.final_artwork = "" | |
| def is_complete(self) -> bool: | |
| """Vérifie si la sélection est complète""" | |
| return self.current_round >= ScoringWeights.TOTAL_ROUNDS | |
| class SessionLogger: | |
| """Version améliorée du logger de sessions avec CommitScheduler simplifié""" | |
| def __init__(self): | |
| # Détection de l'environnement HF Spaces | |
| self.is_hf_space = os.environ.get("SPACE_ID") is not None | |
| # Sessions pour le dataset HF (seulement les logs de sessions) | |
| self.sessions_dir = Path("art_matcher_sessions") | |
| self.sessions_dir.mkdir(parents=True, exist_ok=True) | |
| # Statistiques locales uniquement | |
| self.local_stats_dir = Path("art_matcher_stats") | |
| self.local_stats_dir.mkdir(parents=True, exist_ok=True) | |
| # Chaque redémarrage crée un nouveau fichier | |
| self.sessions_file = self.sessions_dir / f"train-{uuid.uuid4()}.json" | |
| self.stats_file = self.local_stats_dir / "global_statistics.json" | |
| # Pour compatibilité avec l'ancien code | |
| self.data_dir = self.sessions_dir | |
| # Initialiser le CommitScheduler si sur HF Spaces | |
| self.scheduler = None | |
| if self.is_hf_space and HF_HUB_AVAILABLE: | |
| try: | |
| # Vérifier que le dataset ID est défini | |
| if not HF_DATASET_ID: | |
| raise ValueError("HF_DATASET_ID n'est pas défini") | |
| logger.info( | |
| f"Tentative d'initialisation du CommitScheduler pour {HF_DATASET_ID}..." | |
| ) | |
| self.scheduler = CommitScheduler( | |
| repo_id=HF_DATASET_ID, | |
| repo_type="dataset", | |
| folder_path=str(self.sessions_dir), # Seulement les sessions! | |
| path_in_repo="data", | |
| every=LOGS_UPLOAD_INTERVAL, | |
| ) | |
| logger.info( | |
| f"✅ CommitScheduler initialisé avec succès pour {HF_DATASET_ID}" | |
| ) | |
| logger.info( | |
| f"📁 Dossier des sessions (HF dataset): {self.sessions_dir}" | |
| ) | |
| logger.info( | |
| f"📊 Dossier des stats (local seulement): {self.local_stats_dir}" | |
| ) | |
| logger.info(f"📝 Fichier de session actuel: {self.sessions_file.name}") | |
| logger.info(f"⏱️ Upload toutes les {LOGS_UPLOAD_INTERVAL} minutes") | |
| except Exception as e: | |
| logger.error( | |
| f"❌ Erreur lors de l'initialisation du CommitScheduler: {e}" | |
| ) | |
| logger.info("Les données seront stockées localement uniquement") | |
| self.scheduler = None | |
| else: | |
| if not self.is_hf_space: | |
| logger.info( | |
| "🏠 Environnement local détecté - pas de synchronisation HF" | |
| ) | |
| if not HF_HUB_AVAILABLE: | |
| logger.warning("📦 huggingface_hub n'est pas installé") | |
| def log_session(self, state: SessionState, recommendation_system: str): | |
| """Enregistre une session de manière thread-safe""" | |
| session_duration = time.time() - state.session_start_time | |
| # Utiliser le session_id unique de l'état, pas l'instance_id | |
| entry = { | |
| "session_id": state.session_id, # ID unique de la session | |
| "datetime": datetime.now().isoformat(), | |
| "duration_seconds": round(session_duration, 2), | |
| "recommended_artwork": state.final_artwork, | |
| "recommendation_type": recommendation_system, | |
| } | |
| # Utiliser le lock du scheduler pour la thread safety | |
| try: | |
| if self.scheduler and hasattr(self.scheduler, "lock"): | |
| with self.scheduler.lock: | |
| self._write_session(entry) | |
| self._update_stats(entry) | |
| logger.info(f"✅ Session écrite avec lock du scheduler") | |
| else: | |
| # Sans scheduler, écriture directe | |
| self._write_session(entry) | |
| self._update_stats(entry) | |
| logger.info(f"📝 Session écrite sans scheduler") | |
| logger.info( | |
| f"Session enregistrée - ID: {entry['session_id'][:8]}... - Durée: {entry['duration_seconds']}s" | |
| ) | |
| logger.info(f"📁 Fichier: {self.sessions_file.name}") | |
| session_logger.info(json.dumps(entry, ensure_ascii=False)) | |
| except Exception as e: | |
| logger.error(f"❌ Erreur lors de l'enregistrement de la session: {e}") | |
| # Toujours essayer de logger dans le fichier local | |
| try: | |
| session_logger.info(json.dumps(entry, ensure_ascii=False)) | |
| except: | |
| pass | |
| def _write_session(self, entry: dict): | |
| """Écrit une entrée de session dans le fichier JSON (format newline-delimited)""" | |
| try: | |
| self.sessions_dir.mkdir(parents=True, exist_ok=True) | |
| # Écrire en mode append avec une nouvelle ligne pour chaque entrée | |
| with self.sessions_file.open("a", encoding="utf-8") as f: | |
| f.write(json.dumps(entry, ensure_ascii=False) + "\n") | |
| f.flush() # Forcer l'écriture sur le disque | |
| # Vérifier que le fichier existe et a du contenu | |
| if self.sessions_file.exists(): | |
| size = self.sessions_file.stat().st_size | |
| logger.debug( | |
| f"📊 Fichier {self.sessions_file.name} - Taille: {size} octets" | |
| ) | |
| except Exception as e: | |
| logger.error(f"❌ Erreur lors de l'écriture dans {self.sessions_file}: {e}") | |
| def _update_stats(self, session_entry: dict): | |
| """Met à jour les statistiques globales""" | |
| # Charger les stats existantes | |
| stats = {} | |
| if self.stats_file.exists(): | |
| try: | |
| with self.stats_file.open("r", encoding="utf-8") as f: | |
| stats = json.load(f) | |
| except json.JSONDecodeError: | |
| stats = {} | |
| # Initialiser la structure si nécessaire | |
| if "total_sessions" not in stats: | |
| stats = { | |
| "total_sessions": 0, | |
| "total_duration_seconds": 0, | |
| "average_duration_seconds": 0, | |
| "artworks_recommended": {}, | |
| "recommendation_types": { | |
| "name_date_place": 0, | |
| "emotions": 0, | |
| "none": 0, | |
| }, | |
| "first_session": session_entry["datetime"], | |
| "last_session": session_entry["datetime"], | |
| } | |
| # Mettre à jour les compteurs | |
| stats["total_sessions"] += 1 | |
| stats["total_duration_seconds"] += session_entry.get("duration_seconds", 0) | |
| stats["average_duration_seconds"] = ( | |
| stats["total_duration_seconds"] / stats["total_sessions"] | |
| ) | |
| stats["last_session"] = session_entry["datetime"] | |
| # Compter les types de recommandation | |
| rec_type = session_entry.get("recommendation_type", "none") | |
| if rec_type in stats["recommendation_types"]: | |
| stats["recommendation_types"][rec_type] += 1 | |
| # Compter les œuvres recommandées | |
| artwork = session_entry.get("recommended_artwork") | |
| if artwork and artwork != "Aucune œuvre trouvée": | |
| if artwork not in stats["artworks_recommended"]: | |
| stats["artworks_recommended"][artwork] = 0 | |
| stats["artworks_recommended"][artwork] += 1 | |
| # Trouver l'œuvre la plus populaire | |
| if stats["artworks_recommended"]: | |
| most_popular = max( | |
| stats["artworks_recommended"].items(), key=lambda x: x[1] | |
| ) | |
| stats["most_popular_artwork"] = { | |
| "title": most_popular[0], | |
| "count": most_popular[1], | |
| "percentage": (most_popular[1] / stats["total_sessions"]) * 100, | |
| } | |
| # Calculer les pourcentages d'utilisation | |
| total = stats["total_sessions"] | |
| if total > 0: | |
| stats["recommendation_percentages"] = { | |
| k: (v / total) * 100 for k, v in stats["recommendation_types"].items() | |
| } | |
| stats["last_updated"] = datetime.now().isoformat() | |
| # Sauvegarder les stats mises à jour | |
| with self.stats_file.open("w", encoding="utf-8") as f: | |
| json.dump(stats, f, indent=2, ensure_ascii=False) | |
| def get_statistics(self) -> dict: | |
| """Retourne les statistiques globales""" | |
| if self.stats_file.exists(): | |
| try: | |
| with self.stats_file.open("r", encoding="utf-8") as f: | |
| return json.load(f) | |
| except Exception as e: | |
| logger.error(f"Erreur lecture stats: {e}") | |
| return {} | |
| # Initialiser le logger de sessions | |
| session_tracker = SessionLogger() | |
| class SecurityValidator: | |
| """Classe pour centraliser les validations de sécurité""" | |
| PATH_TRAVERSAL_PATTERN = re.compile(r"\.\.|\.\/") | |
| VALID_FILENAME_PATTERN = re.compile(r"^[\w\-\.\s]+$") | |
| VALID_INPUT_PATTERN = re.compile( | |
| r"^[\w\-\s\'\.,àâäéèêëïîôûùüÿæœçÀÂÄÉÈÊËÏÎÔÛÙÜŸÆŒÇ]+$", re.UNICODE | |
| ) | |
| DATE_PATTERN = re.compile(r"^\d{1,2}/\d{1,2}$") | |
| def validate_filename(cls, filename: str) -> bool: | |
| """Valide qu'un nom de fichier est sécurisé""" | |
| if not filename: | |
| return False | |
| # Vérifier les tentatives de path traversal | |
| if cls.PATH_TRAVERSAL_PATTERN.search(filename): | |
| logger.warning(f"Tentative de path traversal détectée: {filename}") | |
| return False | |
| # Vérifier que le nom ne contient que des caractères autorisés | |
| base_name = os.path.basename(filename) | |
| if not cls.VALID_FILENAME_PATTERN.match(base_name): | |
| logger.warning(f"Nom de fichier invalide: {filename}") | |
| return False | |
| return True | |
| def sanitize_input(cls, input_str: str, max_length: int = 100) -> str: | |
| """Nettoie et valide une entrée utilisateur""" | |
| if not input_str: | |
| return "" | |
| # Tronquer si trop long | |
| input_str = input_str[:max_length].strip() | |
| if not cls.VALID_INPUT_PATTERN.match(input_str): | |
| # Garder seulement les caractères valides | |
| cleaned = "".join(c for c in input_str if cls.VALID_INPUT_PATTERN.match(c)) | |
| logger.info(f"Input sanitized: '{input_str}' -> '{cleaned}'") | |
| return cleaned | |
| return input_str | |
| def validate_date(cls, date_str: str) -> Tuple[bool, Optional[datetime]]: | |
| """Valide et parse une date au format JJ/MM""" | |
| if not date_str: | |
| return False, None | |
| if not cls.DATE_PATTERN.match(date_str): | |
| return False, None | |
| try: | |
| day, month = map(int, date_str.split("/")) | |
| if not (1 <= day <= 31 and 1 <= month <= 12): | |
| return False, None | |
| date_obj = datetime(year=2000, month=month, day=day) | |
| return True, date_obj | |
| except (ValueError, Exception) as e: | |
| logger.error(f"Erreur de parsing de date: {e}") | |
| return False, None | |
| class ImageIndexer: | |
| """Classe pour indexer et mapper les images depuis la base de données CSV""" | |
| # Constants for better maintainability | |
| IMAGE_EXTENSIONS = (".jpg", ".png") | |
| COMMON_SUFFIXES = [".jpg", ".png", "_medium"] | |
| MAR_BVM_TEST_SUFFIXES = ["-001", "-002", "-003"] | |
| def __init__(self, images_dir: str): | |
| self.images_dir = os.path.abspath(images_dir) | |
| self.available_files = set() | |
| self.image_lookup = {} # normalized_name -> filename | |
| self.mar_bvm_lookup = {} # Special handling for MAR-BVM files | |
| self._build_index() | |
| def _strip_file_extensions(self, filename: str) -> str: | |
| """Remove file extensions from filename""" | |
| base_name = filename.lower() | |
| if base_name.endswith("_medium.jpg"): | |
| return base_name[:-11] | |
| elif base_name.endswith((".jpg", ".png")): | |
| return base_name[:-4] | |
| return base_name | |
| def _normalize_basic_patterns(self, name: str) -> str: | |
| """Apply basic normalization patterns""" | |
| # Remove trailing comma and normalize whitespace | |
| normalized = name.lower().strip().rstrip(",") | |
| # Remove common suffixes | |
| for suffix in self.COMMON_SUFFIXES: | |
| if normalized.endswith(suffix): | |
| normalized = normalized[: -len(suffix)] | |
| # Normalize spaces and underscores to dashes | |
| return re.sub(r"[\s_]+", "-", normalized) | |
| def _normalize_mar_bvm_format(self, name: str) -> str: | |
| """Handle MAR-BVM specific normalization""" | |
| if "mar-bvm" not in name: | |
| return name | |
| # Replace .0. with -0- and remaining dots with dashes | |
| return name.replace(".0.", "-0-").replace(".", "-") | |
| def _normalize_name(self, name: str) -> str: | |
| """Normalise un nom pour la comparaison""" | |
| normalized = self._normalize_basic_patterns(name) | |
| # Special handling for MAR-BVM format | |
| if "mar-bvm" in normalized: | |
| normalized = self._normalize_mar_bvm_format(normalized) | |
| # For files starting with year (like 2022.0.86), keep dots | |
| elif not normalized.startswith("20"): | |
| normalized = normalized.replace(".", "-") | |
| return normalized | |
| def _create_mar_bvm_lookups(self, normalized: str, filename: str): | |
| """Create additional lookup entries for MAR-BVM files""" | |
| if "mar-bvm" not in normalized: | |
| return | |
| parts = normalized.split("-") | |
| for i, part in enumerate(parts): | |
| if part.isdigit() and i >= 5: # After mar-bvm-7-2022-0 | |
| base_key = "-".join(parts[:6]) # mar-bvm-7-2022-0-22 | |
| if base_key not in self.mar_bvm_lookup: | |
| self.mar_bvm_lookup[base_key] = [] | |
| self.mar_bvm_lookup[base_key].append(filename) | |
| break | |
| def _process_image_file(self, filename: str): | |
| """Process a single image file for indexing""" | |
| if not SecurityValidator.validate_filename(filename): | |
| logger.warning(f"Fichier ignoré pour raison de sécurité: {filename}") | |
| return | |
| if not filename.lower().endswith(self.IMAGE_EXTENSIONS): | |
| return | |
| self.available_files.add(filename) | |
| base_name = self._strip_file_extensions(filename) | |
| normalized = self._normalize_name(base_name) | |
| self.image_lookup[normalized] = filename | |
| self._create_mar_bvm_lookups(normalized, filename) | |
| def _build_index(self): | |
| """Construit un index des images disponibles""" | |
| try: | |
| all_files = os.listdir(self.images_dir) | |
| for filename in all_files: | |
| self._process_image_file(filename) | |
| logger.info( | |
| f"Index des images construit: {len(self.available_files)} fichiers disponibles, " | |
| f"{len(self.image_lookup)} entrées normalisées" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Erreur lors de la construction de l'index: {e}") | |
| self.available_files = set() | |
| def _clean_input_name(self, image_name: str) -> str: | |
| """Clean and prepare input name for processing""" | |
| # Basic cleaning | |
| cleaned = image_name.strip().rstrip(",").rstrip("-").strip() | |
| # Remove spaces before -001, -002, etc. | |
| return re.sub(r"\s+(-\d)", r"\1", cleaned) | |
| def _normalize_mar_bvm_input(self, image_name: str) -> str: | |
| """Handle MAR-BVM specific input normalization""" | |
| if "MAR-BVM" not in image_name: | |
| return image_name | |
| # Handle missing "7-" in MAR-BVM-2022-0-153 | |
| if "MAR-BVM-2022-0-" in image_name: | |
| image_name = image_name.replace("MAR-BVM-2022-0-", "MAR-BVM-7-2022-0-") | |
| # Convert .0. to -0- | |
| if ".0." in image_name: | |
| image_name = image_name.replace(".0.", "-0-") | |
| # Handle .001, .002 at the end (convert to -001, -002) | |
| image_name = re.sub(r"\.(\d{3})$", r"-\1", image_name) | |
| # Handle .1 or .2 suffix | |
| if image_name.endswith(".1"): | |
| image_name = image_name[:-2] + "-1" | |
| elif image_name.endswith(".2"): | |
| image_name = image_name[:-2] + "-2" | |
| # Replace any remaining dots with dashes (but be careful not to mess up already processed parts) | |
| return image_name.replace(".", "-") | |
| def _try_mar_bvm_lookups(self, normalized: str) -> Optional[str]: | |
| """Try various MAR-BVM specific lookup strategies""" | |
| # Check special MAR-BVM lookup | |
| if normalized in self.mar_bvm_lookup and self.mar_bvm_lookup[normalized]: | |
| return self.mar_bvm_lookup[normalized][0] | |
| # Try with suffix variations | |
| for suffix in self.MAR_BVM_TEST_SUFFIXES: | |
| test_pattern = f"{normalized}{suffix}" | |
| if test_pattern in self.image_lookup: | |
| return self.image_lookup[test_pattern] | |
| return None | |
| def _try_year_format_lookup(self, image_name: str) -> Optional[str]: | |
| """Handle special case for files starting with year""" | |
| if not image_name.startswith("20"): | |
| return None | |
| test_name = image_name.lower().replace(" ", "-") | |
| return self.image_lookup.get(test_name) | |
| def _try_partial_matching(self, normalized: str) -> Optional[str]: | |
| """Try partial matching as last resort""" | |
| for key, filename in self.image_lookup.items(): | |
| if key.startswith(normalized) or normalized in key: | |
| return filename | |
| return None | |
| def _split_multiple_names(self, image_name: str) -> List[str]: | |
| """Split image names that contain multiple names separated by commas or slashes""" | |
| # First try comma separation | |
| if "," in image_name: | |
| return [name.strip() for name in image_name.split(",") if name.strip()] | |
| # Then try slash separation | |
| if "/" in image_name: | |
| return [name.strip() for name in image_name.split("/") if name.strip()] | |
| # Handle " - " separation (for cases like "MAR-BVM-7-2022.0.81 - 2022.0.81") | |
| if " - " in image_name and image_name.count(" - ") == 1: | |
| parts = [name.strip() for name in image_name.split(" - ")] | |
| # Only use the first part if they look like duplicates | |
| if len(parts) == 2: | |
| first, second = parts | |
| # Check if second part is a suffix of the first (like duplicate year) | |
| if first.endswith(second) or second in first: | |
| return [first] | |
| return parts | |
| return [image_name] | |
| def find_image(self, image_name: str) -> Optional[str]: | |
| """Trouve un fichier image correspondant au nom donné""" | |
| if not image_name: | |
| return None | |
| # Handle multiple image names in one field | |
| possible_names = self._split_multiple_names(image_name) | |
| # Try each name individually | |
| for name in possible_names: | |
| result = self._find_single_image(name) | |
| if result: | |
| return result | |
| return None | |
| def _find_single_image(self, image_name: str) -> Optional[str]: | |
| """Find a single image by name""" | |
| # Clean and normalize the input | |
| cleaned_name = self._clean_input_name(image_name) | |
| processed_name = self._normalize_mar_bvm_input(cleaned_name) | |
| normalized = self._normalize_name(processed_name) | |
| # Try direct lookup first | |
| if normalized in self.image_lookup: | |
| return self.image_lookup[normalized] | |
| # Try MAR-BVM specific lookups | |
| if "mar-bvm" in normalized: | |
| result = self._try_mar_bvm_lookups(normalized) | |
| if result: | |
| return result | |
| # Try year format lookup | |
| result = self._try_year_format_lookup(image_name) | |
| if result: | |
| return result | |
| # Try partial matching as last resort | |
| return self._try_partial_matching(normalized) | |
| def get_all_files(self) -> Set[str]: | |
| """Retourne tous les fichiers disponibles""" | |
| return self.available_files.copy() | |
| class ArtMatcherV2: | |
| """Classe principale pour le matching d'œuvres d'art""" | |
| def __init__(self, csv_path: str, images_dir: str): | |
| """Initialise le système avec la base de données et le répertoire d'images""" | |
| self.db = Database(csv_path) | |
| self.images_dir = os.path.abspath(images_dir) | |
| self.emotion_wheel = EmotionWheel() | |
| self.weights = ScoringWeights() | |
| self.optimizer_helper = WeightedLeximaxOptimizer(TargetProfile(), {}) | |
| self.image_indexer = ImageIndexer(images_dir) | |
| df = self.db.get_dataframe() | |
| self.df_with_images = df[ | |
| df["name_image"].notna() | |
| & (df["name_image"] != "") | |
| & (df["name_image"].str.strip() != "") | |
| ].copy() | |
| self.df_with_images["database_id_str"] = self.df_with_images[ | |
| "database_id" | |
| ].astype(str) | |
| self.id_to_index = { | |
| str(row["database_id"]): idx for idx, row in self.df_with_images.iterrows() | |
| } | |
| self.artwork_images = self._build_artwork_image_index() | |
| self.temp_db_with_images = Database.__new__(Database) | |
| self.temp_db_with_images.dataframe = self.df_with_images | |
| logger.info(f"Base de données chargée: {self.db.n_pieces()} œuvres") | |
| logger.info(f"Œuvres avec images: {len(self.df_with_images)}") | |
| logger.info(f"Index des images: {len(self.artwork_images)} œuvres mappées") | |
| def _sanitize_input(self, input_str: str) -> str: | |
| """Nettoie et valide une entrée utilisateur""" | |
| return SecurityValidator.sanitize_input(input_str) | |
| def _parse_date(self, date_str: str) -> Optional[datetime]: | |
| """Parse une date avec validation""" | |
| is_valid, date_obj = SecurityValidator.validate_date(date_str) | |
| return date_obj if is_valid else None | |
| def _build_artwork_image_index(self) -> Dict[str, List[str]]: | |
| """Construit un index artwork_id -> [image_paths] au démarrage""" | |
| artwork_images = {} | |
| for idx, row in self.df_with_images.iterrows(): | |
| artwork_id = str(row["database_id"]) | |
| image_paths = [] | |
| if row["name_image"] and str(row["name_image"]).strip(): | |
| # Parse the image names - handle special separators | |
| image_string = str(row["name_image"]).strip().strip('"') | |
| # Handle cases with " / " or " - " separators | |
| if " / " in image_string: | |
| # Take first part before the slash | |
| image_string = image_string.split(" / ")[0].strip() | |
| # Special case: if it has " - 2022" it's a separator, not part of the name | |
| if " - 2022" in image_string: | |
| # Take the part before " - 2022" | |
| image_string = image_string.split(" - 2022")[0].strip() | |
| elif " - " in image_string and "MAR-BVM-7-2022-0-" not in image_string: | |
| # For other MAR-BVM formats with " - " separator | |
| parts = image_string.split(" - ") | |
| if "MAR-BVM" in parts[0]: | |
| image_string = parts[0].strip() | |
| # Clean up trailing " -" or spaces before "-001" | |
| image_string = re.sub( | |
| r"\s+-\s*$", "", image_string | |
| ) # Remove trailing " -" | |
| image_string = re.sub( | |
| r"\s+(-\d)", r"\1", image_string | |
| ) # Remove spaces before -001 | |
| # Parse comma-separated list | |
| images = [ | |
| img.strip() | |
| for img in re.split(r"[,/]", image_string) | |
| if img.strip() | |
| ] | |
| for img_name in images: | |
| # Find the actual file for this image name | |
| matched_file = self.image_indexer.find_image(img_name) | |
| if matched_file: | |
| img_path = os.path.join(self.images_dir, matched_file) | |
| image_paths.append(img_path) | |
| if image_paths: | |
| artwork_images[artwork_id] = image_paths | |
| return artwork_images | |
| def preselect_artworks( | |
| self, firstname: str, birthday: str, city: str | |
| ) -> pd.DataFrame: | |
| """ | |
| Pré-sélectionne les œuvres selon la hiérarchie: prénom > date > ville | |
| """ | |
| logger.info("=== DÉBUT PRÉ-SÉLECTION ===") | |
| # Nettoyer les entrées | |
| firstname = self._sanitize_input(firstname) | |
| city = self._sanitize_input(city) | |
| logger.info( | |
| f"Critères de pré-sélection: prénom='{firstname}', date='{birthday}', ville='{city}'" | |
| ) | |
| birth_date = self._parse_date(birthday) | |
| if birth_date: | |
| logger.info(f"Date convertie: {birth_date.strftime('%d/%m')}") | |
| profile = TargetProfile() | |
| profile.set_target_name(firstname) | |
| profile.set_target_date(birth_date) | |
| profile.set_target_place(city) | |
| weights = { | |
| "related_names": self.weights.PRESELECTION_NAME_WEIGHT, | |
| "related_dates": self.weights.PRESELECTION_DATE_WEIGHT, | |
| "related_places": self.weights.PRESELECTION_PLACE_WEIGHT, | |
| "related_emotions": self.weights.PRESELECTION_EMOTION_WEIGHT, | |
| } | |
| logger.info( | |
| f"Poids utilisés: nom={weights['related_names']}, date={weights['related_dates']}, lieu={weights['related_places']}, émotions={weights['related_emotions']}" | |
| ) | |
| optimizer = WeightedLeximaxOptimizer(profile, weights) | |
| result = optimizer.optimize_max(self.temp_db_with_images) | |
| preselected = result[result["score"] > (0, 0, 0)] | |
| logger.info(f"Œuvres avec score > 0: {len(preselected)}") | |
| if len(preselected) < self.weights.MIN_PRESELECTION_COUNT: | |
| preselected = result.head(self.weights.MIN_PRESELECTION_COUNT) | |
| logger.info(f"Ajustement au minimum requis: {len(preselected)} œuvres") | |
| logger.info("Top 5 pré-sélections:") | |
| for i, (idx, piece) in enumerate(preselected.head(5).iterrows()): | |
| logger.info( | |
| f" {i+1}. Œuvre #{piece['database_id']} - Score: {piece['score']}" | |
| ) | |
| if firstname and piece["related_names"]: | |
| name_score = Optimizer.name_similarity( | |
| firstname, piece["related_names"] | |
| ) | |
| if name_score > 0: | |
| logger.info( | |
| f" → Nom: {piece['related_names']} (score: {name_score:.2f})" | |
| ) | |
| if birth_date and piece["related_dates"]: | |
| date_score = Optimizer.date_similarity( | |
| birth_date, piece["related_dates"] | |
| ) | |
| if date_score > 0: | |
| logger.info( | |
| f" → Dates: {[d.strftime('%d/%m') for d in piece['related_dates']]} (score: {date_score:.2f})" | |
| ) | |
| if city and piece["related_places"]: | |
| place_score = self.optimizer_helper.place_similarity( | |
| city, piece["related_places"] | |
| ) | |
| if place_score > 0: | |
| logger.info( | |
| f" → Lieux: {piece['related_places']} (score: {place_score:.2f})" | |
| ) | |
| logger.info("=== FIN PRÉ-SÉLECTION ===") | |
| return preselected | |
| def get_random_images_for_selection( | |
| self, round_num: int, already_selected: List[str] = None | |
| ) -> List[Tuple[str, str]]: | |
| """ | |
| Retourne 3 images aléatoires depuis l'index pré-construit | |
| Exclut les œuvres déjà sélectionnées dans les tours précédents | |
| """ | |
| logger.info(f"=== SÉLECTION D'IMAGES POUR LE TOUR {round_num} ===") | |
| if already_selected: | |
| logger.info(f"Œuvres déjà sélectionnées à exclure: {already_selected}") | |
| available_artworks = list(self.artwork_images.keys()) | |
| # Exclure les œuvres déjà sélectionnées | |
| if already_selected: | |
| already_selected_set = set(already_selected) | |
| available_artworks = [ | |
| a for a in available_artworks if a not in already_selected_set | |
| ] | |
| logger.info( | |
| f"Nombre total d'œuvres avec images disponibles: {len(available_artworks)}" | |
| ) | |
| if len(available_artworks) < self.weights.MAX_IMAGES_PER_SELECTION: | |
| logger.warning( | |
| f"Seulement {len(available_artworks)} œuvres avec images disponibles" | |
| ) | |
| direct_images = [] | |
| for filename in list(self.image_indexer.get_all_files())[:10]: | |
| if filename.endswith(".jpg"): | |
| img_path = os.path.join(self.images_dir, filename) | |
| direct_images.append((img_path, "0")) | |
| return direct_images[: self.weights.MAX_IMAGES_PER_SELECTION] | |
| num_to_select = min( | |
| self.weights.MAX_IMAGES_PER_SELECTION, len(available_artworks) | |
| ) | |
| selected_artworks = random.sample(available_artworks, num_to_select) | |
| logger.info(f"Œuvres sélectionnées aléatoirement: {selected_artworks}") | |
| selected = [] | |
| for artwork_id in selected_artworks: | |
| img_path = random.choice(self.artwork_images[artwork_id]) | |
| selected.append((img_path, artwork_id)) | |
| if artwork_id in self.id_to_index: | |
| idx = self.id_to_index[artwork_id] | |
| artwork = self.df_with_images.loc[idx] | |
| logger.info(f" Image {len(selected)}: Œuvre #{artwork_id}") | |
| logger.info(f" Type: {artwork['art_piece_type']}") | |
| logger.info(f" Émotions: {artwork['related_emotions']}") | |
| logger.info(f"=== FIN SÉLECTION IMAGES TOUR {round_num} ===") | |
| return selected | |
| def extract_emotions_from_image_id(self, database_id: str) -> List[str]: | |
| """ | |
| Extrait les émotions associées à une œuvre via son ID | |
| Utilise l'index pré-calculé pour éviter les conversions répétées | |
| """ | |
| if database_id in self.id_to_index: | |
| idx = self.id_to_index[database_id] | |
| emotions = self.df_with_images.loc[idx, "related_emotions"] | |
| if isinstance(emotions, list): | |
| return emotions | |
| return [] | |
| def _cached_emotion_similarity(self, emotion1: str, emotion2: str) -> float: | |
| """Cache les calculs de similarité émotionnelle""" | |
| return self.emotion_wheel.calculate_emotion_similarity(emotion1, emotion2) | |
| def calculate_emotion_profile(self, selected_ids: List[str]) -> Dict[str, float]: | |
| """ | |
| Calcule le profil émotionnel basé sur les images sélectionnées | |
| """ | |
| logger.info("=== CALCUL DU PROFIL ÉMOTIONNEL ===") | |
| logger.info(f"Images sélectionnées: {selected_ids}") | |
| emotion_counter = Counter() | |
| for db_id in selected_ids: | |
| emotions = self.extract_emotions_from_image_id(db_id) | |
| logger.info(f" Image {db_id}: émotions = {emotions}") | |
| emotion_counter.update(emotions) | |
| total = sum(emotion_counter.values()) | |
| if total > 0: | |
| emotion_profile = { | |
| emotion: count / total for emotion, count in emotion_counter.items() | |
| } | |
| logger.info(f"Profil émotionnel calculé: {emotion_profile}") | |
| else: | |
| emotion_profile = {} | |
| logger.info("Aucune émotion trouvée dans les images sélectionnées") | |
| logger.info("=== FIN CALCUL PROFIL ÉMOTIONNEL ===") | |
| return emotion_profile | |
| def _get_artwork_image(self, artwork) -> Optional[str]: | |
| """Retourne le chemin de l'image pour une œuvre d'art""" | |
| artwork_id = str(artwork["database_id"]) | |
| # Simply return the first image from our pre-built index | |
| if artwork_id in self.artwork_images: | |
| return self.artwork_images[artwork_id][0] | |
| return None | |
| def find_best_match( | |
| self, firstname: str, birthday: str, city: str, selected_image_ids: List[str] | |
| ) -> Tuple[Optional[str], str, Dict]: | |
| """ | |
| Trouve la meilleure correspondance selon la hiérarchie du scénario: | |
| 1. Match exact (name/date/city) = gagnant automatique | |
| 2. Si pré-sélection existe: utiliser émotions pour départager | |
| 3. Si aucune pré-sélection: utiliser émotions seules | |
| 4. Type d'objet comme critère de départage final | |
| """ | |
| firstname = self._sanitize_input(firstname) | |
| city = self._sanitize_input(city) | |
| birth_date = self._parse_date(birthday) | |
| logger.info( | |
| f"Recherche de correspondance pour: {firstname}, {birthday}, {city}" | |
| ) | |
| preselected = self.preselect_artworks(firstname, birthday, city) | |
| logger.info("=== DÉTECTION DE MATCH EXACT ===") | |
| for idx, piece in preselected.iterrows(): | |
| if firstname and piece["related_names"]: | |
| name_score = Optimizer.name_similarity( | |
| firstname, piece["related_names"] | |
| ) | |
| if name_score >= 0.95: | |
| logger.info( | |
| f"🎯 MATCH EXACT TROUVÉ: prénom '{firstname}' → œuvre #{piece['database_id']} (score: {name_score:.2f})" | |
| ) | |
| logger.info(f" Noms dans l'œuvre: {piece['related_names']}") | |
| match_image = self._get_artwork_image(piece) | |
| match_info = { | |
| "title": f"Œuvre #{piece['database_id']}", | |
| "type": piece["art_piece_type"], | |
| "place": piece["art_piece_place"], | |
| "emotions": piece["related_emotions"], | |
| "explanation": piece["explanation"], | |
| } | |
| return ( | |
| match_image, | |
| f"Prénom '{firstname}' correspond exactement", | |
| match_info, | |
| ) | |
| if birth_date and piece["related_dates"]: | |
| date_score = Optimizer.date_similarity( | |
| birth_date, piece["related_dates"] | |
| ) | |
| if date_score == 1.0: | |
| logger.info( | |
| f"🎯 MATCH EXACT TROUVÉ: date '{birthday}' → œuvre #{piece['database_id']}" | |
| ) | |
| logger.info( | |
| f" Dates dans l'œuvre: {[d.strftime('%d/%m/%Y') for d in piece['related_dates']]}" | |
| ) | |
| match_image = self._get_artwork_image(piece) | |
| match_info = { | |
| "title": f"Œuvre #{piece['database_id']}", | |
| "type": piece["art_piece_type"], | |
| "place": piece["art_piece_place"], | |
| "emotions": piece["related_emotions"], | |
| "explanation": piece["explanation"], | |
| } | |
| return ( | |
| match_image, | |
| f"Date d'anniversaire {birthday} correspond exactement", | |
| match_info, | |
| ) | |
| if city and piece["related_places"]: | |
| place_score = self.optimizer_helper.place_similarity( | |
| city, piece["related_places"] | |
| ) | |
| if place_score == 1.0: | |
| logger.info( | |
| f"🎯 MATCH EXACT TROUVÉ: ville '{city}' → œuvre #{piece['database_id']}" | |
| ) | |
| logger.info(f" Lieux dans l'œuvre: {piece['related_places']}") | |
| match_image = self._get_artwork_image(piece) | |
| match_info = { | |
| "title": f"Œuvre #{piece['database_id']}", | |
| "type": piece["art_piece_type"], | |
| "place": piece["art_piece_place"], | |
| "emotions": piece["related_emotions"], | |
| "explanation": piece["explanation"], | |
| } | |
| return ( | |
| match_image, | |
| f"Ville '{city}' correspond exactement", | |
| match_info, | |
| ) | |
| logger.info("Aucun match exact trouvé, passage à la sélection par émotions") | |
| emotion_profile = self.calculate_emotion_profile(selected_image_ids) | |
| logger.info("=== STRATÉGIE DE MATCHING ===") | |
| valid_preselection = preselected[preselected["score"] > (0, 0, 0)] | |
| if len(valid_preselection) > 0: | |
| logger.info( | |
| f"📋 CAS A: {len(valid_preselection)} œuvres pré-sélectionnées - utilisation des émotions pour départager" | |
| ) | |
| candidates = valid_preselection | |
| else: | |
| logger.info( | |
| f"📋 CAS B: Aucune pré-sélection valide - recherche par émotions sur {len(self.df_with_images)} œuvres" | |
| ) | |
| candidates = self.df_with_images | |
| # Exclure les œuvres déjà sélectionnées par l'utilisateur | |
| selected_artwork_ids = set(selected_image_ids) | |
| candidates = candidates[ | |
| ~candidates["database_id"].astype(str).isin(selected_artwork_ids) | |
| ] | |
| logger.info( | |
| f"Après exclusion des œuvres déjà sélectionnées {selected_artwork_ids}: {len(candidates)} candidats restants" | |
| ) | |
| logger.info("=== CALCUL DES SCORES ÉMOTIONNELS ===") | |
| best_matches = [] | |
| best_emotion_score = -1 | |
| for idx, piece in candidates.iterrows(): | |
| emotion_score = 0 | |
| if emotion_profile and piece["related_emotions"]: | |
| for user_emotion, weight in emotion_profile.items(): | |
| best_similarity = 0 | |
| for piece_emotion in piece["related_emotions"]: | |
| similarity = self._cached_emotion_similarity( | |
| user_emotion, piece_emotion | |
| ) | |
| if similarity > best_similarity: | |
| best_similarity = similarity | |
| emotion_score += best_similarity * weight | |
| if len(piece["related_emotions"]) > 0: | |
| emotion_score /= len(piece["related_emotions"]) | |
| if emotion_score > best_emotion_score: | |
| best_emotion_score = emotion_score | |
| best_matches = [piece] | |
| logger.info( | |
| f" Nouveau meilleur score émotionnel: {emotion_score:.3f} - Œuvre #{piece['database_id']}" | |
| ) | |
| elif emotion_score == best_emotion_score and emotion_score > 0: | |
| best_matches.append(piece) | |
| logger.info( | |
| f" Score égal au meilleur: {emotion_score:.3f} - Œuvre #{piece['database_id']}" | |
| ) | |
| logger.info( | |
| f"Nombre de meilleures correspondances: {len(best_matches)} avec score {best_emotion_score:.3f}" | |
| ) | |
| if len(best_matches) > 1: | |
| logger.info("=== DÉPARTAGE PAR TYPE D'OBJET ===") | |
| selected_types = [] | |
| for img_id in selected_image_ids: | |
| if img_id in self.id_to_index: | |
| idx = self.id_to_index[img_id] | |
| selected_types.append( | |
| self.df_with_images.loc[idx, "art_piece_type"] | |
| ) | |
| selected_types_counter = Counter(selected_types) | |
| type_scored_matches = [] | |
| best_type_score = -1 | |
| for piece in best_matches: | |
| type_score = selected_types_counter.get(piece["art_piece_type"], 0) | |
| if type_score > best_type_score: | |
| best_type_score = type_score | |
| type_scored_matches = [piece] | |
| elif type_score == best_type_score: | |
| type_scored_matches.append(piece) | |
| if len(type_scored_matches) > 1: | |
| logger.info( | |
| f" {len(type_scored_matches)} œuvres avec le même score de type ({best_type_score}) - sélection aléatoire" | |
| ) | |
| best_match = random.choice(type_scored_matches) | |
| match_reason = ( | |
| "Sélection aléatoire parmi les meilleures correspondances" | |
| ) | |
| else: | |
| best_match = type_scored_matches[0] | |
| match_reason = f"Type d'objet '{best_match['art_piece_type']}' préféré" | |
| logger.info( | |
| f" Type '{best_match['art_piece_type']}' sélectionné avec score {best_type_score}" | |
| ) | |
| elif len(best_matches) == 1: | |
| best_match = best_matches[0] | |
| match_reason = "Meilleure correspondance émotionnelle" | |
| else: | |
| logger.info("Aucune correspondance trouvée") | |
| return None, "Aucune correspondance trouvée", {} | |
| reasons = [] | |
| if len(valid_preselection) > 0: | |
| if firstname and best_match["related_names"]: | |
| name_score = Optimizer.name_similarity( | |
| firstname, best_match["related_names"] | |
| ) | |
| if name_score > 0: | |
| reasons.append(f"prénom '{firstname}' trouvé") | |
| if birth_date and best_match["related_dates"]: | |
| date_score = Optimizer.date_similarity( | |
| birth_date, best_match["related_dates"] | |
| ) | |
| if date_score > 0: | |
| reasons.append( | |
| f"date {'exacte' if date_score == 1.0 else 'partielle'}" | |
| ) | |
| if city and best_match["related_places"]: | |
| place_score = self.optimizer_helper.place_similarity( | |
| city, best_match["related_places"] | |
| ) | |
| if place_score > 0: | |
| reasons.append(f"ville '{city}' trouvée") | |
| if best_emotion_score > 0: | |
| reasons.append(f"correspondance émotionnelle") | |
| if len(reasons) == 0: | |
| reasons.append(match_reason) | |
| final_reason = " ; ".join(reasons) | |
| logger.info(f"\n🏆 RÉSULTAT FINAL: Œuvre #{best_match['database_id']}") | |
| logger.info(f" Raison: {final_reason}") | |
| logger.info(f" Type: {best_match['art_piece_type']}") | |
| logger.info(f" Lieu: {best_match['art_piece_place']}") | |
| match_image = self._get_artwork_image(best_match) | |
| match_info = { | |
| "title": f"Œuvre #{best_match['database_id']}", | |
| "type": best_match["art_piece_type"], | |
| "place": best_match["art_piece_place"], | |
| "emotions": best_match["related_emotions"], | |
| "explanation": best_match["explanation"], | |
| } | |
| return match_image, final_reason, match_info | |
| csv_path = "PP1-Collection_Database_new-cleaned.csv" | |
| images_dir = "pictures_data" | |
| if not os.path.exists(csv_path): | |
| logger.error(f"Fichier CSV introuvable: {csv_path}") | |
| if not os.path.exists(images_dir): | |
| logger.error(f"Répertoire images introuvable: {images_dir}") | |
| matcher = ArtMatcherV2(csv_path, images_dir) | |
| def process_user_info(firstname: str, birthday: str, city: str, state: SessionState): | |
| """Traite les informations utilisateur avec validation""" | |
| firstname = SecurityValidator.sanitize_input(firstname) | |
| city = SecurityValidator.sanitize_input(city) | |
| state.firstname = firstname | |
| state.birthday = birthday | |
| state.city = city | |
| if not firstname or not birthday: | |
| return ( | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| "Veuillez remplir au moins votre prénom et date de naissance.", | |
| state, | |
| ) | |
| is_valid, _ = SecurityValidator.validate_date(birthday) | |
| if not is_valid: | |
| return ( | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| "Format de date invalide. Utilisez JJ/MM (ex: 15/03)", | |
| state, | |
| ) | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| "Informations enregistrées ! Passons à la sélection d'images.", | |
| state, | |
| ) | |
| def load_images_for_round(round_num: int, state: SessionState): | |
| """Charge 3 images pour un tour de sélection""" | |
| images_data = matcher.get_random_images_for_selection( | |
| round_num, state.selected_images | |
| ) | |
| if len(images_data) < ScoringWeights.MAX_IMAGES_PER_SELECTION: | |
| logger.warning(f"Seulement {len(images_data)} images disponibles") | |
| return ( | |
| [None, None, None], | |
| [], | |
| f"Pas assez d'images disponibles (seulement {len(images_data)} trouvées)", | |
| state, | |
| ) | |
| images = [img[0] for img in images_data] | |
| ids = [img[1] for img in images_data] | |
| state.current_image_ids = ids | |
| return ( | |
| images, | |
| ids, | |
| f"Tour {round_num + 1}/{ScoringWeights.TOTAL_ROUNDS} : Sélectionnez l'image qui vous attire le plus", | |
| state, | |
| ) | |
| def select_image(choice: Optional[int], state: SessionState): | |
| """Traite la sélection d'image""" | |
| if choice is None: | |
| return ( | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| "Veuillez sélectionner une image", | |
| state, | |
| ) | |
| if state.current_image_ids and len(state.current_image_ids) > choice: | |
| selected_id = state.current_image_ids[choice] | |
| else: | |
| return ( | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| "Erreur: image non trouvée", | |
| state, | |
| ) | |
| state.selected_images.append(selected_id) | |
| state.current_round += 1 | |
| logger.info( | |
| f"Tour {state.current_round}: Image {choice+1} sélectionnée (ID: {selected_id})" | |
| ) | |
| if state.current_round < ScoringWeights.TOTAL_ROUNDS: | |
| new_images, new_ids, message, state = load_images_for_round( | |
| state.current_round, state | |
| ) | |
| return ( | |
| gr.update(value=new_images[0]), | |
| gr.update(value=new_images[1]), | |
| gr.update(value=new_images[2]), | |
| gr.update(value=None), | |
| message, | |
| state, | |
| gr.update(visible=True), # keep selection_section visible | |
| gr.update(visible=False), # keep loading_section hidden | |
| ) | |
| else: | |
| # Toutes les sélections sont terminées, afficher le loading | |
| return ( | |
| gr.update(), # img1 | |
| gr.update(), # img2 | |
| gr.update(), # img3 | |
| gr.update(), # image_choice | |
| "", # status_message vide | |
| state, | |
| gr.update(visible=False), # hide selection_section | |
| gr.update(visible=True), # show loading_section | |
| ) | |
| def show_results(state: SessionState): | |
| """Affiche les résultats finaux""" | |
| if not state.is_complete(): | |
| return ( | |
| gr.update(visible=False), # info_section | |
| gr.update(visible=True), # selection_section | |
| gr.update(visible=False), # loading_section | |
| gr.update(visible=False), # results_section | |
| None, | |
| "", | |
| "", | |
| ) | |
| match_image, reason, info = matcher.find_best_match( | |
| state.firstname, | |
| state.birthday, | |
| state.city, | |
| state.selected_images, | |
| ) | |
| if match_image: | |
| # Déterminer le type de système de recommandation utilisé | |
| if "correspond exactement" in reason.lower(): | |
| # Match exact sur nom, date ou lieu | |
| recommendation_type = "name_date_place" | |
| else: | |
| # Match basé sur les émotions | |
| recommendation_type = "emotions" | |
| # Enregistrer l'œuvre finale et le type de recommandation | |
| state.final_artwork = info.get("title", "Œuvre inconnue") | |
| state.recommendation_type = recommendation_type | |
| # Logger la session | |
| session_tracker.log_session(state, recommendation_type) | |
| explanation = f""" | |
| **Votre œuvre correspondante a été trouvée !** | |
| **Raison du match :** {reason} | |
| **Détails de l'œuvre :** | |
| - Type : {info.get('type', 'Non spécifié')} | |
| - Lieu : {info.get('place', 'Non spécifié')} | |
| - Émotions : {', '.join(info.get('emotions', [])) if info.get('emotions') else 'Non spécifiées'} | |
| **Description :** | |
| {info.get('explanation', 'Aucune description disponible')} | |
| """ | |
| else: | |
| # Aucune œuvre trouvée - logger quand même | |
| state.final_artwork = "Aucune œuvre trouvée" | |
| state.recommendation_type = "none" | |
| session_tracker.log_session(state, "none") | |
| explanation = "Désolé, aucune œuvre correspondante n'a pu être trouvée." | |
| return ( | |
| gr.update(visible=False), # info_section | |
| gr.update(visible=False), # selection_section | |
| gr.update(visible=False), # loading_section | |
| gr.update(visible=True), # results_section | |
| match_image, | |
| info.get("title", "Œuvre non trouvée") if match_image else "Œuvre non trouvée", | |
| explanation, | |
| ) | |
| with gr.Blocks( | |
| title="Art Matcher", | |
| theme=gr.themes.Soft(primary_hue="teal", secondary_hue="teal", neutral_hue="zinc"), | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎨 Art Matcher | |
| ### Découvrez l'œuvre d'art qui vous correspond ! | |
| Cette application utilise vos informations personnelles et vos préférences visuelles | |
| pour trouver l'œuvre d'art qui vous correspond le mieux dans notre collection. | |
| """ | |
| ) | |
| session_state = gr.State(SessionState()) | |
| with gr.Group(visible=True) as info_section: | |
| gr.Markdown("### Étape 1 : Vos informations") | |
| with gr.Row(): | |
| firstname_input = gr.Textbox( | |
| label="Prénom", placeholder="Entrez votre prénom", max_lines=1 | |
| ) | |
| birthday_input = gr.Textbox( | |
| label="Date d'anniversaire (JJ/MM)", | |
| placeholder="Ex: 25/12", | |
| max_lines=1, | |
| ) | |
| city_input = gr.Textbox( | |
| label="Ville de résidence", placeholder="Ex: Paris", max_lines=1 | |
| ) | |
| submit_info_btn = gr.Button("Valider mes informations", variant="primary") | |
| with gr.Group(visible=False) as selection_section: | |
| selection_title = gr.Markdown("### Étape 2 : Sélection d'images") | |
| with gr.Row(): | |
| img1 = gr.Image(label="Image 1", type="filepath", height=300) | |
| img2 = gr.Image(label="Image 2", type="filepath", height=300) | |
| img3 = gr.Image(label="Image 3", type="filepath", height=300) | |
| image_choice = gr.Radio( | |
| choices=["Image 1", "Image 2", "Image 3"], | |
| label="Quelle image vous attire le plus ?", | |
| type="index", | |
| ) | |
| select_btn = gr.Button("Valider mon choix", variant="primary") | |
| with gr.Group(visible=False) as loading_section: | |
| gr.Markdown("### ⏳ Analyse en cours...") | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center; padding: 40px;"> | |
| <div style="display: inline-block; width: 60px; height: 60px; border: 6px solid #f3f3f3; border-top: 6px solid #14b8a6; border-radius: 50%; animation: spin 1s linear infinite;"></div> | |
| <style> | |
| @keyframes spin { | |
| 0% { transform: rotate(0deg); } | |
| 100% { transform: rotate(360deg); } | |
| } | |
| </style> | |
| <p style="margin-top: 20px; font-size: 18px; color: #666;"> | |
| <strong>Traitement de vos sélections...</strong><br> | |
| <span style="font-size: 14px;">Nous analysons votre profil pour trouver l'œuvre parfaite</span> | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Group(visible=False) as results_section: | |
| gr.Markdown("### Votre œuvre correspondante") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| result_image = gr.Image(label="Votre œuvre", height=400) | |
| result_title = gr.Markdown("## Titre de l'œuvre") | |
| with gr.Column(scale=1): | |
| result_explanation = gr.Markdown("") | |
| restart_btn = gr.Button("Terminer", variant="primary") | |
| status_message = gr.Markdown("") | |
| def on_info_submit(firstname, birthday, city, state): | |
| info_vis, select_vis, results_vis, message, state = process_user_info( | |
| firstname, birthday, city, state | |
| ) | |
| if select_vis["visible"]: | |
| images, ids, round_message, state = load_images_for_round(0, state) | |
| return ( | |
| info_vis, | |
| select_vis, | |
| results_vis, | |
| images[0] if len(images) > 0 else None, | |
| images[1] if len(images) > 1 else None, | |
| images[2] if len(images) > 2 else None, | |
| round_message, | |
| state, | |
| ) | |
| else: | |
| return (info_vis, select_vis, results_vis, None, None, None, message, state) | |
| submit_info_btn.click( | |
| fn=on_info_submit, | |
| inputs=[firstname_input, birthday_input, city_input, session_state], | |
| outputs=[ | |
| info_section, | |
| selection_section, | |
| results_section, | |
| img1, | |
| img2, | |
| img3, | |
| status_message, | |
| session_state, | |
| ], | |
| ) | |
| def on_image_select(choice, state): | |
| result = select_image(choice, state) | |
| # La fonction select_image retourne maintenant 8 valeurs | |
| if len(result) == 8: | |
| ( | |
| img1_update, | |
| img2_update, | |
| img3_update, | |
| choice_update, | |
| message, | |
| state, | |
| selection_vis, | |
| loading_vis, | |
| ) = result | |
| return ( | |
| gr.update(), # info_section | |
| selection_vis, # selection_section | |
| loading_vis, # loading_section | |
| gr.update(), # results_section | |
| img1_update, # img1 | |
| img2_update, # img2 | |
| img3_update, # img3 | |
| choice_update, # image_choice | |
| message, # status_message | |
| state, | |
| ) | |
| else: | |
| # Format avec 6 valeurs (cas sans loading) | |
| (img1_update, img2_update, img3_update, choice_update, message, state) = ( | |
| result | |
| ) | |
| return ( | |
| gr.update(), # info_section | |
| gr.update(), # selection_section | |
| gr.update(), # loading_section | |
| gr.update(), # results_section | |
| img1_update, # img1 | |
| img2_update, # img2 | |
| img3_update, # img3 | |
| choice_update, # image_choice | |
| message, # status_message | |
| state, | |
| ) | |
| def handle_final_results(state): | |
| if state.is_complete(): | |
| return show_results(state) | |
| else: | |
| return gr.update(), gr.update(), gr.update(), gr.update(), None, "", "" | |
| select_btn.click( | |
| fn=on_image_select, | |
| inputs=[image_choice, session_state], | |
| outputs=[ | |
| info_section, | |
| selection_section, | |
| loading_section, | |
| results_section, | |
| img1, | |
| img2, | |
| img3, | |
| image_choice, | |
| status_message, | |
| session_state, | |
| ], | |
| ).then( | |
| fn=handle_final_results, | |
| inputs=[session_state], | |
| outputs=[ | |
| info_section, | |
| selection_section, | |
| loading_section, | |
| results_section, | |
| result_image, | |
| result_title, | |
| result_explanation, | |
| ], | |
| ) | |
| def restart_app(state): | |
| state.reset() | |
| return ( | |
| gr.update(visible=True), # info_section | |
| gr.update(visible=False), # selection_section | |
| gr.update(visible=False), # loading_section | |
| gr.update(visible=False), # results_section | |
| "", # firstname_input | |
| "", # birthday_input | |
| "", # city_input | |
| None, # image_choice | |
| "Application réinitialisée. Veuillez entrer vos informations.", # status_message | |
| state, | |
| ) | |
| restart_btn.click( | |
| fn=restart_app, | |
| inputs=[session_state], | |
| outputs=[ | |
| info_section, | |
| selection_section, | |
| loading_section, | |
| results_section, | |
| firstname_input, | |
| birthday_input, | |
| city_input, | |
| image_choice, | |
| status_message, | |
| session_state, | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |