| | from flax import linen as nn |
| | import jax |
| | import jax.numpy as jnp |
| | from local_response_norm import LocalResponseNorm |
| |
|
| | EPSILON = 1e-8 |
| | MAX_DISC_FEATURES = 128 |
| | MAX_GEN_FEATURES = 512 |
| | LATENT_DIM = 512 |
| | MAX_LAYERS = 7 |
| |
|
| | def get_gen_layers(layer): |
| | resolution = int(4 * 2 ** layer) |
| | features = min(int(32 * 2 ** (MAX_LAYERS - 1 - layer)), MAX_GEN_FEATURES) |
| | layers = [] |
| | layers.append(lambda x: jax.image.resize(x, shape=(x.shape[0], resolution, resolution, x.shape[3]), method="linear")) |
| | layers.append(lambda x: nn.ConvTranspose(features=features, kernel_size=(3, 3), name=f"ConvTranspose_{resolution}_{features}")(x)) |
| | layers.append(lambda x: nn.relu(x)) |
| | return layers |
| |
|
| | def get_initial_gen_layers(num_layers): |
| | layers = [] |
| | layers.append(lambda x: x.reshape(x.shape[0], 1, 1, -1)) |
| | return layers |
| |
|
| | def get_final_gen_layers(num_layers): |
| | resolution = int(4 * 2 ** (num_layers - 1)) |
| | layers = [] |
| | layers.append(lambda x: nn.ConvTranspose(features=3, kernel_size=(3, 3), name=f"ConvTranspose_{resolution}_3")(x)) |
| | return layers |
| |
|
| | class Generator(nn.Module): |
| | num_layers: int = None |
| |
|
| | def setup(self): |
| | |
| | layers = [] |
| | layers.extend(get_initial_gen_layers(self.num_layers)) |
| | for layer in range(self.num_layers): |
| | layers.extend(get_gen_layers(layer)) |
| | layers.extend(get_final_gen_layers(self.num_layers)) |
| | self.layers = layers |
| |
|
| | @nn.compact |
| | def __call__(self, x): |
| | result = x |
| | for layer in self.layers: |
| | result = layer(result) |
| | return result |