File size: 7,264 Bytes
3c45764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import torch 
import torch.nn as nn
import os
import random
import math
from PIL import Image
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from config import (
    patch_size, scale, dir_HR, dir_LR, dir_valid_HR, dir_valid_LR, 
    _project_root, device, gt_size
)
from realesrgan import RealESRGANDegrader
from autoencoder import get_vqgan

# Initialize degradation pipeline and VQGAN (lazy loading)
_degrader = None
_vqgan = None

def get_degrader():
    """Get or create degradation pipeline."""
    global _degrader
    if _degrader is None:
        _degrader = RealESRGANDegrader(scale=scale)
    return _degrader

def get_vqgan_model():
    """Get or create VQGAN model."""
    global _vqgan
    if _vqgan is None:
        _vqgan = get_vqgan(device=device)
    return _vqgan


class SRDatasetOnTheFly(torch.utils.data.Dataset):
    """
    PyTorch Dataset for on-the-fly degradation and VQGAN encoding.
    
    This dataset:
    1. Loads full HR images
    2. Crops 256x256 patches on-the-fly
    3. Applies RealESRGAN degradation to generate LR
    4. Upsamples LR to 256x256 using bicubic
    5. Encodes both HR and LR through VQGAN to get 64x64 latents
    
    Args:
        dir_HR (str): Directory path containing high-resolution images.
        scale (int, optional): Super-resolution scale factor. Defaults to config.scale (4).
        patch_size (int, optional): Size of patches. Defaults to config.patch_size (256).
        max_samples (int, optional): Maximum number of images to load. If None, loads all.
    
    Returns:
        tuple: (hr_latent, lr_latent) where both are torch.Tensor of shape (C, 64, 64)
               representing VQGAN-encoded latents.
    """
    
    def __init__(self, dir_HR, scale=scale, patch_size=patch_size, max_samples=None):
        super().__init__()
        
        self.dir_HR = dir_HR
        self.scale = scale
        self.patch_size = patch_size
        
        # Get all image files
        self.filenames = sorted([
            f for f in os.listdir(self.dir_HR) 
            if f.lower().endswith(('.png', '.jpg', '.jpeg'))
        ])
        
        # Limit to max_samples if specified
        if max_samples is not None:
            self.filenames = self.filenames[:max_samples]
        
        # Initialize degradation and VQGAN (will be loaded on first use)
        self.degrader = None
        self.vqgan = None
    
    def __len__(self):
        return len(self.filenames)
    
    def _load_image(self, img_path):
        """Load and validate image."""
        img = Image.open(img_path).convert("RGB")
        img_tensor = TF.to_tensor(img)  # (C, H, W) in range [0, 1]
        return img_tensor
    
    def _crop_patch(self, img_tensor, patch_size):
        """
        Crop a random patch from image.
        
        Args:
            img_tensor: (C, H, W) tensor
            patch_size: Size of patch to crop
        
        Returns:
            patch: (C, patch_size, patch_size) tensor
        """
        C, H, W = img_tensor.shape
        
        # Pad if image is smaller than patch_size
        if H < patch_size or W < patch_size:
            pad_h = max(0, patch_size - H)
            pad_w = max(0, patch_size - W)
            img_tensor = F.pad(img_tensor, (0, pad_w, 0, pad_h), mode='reflect')
            H, W = img_tensor.shape[1], img_tensor.shape[2]
        
        # Random crop
        top = random.randint(0, max(0, H - patch_size))
        left = random.randint(0, max(0, W - patch_size))
        
        patch = img_tensor[:, top:top+patch_size, left:left+patch_size]
        return patch
    
    def _apply_augmentations(self, hr, lr):
        """
        Apply synchronized augmentations to HR and LR.
        
        Args:
            hr: (C, H, W) HR tensor
            lr: (C, H, W) LR tensor
        
        Returns:
            hr_aug, lr_aug: Augmented tensors
        """
        # Horizontal flip
        if random.random() < 0.5:
            hr = torch.flip(hr, dims=[2])
            lr = torch.flip(lr, dims=[2])
        
        # Vertical flip
        if random.random() < 0.5:
            hr = torch.flip(hr, dims=[1])
            lr = torch.flip(lr, dims=[1])
        
        # 180° rotation
        if random.random() < 0.5:
            hr = torch.rot90(hr, k=2, dims=[1, 2])
            lr = torch.rot90(lr, k=2, dims=[1, 2])
        
        return hr, lr
    
    def __getitem__(self, idx):
        # Load HR image
        hr_path = os.path.join(self.dir_HR, self.filenames[idx])
        hr_full = self._load_image(hr_path)  # (C, H, W) in [0, 1]
        
        # Crop 256x256 patch from HR
        hr_patch = self._crop_patch(hr_full, self.patch_size)  # (C, 256, 256)
        
        # Initialize degrader and VQGAN on first use
        if self.degrader is None:
            self.degrader = get_degrader()
        if self.vqgan is None:
            self.vqgan = get_vqgan_model()
        
        # Apply degradation on-the-fly to generate LR
        # Degrader expects (C, H, W) and returns (C, H//scale, W//scale)
        hr_patch_gpu = hr_patch.to(device)  # (C, 256, 256)
        with torch.no_grad():
            lr_patch = self.degrader.degrade(hr_patch_gpu)  # (C, 64, 64) in pixel space
        
        # Upsample LR to 256x256 using bicubic interpolation
        lr_patch_upsampled = F.interpolate(
            lr_patch.unsqueeze(0),  # (1, C, 64, 64)
            size=(self.patch_size, self.patch_size),
            mode='bicubic',
            align_corners=False
        ).squeeze(0)  # (C, 256, 256)
        
        # Apply augmentations (synchronized)
        hr_patch, lr_patch_upsampled = self._apply_augmentations(
            hr_patch.cpu(), 
            lr_patch_upsampled.cpu()
        )
        
        # Encode through VQGAN to get latents (64x64)
        # Move to device for encoding
        hr_patch_gpu = hr_patch.to(device).unsqueeze(0)  # (1, C, 256, 256)
        lr_patch_gpu = lr_patch_upsampled.to(device).unsqueeze(0)  # (1, C, 256, 256)
        
        with torch.no_grad():
            # Encode HR: 256x256 -> 64x64 latent
            hr_latent = self.vqgan.encode(hr_patch_gpu)  # (1, C, 64, 64)
            
            # Encode LR: 256x256 -> 64x64 latent
            lr_latent = self.vqgan.encode(lr_patch_gpu)  # (1, C, 64, 64)
        
        # Remove batch dimension and move to CPU
        hr_latent = hr_latent.squeeze(0).cpu()  # (C, 64, 64)
        lr_latent = lr_latent.squeeze(0).cpu()  # (C, 64, 64)
        
        return hr_latent, lr_latent


# Create datasets using on-the-fly processing
train_dataset = SRDatasetOnTheFly(
    dir_HR=dir_HR,
    scale=scale,
    patch_size=patch_size
)

valid_dataset = SRDatasetOnTheFly(
    dir_HR=dir_valid_HR,
    scale=scale,
    patch_size=patch_size
)

# Mini dataset with 8 images for testing
mini_dataset = SRDatasetOnTheFly(
    dir_HR=dir_HR,
    scale=scale,
    patch_size=patch_size,
    max_samples=8
)

print(f"\nFull training dataset size: {len(train_dataset)}")
print(f"Full validation dataset size: {len(valid_dataset)}")
print(f"Mini dataset size: {len(mini_dataset)}")
print(f"Using on-the-fly degradation and VQGAN encoding")
print(f"Output: 64x64 latents (from 256x256 patches)")