Aurel-test's picture
Upload folder using huggingface_hub
c3c0d39 verified
from .database import Database
from .query import TargetProfile, WeightedLeximaxOptimizer, WeightedSumOptimizer
from .utils import str_to_date
from .emotions import EmotionWheel
import argparse
import logging
import os
import sys
class CliOptions:
"""
A class used to manage the CLI options.
"""
def __init__(self):
"""
Loads the options using argparse and check them, throwing an exception
in case the arguments do not fit the requirements.
"""
parser = argparse.ArgumentParser(
prog="art_pieces_db", description="Manage the database of art pieces"
)
subparsers = parser.add_subparsers(required=True)
parser_help = subparsers.add_parser(
"check-csv", help="checks the CSV input data and exit"
)
parser_help.add_argument("input_csv")
parser_help.set_defaults(func=self.execute_check_csv)
parser_query = subparsers.add_parser(
"query",
help="query the CSV input data to find the the pieces that are the closest to a target profile",
)
parser_query.add_argument("input_csv")
parser_query.add_argument("--name")
parser_query.add_argument("--date")
parser_query.add_argument("--emotion")
parser_query.add_argument("--place")
parser_query.add_argument(
"--aggregator", choices=["sum", "leximax"], default="leximax"
)
parser_query.add_argument("--limit", type=int, default=10)
parser_query.add_argument(
"--weight-name",
type=float,
default=1.0,
help="Weight for name similarity (default: 1.0)",
)
parser_query.add_argument(
"--weight-date",
type=float,
default=1.0,
help="Weight for date similarity (default: 1.0)",
)
parser_query.add_argument(
"--weight-emotion",
type=float,
default=1.0,
help="Weight for emotion similarity (default: 1.0)",
)
parser_query.add_argument(
"--weight-place",
type=float,
default=1.0,
help="Weight for place similarity (default: 1.0)",
)
parser_query.set_defaults(func=self.execute_query)
parser_emotions = subparsers.add_parser(
"list-emotions", help="list all valid emotions from Plutchik's wheel"
)
parser_emotions.set_defaults(func=self.execute_list_emotions)
args = parser.parse_args()
args.func(args)
def execute_check_csv(self, args):
"""
Checks an input CSV file.
"""
CliOptions.load_database(args)
def execute_query(self, args):
"""
Query the CSV input data to find the the pieces that are the closest to a target profile.
"""
database = CliOptions.load_database(args)
optimizer = CliOptions.create_optimizer(args)
df = optimizer.optimize_max(database).head(args.limit)
df.index.name = "result_index"
df.to_csv(
sys.stdout,
columns=[
"database_id",
"related_names",
"related_dates",
"related_places",
"related_emotions",
"score",
],
)
def execute_list_emotions(self, args):
"""
List all valid emotions from Plutchik's wheel.
"""
wheel = EmotionWheel()
print("\nPlutchik's Wheel of Emotions")
print("=" * 50)
print("\nPrimary Emotions with Intensity Levels:")
print("-" * 50)
for primary, emotion in wheel.emotions.items():
print(f"\n{primary.value.upper()}:")
print(f" Mild: {emotion.mild}")
print(f" Basic: {emotion.basic}")
print(f" Intense: {emotion.intense}")
print("\n\nEmotion Opposites:")
print("-" * 50)
shown = set()
for e1, e2 in wheel.opposites.items():
pair = tuple(sorted([e1.value, e2.value]))
if pair not in shown:
print(f" {e1.value} <-> {e2.value}")
shown.add(pair)
print("\n\nEmotion Combinations (Dyads):")
print("-" * 50)
for (e1, e2), result in sorted(wheel.dyads.items()):
print(f" {e1} + {e2} = {result}")
def load_database(args):
csv_file = os.path.abspath(args.input_csv)
if not os.access(csv_file, os.R_OK):
logging.fatal(f"cannot read input file {csv_file}")
CliOptions.exit_on_param_error()
logging.info(f"reading CSV file {csv_file}")
database = Database(csv_file)
logging.info(f"read a database with {database.n_pieces()} art pieces")
return database
def create_optimizer(args):
profile = TargetProfile()
if args.name is not None:
profile.set_target_name(args.name)
if args.date is not None:
try:
profile.set_target_date(str_to_date(args.date))
except ValueError:
logging.fatal(
f'cannot translate argument "{args.date}" into a date (type e.g. 25/12/2025)'
)
CliOptions.exit_on_param_error()
if args.emotion is not None:
profile.set_target_emotion(args.emotion.lower())
if args.place is not None:
profile.set_target_place(args.place)
logging.info(f"target profile is {profile}")
weights = {
"related_names": args.weight_name,
"related_dates": args.weight_date,
"related_emotions": args.weight_emotion,
"related_places": args.weight_place,
}
if args.aggregator == "sum":
logging.info("aggregator is Sum")
return WeightedSumOptimizer(profile, weights)
elif args.aggregator == "leximax":
logging.info("aggregator is Leximax")
return WeightedLeximaxOptimizer(profile, weights)
else:
logging.fatal(f'unknown aggregator "{args.aggregator}"')
CliOptions.exit_on_param_error()
def exit_on_param_error():
os._exit(3)