iljung1106
Initial commit
546ff88
"""
Artist Style Embedding - Dataset & DataLoader
"""
import os
import random
import warnings
from pathlib import Path
from typing import Dict, List, Tuple
from collections import defaultdict
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
from PIL import Image
import numpy as np
from tqdm import tqdm
# PIL 경고 억제
warnings.filterwarnings('ignore', category=UserWarning, module='PIL')
class ArtistDataset(Dataset):
"""Multi-branch artist dataset"""
def __init__(
self,
dataset_root: str,
dataset_face_root: str,
dataset_eyes_root: str,
artist_to_idx: Dict[str, int],
image_paths: Dict[str, List[str]], # 이 split의 full 이미지들
face_paths: Dict[str, List[str]], # 이 split의 face 이미지들
eye_paths: Dict[str, List[str]], # 이 split의 eye 이미지들
image_size: int = 224,
is_training: bool = True,
):
self.dataset_root = Path(dataset_root)
self.dataset_face_root = Path(dataset_face_root)
self.dataset_eyes_root = Path(dataset_eyes_root)
self.artist_to_idx = artist_to_idx
self.image_size = image_size
self.is_training = is_training
# Flat sample list
self.samples = []
for artist, paths in image_paths.items():
for img_path in paths:
self.samples.append((artist, os.path.basename(img_path)))
self.transform = self._get_transforms()
self.transform_eval = self._get_eval_transforms()
# Face/Eye paths per artist (이미 split된 것)
self._face_cache = {artist: [Path(p) for p in paths] for artist, paths in face_paths.items()}
self._eye_cache = {artist: [Path(p) for p in paths] for artist, paths in eye_paths.items()}
def _get_transforms(self):
return transforms.Compose([
transforms.Resize((self.image_size + 32, self.image_size + 32)),
transforms.RandomCrop(self.image_size),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.02),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.RandomErasing(p=0.1, scale=(0.02, 0.1)),
])
def _get_eval_transforms(self):
return transforms.Compose([
transforms.Resize((self.image_size, self.image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def _load_image(self, path: Path):
try:
img = Image.open(path)
# RGBA, Palette 등 모든 포맷을 RGB로 변환
if img.mode in ('RGBA', 'LA', 'P'):
# 투명 배경을 흰색으로
background = Image.new('RGB', img.size, (255, 255, 255))
if img.mode == 'P':
img = img.convert('RGBA')
background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None)
return background
return img.convert('RGB')
except Exception:
return None
def _get_placeholder(self) -> torch.Tensor:
return torch.zeros(3, self.image_size, self.image_size)
def __len__(self):
return len(self.samples)
def __getitem__(self, idx: int) -> Dict:
artist, img_name = self.samples[idx]
label = self.artist_to_idx[artist]
transform = self.transform if self.is_training else self.transform_eval
# Full image
full_path = self.dataset_root / artist / img_name
full_img = self._load_image(full_path)
if full_img is None:
return self.__getitem__((idx + 1) % len(self))
full_tensor = transform(full_img)
# Face image
face_paths = self._face_cache.get(artist, [])
if face_paths:
face_path = random.choice(face_paths)
face_img = self._load_image(face_path)
face_tensor = transform(face_img) if face_img else self._get_placeholder()
has_face = face_img is not None
else:
face_tensor = self._get_placeholder()
has_face = False
# Eye image
eye_paths = self._eye_cache.get(artist, [])
if eye_paths:
eye_path = random.choice(eye_paths)
eye_img = self._load_image(eye_path)
eye_tensor = transform(eye_img) if eye_img else self._get_placeholder()
has_eye = eye_img is not None
else:
eye_tensor = self._get_placeholder()
has_eye = False
return {
'full': full_tensor,
'face': face_tensor,
'eye': eye_tensor,
'has_face': has_face,
'has_eye': has_eye,
'label': label,
'artist': artist,
}
class PKSampler(Sampler):
"""P classes, K samples per class sampler for metric learning"""
def __init__(self, dataset: ArtistDataset, p: int = 32, k: int = 4):
self.dataset = dataset
self.p = p
self.k = k
self.class_to_indices = defaultdict(list)
for idx, (artist, _) in enumerate(dataset.samples):
label = dataset.artist_to_idx[artist]
self.class_to_indices[label].append(idx)
self.classes = list(self.class_to_indices.keys())
def __iter__(self):
class_indices = {
c: random.sample(indices, len(indices))
for c, indices in self.class_to_indices.items()
}
class_pointers = {c: 0 for c in self.classes}
class_order = self.classes.copy()
random.shuffle(class_order)
batches = []
batch = []
classes_in_batch = set()
for cls in class_order:
if len(classes_in_batch) >= self.p:
batches.append(batch)
batch = []
classes_in_batch = set()
indices = class_indices[cls]
ptr = class_pointers[cls]
samples = []
for _ in range(self.k):
if ptr >= len(indices):
ptr = 0
random.shuffle(indices)
samples.append(indices[ptr])
ptr += 1
class_pointers[cls] = ptr
batch.extend(samples)
classes_in_batch.add(cls)
if batch:
batches.append(batch)
random.shuffle(batches)
for batch in batches:
yield batch
def __len__(self):
return len(self.classes) // self.p
def build_dataset_splits(
dataset_root: str,
dataset_face_root: str,
dataset_eyes_root: str,
min_images: int = 3,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
seed: int = 42,
) -> Tuple[Dict[str, int], Dict[str, Dict[str, List[str]]], Dict[str, Dict[str, List[str]]], Dict[str, Dict[str, List[str]]]]:
"""
Returns:
artist_to_idx: 작가명 -> 인덱스 매핑
full_splits: {'train': {artist: [paths]}, 'val': {...}, 'test': {...}}
face_splits: 동일 구조
eye_splits: 동일 구조
"""
random.seed(seed)
np.random.seed(seed)
dataset_path = Path(dataset_root)
face_path = Path(dataset_face_root)
eyes_path = Path(dataset_eyes_root)
artist_images = {}
artist_faces = {}
artist_eyes = {}
print("Scanning dataset...")
artists = [d for d in dataset_path.iterdir() if d.is_dir()]
for artist_dir in tqdm(artists, desc="Loading artists"):
artist_name = artist_dir.name
# Full images
images = list(artist_dir.glob("*.jpg")) + \
list(artist_dir.glob("*.png")) + \
list(artist_dir.glob("*.webp"))
if len(images) >= min_images:
artist_images[artist_name] = [str(p) for p in images]
# Face images
face_dir = face_path / artist_name
if face_dir.exists():
faces = list(face_dir.glob("*.jpg")) + \
list(face_dir.glob("*.png")) + \
list(face_dir.glob("*.webp"))
artist_faces[artist_name] = [str(p) for p in faces]
else:
artist_faces[artist_name] = []
# Eye images
eye_dir = eyes_path / artist_name
if eye_dir.exists():
eyes = list(eye_dir.glob("*.jpg")) + \
list(eye_dir.glob("*.png")) + \
list(eye_dir.glob("*.webp"))
artist_eyes[artist_name] = [str(p) for p in eyes]
else:
artist_eyes[artist_name] = []
print(f"Found {len(artist_images)} artists with >= {min_images} images")
artists_sorted = sorted(artist_images.keys())
artist_to_idx = {name: idx for idx, name in enumerate(artists_sorted)}
full_splits = {'train': {}, 'val': {}, 'test': {}}
face_splits = {'train': {}, 'val': {}, 'test': {}}
eye_splits = {'train': {}, 'val': {}, 'test': {}}
for artist in artist_images.keys():
# Full images 분할
images = artist_images[artist]
random.shuffle(images)
n = len(images)
n_train = max(1, int(n * train_ratio))
n_val = max(1, int(n * val_ratio))
full_splits['train'][artist] = images[:n_train]
full_splits['val'][artist] = images[n_train:n_train + n_val]
full_splits['test'][artist] = images[n_train + n_val:]
# Face images 분할 (동일 비율)
faces = artist_faces[artist]
if faces:
random.shuffle(faces)
n_f = len(faces)
n_f_train = max(1, int(n_f * train_ratio)) if n_f > 0 else 0
n_f_val = max(1, int(n_f * val_ratio)) if n_f > 1 else 0
face_splits['train'][artist] = faces[:n_f_train]
face_splits['val'][artist] = faces[n_f_train:n_f_train + n_f_val]
face_splits['test'][artist] = faces[n_f_train + n_f_val:]
else:
face_splits['train'][artist] = []
face_splits['val'][artist] = []
face_splits['test'][artist] = []
# Eye images 분할 (동일 비율)
eyes = artist_eyes[artist]
if eyes:
random.shuffle(eyes)
n_e = len(eyes)
n_e_train = max(1, int(n_e * train_ratio)) if n_e > 0 else 0
n_e_val = max(1, int(n_e * val_ratio)) if n_e > 1 else 0
eye_splits['train'][artist] = eyes[:n_e_train]
eye_splits['val'][artist] = eyes[n_e_train:n_e_train + n_e_val]
eye_splits['test'][artist] = eyes[n_e_train + n_e_val:]
else:
eye_splits['train'][artist] = []
eye_splits['val'][artist] = []
eye_splits['test'][artist] = []
# 통계 출력
for split_name in ['train', 'val', 'test']:
total_full = sum(len(imgs) for imgs in full_splits[split_name].values())
total_face = sum(len(imgs) for imgs in face_splits[split_name].values())
total_eye = sum(len(imgs) for imgs in eye_splits[split_name].values())
print(f"{split_name}: {total_full} full, {total_face} face, {total_eye} eye images")
return artist_to_idx, full_splits, face_splits, eye_splits
for split_name, split_data in splits.items():
total = sum(len(imgs) for imgs in split_data.values())
print(f"{split_name}: {len(split_data)} artists, {total} images")
return artist_to_idx, splits
def collate_fn(batch):
return {
'full': torch.stack([item['full'] for item in batch]),
'face': torch.stack([item['face'] for item in batch]),
'eye': torch.stack([item['eye'] for item in batch]),
'has_face': torch.tensor([item['has_face'] for item in batch]),
'has_eye': torch.tensor([item['has_eye'] for item in batch]),
'label': torch.tensor([item['label'] for item in batch]),
'artist': [item['artist'] for item in batch],
}
def create_dataloaders(
config,
artist_to_idx: Dict[str, int],
full_splits: Dict[str, Dict[str, List[str]]],
face_splits: Dict[str, Dict[str, List[str]]],
eye_splits: Dict[str, Dict[str, List[str]]],
) -> Tuple[DataLoader, DataLoader, DataLoader]:
train_dataset = ArtistDataset(
dataset_root=config.data.dataset_root,
dataset_face_root=config.data.dataset_face_root,
dataset_eyes_root=config.data.dataset_eyes_root,
artist_to_idx=artist_to_idx,
image_paths=full_splits['train'],
face_paths=face_splits['train'],
eye_paths=eye_splits['train'],
image_size=config.data.image_size,
is_training=True,
)
# batch_size에서 P와 K 계산
# batch_size = P * K, K는 samples_per_class로 고정
k = config.train.samples_per_class
p = config.train.batch_size // k # batch_size=256이면 P=64
p = min(p, len(artist_to_idx)) # 클래스 수 초과 방지
print(f"PKSampler: P={p} classes × K={k} samples = {p*k} batch size")
train_sampler = PKSampler(
train_dataset,
p=p,
k=k,
)
train_loader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
num_workers=config.data.num_workers,
pin_memory=config.data.pin_memory,
collate_fn=collate_fn,
)
val_dataset = ArtistDataset(
dataset_root=config.data.dataset_root,
dataset_face_root=config.data.dataset_face_root,
dataset_eyes_root=config.data.dataset_eyes_root,
artist_to_idx=artist_to_idx,
image_paths=full_splits['val'],
face_paths=face_splits['val'],
eye_paths=eye_splits['val'],
image_size=config.data.image_size,
is_training=False,
)
val_loader = DataLoader(
val_dataset,
batch_size=config.train.batch_size,
shuffle=False,
num_workers=config.data.num_workers,
pin_memory=config.data.pin_memory,
collate_fn=collate_fn,
)
test_dataset = ArtistDataset(
dataset_root=config.data.dataset_root,
dataset_face_root=config.data.dataset_face_root,
dataset_eyes_root=config.data.dataset_eyes_root,
artist_to_idx=artist_to_idx,
image_paths=full_splits['test'],
face_paths=face_splits['test'],
eye_paths=eye_splits['test'],
image_size=config.data.image_size,
is_training=False,
)
test_loader = DataLoader(
test_dataset,
batch_size=config.train.batch_size,
shuffle=False,
num_workers=config.data.num_workers,
pin_memory=config.data.pin_memory,
collate_fn=collate_fn,
)
return train_loader, val_loader, test_loader