{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "e9589635", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/home/ubuntu/Qwen-Image-Edit-Angles\n" ] } ], "source": [ "%cd /home/ubuntu/Qwen-Image-Edit-Angles" ] }, { "cell_type": "code", "execution_count": 2, "id": "029dd0ba", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/lib/python3/dist-packages/sklearn/utils/fixes.py:25: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", " from pkg_resources import parse_version # type: ignore\n", "2025-11-26 13:30:28.676566: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "2025-11-26 13:30:28.690936: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "E0000 00:00:1764163828.708252 3977707 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "E0000 00:00:1764163828.713751 3977707 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "W0000 00:00:1764163828.727448 3977707 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", "W0000 00:00:1764163828.727463 3977707 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", "W0000 00:00:1764163828.727466 3977707 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", "W0000 00:00:1764163828.727467 3977707 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", "2025-11-26 13:30:28.731891: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" ] }, { "ename": "AttributeError", "evalue": "'MessageFactory' object has no attribute 'GetPrototype'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "\u001b[0;31mAttributeError\u001b[0m: 'MessageFactory' object has no attribute 'GetPrototype'" ] }, { "ename": "AttributeError", "evalue": "'MessageFactory' object has no attribute 'GetPrototype'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "\u001b[0;31mAttributeError\u001b[0m: 'MessageFactory' object has no attribute 'GetPrototype'" ] }, { "ename": "AttributeError", "evalue": "'MessageFactory' object has no attribute 'GetPrototype'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "\u001b[0;31mAttributeError\u001b[0m: 'MessageFactory' object has no attribute 'GetPrototype'" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/.local/lib/python3.10/site-packages/google/api_core/_python_version_support.py:266: FutureWarning: You are using a Python version (3.10.12) which Google will stop supporting in new releases of google.api_core once it reaches its end of life (2026-10-04). Please upgrade to the latest Python version, or at least Python 3.11, to continue receiving updates for google.api_core past that date.\n", " warnings.warn(message, FutureWarning)\n" ] }, { "ename": "AttributeError", "evalue": "'MessageFactory' object has no attribute 'GetPrototype'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "\u001b[0;31mAttributeError\u001b[0m: 'MessageFactory' object has no attribute 'GetPrototype'" ] }, { "ename": "AttributeError", "evalue": "'MessageFactory' object has no attribute 'GetPrototype'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "\u001b[0;31mAttributeError\u001b[0m: 'MessageFactory' object has no attribute 'GetPrototype'" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Skipping import of cpp extensions due to incompatible torch version 2.9.1+cu128 for torchao version 0.14.1 Please see https://github.com/pytorch/ao/issues/2919 for more info\n", "TMA benchmarks will be running without grid constant TMA descriptor.\n", "WARNING:bitsandbytes.cextension:Could not find the bitsandbytes CUDA binary at PosixPath('/usr/local/lib/python3.10/dist-packages/bitsandbytes/libbitsandbytes_cuda128.so')\n", "ERROR:bitsandbytes.cextension:Could not load bitsandbytes native library: /lib/x86_64-linux-gnu/libstdc++.so.6: version `GLIBCXX_3.4.32' not found (required by /usr/local/lib/python3.10/dist-packages/bitsandbytes/libbitsandbytes_cpu.so)\n", "Traceback (most recent call last):\n", " File \"/usr/local/lib/python3.10/dist-packages/bitsandbytes/cextension.py\", line 85, in \n", " lib = get_native_library()\n", " File \"/usr/local/lib/python3.10/dist-packages/bitsandbytes/cextension.py\", line 72, in get_native_library\n", " dll = ct.cdll.LoadLibrary(str(binary_path))\n", " File \"/usr/lib/python3.10/ctypes/__init__.py\", line 452, in LoadLibrary\n", " return self._dlltype(name)\n", " File \"/usr/lib/python3.10/ctypes/__init__.py\", line 374, in __init__\n", " self._handle = _dlopen(self._name, mode)\n", "OSError: /lib/x86_64-linux-gnu/libstdc++.so.6: version `GLIBCXX_3.4.32' not found (required by /usr/local/lib/python3.10/dist-packages/bitsandbytes/libbitsandbytes_cpu.so)\n", "WARNING:bitsandbytes.cextension:\n", "CUDA Setup failed despite CUDA being available. Please run the following command to get more information:\n", "\n", "python -m bitsandbytes\n", "\n", "Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them\n", "to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes\n", "and open an issue at: https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ed2a88fbeaf7493ca5d5fea50c1fb927", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Fetching 7 files: 0%| | 0/7 [00:00 of len17920\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1000/1000 [05:01<00:00, 3.32it/s]\n" ] } ], "source": [ "from pathlib import Path\n", "\n", "import torch\n", "import tqdm\n", "import torchvision.transforms.v2.functional as TF\n", "\n", "from qwenimage.sources import EditingSource\n", "\n", "\n", "src = EditingSource(\n", " data_dir=\"/data/CrispEdit\",\n", " total_per=10,\n", ")\n", "\n", "data_range = 1000\n", "\n", "gen_steps=50\n", "\n", "data_dir = Path(\"/data/regression_output_1024\")\n", "\n", "\n", "real_ims = []\n", "fake_ims = []\n", "for d_idx in tqdm.tqdm(range(data_range)):\n", " dpath = data_dir / f\"{d_idx:06d}.pt\"\n", " out_dict = torch.load(dpath)\n", "\n", " x_0 = out_dict[\"output\"]\n", " h,w = out_dict[\"height\"], out_dict[\"width\"]\n", " lh,lw = h//16, w//16\n", "\n", " model_pred_im = foundation.latents_to_pil(x_0,lh,lw)[0]\n", " gt_im = src[d_idx][-1]\n", " \n", " gt_im = TF.resize(gt_im, (h,w))\n", "\n", " real_ims.append(gt_im)\n", " fake_ims.append(model_pred_im)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "efa6cd25", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[ 4.1450e-05, -3.5806e-05]], device='cuda:0', grad_fn=),\n", " torch.Size([1, 2]))" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# from urllib.request import urlopen\n", "# from transformers.image_utils import load_image\n", "# from PIL import Image\n", "# import timm\n", "\n", "# img = load_image(\"/home/ubuntu/Qwen-Image-Edit-Angles/scripts/assets/test_images_v1/test_image_1.jpg\")\n", "\n", "# model = timm.create_model(\n", "# 'vit_huge_plus_patch16_dinov3.lvd1689m',\n", "# pretrained=True,\n", "# num_classes=2,\n", "# ).to(device=\"cuda\")\n", "# model = model.eval()\n", "\n", "# data_config = timm.data.resolve_model_data_config(model)\n", "# data_config[\"input_size\"] = (3, 1024, 1024)\n", "# transforms = timm.data.create_transform(**data_config, is_training=False)\n", "\n", "# output = model(transforms(img).to(\"cuda\").unsqueeze(0))\n", "\n", "# output, output.shape\n", "# # (tensor([[ 4.1450e-05, -3.5806e-05]], device='cuda:0', grad_fn=), torch.Size([1, 2]))\n" ] }, { "cell_type": "code", "execution_count": null, "id": "7270e152", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Compose(\n", " Resize(size=1024, interpolation=bicubic, max_size=None, antialias=True)\n", " CenterCrop(size=(1024, 1024))\n", " MaybeToTensor()\n", " Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))\n", ")" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [] }, { "cell_type": "code", "execution_count": 9, "id": "2f276cea", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training samples: 1600, Validation samples: 400\n", "Model parameters: 1,281\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import Dataset, DataLoader\n", "import torchvision.transforms.v2 as T\n", "from torchvision.transforms.v2 import Normalize\n", "import timm\n", "from PIL import Image\n", "from tqdm import tqdm\n", "\n", "# Custom Dataset for real/fake classification\n", "class RealFakeDataset(Dataset):\n", " def __init__(self, real_images, fake_images, transform=None):\n", " \"\"\"\n", " Args:\n", " real_images: List of PIL Images (real)\n", " fake_images: List of PIL Images (fake)\n", " transform: Optional transform to apply to images\n", " \"\"\"\n", " self.images = real_images + fake_images\n", " # Label 1 for real, -1 for fake (for hinge loss)\n", " self.labels = [1.0] * len(real_images) + [-1.0] * len(fake_images)\n", " self.transform = transform\n", " \n", " def __len__(self):\n", " return len(self.images)\n", " \n", " def __getitem__(self, idx):\n", " image = self.images[idx]\n", " label = self.labels[idx]\n", " \n", " if self.transform:\n", " image = self.transform(image)\n", " \n", " return image, label\n", "\n", "\n", "# Hinge Loss for Discriminator\n", "class HingeLoss(nn.Module):\n", " def __init__(self, r1_gamma=0.0):\n", " \"\"\"\n", " Hinge loss for discriminator training.\n", " \n", " Args:\n", " r1_gamma: Weight for R1 gradient penalty (set > 0 to enable)\n", " \"\"\"\n", " super().__init__()\n", " self.r1_gamma = r1_gamma\n", " \n", " def forward(self, scores, labels, features=None):\n", " \"\"\"\n", " Args:\n", " scores: Discriminator scores (B,)\n", " labels: Ground truth labels (1.0 for real, -1.0 for fake)\n", " features: Optional features for R1 penalty\n", " \n", " Returns:\n", " loss: Hinge loss value\n", " \"\"\"\n", " # Hinge loss: E[max(0, 1 - label * score)]\n", " # For real (label=1): max(0, 1 - score)\n", " # For fake (label=-1): max(0, 1 + score)\n", " hinge_loss = torch.mean(torch.clamp(1.0 - labels * scores, min=0.0))\n", " \n", " loss = hinge_loss\n", " \n", " # R1 gradient penalty (optional)\n", " if self.r1_gamma > 0 and features is not None:\n", " # Only compute for real images\n", " real_mask = labels > 0\n", " if real_mask.any():\n", " real_scores = scores[real_mask]\n", " real_features = features[real_mask]\n", " \n", " # Compute gradients\n", " gradients = torch.autograd.grad(\n", " outputs=real_scores.sum(),\n", " inputs=real_features,\n", " create_graph=True,\n", " retain_graph=True,\n", " only_inputs=True\n", " )[0]\n", " \n", " # R1 penalty: ||∇D||^2\n", " r1_penalty = gradients.pow(2).sum(dim=tuple(range(1, len(gradients.shape))))\n", " loss = loss + self.r1_gamma * r1_penalty.mean()\n", " \n", " return loss\n", "\n", "\n", "# Initialize model with single output\n", "model = timm.create_model(\n", " 'vit_huge_plus_patch16_dinov3.lvd1689m',\n", " pretrained=True,\n", " num_classes=1, # Single discriminator score\n", ").to(device=\"cuda\")\n", "\n", "data_config = timm.data.resolve_model_data_config(model)\n", "data_config[\"input_size\"] = (3, 1024, 1024)\n", "\n", "\n", "train_transforms_orig = timm.data.create_transform(**data_config, is_training=True)\n", "val_transforms_orig = timm.data.create_transform(**data_config, is_training=False)\n", "\n", "# Setup custom transforms without resize\n", "train_transforms_custom = T.Compose([\n", " T.ToImage(),\n", " T.RGB(),\n", " T.ToDtype(torch.float32, scale=True),\n", " Normalize(mean=torch.tensor([0.4850, 0.4560, 0.4060]), std=torch.tensor([0.2290, 0.2240, 0.2250]))\n", "])\n", "\n", "val_transforms_custom = T.Compose([\n", " T.ToImage(),\n", " T.RGB(),\n", " T.ToDtype(torch.float32, scale=True),\n", " Normalize(mean=torch.tensor([0.4850, 0.4560, 0.4060]), std=torch.tensor([0.2290, 0.2240, 0.2250]))\n", "])\n", "\n", "batch_size = 8\n", "if batch_size > 1:\n", " train_transforms = train_transforms_orig\n", " val_transforms = val_transforms_orig\n", "else:\n", " train_transforms = train_transforms_custom\n", " val_transforms = val_transforms_custom\n", "\n", "# Split data into train/val (80/20 split)\n", "train_size = int(0.8 * len(real_ims))\n", "train_real = real_ims[:train_size]\n", "train_fake = fake_ims[:train_size]\n", "val_real = real_ims[train_size:]\n", "val_fake = fake_ims[train_size:]\n", "\n", "# Create datasets and dataloaders\n", "train_dataset = RealFakeDataset(train_real, train_fake, transform=train_transforms)\n", "val_dataset = RealFakeDataset(val_real, val_fake, transform=val_transforms)\n", "\n", "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)\n", "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)\n", "\n", "# Loss function and optimizer\n", "# Set r1_gamma > 0 to enable R1 gradient penalty (beneficial for resolutions > 128^2)\n", "criterion = HingeLoss(r1_gamma=0.1)\n", "\n", "# Only train the classifier head, freeze the backbone\n", "for param in model.parameters():\n", " param.requires_grad = False\n", "# Unfreeze the classifier head\n", "for param in model.get_classifier().parameters():\n", " param.requires_grad = True\n", "\n", "optimizer = optim.AdamW(model.get_classifier().parameters(), lr=1e-3, weight_decay=1e-4)\n", "scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)\n", "\n", "# Training function\n", "def train_epoch(model, dataloader, criterion, optimizer, device):\n", " model.train()\n", " running_loss = 0.0\n", " correct = 0\n", " total = 0\n", " \n", " pbar = tqdm(dataloader, desc='Training')\n", " for images, labels in pbar:\n", " images, labels = images.to(device), labels.to(device)\n", " \n", " optimizer.zero_grad()\n", " scores = model(images).squeeze(-1) # (B, 1) -> (B,)\n", " loss = criterion(scores, labels)\n", " loss.backward()\n", " optimizer.step()\n", " \n", " running_loss += loss.item()\n", " # Prediction: positive score -> real, negative score -> fake\n", " predicted = (scores > 0).float() * 2 - 1 # Convert to {-1, 1}\n", " total += labels.size(0)\n", " correct += (predicted == labels).sum().item()\n", " \n", " pbar.set_postfix({'loss': loss.item(), 'acc': 100.*correct/total})\n", " \n", " epoch_loss = running_loss / len(dataloader)\n", " epoch_acc = 100. * correct / total\n", " return epoch_loss, epoch_acc\n", "\n", "# Validation function\n", "def validate(model, dataloader, criterion, device):\n", " model.eval()\n", " running_loss = 0.0\n", " correct = 0\n", " total = 0\n", " \n", " with torch.no_grad():\n", " pbar = tqdm(dataloader, desc='Validation')\n", " for images, labels in pbar:\n", " images, labels = images.to(device), labels.to(device)\n", " \n", " scores = model(images).squeeze(-1) # (B, 1) -> (B,)\n", " loss = criterion(scores, labels)\n", " \n", " running_loss += loss.item()\n", " # Prediction: positive score -> real, negative score -> fake\n", " predicted = (scores > 0).float() * 2 - 1 # Convert to {-1, 1}\n", " total += labels.size(0)\n", " correct += (predicted == labels).sum().item()\n", " \n", " pbar.set_postfix({'loss': loss.item(), 'acc': 100.*correct/total})\n", " \n", " epoch_loss = running_loss / len(dataloader)\n", " epoch_acc = 100. * correct / total\n", " return epoch_loss, epoch_acc\n", "\n", "print(f\"Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}\")\n", "print(f\"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}\")\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "ee26768e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 1/10\n", "------------------------------------------------------------\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training: 100%|██████████| 200/200 [06:20<00:00, 1.90s/it, loss=0.904, acc=55.2]\n", "Validation: 100%|██████████| 50/50 [01:36<00:00, 1.92s/it, loss=0.743, acc=62] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 1 Summary:\n", "Train Loss: 0.9525 | Train Acc: 55.19%\n", "Val Loss: 0.8997 | Val Acc: 62.00%\n", "LR: 0.000976\n", "✓ New best model saved! Val Acc: 62.00%\n", "\n", "Epoch 2/10\n", "------------------------------------------------------------\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training: 100%|██████████| 200/200 [06:20<00:00, 1.90s/it, loss=0.887, acc=63.6]\n", "Validation: 100%|██████████| 50/50 [01:36<00:00, 1.92s/it, loss=1.18, acc=62] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 2 Summary:\n", "Train Loss: 0.8568 | Train Acc: 63.56%\n", "Val Loss: 0.8351 | Val Acc: 62.00%\n", "LR: 0.000905\n", "\n", "Epoch 3/10\n", "------------------------------------------------------------\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training: 100%|██████████| 200/200 [06:19<00:00, 1.90s/it, loss=1.11, acc=63.4] \n", "Validation: 44%|████▍ | 22/50 [00:44<00:57, 2.04s/it, loss=0.838, acc=72.2]\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/tmp/ipykernel_3977707/1866398459.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;31m# Validate\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 19\u001b[0;31m \u001b[0mval_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_acc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalidate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'cuda'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;31m# Update learning rate\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/tmp/ipykernel_3977707/1776458232.py\u001b[0m in \u001b[0;36mvalidate\u001b[0;34m(model, dataloader, criterion, device)\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscores\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m \u001b[0mrunning_loss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 201\u001b[0m \u001b[0;31m# Prediction: positive score -> real, negative score -> fake\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[0mpredicted\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mscores\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;31m# Convert to {-1, 1}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "# Training loop\n", "num_epochs = 10\n", "best_val_acc = 0.0\n", "history = {\n", " 'train_loss': [],\n", " 'train_acc': [],\n", " 'val_loss': [],\n", " 'val_acc': []\n", "}\n", "\n", "for epoch in range(num_epochs):\n", " print(f'\\nEpoch {epoch+1}/{num_epochs}')\n", " print('-' * 60)\n", " \n", " # Train\n", " train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, 'cuda')\n", " \n", " # Validate\n", " val_loss, val_acc = validate(model, val_loader, criterion, 'cuda')\n", " \n", " # Update learning rate\n", " scheduler.step()\n", " \n", " # Save history\n", " history['train_loss'].append(train_loss)\n", " history['train_acc'].append(train_acc)\n", " history['val_loss'].append(val_loss)\n", " history['val_acc'].append(val_acc)\n", " \n", " # Print epoch summary\n", " print(f'\\nEpoch {epoch+1} Summary:')\n", " print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')\n", " print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')\n", " print(f'LR: {optimizer.param_groups[0][\"lr\"]:.6f}')\n", " \n", " # Save best model\n", " if val_acc > best_val_acc:\n", " best_val_acc = val_acc\n", " torch.save({\n", " 'epoch': epoch,\n", " 'model_state_dict': model.state_dict(),\n", " 'optimizer_state_dict': optimizer.state_dict(),\n", " 'val_acc': val_acc,\n", " }, 'best_dinov3_classifier.pt')\n", " print(f'✓ New best model saved! Val Acc: {val_acc:.2f}%')\n", "\n", "print(f'\\nTraining completed! Best validation accuracy: {best_val_acc:.2f}%')\n" ] }, { "cell_type": "code", "execution_count": null, "id": "367faf67", "metadata": {}, "outputs": [], "source": [ "# Plot training history\n", "import matplotlib.pyplot as plt\n", "\n", "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))\n", "\n", "# Plot loss\n", "ax1.plot(history['train_loss'], label='Train Loss', marker='o')\n", "ax1.plot(history['val_loss'], label='Val Loss', marker='o')\n", "ax1.set_xlabel('Epoch')\n", "ax1.set_ylabel('Loss')\n", "ax1.set_title('Training and Validation Loss')\n", "ax1.legend()\n", "ax1.grid(True, alpha=0.3)\n", "\n", "# Plot accuracy\n", "ax2.plot(history['train_acc'], label='Train Acc', marker='o')\n", "ax2.plot(history['val_acc'], label='Val Acc', marker='o')\n", "ax2.set_xlabel('Epoch')\n", "ax2.set_ylabel('Accuracy (%)')\n", "ax2.set_title('Training and Validation Accuracy')\n", "ax2.legend()\n", "ax2.grid(True, alpha=0.3)\n", "\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "# Print final statistics\n", "print(f\"\\nFinal Statistics:\")\n", "print(f\"Best Validation Accuracy: {best_val_acc:.2f}%\")\n", "print(f\"Final Train Accuracy: {history['train_acc'][-1]:.2f}%\")\n", "print(f\"Final Validation Accuracy: {history['val_acc'][-1]:.2f}%\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "67c6eea4", "metadata": {}, "outputs": [], "source": [ "# Inference on test images\n", "def predict_image(model, image, transform, device='cuda'):\n", " \"\"\"\n", " Predict whether an image is real or fake using discriminator score\n", " Returns: (prediction, score, confidence)\n", " \"\"\"\n", " model.eval()\n", " with torch.no_grad():\n", " img_tensor = transform(image).unsqueeze(0).to(device)\n", " score = model(img_tensor).squeeze().item()\n", " \n", " # Positive score -> real, negative score -> fake\n", " label = \"Real\" if score > 0 else \"Fake\"\n", " \n", " # Convert score to confidence using sigmoid\n", " confidence = torch.sigmoid(torch.tensor(abs(score))).item()\n", " \n", " return label, score, confidence\n", "\n", "# Test on some examples from validation set\n", "num_test_samples = 5\n", "print(\"\\nTesting on validation samples:\")\n", "print(\"=\" * 60)\n", "\n", "for i in range(min(num_test_samples, len(val_real))):\n", " # Test real image\n", " real_img = val_real[i]\n", " label, score, conf = predict_image(model, real_img, val_transforms)\n", " print(f\"Real Image {i+1}: Predicted={label}, Score={score:.4f}, Confidence={conf:.2%}\")\n", " \n", " # Test fake image\n", " fake_img = val_fake[i]\n", " label, score, conf = predict_image(model, fake_img, val_transforms)\n", " print(f\"Fake Image {i+1}: Predicted={label}, Score={score:.4f}, Confidence={conf:.2%}\")\n", " print()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }