Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Artist Style Embedding - Dataset & DataLoader | |
| """ | |
| import os | |
| import random | |
| import warnings | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple | |
| from collections import defaultdict | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader, Sampler | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| from tqdm import tqdm | |
| # PIL 경고 억제 | |
| warnings.filterwarnings('ignore', category=UserWarning, module='PIL') | |
| class ArtistDataset(Dataset): | |
| """Multi-branch artist dataset""" | |
| def __init__( | |
| self, | |
| dataset_root: str, | |
| dataset_face_root: str, | |
| dataset_eyes_root: str, | |
| artist_to_idx: Dict[str, int], | |
| image_paths: Dict[str, List[str]], # 이 split의 full 이미지들 | |
| face_paths: Dict[str, List[str]], # 이 split의 face 이미지들 | |
| eye_paths: Dict[str, List[str]], # 이 split의 eye 이미지들 | |
| image_size: int = 224, | |
| is_training: bool = True, | |
| ): | |
| self.dataset_root = Path(dataset_root) | |
| self.dataset_face_root = Path(dataset_face_root) | |
| self.dataset_eyes_root = Path(dataset_eyes_root) | |
| self.artist_to_idx = artist_to_idx | |
| self.image_size = image_size | |
| self.is_training = is_training | |
| # Flat sample list | |
| self.samples = [] | |
| for artist, paths in image_paths.items(): | |
| for img_path in paths: | |
| self.samples.append((artist, os.path.basename(img_path))) | |
| self.transform = self._get_transforms() | |
| self.transform_eval = self._get_eval_transforms() | |
| # Face/Eye paths per artist (이미 split된 것) | |
| self._face_cache = {artist: [Path(p) for p in paths] for artist, paths in face_paths.items()} | |
| self._eye_cache = {artist: [Path(p) for p in paths] for artist, paths in eye_paths.items()} | |
| def _get_transforms(self): | |
| return transforms.Compose([ | |
| transforms.Resize((self.image_size + 32, self.image_size + 32)), | |
| transforms.RandomCrop(self.image_size), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.02), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| transforms.RandomErasing(p=0.1, scale=(0.02, 0.1)), | |
| ]) | |
| def _get_eval_transforms(self): | |
| return transforms.Compose([ | |
| transforms.Resize((self.image_size, self.image_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def _load_image(self, path: Path): | |
| try: | |
| img = Image.open(path) | |
| # RGBA, Palette 등 모든 포맷을 RGB로 변환 | |
| if img.mode in ('RGBA', 'LA', 'P'): | |
| # 투명 배경을 흰색으로 | |
| background = Image.new('RGB', img.size, (255, 255, 255)) | |
| if img.mode == 'P': | |
| img = img.convert('RGBA') | |
| background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None) | |
| return background | |
| return img.convert('RGB') | |
| except Exception: | |
| return None | |
| def _get_placeholder(self) -> torch.Tensor: | |
| return torch.zeros(3, self.image_size, self.image_size) | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx: int) -> Dict: | |
| artist, img_name = self.samples[idx] | |
| label = self.artist_to_idx[artist] | |
| transform = self.transform if self.is_training else self.transform_eval | |
| # Full image | |
| full_path = self.dataset_root / artist / img_name | |
| full_img = self._load_image(full_path) | |
| if full_img is None: | |
| return self.__getitem__((idx + 1) % len(self)) | |
| full_tensor = transform(full_img) | |
| # Face image | |
| face_paths = self._face_cache.get(artist, []) | |
| if face_paths: | |
| face_path = random.choice(face_paths) | |
| face_img = self._load_image(face_path) | |
| face_tensor = transform(face_img) if face_img else self._get_placeholder() | |
| has_face = face_img is not None | |
| else: | |
| face_tensor = self._get_placeholder() | |
| has_face = False | |
| # Eye image | |
| eye_paths = self._eye_cache.get(artist, []) | |
| if eye_paths: | |
| eye_path = random.choice(eye_paths) | |
| eye_img = self._load_image(eye_path) | |
| eye_tensor = transform(eye_img) if eye_img else self._get_placeholder() | |
| has_eye = eye_img is not None | |
| else: | |
| eye_tensor = self._get_placeholder() | |
| has_eye = False | |
| return { | |
| 'full': full_tensor, | |
| 'face': face_tensor, | |
| 'eye': eye_tensor, | |
| 'has_face': has_face, | |
| 'has_eye': has_eye, | |
| 'label': label, | |
| 'artist': artist, | |
| } | |
| class PKSampler(Sampler): | |
| """P classes, K samples per class sampler for metric learning""" | |
| def __init__(self, dataset: ArtistDataset, p: int = 32, k: int = 4): | |
| self.dataset = dataset | |
| self.p = p | |
| self.k = k | |
| self.class_to_indices = defaultdict(list) | |
| for idx, (artist, _) in enumerate(dataset.samples): | |
| label = dataset.artist_to_idx[artist] | |
| self.class_to_indices[label].append(idx) | |
| self.classes = list(self.class_to_indices.keys()) | |
| def __iter__(self): | |
| class_indices = { | |
| c: random.sample(indices, len(indices)) | |
| for c, indices in self.class_to_indices.items() | |
| } | |
| class_pointers = {c: 0 for c in self.classes} | |
| class_order = self.classes.copy() | |
| random.shuffle(class_order) | |
| batches = [] | |
| batch = [] | |
| classes_in_batch = set() | |
| for cls in class_order: | |
| if len(classes_in_batch) >= self.p: | |
| batches.append(batch) | |
| batch = [] | |
| classes_in_batch = set() | |
| indices = class_indices[cls] | |
| ptr = class_pointers[cls] | |
| samples = [] | |
| for _ in range(self.k): | |
| if ptr >= len(indices): | |
| ptr = 0 | |
| random.shuffle(indices) | |
| samples.append(indices[ptr]) | |
| ptr += 1 | |
| class_pointers[cls] = ptr | |
| batch.extend(samples) | |
| classes_in_batch.add(cls) | |
| if batch: | |
| batches.append(batch) | |
| random.shuffle(batches) | |
| for batch in batches: | |
| yield batch | |
| def __len__(self): | |
| return len(self.classes) // self.p | |
| def build_dataset_splits( | |
| dataset_root: str, | |
| dataset_face_root: str, | |
| dataset_eyes_root: str, | |
| min_images: int = 3, | |
| train_ratio: float = 0.8, | |
| val_ratio: float = 0.1, | |
| seed: int = 42, | |
| ) -> Tuple[Dict[str, int], Dict[str, Dict[str, List[str]]], Dict[str, Dict[str, List[str]]], Dict[str, Dict[str, List[str]]]]: | |
| """ | |
| Returns: | |
| artist_to_idx: 작가명 -> 인덱스 매핑 | |
| full_splits: {'train': {artist: [paths]}, 'val': {...}, 'test': {...}} | |
| face_splits: 동일 구조 | |
| eye_splits: 동일 구조 | |
| """ | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| dataset_path = Path(dataset_root) | |
| face_path = Path(dataset_face_root) | |
| eyes_path = Path(dataset_eyes_root) | |
| artist_images = {} | |
| artist_faces = {} | |
| artist_eyes = {} | |
| print("Scanning dataset...") | |
| artists = [d for d in dataset_path.iterdir() if d.is_dir()] | |
| for artist_dir in tqdm(artists, desc="Loading artists"): | |
| artist_name = artist_dir.name | |
| # Full images | |
| images = list(artist_dir.glob("*.jpg")) + \ | |
| list(artist_dir.glob("*.png")) + \ | |
| list(artist_dir.glob("*.webp")) | |
| if len(images) >= min_images: | |
| artist_images[artist_name] = [str(p) for p in images] | |
| # Face images | |
| face_dir = face_path / artist_name | |
| if face_dir.exists(): | |
| faces = list(face_dir.glob("*.jpg")) + \ | |
| list(face_dir.glob("*.png")) + \ | |
| list(face_dir.glob("*.webp")) | |
| artist_faces[artist_name] = [str(p) for p in faces] | |
| else: | |
| artist_faces[artist_name] = [] | |
| # Eye images | |
| eye_dir = eyes_path / artist_name | |
| if eye_dir.exists(): | |
| eyes = list(eye_dir.glob("*.jpg")) + \ | |
| list(eye_dir.glob("*.png")) + \ | |
| list(eye_dir.glob("*.webp")) | |
| artist_eyes[artist_name] = [str(p) for p in eyes] | |
| else: | |
| artist_eyes[artist_name] = [] | |
| print(f"Found {len(artist_images)} artists with >= {min_images} images") | |
| artists_sorted = sorted(artist_images.keys()) | |
| artist_to_idx = {name: idx for idx, name in enumerate(artists_sorted)} | |
| full_splits = {'train': {}, 'val': {}, 'test': {}} | |
| face_splits = {'train': {}, 'val': {}, 'test': {}} | |
| eye_splits = {'train': {}, 'val': {}, 'test': {}} | |
| for artist in artist_images.keys(): | |
| # Full images 분할 | |
| images = artist_images[artist] | |
| random.shuffle(images) | |
| n = len(images) | |
| n_train = max(1, int(n * train_ratio)) | |
| n_val = max(1, int(n * val_ratio)) | |
| full_splits['train'][artist] = images[:n_train] | |
| full_splits['val'][artist] = images[n_train:n_train + n_val] | |
| full_splits['test'][artist] = images[n_train + n_val:] | |
| # Face images 분할 (동일 비율) | |
| faces = artist_faces[artist] | |
| if faces: | |
| random.shuffle(faces) | |
| n_f = len(faces) | |
| n_f_train = max(1, int(n_f * train_ratio)) if n_f > 0 else 0 | |
| n_f_val = max(1, int(n_f * val_ratio)) if n_f > 1 else 0 | |
| face_splits['train'][artist] = faces[:n_f_train] | |
| face_splits['val'][artist] = faces[n_f_train:n_f_train + n_f_val] | |
| face_splits['test'][artist] = faces[n_f_train + n_f_val:] | |
| else: | |
| face_splits['train'][artist] = [] | |
| face_splits['val'][artist] = [] | |
| face_splits['test'][artist] = [] | |
| # Eye images 분할 (동일 비율) | |
| eyes = artist_eyes[artist] | |
| if eyes: | |
| random.shuffle(eyes) | |
| n_e = len(eyes) | |
| n_e_train = max(1, int(n_e * train_ratio)) if n_e > 0 else 0 | |
| n_e_val = max(1, int(n_e * val_ratio)) if n_e > 1 else 0 | |
| eye_splits['train'][artist] = eyes[:n_e_train] | |
| eye_splits['val'][artist] = eyes[n_e_train:n_e_train + n_e_val] | |
| eye_splits['test'][artist] = eyes[n_e_train + n_e_val:] | |
| else: | |
| eye_splits['train'][artist] = [] | |
| eye_splits['val'][artist] = [] | |
| eye_splits['test'][artist] = [] | |
| # 통계 출력 | |
| for split_name in ['train', 'val', 'test']: | |
| total_full = sum(len(imgs) for imgs in full_splits[split_name].values()) | |
| total_face = sum(len(imgs) for imgs in face_splits[split_name].values()) | |
| total_eye = sum(len(imgs) for imgs in eye_splits[split_name].values()) | |
| print(f"{split_name}: {total_full} full, {total_face} face, {total_eye} eye images") | |
| return artist_to_idx, full_splits, face_splits, eye_splits | |
| for split_name, split_data in splits.items(): | |
| total = sum(len(imgs) for imgs in split_data.values()) | |
| print(f"{split_name}: {len(split_data)} artists, {total} images") | |
| return artist_to_idx, splits | |
| def collate_fn(batch): | |
| return { | |
| 'full': torch.stack([item['full'] for item in batch]), | |
| 'face': torch.stack([item['face'] for item in batch]), | |
| 'eye': torch.stack([item['eye'] for item in batch]), | |
| 'has_face': torch.tensor([item['has_face'] for item in batch]), | |
| 'has_eye': torch.tensor([item['has_eye'] for item in batch]), | |
| 'label': torch.tensor([item['label'] for item in batch]), | |
| 'artist': [item['artist'] for item in batch], | |
| } | |
| def create_dataloaders( | |
| config, | |
| artist_to_idx: Dict[str, int], | |
| full_splits: Dict[str, Dict[str, List[str]]], | |
| face_splits: Dict[str, Dict[str, List[str]]], | |
| eye_splits: Dict[str, Dict[str, List[str]]], | |
| ) -> Tuple[DataLoader, DataLoader, DataLoader]: | |
| train_dataset = ArtistDataset( | |
| dataset_root=config.data.dataset_root, | |
| dataset_face_root=config.data.dataset_face_root, | |
| dataset_eyes_root=config.data.dataset_eyes_root, | |
| artist_to_idx=artist_to_idx, | |
| image_paths=full_splits['train'], | |
| face_paths=face_splits['train'], | |
| eye_paths=eye_splits['train'], | |
| image_size=config.data.image_size, | |
| is_training=True, | |
| ) | |
| # batch_size에서 P와 K 계산 | |
| # batch_size = P * K, K는 samples_per_class로 고정 | |
| k = config.train.samples_per_class | |
| p = config.train.batch_size // k # batch_size=256이면 P=64 | |
| p = min(p, len(artist_to_idx)) # 클래스 수 초과 방지 | |
| print(f"PKSampler: P={p} classes × K={k} samples = {p*k} batch size") | |
| train_sampler = PKSampler( | |
| train_dataset, | |
| p=p, | |
| k=k, | |
| ) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_sampler=train_sampler, | |
| num_workers=config.data.num_workers, | |
| pin_memory=config.data.pin_memory, | |
| collate_fn=collate_fn, | |
| ) | |
| val_dataset = ArtistDataset( | |
| dataset_root=config.data.dataset_root, | |
| dataset_face_root=config.data.dataset_face_root, | |
| dataset_eyes_root=config.data.dataset_eyes_root, | |
| artist_to_idx=artist_to_idx, | |
| image_paths=full_splits['val'], | |
| face_paths=face_splits['val'], | |
| eye_paths=eye_splits['val'], | |
| image_size=config.data.image_size, | |
| is_training=False, | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=config.train.batch_size, | |
| shuffle=False, | |
| num_workers=config.data.num_workers, | |
| pin_memory=config.data.pin_memory, | |
| collate_fn=collate_fn, | |
| ) | |
| test_dataset = ArtistDataset( | |
| dataset_root=config.data.dataset_root, | |
| dataset_face_root=config.data.dataset_face_root, | |
| dataset_eyes_root=config.data.dataset_eyes_root, | |
| artist_to_idx=artist_to_idx, | |
| image_paths=full_splits['test'], | |
| face_paths=face_splits['test'], | |
| eye_paths=eye_splits['test'], | |
| image_size=config.data.image_size, | |
| is_training=False, | |
| ) | |
| test_loader = DataLoader( | |
| test_dataset, | |
| batch_size=config.train.batch_size, | |
| shuffle=False, | |
| num_workers=config.data.num_workers, | |
| pin_memory=config.data.pin_memory, | |
| collate_fn=collate_fn, | |
| ) | |
| return train_loader, val_loader, test_loader | |