NextGenC commited on
Commit
8add310
·
verified ·
1 Parent(s): 941a7fe

Upload 24 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 7gen_epoch_10.png filter=lfs diff=lfs merge=lfs -text
37
+ 7gen_epoch_100.png filter=lfs diff=lfs merge=lfs -text
38
+ 7gen_epoch_20.png filter=lfs diff=lfs merge=lfs -text
39
+ 7gen_epoch_30.png filter=lfs diff=lfs merge=lfs -text
40
+ 7gen_epoch_40.png filter=lfs diff=lfs merge=lfs -text
41
+ 7gen_epoch_50.png filter=lfs diff=lfs merge=lfs -text
42
+ 7gen_epoch_60.png filter=lfs diff=lfs merge=lfs -text
43
+ 7gen_epoch_70.png filter=lfs diff=lfs merge=lfs -text
44
+ 7gen_epoch_80.png filter=lfs diff=lfs merge=lfs -text
45
+ 7gen_epoch_90.png filter=lfs diff=lfs merge=lfs -text
7gen.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 7Gen - MNIST için Gelişmiş Üretici Model
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torchvision import datasets, transforms
6
+ from torch.utils.data import DataLoader
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+ import os
11
+
12
+ print("🚀 7Gen - Gelişmiş MNIST Üretici Sistemi 🚀")
13
+
14
+ # Cihaz ayarları
15
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+ print(f'Kullanılan cihaz: {device}')
17
+
18
+ # Hiperparametreler
19
+ batch_size = 64
20
+ latent_dim = 100
21
+ num_classes = 10
22
+ num_epochs = 100
23
+ lr = 0.0002
24
+
25
+ # Veri yükleme
26
+ transform = transforms.Compose([
27
+ transforms.ToTensor(), # Burayı düzelttim
28
+ transforms.Normalize([0.5], [0.5])
29
+ ])
30
+
31
+ dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
32
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
33
+
34
+ # Generator modeli
35
+ class Generator(nn.Module):
36
+ def __init__(self):
37
+ super(Generator, self).__init__()
38
+
39
+ self.label_emb = nn.Embedding(num_classes, num_classes)
40
+
41
+ self.model = nn.Sequential(
42
+ nn.Linear(latent_dim + num_classes, 256),
43
+ nn.LeakyReLU(0.2),
44
+ nn.BatchNorm1d(256),
45
+
46
+ nn.Linear(256, 512),
47
+ nn.LeakyReLU(0.2),
48
+ nn.BatchNorm1d(512),
49
+
50
+ nn.Linear(512, 1024),
51
+ nn.LeakyReLU(0.2),
52
+ nn.BatchNorm1d(1024),
53
+
54
+ nn.Linear(1024, 784),
55
+ nn.Tanh()
56
+ )
57
+
58
+ def forward(self, noise, labels):
59
+ label_embedding = self.label_emb(labels)
60
+ gen_input = torch.cat((noise, label_embedding), -1)
61
+ img = self.model(gen_input)
62
+ img = img.view(img.size(0), 1, 28, 28)
63
+ return img
64
+
65
+ # Discriminator modeli
66
+ class Discriminator(nn.Module):
67
+ def __init__(self):
68
+ super(Discriminator, self).__init__()
69
+
70
+ self.label_emb = nn.Embedding(num_classes, num_classes)
71
+
72
+ self.model = nn.Sequential(
73
+ nn.Linear(784 + num_classes, 512),
74
+ nn.LeakyReLU(0.2),
75
+ nn.Dropout(0.3),
76
+
77
+ nn.Linear(512, 256),
78
+ nn.LeakyReLU(0.2),
79
+ nn.Dropout(0.3),
80
+
81
+ nn.Linear(256, 1),
82
+ nn.Sigmoid()
83
+ )
84
+
85
+ def forward(self, img, labels):
86
+ img_flat = img.view(img.size(0), -1)
87
+ label_embedding = self.label_emb(labels)
88
+ d_input = torch.cat((img_flat, label_embedding), -1)
89
+ validity = self.model(d_input)
90
+ return validity
91
+
92
+ # Model oluşturma
93
+ generator = Generator().to(device)
94
+ discriminator = Discriminator().to(device)
95
+
96
+ # Loss ve optimizer
97
+ adversarial_loss = nn.BCELoss()
98
+ optimizer_G = optim.Adam(generator.parameters(), lr=lr)
99
+ optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)
100
+
101
+ # Klasör oluştur
102
+ os.makedirs('generated_images', exist_ok=True)
103
+
104
+ # Eğitim
105
+ print("\n🔥 7Gen Eğitimi Başlıyor...")
106
+
107
+ for epoch in range(num_epochs):
108
+ for i, (imgs, labels) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")):
109
+ imgs = imgs.to(device)
110
+ labels = labels.to(device)
111
+ batch_size = imgs.size(0)
112
+
113
+ # Ground truth'lar
114
+ valid = torch.ones(batch_size, 1).to(device)
115
+ fake = torch.zeros(batch_size, 1).to(device)
116
+
117
+ # Generator eğitimi
118
+ optimizer_G.zero_grad()
119
+ z = torch.randn(batch_size, latent_dim).to(device)
120
+ gen_labels = torch.randint(0, num_classes, (batch_size,)).to(device)
121
+ gen_imgs = generator(z, gen_labels)
122
+
123
+ g_loss = adversarial_loss(discriminator(gen_imgs, gen_labels), valid)
124
+ g_loss.backward()
125
+ optimizer_G.step()
126
+
127
+ # Discriminator eğitimi
128
+ optimizer_D.zero_grad()
129
+ real_loss = adversarial_loss(discriminator(imgs, labels), valid)
130
+ fake_loss = adversarial_loss(discriminator(gen_imgs.detach(), gen_labels), fake)
131
+ d_loss = (real_loss + fake_loss) / 2
132
+
133
+ d_loss.backward()
134
+ optimizer_D.step()
135
+
136
+ print(f"Epoch {epoch+1}/{num_epochs} - D loss: {d_loss:.4f}, G loss: {g_loss:.4f}")
137
+
138
+ # Her 10 epoch'ta örnek üret
139
+ if (epoch + 1) % 10 == 0:
140
+ with torch.no_grad():
141
+ z = torch.randn(100, latent_dim).to(device)
142
+ labels = torch.tensor([i for i in range(10) for _ in range(10)]).to(device)
143
+ gen_imgs = generator(z, labels)
144
+ gen_imgs = (gen_imgs + 1) / 2
145
+
146
+ fig, axes = plt.subplots(10, 10, figsize=(10, 10))
147
+ for i in range(10):
148
+ for j in range(10):
149
+ idx = i * 10 + j
150
+ axes[i, j].imshow(gen_imgs[idx][0].cpu().numpy(), cmap='gray')
151
+ axes[i, j].axis('off')
152
+ plt.savefig(f'generated_images/7gen_epoch_{epoch+1}.png')
153
+ plt.close()
154
+
155
+ # Model kaydetme
156
+ os.makedirs('models', exist_ok=True)
157
+ torch.save(generator.state_dict(), 'models/7gen_generator.pth')
158
+ torch.save(discriminator.state_dict(), 'models/7gen_discriminator.pth')
159
+
160
+ print("\n✅ 7Gen eğitimi tamamlandı!")
161
+
162
+ # Kullanım örneği
163
+ def generate_digit(digit, num_samples=5):
164
+ generator.eval()
165
+ with torch.no_grad():
166
+ z = torch.randn(num_samples, latent_dim).to(device)
167
+ labels = torch.full((num_samples,), digit).to(device)
168
+ gen_imgs = generator(z, labels)
169
+ gen_imgs = (gen_imgs + 1) / 2
170
+
171
+ plt.figure(figsize=(10, 2))
172
+ for i in range(num_samples):
173
+ plt.subplot(1, num_samples, i+1)
174
+ plt.imshow(gen_imgs[i][0].cpu().numpy(), cmap='gray')
175
+ plt.axis('off')
176
+ plt.savefig(f'generated_images/digit_{digit}_samples.png')
177
+ plt.show()
178
+
179
+ # Test et
180
+ print("\n🎯 Test örnekleri üretiliyor...")
181
+ for digit in range(10):
182
+ generate_digit(digit, num_samples=5)
183
+
184
+ print("\n🎉 7Gen hazır! generated_images klasörüne bak.")
7gen_discriminator.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8214e331ef016153c1afdfbf6538a0936523c3a5acb5a0435da4ec1d1ca0a52
3
+ size 2158141
7gen_epoch_10.png ADDED

