| | import math |
| | import tensorflow as tf |
| | import numpy as np |
| | import dnnlib.tflib as tflib |
| | from functools import partial |
| |
|
| |
|
| | def create_stub(name, batch_size): |
| | return tf.constant(0, dtype='float32', shape=(batch_size, 0)) |
| |
|
| |
|
| | def create_variable_for_generator(name, batch_size, tiled_dlatent, model_scale=18, tile_size = 1): |
| | if tiled_dlatent: |
| | low_dim_dlatent = tf.get_variable('learnable_dlatents', |
| | shape=(batch_size, tile_size, 512), |
| | dtype='float32', |
| | initializer=tf.initializers.random_normal()) |
| | return tf.tile(low_dim_dlatent, [1, model_scale // tile_size, 1]) |
| | else: |
| | return tf.get_variable('learnable_dlatents', |
| | shape=(batch_size, model_scale, 512), |
| | dtype='float32', |
| | initializer=tf.initializers.random_normal()) |
| |
|
| |
|
| | class Generator: |
| | def __init__(self, model, batch_size, custom_input=None, clipping_threshold=2, tiled_dlatent=False, model_res=1024, randomize_noise=False): |
| | self.batch_size = batch_size |
| | self.tiled_dlatent=tiled_dlatent |
| | self.model_scale = int(2*(math.log(model_res,2)-1)) |
| |
|
| | if tiled_dlatent: |
| | self.initial_dlatents = np.zeros((self.batch_size, 512)) |
| | model.components.synthesis.run(np.zeros((self.batch_size, self.model_scale, 512)), |
| | randomize_noise=randomize_noise, minibatch_size=self.batch_size, |
| | custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size, tiled_dlatent=True), |
| | partial(create_stub, batch_size=batch_size)], |
| | structure='fixed') |
| | else: |
| | self.initial_dlatents = np.zeros((self.batch_size, self.model_scale, 512)) |
| | if custom_input is not None: |
| | model.components.synthesis.run(self.initial_dlatents, |
| | randomize_noise=randomize_noise, minibatch_size=self.batch_size, |
| | custom_inputs=[partial(custom_input.eval(), batch_size=batch_size), partial(create_stub, batch_size=batch_size)], |
| | structure='fixed') |
| | else: |
| | model.components.synthesis.run(self.initial_dlatents, |
| | randomize_noise=randomize_noise, minibatch_size=self.batch_size, |
| | custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size, tiled_dlatent=False, model_scale=self.model_scale), |
| | partial(create_stub, batch_size=batch_size)], |
| | structure='fixed') |
| |
|
| | self.dlatent_avg_def = model.get_var('dlatent_avg') |
| | self.reset_dlatent_avg() |
| | self.sess = tf.compat.v1.get_default_session() |
| | self.graph = tf.compat.v1.get_default_graph() |
| |
|
| | self.dlatent_variable = next(v for v in tf.compat.v1.global_variables() if 'learnable_dlatents' in v.name) |
| | self._assign_dlatent_ph = tf.compat.v1.placeholder(tf.float32, name="assign_dlatent_ph") |
| | self._assign_dlantent = tf.assign(self.dlatent_variable, self._assign_dlatent_ph) |
| | self.set_dlatents(self.initial_dlatents) |
| |
|
| | def get_tensor(name): |
| | try: |
| | return self.graph.get_tensor_by_name(name) |
| | except KeyError: |
| | return None |
| |
|
| | self.generator_output = get_tensor('G_synthesis_1/_Run/concat:0') |
| | if self.generator_output is None: |
| | self.generator_output = get_tensor('G_synthesis_1/_Run/concat/concat:0') |
| | if self.generator_output is None: |
| | self.generator_output = get_tensor('G_synthesis_1/_Run/concat_1/concat:0') |
| | |
| | if self.generator_output is None: |
| | self.generator_output = get_tensor('G_synthesis/_Run/concat:0') |
| | if self.generator_output is None: |
| | self.generator_output = get_tensor('G_synthesis/_Run/concat/concat:0') |
| | if self.generator_output is None: |
| | self.generator_output = get_tensor('G_synthesis/_Run/concat_1/concat:0') |
| | if self.generator_output is None: |
| | for op in self.graph.get_operations(): |
| | print(op) |
| | raise Exception("Couldn't find G_synthesis_1/_Run/concat tensor output") |
| | self.generated_image = tflib.convert_images_to_uint8(self.generator_output, nchw_to_nhwc=True, uint8_cast=False) |
| | self.generated_image_uint8 = tf.saturate_cast(self.generated_image, tf.uint8) |
| |
|
| | |
| | |
| | |
| | clipping_mask = tf.math.logical_or(self.dlatent_variable > clipping_threshold, self.dlatent_variable < -clipping_threshold) |
| | clipped_values = tf.where(clipping_mask, tf.random.normal(shape=self.dlatent_variable.shape), self.dlatent_variable) |
| | self.stochastic_clip_op = tf.assign(self.dlatent_variable, clipped_values) |
| |
|
| | def reset_dlatents(self): |
| | self.set_dlatents(self.initial_dlatents) |
| |
|
| | def set_dlatents(self, dlatents): |
| | if self.tiled_dlatent: |
| | if (dlatents.shape != (self.batch_size, 512)) and (dlatents.shape[1] != 512): |
| | dlatents = np.mean(dlatents, axis=1) |
| | if (dlatents.shape != (self.batch_size, 512)): |
| | dlatents = np.vstack([dlatents, np.zeros((self.batch_size-dlatents.shape[0], 512))]) |
| | assert (dlatents.shape == (self.batch_size, 512)) |
| | else: |
| | if (dlatents.shape[1] > self.model_scale): |
| | dlatents = dlatents[:,:self.model_scale,:] |
| | if (isinstance(dlatents.shape[0], int)): |
| | if (dlatents.shape != (self.batch_size, self.model_scale, 512)): |
| | dlatents = np.vstack([dlatents, np.zeros((self.batch_size-dlatents.shape[0], self.model_scale, 512))]) |
| | assert (dlatents.shape == (self.batch_size, self.model_scale, 512)) |
| | self.sess.run([self._assign_dlantent], {self._assign_dlatent_ph: dlatents}) |
| | return |
| | else: |
| | self._assign_dlantent = tf.assign(self.dlatent_variable, dlatents) |
| | return |
| | self.sess.run([self._assign_dlantent], {self._assign_dlatent_ph: dlatents}) |
| |
|
| | def stochastic_clip_dlatents(self): |
| | self.sess.run(self.stochastic_clip_op) |
| |
|
| | def get_dlatents(self): |
| | return self.sess.run(self.dlatent_variable) |
| |
|
| | def get_dlatent_avg(self): |
| | return self.dlatent_avg |
| |
|
| | def set_dlatent_avg(self, dlatent_avg): |
| | self.dlatent_avg = dlatent_avg |
| |
|
| | def reset_dlatent_avg(self): |
| | self.dlatent_avg = self.dlatent_avg_def |
| |
|
| | def generate_images(self, dlatents=None): |
| | if dlatents is not None: |
| | self.set_dlatents(dlatents) |
| | return self.sess.run(self.generated_image_uint8) |
| |
|