Update README.md
Browse files
README.md
CHANGED
|
@@ -57,68 +57,7 @@ class ModelColorization(nn.Module, PyTorchModelHubMixin):
|
|
| 57 |
x = self.decoder(x)
|
| 58 |
return x
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
md
|
| 63 |
-
Copy code
|
| 64 |
-
---
|
| 65 |
-
tags:
|
| 66 |
-
- autoencoder
|
| 67 |
-
- image-colorization
|
| 68 |
-
- pytorch
|
| 69 |
-
- pytorch_model_hub_mixin
|
| 70 |
-
---
|
| 71 |
-
|
| 72 |
-
# Model Colorization Autoencoder
|
| 73 |
-
|
| 74 |
-
## Model Description
|
| 75 |
-
|
| 76 |
-
This autoencoder model is designed for image colorization. It takes grayscale images as input and outputs colorized versions of those images. The model architecture consists of an encoder-decoder structure, where the encoder compresses the input image into a latent representation, and the decoder reconstructs the image in color.
|
| 77 |
-
|
| 78 |
-
### Architecture
|
| 79 |
-
|
| 80 |
-
- **Encoder**: The encoder comprises three convolutional layers followed by max pooling and ReLU activations, each paired with batch normalization. It ends with a flattening layer and a fully connected layer to produce a latent vector.
|
| 81 |
-
- **Decoder**: The decoder mirrors the encoder, using linear and transposed convolutional layers with ReLU activations and batch normalization. The final layer outputs a color image using a sigmoid activation function.
|
| 82 |
-
|
| 83 |
-
The architecture details are as follows:
|
| 84 |
-
```python
|
| 85 |
-
class ModelColorization(nn.Module, PyTorchModelHubMixin):
|
| 86 |
-
def __init__(self):
|
| 87 |
-
super(ModelColorization, self).__init__()
|
| 88 |
-
self.encoder = nn.Sequential(
|
| 89 |
-
nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
|
| 90 |
-
nn.MaxPool2d(kernel_size=2, stride=2),
|
| 91 |
-
nn.ReLU(),
|
| 92 |
-
nn.BatchNorm2d(64),
|
| 93 |
-
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
|
| 94 |
-
nn.MaxPool2d(kernel_size=2, stride=2),
|
| 95 |
-
nn.ReLU(),
|
| 96 |
-
nn.BatchNorm2d(32),
|
| 97 |
-
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
|
| 98 |
-
nn.MaxPool2d(kernel_size=2, stride=2),
|
| 99 |
-
nn.ReLU(),
|
| 100 |
-
nn.BatchNorm2d(16),
|
| 101 |
-
nn.Flatten(),
|
| 102 |
-
nn.Linear(16*45*45, 4000),
|
| 103 |
-
)
|
| 104 |
-
self.decoder = nn.Sequential(
|
| 105 |
-
nn.Linear(4000, 16 * 45 * 45),
|
| 106 |
-
nn.ReLU(),
|
| 107 |
-
nn.Unflatten(1, (16, 45, 45)),
|
| 108 |
-
nn.ConvTranspose2d(16, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
|
| 109 |
-
nn.ReLU(),
|
| 110 |
-
nn.BatchNorm2d(32),
|
| 111 |
-
nn.ConvTranspose2d(32, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
|
| 112 |
-
nn.ReLU(),
|
| 113 |
-
nn.BatchNorm2d(64),
|
| 114 |
-
nn.ConvTranspose2d(64, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
|
| 115 |
-
nn.Sigmoid()
|
| 116 |
-
)
|
| 117 |
-
|
| 118 |
-
def forward(self, x):
|
| 119 |
-
x = self.encoder(x)
|
| 120 |
-
x = self.decoder(x)
|
| 121 |
-
return x
|
| 122 |
|
| 123 |
### Training Details
|
| 124 |
The model was trained using PyTorch for 5 epochs. Here are the training and validation losses observed during the training:
|
|
@@ -140,4 +79,4 @@ pip install torch torchvision transformers
|
|
| 140 |
from transformers import AutoModel
|
| 141 |
|
| 142 |
model = AutoModel.from_pretrained("sebastiansarasti/AutoEncoderImageColorization")
|
| 143 |
-
```
|
|
|
|
| 57 |
x = self.decoder(x)
|
| 58 |
return x
|
| 59 |
|
| 60 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
### Training Details
|
| 63 |
The model was trained using PyTorch for 5 epochs. Here are the training and validation losses observed during the training:
|
|
|
|
| 79 |
from transformers import AutoModel
|
| 80 |
|
| 81 |
model = AutoModel.from_pretrained("sebastiansarasti/AutoEncoderImageColorization")
|
| 82 |
+
```python
|