Spaces:
Sleeping
Sleeping
File size: 7,138 Bytes
a7e18d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import os
from typing import Any, Literal
from pandas import DataFrame
try:
from trackio.media.media import TrackioMedia
from trackio.utils import MEDIA_DIR
except ImportError:
from media.media import TrackioMedia
from utils import MEDIA_DIR
class Table:
"""
Initializes a Table object. Tables can be used to log tabular data including images, numbers, and text.
Args:
columns (`list[str]`, *optional*):
Names of the columns in the table. Optional if `data` is provided. Not
expected if `dataframe` is provided. Currently ignored.
data (`list[list[Any]]`, *optional*):
2D row-oriented array of values. Each value can be: a number, a string (treated as Markdown and truncated if too long),
or a `Trackio.Image` or list of `Trackio.Image` objects.
dataframe (`pandas.`DataFrame``, *optional*):
DataFrame object used to create the table. When set, `data` and `columns`
arguments are ignored.
rows (`list[list[any]]`, *optional*):
Currently ignored.
optional (`bool` or `list[bool]`, *optional*, defaults to `True`):
Currently ignored.
allow_mixed_types (`bool`, *optional*, defaults to `False`):
Currently ignored.
log_mode: (`Literal["IMMUTABLE", "MUTABLE", "INCREMENTAL"]` or `None`, *optional*, defaults to `"IMMUTABLE"`):
Currently ignored.
"""
TYPE = "trackio.table"
def __init__(
self,
columns: list[str] | None = None,
data: list[list[Any]] | None = None,
dataframe: DataFrame | None = None,
rows: list[list[Any]] | None = None,
optional: bool | list[bool] = True,
allow_mixed_types: bool = False,
log_mode: Literal["IMMUTABLE", "MUTABLE", "INCREMENTAL"] | None = "IMMUTABLE",
):
# TODO: implement support for columns, dtype, optional, allow_mixed_types, and log_mode.
# for now (like `rows`) they are included for API compat but don't do anything.
if dataframe is None:
self.data = DataFrame(data) if data is not None else DataFrame()
else:
self.data = dataframe
def _has_media_objects(self, dataframe: DataFrame) -> bool:
"""Check if dataframe contains any TrackioMedia objects or lists of TrackioMedia objects."""
for col in dataframe.columns:
if dataframe[col].apply(lambda x: isinstance(x, TrackioMedia)).any():
return True
if (
dataframe[col]
.apply(
lambda x: isinstance(x, list)
and len(x) > 0
and isinstance(x[0], TrackioMedia)
)
.any()
):
return True
return False
def _process_data(self, project: str, run: str, step: int = 0):
"""Convert dataframe to dict format, processing any TrackioMedia objects if present."""
df = self.data
if not self._has_media_objects(df):
return df.to_dict(orient="records")
processed_df = df.copy()
for col in processed_df.columns:
for idx in processed_df.index:
value = processed_df.at[idx, col]
if isinstance(value, TrackioMedia):
value._save(project, run, step)
processed_df.at[idx, col] = value._to_dict()
if (
isinstance(value, list)
and len(value) > 0
and isinstance(value[0], TrackioMedia)
):
[v._save(project, run, step) for v in value]
processed_df.at[idx, col] = [v._to_dict() for v in value]
return processed_df.to_dict(orient="records")
@staticmethod
def to_display_format(table_data: list[dict]) -> list[dict]:
"""Convert stored table data to display format for UI rendering. Note
that this does not use the self.data attribute, but instead uses the
table_data parameter, which is is what the UI receives.
Args:
table_data: List of dictionaries representing table rows (from stored _value)
Returns:
Table data with images converted to markdown syntax and long text truncated.
"""
truncate_length = int(os.getenv("TRACKIO_TABLE_TRUNCATE_LENGTH", "250"))
def convert_image_to_markdown(image_data: dict) -> str:
relative_path = image_data.get("file_path", "")
caption = image_data.get("caption", "")
absolute_path = MEDIA_DIR / relative_path
return f'<img src="/gradio_api/file={absolute_path}" alt="{caption}" />'
processed_data = []
for row in table_data:
processed_row = {}
for key, value in row.items():
if isinstance(value, dict) and value.get("_type") == "trackio.image":
processed_row[key] = convert_image_to_markdown(value)
elif (
isinstance(value, list)
and len(value) > 0
and isinstance(value[0], dict)
and value[0].get("_type") == "trackio.image"
):
# This assumes that if the first item is an image, all items are images. Ok for now since we don't support mixed types in a single cell.
processed_row[key] = (
'<div style="display: flex; gap: 10px;">'
+ "".join([convert_image_to_markdown(item) for item in value])
+ "</div>"
)
elif isinstance(value, str) and len(value) > truncate_length:
truncated = value[:truncate_length]
full_text = value.replace("<", "<").replace(">", ">")
processed_row[key] = (
f'<details style="display: inline;">'
f'<summary style="display: inline; cursor: pointer;">{truncated}…<span><em>(truncated, click to expand)</em></span></summary>'
f'<div style="margin-top: 10px; padding: 10px; background: #f5f5f5; border-radius: 4px; max-height: 400px; overflow: auto;">'
f'<pre style="white-space: pre-wrap; word-wrap: break-word; margin: 0;">{full_text}</pre>'
f"</div>"
f"</details>"
)
else:
processed_row[key] = value
processed_data.append(processed_row)
return processed_data
def _to_dict(self, project: str, run: str, step: int = 0):
"""Convert table to dictionary representation.
Args:
project: Project name for saving media files
run: Run name for saving media files
step: Step number for saving media files
"""
data = self._process_data(project, run, step)
return {
"_type": self.TYPE,
"_value": data,
}
|