Spaces:
Sleeping
Sleeping
| import json | |
| import torch | |
| import torch.nn as nn | |
| import os | |
| from pathlib import Path | |
| from typing import Optional, Union, Dict | |
| from huggingface_hub import snapshot_download | |
| import warnings | |
| class ConvVAE(nn.Module): | |
| def __init__(self, latent_size): | |
| super(ConvVAE, self).__init__() | |
| # Encoder | |
| self.encoder = nn.Sequential( | |
| nn.Conv2d(3, 64, 3, stride=2, padding=1), # (batch, 64, 64, 64) | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| nn.Conv2d(64, 128, 3, stride=2, padding=1), # (batch, 128, 32, 32) | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(), | |
| nn.Conv2d(128, 256, 3, stride=2, padding=1), # (batch, 256, 16, 16) | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(), | |
| nn.Conv2d(256, 512, 3, stride=2, padding=1), # (batch, 512, 8, 8) | |
| nn.BatchNorm2d(512), | |
| nn.ReLU() | |
| ) | |
| self.fc_mu = nn.Linear(512 * 8 * 8, latent_size) | |
| self.fc_logvar = nn.Linear(512 * 8 * 8, latent_size) | |
| self.fc2 = nn.Linear(latent_size, 512 * 8 * 8) | |
| self.decoder = nn.Sequential( | |
| nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), # (batch, 256, 16, 16) | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(), | |
| nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # (batch, 128, 32, 32) | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(), | |
| nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # (batch, 64, 64, 64) | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1), # (batch, 3, 128, 128) | |
| nn.Tanh() | |
| ) | |
| def forward(self, x): | |
| mu, logvar = self.encode(x) | |
| z = self.reparameterize(mu, logvar) | |
| decoded = self.decode(z) | |
| return decoded, mu, logvar | |
| def encode(self, x): | |
| x = self.encoder(x) | |
| x = x.view(x.size(0), -1) | |
| mu = self.fc_mu(x) | |
| logvar = self.fc_logvar(x) | |
| return mu, logvar | |
| def reparameterize(self, mu, logvar): | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| return mu + eps * std | |
| def decode(self, z): | |
| x = self.fc2(z) | |
| x = x.view(-1, 512, 8, 8) | |
| decoded = self.decoder(x) | |
| return decoded | |
| def from_pretrained( | |
| cls, | |
| model_id: str, | |
| revision: Optional[str] = None, | |
| cache_dir: Optional[Union[str, Path]] = None, | |
| force_download: bool = False, | |
| proxies: Optional[Dict] = None, | |
| resume_download: bool = False, | |
| local_files_only: bool = False, | |
| token: Union[str, bool, None] = None, | |
| map_location: str = "cpu", | |
| strict: bool = False, | |
| **model_kwargs, | |
| ): | |
| """ | |
| Load a pretrained model from a given model ID. | |
| Args: | |
| model_id (str): Identifier of the model to load. | |
| revision (Optional[str]): Specific model revision to use. | |
| cache_dir (Optional[Union[str, Path]]): Directory to store downloaded models. | |
| force_download (bool): Force re-download even if the model exists. | |
| proxies (Optional[Dict]): Proxy configuration for downloads. | |
| resume_download (bool): Resume interrupted downloads. | |
| local_files_only (bool): Use only local files, don't download. | |
| token (Union[str, bool, None]): Token for API authentication. | |
| map_location (str): Device to map model to. Defaults to "cpu". | |
| strict (bool): Enforce strict state_dict loading. | |
| **model_kwargs: Additional keyword arguments for model initialization. | |
| Returns: | |
| An instance of the model loaded from the pretrained weights. | |
| """ | |
| model_dir = Path(model_id) | |
| if not model_dir.exists(): | |
| model_dir = Path( | |
| snapshot_download( | |
| repo_id=model_id, | |
| revision=revision, | |
| cache_dir=cache_dir, | |
| force_download=force_download, | |
| proxies=proxies, | |
| resume_download=resume_download, | |
| token=token, | |
| local_files_only=local_files_only, | |
| ) | |
| ) | |
| config_file = model_dir / "config.json" | |
| with open(config_file, 'r') as f: | |
| config = json.load(f) | |
| latent_size = config.get('latent_size') | |
| if latent_size is None: | |
| raise ValueError("The configuration file is missing the 'latent_size' key.") | |
| model = cls(latent_size, **model_kwargs) | |
| model_file = model_dir / "model_conv_vae_256_epoch_304.pth" | |
| if not model_file.exists(): | |
| raise FileNotFoundError(f"The model checkpoint '{model_file}' does not exist.") | |
| state_dict = torch.load(model_file, map_location=map_location) | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| if k.startswith('_orig_mod.'): | |
| new_state_dict[k[len('_orig_mod.'):]] = v | |
| else: | |
| new_state_dict[k] = v | |
| model.load_state_dict(new_state_dict, strict=strict) | |
| model.to(map_location) | |
| return model | |
| model = ConvVAE.from_pretrained( | |
| model_id="BioMike/classical_portrait_vae", | |
| cache_dir="./model_cache", | |
| map_location="cpu", | |
| strict=True).eval() | |