File size: 3,172 Bytes
bd4a200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from torch.utils import data
import os
import torch
import numpy as np
import cv2
import random

class myDataset(data.Dataset):
    """Custom data.Dataset compatible with data.DataLoader."""
    def __init__(self, train_data_dir):
        self.img_path = os.path.join(train_data_dir, "hair")
        self.pose_path = os.path.join(train_data_dir, "pose.npy")
        self.non_hair_path = os.path.join(train_data_dir, "no_hair")
        self.ref_path = os.path.join(train_data_dir, "ref_hair")

        self.lists = os.listdir(self.img_path)
        self.len = len(self.lists)
        self.pose = np.load(self.pose_path)

    def __getitem__(self, index):
        """Returns one data pair (source and target)."""
        # seq_len, fea_dim
        random_number1 = random.randrange(0, 120)
        random_number2 = random.randrange(0, 120)
        while random_number2==random_number1:
            random_number2 = random.randrange(0, 120)
        name = self.lists[index]

        hair_path = os.path.join(self.img_path, name, str(random_number1)+'.jpg')
        non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2)+'.jpg')
        ref_folder = os.path.join(self.ref_path, name)
        files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
        ref_path = os.path.join(ref_folder, files[0])
        img_hair = cv2.imread(hair_path)
        img_non_hair = cv2.imread(non_hair_path)
        ref_hair = cv2.imread(ref_path)

        img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
        img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
        ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)

        img_hair = cv2.resize(img_hair, (512, 512))
        img_non_hair = cv2.resize(img_non_hair, (512, 512))
        ref_hair = cv2.resize(ref_hair, (512, 512))
        img_hair = (img_hair/255.0)* 2 - 1
        img_non_hair = (img_non_hair/255.0)
        ref_hair = (ref_hair/255.0)* 2 - 1

        img_hair = torch.tensor(img_hair).permute(2, 0, 1)  
        img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)  
        ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)  

        pose1 = self.pose[random_number1]
        pose1 = torch.tensor(pose1)   
        pose2 = self.pose[random_number2]
        pose2 = torch.tensor(pose2)      
        
        return {
            'hair_pose': pose1, 
            'img_hair':img_hair, 
            'bald_pose': pose2, 
            'img_non_hair':img_non_hair, 
            'ref_hair':ref_hair
            }  
    
    def __len__(self):
        return self.len
    
if __name__ == "__main__":

    train_dataset = myDataset("./data")
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=1,
        num_workers=1,
    )

    for epoch in range(0, len(train_dataset)+1):
        for step, batch in enumerate(train_dataloader):
            print("batch[hair_pose]:", batch["hair_pose"])
            print("batch[img_hair]:", batch["img_hair"])
            print("batch[bald_pose]:", batch["bald_pose"])
            print("batch[img_non_hair]:", batch["img_non_hair"])
            print("batch[ref_hair]:", batch["ref_hair"])