from modules.unet import UNetModel from generative.networks.nets import VQVAE from config import config myUnet = UNetModel( image_size=config.image_size / config.r, model_channels=128, in_channels=8, out_channels=8, num_res_blocks=8, num_heads=8, attention_resolutions=(64, 32, 16, 8), num_heads_upsample=-1, num_head_channels=-1, resblock_updown=True, channel_mult=(1, 1, 2, 2, 4, 4), use_scale_shift_norm=True, use_new_attention_order=True ) myVQGANModel = VQVAE( spatial_dims=2, in_channels=1, out_channels=1, num_channels=(128, 256, 512), num_res_channels=512, num_res_layers=2, downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1),), upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), num_embeddings=1024, embedding_dim=4, ) if __name__ == "__main__": print("Number of model parameters:", sum([p.numel() for p in myUnet.parameters()])) print("Number of model parameters:", sum([p.numel() for p in myVQGANModel.parameters()]))