Git LFS Details

  • SHA256: 95e9848b744750c1b5e7e3f433b0b67205ea382a2288e6fdf77324546eb5aa67
  • Pointer size: 131 Bytes
  • Size of remote file: 322 kB
7gen_epoch_100.png ADDED

Git LFS Details

  • SHA256: fa669564cd2277870202410b30ffa7836cdf8f0f13de7733a8234437d1cae475
  • Pointer size: 131 Bytes
  • Size of remote file: 269 kB
7gen_epoch_20.png ADDED

Git LFS Details

  • SHA256: 8cfe3d8b604848a3d8d3983df64385cc6ede5614bf0b3efd987dda71b4ea04e9
  • Pointer size: 131 Bytes
  • Size of remote file: 309 kB
7gen_epoch_30.png ADDED

Git LFS Details

  • SHA256: 045ce691746b8f29989b1fd968b18986368305ff0caedd8dff51878c34bd57a7
  • Pointer size: 131 Bytes
  • Size of remote file: 289 kB
7gen_epoch_40.png ADDED

Git LFS Details

  • SHA256: 80420068dfb9b2ca316c9461d1e12f803043efeee1b45bd7a44e60c88572479c
  • Pointer size: 131 Bytes
  • Size of remote file: 306 kB
7gen_epoch_50.png ADDED

