import functools import hashlib import logging import os import re import shutil import time import traceback from typing import List, Optional import pandas as pd from h2o_wave import Q, ui from h2o_wave.types import FormCard, ImageCard, MarkupCard, StatListItem, Tab from llm_studio.app_utils.config import default_cfg from llm_studio.app_utils.db import Dataset from llm_studio.app_utils.sections.common import clean_dashboard from llm_studio.app_utils.sections.experiment import experiment_start from llm_studio.app_utils.sections.histogram_card import histogram_card from llm_studio.app_utils.utils import ( add_model_type, azure_download, azure_file_options, check_valid_upload_content, clean_error, dir_file_table, get_data_dir, get_dataset_elements, get_datasets, get_experiments_status, get_frame_stats, get_model_types, get_problem_types, get_unique_dataset_name, kaggle_download, local_download, make_label, parse_ui_elements, remove_temp_files, s3_download, s3_file_options, ) from llm_studio.app_utils.wave_utils import busy_dialog, ui_table_from_df from llm_studio.src.datasets.conversation_chain_handler import get_conversation_chains from llm_studio.src.utils.config_utils import ( load_config_py, load_config_yaml, save_config_yaml, ) from llm_studio.src.utils.data_utils import ( get_fill_columns, read_dataframe, read_dataframe_drop_missing_labels, sanity_check, ) from llm_studio.src.utils.plot_utils import PlotData logger = logging.getLogger(__name__) def file_extension_is_compatible(q): cfg = q.client["dataset/import/cfg"] allowed_extensions = cfg.dataset._allowed_file_extensions is_correct_extension = [] for mode in ["train", "validation"]: dataset_name = q.client[f"dataset/import/cfg/{mode}_dataframe"] if dataset_name is None or dataset_name == "None": continue is_correct_extension.append(dataset_name.endswith(allowed_extensions)) return all(is_correct_extension) async def dataset_import( q: Q, step: int, edit: Optional[bool] = False, error: Optional[str] = "", warning: Optional[str] = "", info: Optional[str] = "", allow_merge: bool = True, ) -> None: """Display dataset import cards. Args: q: Q step: current step of wizard edit: whether in edit mode error: optional error message warning: optional warning message info: optional info message allow_merge: whether to allow merging dataset when importing """ await clean_dashboard(q, mode="full") q.client["nav/active"] = "dataset/import" if step == 1: # select import data source q.page["dataset/import"] = ui.form_card(box="content", items=[]) q.client.delete_cards.add("dataset/import") if q.client["dataset/import/source"] is None: q.client["dataset/import/source"] = "Upload" import_choices = [ ui.choice("Upload", "Upload"), ui.choice("Local", "Local"), ui.choice("S3", "AWS S3"), ui.choice("Azure", "Azure Datalake"), ui.choice("Kaggle", "Kaggle"), ] items = [ ui.text_l("Import dataset"), ui.dropdown( name="dataset/import/source", label="Source", value=( "Upload" if q.client["dataset/import/source"] is None else q.client["dataset/import/source"] ), choices=import_choices, trigger=True, tooltip="Source of dataset import", ), ] if ( q.client["dataset/import/source"] is None or q.client["dataset/import/source"] == "S3" ): if q.client["dataset/import/s3_bucket"] is None: q.client["dataset/import/s3_bucket"] = q.client[ "default_aws_bucket_name" ] if q.client["dataset/import/s3_access_key"] is None: q.client["dataset/import/s3_access_key"] = q.client[ "default_aws_access_key" ] if q.client["dataset/import/s3_secret_key"] is None: q.client["dataset/import/s3_secret_key"] = q.client[ "default_aws_secret_key" ] files = s3_file_options( q.client["dataset/import/s3_bucket"], q.client["dataset/import/s3_access_key"], q.client["dataset/import/s3_secret_key"], ) if not files: ui_filename = ui.textbox( name="dataset/import/s3_filename", label="File name", value="", required=True, tooltip="File name to be imported", ) else: if default_cfg.s3_filename in files: default_file = default_cfg.s3_filename else: default_file = files[0] ui_filename = ui.dropdown( name="dataset/import/s3_filename", label="File name", value=default_file, choices=[ui.choice(x, x.split("/")[-1]) for x in files], required=True, tooltip="File name to be imported", ) items += [ ui.textbox( name="dataset/import/s3_bucket", label="S3 bucket name", value=q.client["dataset/import/s3_bucket"], trigger=True, required=True, tooltip="S3 bucket name including relative paths", ), ui.textbox( name="dataset/import/s3_access_key", label="AWS access key", value=q.client["dataset/import/s3_access_key"], trigger=True, required=True, password=True, tooltip="Optional AWS access key; empty for anonymous access.", ), ui.textbox( name="dataset/import/s3_secret_key", label="AWS secret key", value=q.client["dataset/import/s3_secret_key"], trigger=True, required=True, password=True, tooltip="Optional AWS secret key; empty for anonymous access.", ), ui_filename, ] elif ( q.client["dataset/import/source"] is None or q.client["dataset/import/source"] == "Azure" ): if q.client["dataset/import/azure_conn_string"] is None: q.client["dataset/import/azure_conn_string"] = q.client[ "default_azure_conn_string" ] if q.client["dataset/import/azure_container"] is None: q.client["dataset/import/azure_container"] = q.client[ "default_azure_container" ] files = azure_file_options( q.client["dataset/import/azure_conn_string"], q.client["dataset/import/azure_container"], ) print(files) if not files: ui_filename = ui.textbox( name="dataset/import/azure_filename", label="File name", value="", required=True, tooltip="File name to be imported", ) else: default_file = files[0] ui_filename = ui.dropdown( name="dataset/import/azure_filename", label="File name", value=default_file, choices=[ui.choice(x, x.split("/")[-1]) for x in files], required=True, tooltip="File name to be imported", ) items += [ ui.textbox( name="dataset/import/azure_conn_string", label="Datalake connection string", value=q.client["dataset/import/azure_conn_string"], trigger=True, required=True, password=True, tooltip="Azure connection string to connect to Datalake storage", ), ui.textbox( name="dataset/import/azure_container", label="Datalake container name", value=q.client["dataset/import/azure_container"], trigger=True, required=True, tooltip="Azure Datalake container name including relative paths", ), ui_filename, ] elif q.client["dataset/import/source"] == "Upload": items += [ ui.file_upload( name="dataset/import/local_upload", label="Upload!", multiple=False, file_extensions=default_cfg.allowed_file_extensions, ) ] elif q.client["dataset/import/source"] == "Local": current_path = ( q.client["dataset/import/local_path_current"] if q.client["dataset/import/local_path_current"] is not None else os.path.expanduser("~") ) if q.args.__wave_submission_name__ == "dataset/import/local_path_list": idx = int(q.args["dataset/import/local_path_list"][0]) options = q.client["dataset/import/local_path_list_last"] new_path = os.path.abspath(os.path.join(current_path, options[idx])) if os.path.exists(new_path): current_path = new_path results_df = dir_file_table(current_path) files_list = results_df[current_path].tolist() q.client["dataset/import/local_path_list_last"] = files_list q.client["dataset/import/local_path_current"] = current_path items += [ ui.textbox( name="dataset/import/local_path", label="File location", value=current_path, required=True, tooltip="Location of file to be imported", ), ui_table_from_df( q=q, df=results_df, name="dataset/import/local_path_list", sortables=[], searchables=[], min_widths={current_path: "400"}, link_col=current_path, height="calc(65vh)", ), ] elif q.client["dataset/import/source"] == "Kaggle": if q.client["dataset/import/kaggle_access_key"] is None: q.client["dataset/import/kaggle_access_key"] = q.client[ "default_kaggle_username" ] if q.client["dataset/import/kaggle_secret_key"] is None: q.client["dataset/import/kaggle_secret_key"] = q.client[ "default_kaggle_secret_key" ] items += [ ui.textbox( name="dataset/import/kaggle_command", label="Kaggle API command", value=default_cfg.kaggle_command, required=True, tooltip="Kaggle API command to be executed", ), ui.textbox( name="dataset/import/kaggle_access_key", label="Kaggle username", value=q.client["dataset/import/kaggle_access_key"], required=True, password=False, tooltip="Kaggle username for API authentication", ), ui.textbox( name="dataset/import/kaggle_secret_key", label="Kaggle secret key", value=q.client["dataset/import/kaggle_secret_key"], required=True, password=True, tooltip="Kaggle secret key for API authentication", ), ] allowed_types = ", ".join(default_cfg.allowed_file_extensions) allowed_types = " or".join(allowed_types.rsplit(",", 1)) items += [ ui.message_bar(type="info", text=info + f"Must be a {allowed_types} file."), ui.message_bar(type="error", text=error), ui.message_bar(type="warning", text=warning), ] q.page["dataset/import"].items = items buttons = [ui.button(name="dataset/list", label="Abort")] if q.client["dataset/import/source"] != "Upload": buttons.insert( 0, ui.button(name="dataset/import/2", label="Continue", primary=True) ) q.page["dataset/import/footer"] = ui.form_card( box="footer", items=[ui.inline(items=buttons, justify="start")] ) q.client.delete_cards.add("dataset/import/footer") q.client["dataset/import/id"] = None q.client["dataset/import/cfg_file"] = None elif step == 2: # download / import data from source q.page["dataset/import/footer"] = ui.form_card(box="footer", items=[]) try: if not q.args["dataset/import/cfg_file"] and not edit: if q.client["dataset/import/source"] == "S3": ( q.client["dataset/import/path"], q.client["dataset/import/name"], ) = await s3_download( q, q.client["dataset/import/s3_bucket"], q.client["dataset/import/s3_filename"], q.client["dataset/import/s3_access_key"], q.client["dataset/import/s3_secret_key"], ) elif q.client["dataset/import/source"] == "Azure": ( q.client["dataset/import/path"], q.client["dataset/import/name"], ) = await azure_download( q, q.client["dataset/import/azure_conn_string"], q.client["dataset/import/azure_container"], q.client["dataset/import/azure_filename"], ) elif q.client["dataset/import/source"] in ("Upload", "Local"): ( q.client["dataset/import/path"], q.client["dataset/import/name"], ) = await local_download(q, q.client["dataset/import/local_path"]) elif q.client["dataset/import/source"] == "Kaggle": ( q.client["dataset/import/path"], q.client["dataset/import/name"], ) = await kaggle_download( q, q.client["dataset/import/kaggle_command"], q.client["dataset/import/kaggle_access_key"], q.client["dataset/import/kaggle_secret_key"], ) # store if in edit mode q.client["dataset/import/edit"] = edit # clear dataset triggers from client for trigger_key in default_cfg.dataset_trigger_keys: if q.client[f"dataset/import/cfg/{trigger_key}"]: del q.client[f"dataset/import/cfg/{trigger_key}"] await dataset_import( q, step=3, edit=edit, error=error, warning=warning, allow_merge=allow_merge, ) except Exception: logger.error("Dataset error:", exc_info=True) error = ( "Dataset import failed. Please make sure all required " "fields are filled correctly." ) await clean_dashboard(q, mode="full") await dataset_import(q, step=1, error=str(error)) elif step == 3: # set dataset configuration q.page["dataset/import/footer"] = ui.form_card(box="footer", items=[]) try: if not q.args["dataset/import/cfg_file"] and not edit: q.client["dataset/import/name"] = get_unique_dataset_name( q, q.client["dataset/import/name"] ) q.page["dataset/import"] = ui.form_card(box="content", items=[]) q.client.delete_cards.add("dataset/import") wizard = q.page["dataset/import"] title = "Configure dataset" items = [ ui.text_l(title), ui.textbox( name="dataset/import/name", label="Dataset name", value=q.client["dataset/import/name"], required=True, ), ] choices_problem_types = [ ui.choice(name, label) for name, label in get_problem_types() ] if q.client["dataset/import/cfg_file"] is None: max_substring_len = 0 for c in choices_problem_types: problem_type_name = c.name.replace("_config", "") if problem_type_name in q.client["dataset/import/name"]: if len(problem_type_name) > max_substring_len: q.client["dataset/import/cfg_file"] = c.name q.client["dataset/import/cfg_category"] = c.name.split("_")[ 0 ] max_substring_len = len(problem_type_name) if q.client["dataset/import/cfg_file"] is None: q.client["dataset/import/cfg_file"] = default_cfg.cfg_file q.client["dataset/import/cfg_category"] = q.client[ # type: ignore "dataset/import/cfg_file" ].split("_")[0] # set default value of problem type if no match to category if ( q.client["dataset/import/cfg_category"] not in q.client["dataset/import/cfg_file"] ): q.client["dataset/import/cfg_file"] = get_problem_types( category=q.client["dataset/import/cfg_category"] )[0][0] model_types = get_model_types(q.client["dataset/import/cfg_file"]) if len(model_types) > 0: # add model type to cfg file name here q.client["dataset/import/cfg_file"] = add_model_type( q.client["dataset/import/cfg_file"], model_types[0][0] ) if not edit: q.client["dataset/import/cfg"] = load_config_py( config_path=( f"llm_studio/python_configs/" f"{q.client['dataset/import/cfg_file']}" ), config_name="ConfigProblemBase", ) option_items = get_dataset_elements(cfg=q.client["dataset/import/cfg"], q=q) items.extend(option_items) items.append(ui.message_bar(type="error", text=error)) items.append(ui.message_bar(type="warning", text=warning)) if file_extension_is_compatible(q): ui_nav_name = "dataset/import/4/edit" if edit else "dataset/import/4" buttons = [ ui.button(name=ui_nav_name, label="Continue", primary=True), ui.button(name="dataset/list", label="Abort"), ] if allow_merge: datasets_df = q.client.app_db.get_datasets_df() if datasets_df.shape[0]: label = "Merge With Existing Dataset" buttons.insert(1, ui.button(name="dataset/merge", label=label)) else: problem_type = make_label( re.sub("_config.*", "", q.client["dataset/import/cfg_file"]) ) items += [ ui.text( " The chosen file extensions is not " f"compatible with {problem_type}. " ) ] buttons = [ ui.button(name="dataset/list", label="Abort"), ] q.page["dataset/import/footer"] = ui.form_card( box="footer", items=[ui.inline(items=buttons, justify="start")] ) wizard.items = items q.client.delete_cards.add("dataset/import/footer") except Exception as exception: logger.error("Dataset error:", exc_info=True) error = clean_error(str(exception)) await clean_dashboard(q, mode="full") await dataset_import(q, step=1, error=str(error)) elif step == 4: # verify if dataset does not exist already dataset_name = q.client["dataset/import/name"] original_name = q.client["dataset/import/original_name"] # used in edit mode valid_dataset_name = get_unique_dataset_name(q, dataset_name) if valid_dataset_name != dataset_name and not ( q.client["dataset/import/edit"] and dataset_name == original_name ): err = f"Dataset {dataset_name} already exists" q.client["dataset/import/name"] = valid_dataset_name await dataset_import(q, 3, edit=edit, error=err) else: await dataset_import(q, 5, edit=edit) elif step == 5: # visualize dataset header = "
{traceback.format_exc()}")],
),
]
buttons = [
ui.button(
name="dataset/import/6", label="Continue", primary=valid_visualization
),
ui.button(
name="dataset/import/3/edit",
label="Back",
primary=not valid_visualization,
),
ui.button(name="dataset/list", label="Abort"),
]
q.page["dataset/import"] = ui.form_card(box="content", items=items)
q.client.delete_cards.add("dataset/import")
q.page["dataset/import/footer"] = ui.form_card(
box="footer", items=[ui.inline(items=buttons, justify="start")]
)
q.client.delete_cards.add("dataset/import/footer")
elif step == 6: # create dataset
if q.client["dataset/import/name"] == "":
await clean_dashboard(q, mode="full")
await dataset_import(q, step=2, error="Please enter all required fields!")
else:
folder_name = q.client["dataset/import/path"].split("/")[-1]
new_folder = q.client["dataset/import/name"]
act_path = q.client["dataset/import/path"]
new_path = new_folder.join(act_path.rsplit(folder_name, 1))
try:
shutil.move(q.client["dataset/import/path"], new_path)
cfg = q.client["dataset/import/cfg"]
# remap old path to new path
for k in default_cfg.dataset_folder_keys:
old_path = getattr(cfg.dataset, k, None)
if old_path is not None:
setattr(
cfg.dataset,
k,
old_path.replace(q.client["dataset/import/path"], new_path),
)
# change the default validation strategy if validation df set
if cfg.dataset.validation_dataframe != "None":
cfg.dataset.validation_strategy = "custom"
cfg_path = f"{new_path}/{q.client['dataset/import/cfg_file']}.yaml"
save_config_yaml(cfg_path, cfg)
train_rows = None
if os.path.exists(cfg.dataset.train_dataframe):
train_rows = read_dataframe_drop_missing_labels(
cfg.dataset.train_dataframe, cfg
).shape[0]
validation_rows = None
if os.path.exists(cfg.dataset.validation_dataframe):
validation_rows = read_dataframe_drop_missing_labels(
cfg.dataset.validation_dataframe, cfg
).shape[0]
dataset = Dataset(
id=q.client["dataset/import/id"],
name=q.client["dataset/import/name"],
path=new_path,
config_file=cfg_path,
train_rows=train_rows,
validation_rows=validation_rows,
)
if q.client["dataset/import/id"] is not None:
q.client.app_db.delete_dataset(dataset.id)
q.client.app_db.add_dataset(dataset)
await dataset_list(q)
except Exception as exception:
logger.error("Dataset error:", exc_info=True)
q.client.app_db._session.rollback()
error = clean_error(str(exception))
await clean_dashboard(q, mode="full")
await dataset_import(q, step=2, error=str(error))
async def dataset_merge(q: Q, step, error=""):
if step == 1: # Select which dataset to merge
await clean_dashboard(q, mode="full")
q.client["nav/active"] = "dataset/merge"
q.page["dataset/merge"] = ui.form_card(box="content", items=[])
q.client.delete_cards.add("dataset/merge")
datasets_df = q.client.app_db.get_datasets_df()
import_choices = [
ui.choice(x["path"], x["name"]) for idx, x in datasets_df.iterrows()
]
items = [
ui.text_l("Merge current dataset with an existing dataset"),
ui.dropdown(
name="dataset/merge/target",
label="Dataset",
value=datasets_df.iloc[0]["path"],
choices=import_choices,
trigger=False,
tooltip="Source of dataset import",
),
]
if error:
items.append(ui.message_bar(type="error", text=error))
q.page["dataset/merge"].items = items
buttons = [
ui.button(name="dataset/merge/action", label="Merge", primary=True),
ui.button(name="dataset/import/3", label="Back", primary=False),
ui.button(name="dataset/list", label="Abort"),
]
q.page["dataset/import/footer"] = ui.form_card(
box="footer", items=[ui.inline(items=buttons, justify="start")]
)
q.client.delete_cards.add("dataset/import/footer")
elif step == 2: # copy file to dataset and go to edit dataset
current_dir = q.client["dataset/import/path"]
target_dir = q.args["dataset/merge/target"]
if current_dir == target_dir:
await dataset_merge(q, step=1, error="Cannot merge dataset with itself")
return
datasets_df = q.client.app_db.get_datasets_df().set_index("path")
has_dataset_entry = current_dir in datasets_df.index
if has_dataset_entry:
experiment_df = q.client.app_db.get_experiments_df()
source_id = int(datasets_df.loc[current_dir, "id"])
has_experiment = any(experiment_df["dataset"].astype(int) == source_id)
else:
source_id = None
has_experiment = False
current_files = os.listdir(current_dir)
current_files = [x for x in current_files if not x.endswith(".yaml")]
target_files = os.listdir(target_dir)
overlapping_files = list(set(current_files).intersection(set(target_files)))
rename_map = {}
for file in overlapping_files:
tmp_str = file.split(".")
if len(tmp_str) == 1:
file_name, extension = file, ""
else:
file_name, extension = ".".join(tmp_str[:-1]), f".{tmp_str[-1]}"
cnt = 1
while f"{file_name}_{cnt}{extension}" in target_files:
cnt += 1
rename_map[file] = f"{file_name}_{cnt}{extension}"
target_files.append(rename_map[file])
if len(overlapping_files):
warning = (
f"Renamed {', '.join(rename_map.keys())} to "
f"{', '.join(rename_map.values())} due to duplicated entries."
)
else:
warning = ""
for file in current_files:
new_file = rename_map.get(file, file)
src = os.path.join(current_dir, file)
dst = os.path.join(target_dir, new_file)
if has_experiment:
if os.path.isdir(src):
shutil.copytree(src, dst)
else:
shutil.copy(src, dst)
else:
shutil.move(src, dst)
if not has_experiment:
shutil.rmtree(current_dir)
if has_dataset_entry:
q.client.app_db.delete_dataset(source_id)
dataset_id = int(datasets_df.loc[target_dir, "id"])
await dataset_edit(q, dataset_id, warning=warning, allow_merge=False)
async def dataset_list_table(
q: Q,
show_experiment_datasets: bool = True,
) -> None:
"""Pepare dataset list form card
Args:
q: Q
show_experiment_datasets: whether to also show datasets linked to experiments
"""
q.client["dataset/list/df_datasets"] = get_datasets(
q=q,
show_experiment_datasets=show_experiment_datasets,
)
df_viz = q.client["dataset/list/df_datasets"].copy()
columns_to_drop = [
"id",
"path",
"config_file",
"validation dataframe",
]
df_viz = df_viz.drop(columns=columns_to_drop, errors="ignore")
if "problem type" in df_viz.columns:
df_viz["problem type"] = df_viz["problem type"].str.replace("Text ", "")
widths = {
"name": "200",
"problem type": "210",
"train dataframe": "190",
"train rows": "120",
"validation rows": "130",
"labels": "120",
"actions": "5",
}
actions_dict = {
"dataset/newexperiment": "New experiment",
"dataset/edit": "Edit dataset",
"dataset/delete/dialog/single": "Delete dataset",
}
q.page["dataset/list"] = ui.form_card(
box="content",
items=[
ui_table_from_df(
q=q,
df=df_viz,
name="dataset/list/table",
sortables=["train rows", "validation rows"],
filterables=["name", "problem type"],
searchables=[],
min_widths=widths,
link_col="name",
height="calc(100vh - 245px)",
actions=actions_dict,
),
ui.message_bar(type="info", text=""),
],
)
q.client.delete_cards.add("dataset/list")
async def dataset_list(q: Q, reset: bool = True) -> None:
"""Display all datasets."""
q.client["nav/active"] = "dataset/list"
if reset:
await clean_dashboard(q, mode="full")
await dataset_list_table(q)
q.page["dataset/display/footer"] = ui.form_card(
box="footer",
items=[
ui.inline(
items=[
ui.button(
name="dataset/import", label="Import dataset", primary=True
),
ui.button(
name="dataset/list/delete",
label="Delete datasets",
primary=False,
),
],
justify="start",
)
],
)
q.client.delete_cards.add("dataset/display/footer")
remove_temp_files(q)
await q.page.save()
async def dataset_newexperiment(q: Q, dataset_id: int):
"""Start a new experiment from given dataset."""
dataset = q.client.app_db.get_dataset(dataset_id)
q.client["experiment/start/cfg_file"] = dataset.config_file.split("/")[-1].replace(
".yaml", ""
)
q.client["experiment/start/cfg_category"] = q.client[
"experiment/start/cfg_file"
].split("_")[0]
q.client["experiment/start/dataset"] = str(dataset_id)
await experiment_start(q)
async def dataset_edit(
q: Q, dataset_id: int, error: str = "", warning: str = "", allow_merge: bool = True
):
"""Edit selected dataset.
Args:
q: Q
dataset_id: dataset id to edit
error: optional error message
warning: optional warning message
allow_merge: whether to allow merging dataset when editing
"""
dataset = q.client.app_db.get_dataset(dataset_id)
experiments_df = q.client.app_db.get_experiments_df()
experiments_df = experiments_df[experiments_df["dataset"] == str(dataset_id)]
statuses, _ = get_experiments_status(experiments_df)
num_invalid = len([stat for stat in statuses if stat in ["running", "queued"]])
if num_invalid:
info = "s" if num_invalid > 1 else ""
info_str = (
f"Dataset {dataset.name} is linked to {num_invalid} "
f"running or queued experiment{info}. Wait for them to finish or stop them "
"first before editing the dataset."
)
q.page["dataset/list"].items[1].message_bar.text = info_str
return
q.client["dataset/import/id"] = dataset_id
q.client["dataset/import/cfg_file"] = dataset.config_file.split("/")[-1].replace(
".yaml", ""
)
q.client["dataset/import/cfg_category"] = q.client["dataset/import/cfg_file"].split(
"_"
)[0]
q.client["dataset/import/path"] = dataset.path
q.client["dataset/import/name"] = dataset.name
q.client["dataset/import/original_name"] = dataset.name
q.client["dataset/import/cfg"] = load_config_yaml(dataset.config_file)
if allow_merge and experiments_df.shape[0]:
allow_merge = False
await dataset_import(
q=q, step=2, edit=True, error=error, warning=warning, allow_merge=allow_merge
)
async def dataset_list_delete(q: Q):
"""Allow to select multiple datasets for deletion."""
await dataset_list_table(q, show_experiment_datasets=False)
q.page["dataset/list"].items[0].table.multiple = True
info_str = "Only datasets not linked to experiments can be deleted."
q.page["dataset/list"].items[1].message_bar.text = info_str
q.page["dataset/display/footer"].items = [
ui.inline(
items=[
ui.button(
name="dataset/delete/dialog", label="Delete datasets", primary=True
),
ui.button(name="dataset/list/delete/abort", label="Abort"),
]
)
]
async def dataset_delete(q: Q, dataset_ids: List[int]):
"""Delete selected datasets.
Args:
q: Q
dataset_ids: list of dataset ids to delete
"""
for dataset_id in dataset_ids:
dataset = q.client.app_db.get_dataset(dataset_id)
q.client.app_db.delete_dataset(dataset.id)
try:
shutil.rmtree(dataset.path)
except OSError:
pass
async def dataset_delete_single(q: Q, dataset_id: int):
dataset = q.client.app_db.get_dataset(dataset_id)
experiments_df = q.client.app_db.get_experiments_df()
num_experiments = sum(experiments_df["dataset"] == str(dataset_id))
if num_experiments:
info = "s" if num_experiments > 1 else ""
info_str = (
f"Dataset {dataset.name} is linked to {num_experiments} "
f"experiment{info}. Only datasets not linked to experiments can be deleted."
)
await dataset_list(q)
q.page["dataset/list"].items[1].message_bar.text = info_str
else:
await dataset_delete(q, [dataset_id])
await dataset_list(q)
async def dataset_display(q: Q) -> None:
"""Display a selected dataset."""
dataset_id = q.client["dataset/list/df_datasets"]["id"].iloc[
q.client["dataset/display/id"]
]
dataset: Dataset = q.client.app_db.get_dataset(dataset_id)
config_filename = dataset.config_file
cfg = load_config_yaml(config_filename)
dataset_filename = cfg.dataset.train_dataframe
if (
q.client["dataset/display/tab"] is None
or q.args["dataset/display/data"] is not None
):
q.client["dataset/display/tab"] = "dataset/display/data"
if q.args["dataset/display/visualization"] is not None:
q.client["dataset/display/tab"] = "dataset/display/visualization"
if q.args["dataset/display/statistics"] is not None:
q.client["dataset/display/tab"] = "dataset/display/statistics"
if q.args["dataset/display/summary"] is not None:
q.client["dataset/display/tab"] = "dataset/display/summary"
await clean_dashboard(q, mode=q.client["dataset/display/tab"])
items: List[Tab] = [
ui.tab(name="dataset/display/data", label="Sample Train Data"),
ui.tab(
name="dataset/display/visualization", label="Sample Train Visualization"
),
ui.tab(name="dataset/display/statistics", label="Train Data Statistics"),
ui.tab(name="dataset/display/summary", label="Summary"),
]
q.page["dataset/display/tab"] = ui.tab_card(
box="nav2",
link=True,
items=items,
value=q.client["dataset/display/tab"],
)
q.client.delete_cards.add("dataset/display/tab")
if q.client["dataset/display/tab"] == "dataset/display/data":
await show_data_tab(q=q, cfg=cfg, filename=dataset_filename)
elif q.client["dataset/display/tab"] == "dataset/display/visualization":
await show_visualization_tab(q, cfg)
elif q.client["dataset/display/tab"] == "dataset/display/statistics":
await show_statistics_tab(
q, dataset_filename=dataset_filename, config_filename=config_filename
)
elif q.client["dataset/display/tab"] == "dataset/display/summary":
await show_summary_tab(q, dataset_id)
q.page["dataset/display/footer"] = ui.form_card(
box="footer",
items=[
ui.inline(
items=[
ui.button(
name="dataset/newexperiment/from_current",
label="Create experiment",
primary=False,
disabled=False,
tooltip=None,
),
ui.button(name="dataset/list", label="Back", primary=False),
],
justify="start",
)
],
)
q.client.delete_cards.add("dataset/display/footer")
async def show_data_tab(q, cfg, filename: str):
fill_columns = get_fill_columns(cfg)
df = read_dataframe(filename, n_rows=200, fill_columns=fill_columns)
q.page["dataset/display/data"] = ui.form_card(
box="first",
items=[
ui_table_from_df(
q=q,
df=df,
name="dataset/display/data/table",
sortables=list(df.columns),
height="calc(100vh - 265px)",
cell_overflow="wrap",
)
],
)
q.client.delete_cards.add("dataset/display/data")
async def show_visualization_tab(q, cfg):
try:
plot = cfg.logging.plots_class.plot_data(cfg)
except Exception as error:
logger.error(f"Error while plotting data preview: {error}", exc_info=True)
plot = PlotData("