Git LFS Details

  • SHA256: f3e53f38e63ac602fb3d5302606dc805dee8331eea1d484fdccedbd8a8a6299e
  • Pointer size: 131 Bytes
  • Size of remote file: 292 kB
7gen_epoch_60.png ADDED

Git LFS Details

  • SHA256: e4c1127b9d87a49a52a8e70f66af15b186c7489ce3dc458d6783ad6e29f6f24c
  • Pointer size: 131 Bytes
  • Size of remote file: 286 kB
7gen_epoch_70.png ADDED

Git LFS Details

  • SHA256: 889ecd7dc5a16b8e744c3f56b928b5aa2c7eac096403f3bcc04d520f0a107c8e
  • Pointer size: 131 Bytes
  • Size of remote file: 272 kB
7gen_epoch_80.png ADDED

Git LFS Details

  • SHA256: bb4ea5555f2effc041674d72deee41087907f2854c5b7c8a169b55794a3f47c1
  • Pointer size: 131 Bytes
  • Size of remote file: 271 kB
7gen_epoch_90.png ADDED

Git LFS Details

  • SHA256: 137ddaa3562138d851c2b814f9fff67b74b5bd86db31e7d99320cf8d76092588
  • Pointer size: 131 Bytes
  • Size of remote file: 261 kB
7gen_generator.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0eb28f1505b52f57de9c68bec28935750a1a14cb6dfe7c1e0b04293d301f5082
3
+ size 5992786
7gen_inference.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 7Gen Inference - Rakam Üretme Arayüzü
2
+ import torch
3
+ import torch.nn as nn
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from PIL import Image
7
+ import os
8
+
9
+ # Model yapısı (eğitimde kullandığımız ile aynı olmalı)
10
+ class Generator(nn.Module):
11
+ def __init__(self):
12
+ super(Generator, self).__init__()
13
+
14
+ self.label_emb = nn.Embedding(10, 10)
15
+
16
+ self.model = nn.Sequential(
17
+ nn.Linear(100 + 10, 256),
18
+ nn.LeakyReLU(0.2),
19
+ nn.BatchNorm1d(256),
20
+
21
+ nn.Linear(256, 512),
22
+ nn.LeakyReLU(0.2),
23
+ nn.BatchNorm1d(512),
24
+
25
+ nn.Linear(512, 1024),
26
+ nn.LeakyReLU(0.2),
27
+ nn.BatchNorm1d(1024),
28
+
29
+ nn.Linear(1024, 784),
30
+ nn.Tanh()
31
+ )
32
+
33
+ def forward(self, noise, labels):
34
+ label_embedding = self.label_emb(labels)
35
+ gen_input = torch.cat((noise, label_embedding), -1)
36
+ img = self.model(gen_input)
37
+ img = img.view(img.size(0), 1, 28, 28)
38
+ return img
39
+
40
+ # 7Gen sınıfı
41
+ class SevenGenInference:
42
+ def __init__(self, model_path='models/7gen_generator.pth'):
43
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
44
+ self.latent_dim = 100
45
+
46
+ # Modeli yükle
47
+ self.generator = Generator().to(self.device)
48
+ self.generator.load_state_dict(torch.load(model_path, map_location=self.device))
49
+ self.generator.eval()
50
+
51
+ print(f"🚀 7Gen yüklendi! Cihaz: {self.device}")
52
+
53
+ def generate_digit(self, digit, count=5):
54
+ """Belirli bir rakamdan istenen sayıda üret"""
55
+ with torch.no_grad():
56
+ z = torch.randn(count, self.latent_dim).to(self.device)
57
+ labels = torch.full((count,), digit).to(self.device)
58
+
59
+ images = self.generator(z, labels)
60
+ images = (images + 1) / 2 # [-1,1] -> [0,1]
61
+
62
+ return images.cpu()
63
+
64
+ def visualize_digits(self, digit, count=5, save_path=None):
65
+ """Üretilen rakamları görselleştir"""
66
+ images = self.generate_digit(digit, count)
67
+
68
+ fig, axes = plt.subplots(1, count, figsize=(2*count, 2))
69
+ if count == 1:
70
+ axes = [axes]
71
+
72
+ for i, ax in enumerate(axes):
73
+ ax.imshow(images[i][0], cmap='gray')
74
+ ax.axis('off')
75
+ ax.set_title(f'Digit: {digit}')
76
+
77
+ plt.tight_layout()
78
+
79
+ if save_path:
80
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
81
+ print(f"💾 Görsel kaydedildi: {save_path}")
82
+
83
+ plt.show()
84
+
85
+ def generate_grid(self, samples_per_digit=10, save_path=None):
86
+ """Her rakamdan örneklerle 10x10 grid oluştur"""
87
+ all_images = []
88
+
89
+ for digit in range(10):
90
+ images = self.generate_digit(digit, samples_per_digit)
91
+ all_images.append(images)
92
+
93
+ all_images = torch.cat(all_images, dim=0)
94
+
95
+ fig, axes = plt.subplots(10, samples_per_digit, figsize=(15, 15))
96
+
97
+ for i in range(10):
98
+ for j in range(samples_per_digit):
99
+ idx = i * samples_per_digit + j
100
+ axes[i, j].imshow(all_images[idx][0], cmap='gray')
101
+ axes[i, j].axis('off')
102
+
103
+ if j == 0:
104
+ axes[i, j].set_ylabel(f'{i}', rotation=0, size=20, labelpad=20)
105
+
106
+ plt.suptitle('7Gen - Üretilen Rakamlar', size=20)
107
+ plt.tight_layout()
108
+
109
+ if save_path:
110
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
111
+ print(f"💾 Grid kaydedildi: {save_path}")
112
+
113
+ plt.show()
114
+
115
+ def save_as_png(self, digit, count=1, output_dir='output'):
116
+ """Tekil PNG dosyaları olarak kaydet"""
117
+ os.makedirs(output_dir, exist_ok=True)
118
+
119
+ images = self.generate_digit(digit, count)
120
+
121
+ for i in range(count):
122
+ img = images[i][0].numpy()
123
+ img = (img * 255).astype(np.uint8)
124
+
125
+ pil_img = Image.fromarray(img)
126
+ filename = f"{output_dir}/digit_{digit}_{i+1}.png"
127
+ pil_img.save(filename)
128
+
129
+ print(f"💾 Kaydedildi: {filename}")
130
+
131
+ def interactive_generate(self):
132
+ """İnteraktif kullanım"""
133
+ print("\n🎮 7Gen İnteraktif Mod")
134
+ print("Çıkmak için 'q' yazın")
135
+
136
+ while True:
137
+ try:
138
+ digit_input = input("\nHangi rakamı üretmek istersin? (0-9): ")
139
+
140
+ if digit_input.lower() == 'q':
141
+ print("👋 Görüşürüz!")
142
+ break
143
+
144
+ digit = int(digit_input)
145
+ if 0 <= digit <= 9:
146
+ count = int(input("Kaç tane üreteyim? (1-20): "))
147
+ if 1 <= count <= 20:
148
+ self.visualize_digits(digit, count)
149
+ else:
150
+ print("❌ 1-20 arası bir sayı gir!")
151
+ else:
152
+ print("❌ 0-9 arası bir rakam gir!")
153
+
154
+ except ValueError:
155
+ print("❌ Geçerli bir sayı gir!")
156
+ except KeyboardInterrupt:
157
+ print("\n👋 Görüşürüz!")
158
+ break
159
+
160
+ # Ana kullanım
161
+ if __name__ == "__main__":
162
+ # 7Gen'i başlat
163
+ seven_gen = SevenGenInference()
164
+
165
+ # Örnekler
166
+ print("\n📝 Örnek kullanımlar:")
167
+ print("1. Tekil rakam üret")
168
+ seven_gen.visualize_digits(digit=7, count=5)
169
+
170
+ print("\n2. Grid oluştur")
171
+ seven_gen.generate_grid(samples_per_digit=10, save_path='7gen_showcase.png')
172
+
173
+ print("\n3. PNG olarak kaydet")
174
+ seven_gen.save_as_png(digit=5, count=3, output_dir='output')
175
+
176
+ print("\n4. İnteraktif mod")
177
+ seven_gen.interactive_generate()
digit_0_samples.png ADDED
digit_1_samples.png ADDED
digit_2_samples.png ADDED
digit_3_samples.png ADDED
digit_4_samples.png ADDED
digit_5_samples.png ADDED
digit_6_samples.png ADDED
digit_7_samples.png ADDED
digit_8_samples.png ADDED
digit_9_samples.png ADDED