Spaces:
Runtime error
Runtime error
| # This file is based on the GauGAN by Rakshit et. al | |
| # https://keras.io/examples/generative/gaugan/ | |
| import tensorflow as tf | |
| import tensorflow_addons as tfa | |
| class SPADE(tf.keras.layers.Layer): | |
| def __init__(self, filters, epsilon=1e-5, **kwargs): | |
| super().__init__(**kwargs) | |
| self.epsilon = epsilon | |
| self.conv = tf.keras.layers.Conv2D(128, 3, padding="same", activation="relu") | |
| self.conv_gamma = tf.keras.layers.Conv2D(filters, 3, padding="same") | |
| self.conv_beta = tf.keras.layers.Conv2D(filters, 3, padding="same") | |
| def build(self, input_shape): | |
| self.resize_shape = input_shape[1:3] | |
| def call(self, input_tensor, raw_mask): | |
| mask = tf.image.resize(raw_mask, self.resize_shape, method="nearest") | |
| x = self.conv(mask) | |
| gamma = self.conv_gamma(x) | |
| beta = self.conv_beta(x) | |
| mean, var = tf.nn.moments(input_tensor, axes=(0, 1, 2), keepdims=True) | |
| std = tf.sqrt(var + self.epsilon) | |
| normalized = (input_tensor - mean) / std | |
| output = gamma * normalized + beta | |
| return output | |
| class ResBlock(tf.keras.layers.Layer): | |
| def __init__(self, filters, **kwargs): | |
| super().__init__(**kwargs) | |
| self.filters = filters | |
| def build(self, input_shape): | |
| input_filter = input_shape[-1] | |
| self.spade_1 = SPADE(input_filter) | |
| self.spade_2 = SPADE(self.filters) | |
| self.conv_1 = tf.keras.layers.Conv2D(self.filters, 3, padding="same") | |
| self.conv_2 = tf.keras.layers.Conv2D(self.filters, 3, padding="same") | |
| self.learned_skip = False | |
| if self.filters != input_filter: | |
| self.learned_skip = True | |
| self.spade_3 = SPADE(input_filter) | |
| self.conv_3 = tf.keras.layers.Conv2D(self.filters, 3, padding="same") | |
| def call(self, input_tensor, mask): | |
| x = self.spade_1(input_tensor, mask) | |
| x = self.conv_1(tf.nn.leaky_relu(x, 0.2)) | |
| x = self.spade_2(x, mask) | |
| x = self.conv_2(tf.nn.leaky_relu(x, 0.2)) | |
| skip = ( | |
| self.conv_3(tf.nn.leaky_relu(self.spade_3(input_tensor, mask), 0.2)) | |
| if self.learned_skip | |
| else input_tensor | |
| ) | |
| output = skip + x | |
| return output | |
| class GaussianSampler(tf.keras.layers.Layer): | |
| def __init__(self, batch_size, latent_dim, **kwargs): | |
| super().__init__(**kwargs) | |
| self.batch_size = batch_size | |
| self.latent_dim = latent_dim | |
| def call(self, inputs): | |
| means, variance = inputs | |
| epsilon = tf.random.normal( | |
| shape=(self.batch_size, self.latent_dim), mean=0.0, stddev=1.0 | |
| ) | |
| samples = means + tf.exp(0.5 * variance) * epsilon | |
| return samples | |
| def downsample( | |
| channels, | |
| kernels, | |
| strides=2, | |
| apply_norm=True, | |
| apply_activation=True, | |
| apply_dropout=False, | |
| ): | |
| block = tf.keras.Sequential() | |
| block.add( | |
| tf.keras.layers.Conv2D( | |
| channels, | |
| kernels, | |
| strides=strides, | |
| padding="same", | |
| use_bias=False, | |
| kernel_initializer=tf.keras.initializers.GlorotNormal(), | |
| ) | |
| ) | |
| if apply_norm: | |
| block.add(tfa.layers.InstanceNormalization()) | |
| if apply_activation: | |
| block.add(tf.keras.layers.LeakyReLU(0.2)) | |
| if apply_dropout: | |
| block.add(tf.keras.layers.Dropout(0.5)) | |
| return block | |
| def build_encoder(image_shape, encoder_downsample_factor=64, latent_dim=256): | |
| input_image = tf.keras.Input(shape=image_shape) | |
| x = downsample(encoder_downsample_factor, 3, apply_norm=False)(input_image) | |
| x = downsample(2 * encoder_downsample_factor, 3)(x) | |
| x = downsample(4 * encoder_downsample_factor, 3)(x) | |
| x = downsample(8 * encoder_downsample_factor, 3)(x) | |
| x = downsample(8 * encoder_downsample_factor, 3)(x) | |
| x = downsample(8 * encoder_downsample_factor, 3)(x) | |
| x = downsample(16 * encoder_downsample_factor, 3)(x) | |
| x = tf.keras.layers.Flatten()(x) | |
| mean = tf.keras.layers.Dense(latent_dim, name="mean")(x) | |
| variance = tf.keras.layers.Dense(latent_dim, name="variance")(x) | |
| return tf.keras.Model(input_image, [mean, variance], name="encoder") | |
| def build_generator(mask_shape, latent_dim=256): | |
| latent = tf.keras.Input(shape=(latent_dim)) | |
| mask = tf.keras.Input(shape=mask_shape) | |
| x = tf.keras.layers.Dense(16384)(latent) | |
| x = tf.keras.layers.Reshape((4, 4, 1024))(x) | |
| x = ResBlock(filters=1024)(x, mask) | |
| x = tf.keras.layers.UpSampling2D((2, 2))(x) | |
| x = ResBlock(filters=1024)(x, mask) | |
| x = tf.keras.layers.UpSampling2D((2, 2))(x) | |
| x = ResBlock(filters=1024)(x, mask) | |
| x = tf.keras.layers.UpSampling2D((2, 2))(x) | |
| x = ResBlock(filters=512)(x, mask) | |
| x = tf.keras.layers.UpSampling2D((2, 2))(x) | |
| x = ResBlock(filters=256)(x, mask) | |
| x = tf.keras.layers.UpSampling2D((2, 2))(x) | |
| x = ResBlock(filters=128)(x, mask) | |
| x = tf.keras.layers.UpSampling2D((2, 2))(x) | |
| x = ResBlock(filters=64)(x, mask) # These 2 added layers | |
| x = tf.keras.layers.UpSampling2D((2, 2))(x) # to make input 512x512 | |
| x = ResBlock(filters=32)(x, mask) # These 2 added layers | |
| x = tf.keras.layers.UpSampling2D((2, 2))(x) # to make input 1024x1024 | |
| x = tf.nn.leaky_relu(x, 0.2) | |
| output_image = tf.nn.sigmoid(tf.keras.layers.Conv2D(3, 4, padding="same")(x)) | |
| return tf.keras.Model([latent, mask], output_image, name="generator") | |
| def build_discriminator(image_shape, downsample_factor=64): | |
| input_image_A = tf.keras.Input(shape=image_shape, name="discriminator_image_A") | |
| input_image_B = tf.keras.Input(shape=image_shape, name="discriminator_image_B") | |
| x = tf.keras.layers.Concatenate()([input_image_A, input_image_B]) | |
| x1 = downsample(downsample_factor, 4, apply_norm=False)(x) | |
| x2 = downsample(2 * downsample_factor, 4)(x1) | |
| x3 = downsample(4 * downsample_factor, 4)(x2) | |
| x4 = downsample(8 * downsample_factor, 4)(x3) | |
| x5 = downsample(8 * downsample_factor, 4)(x4) | |
| x6 = downsample(8 * downsample_factor, 4)(x5) | |
| x7 = downsample(16 * downsample_factor, 4)(x6) | |
| x8 = tf.keras.layers.Conv2D(1, 4)(x7) | |
| outputs = [x1, x2, x3, x4, x5, x6, x7, x8] | |
| return tf.keras.Model([input_image_A, input_image_B], outputs) | |
| def generator_loss(y): | |
| return -tf.reduce_mean(y) | |
| def kl_divergence_loss(mean, variance): | |
| return -0.5 * tf.reduce_sum(1 + variance - tf.square(mean) - tf.exp(variance)) | |
| class FeatureMatchingLoss(tf.keras.losses.Loss): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.mae = tf.keras.losses.MeanAbsoluteError() | |
| def call(self, y_true, y_pred): | |
| loss = 0 | |
| for i in range(len(y_true) - 1): | |
| loss += self.mae(y_true[i], y_pred[i]) | |
| return loss | |
| class VGGFeatureMatchingLoss(tf.keras.losses.Loss): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.encoder_layers = [ | |
| "block1_conv1", | |
| "block2_conv1", | |
| "block3_conv1", | |
| "block4_conv1", | |
| "block5_conv1", | |
| ] | |
| self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] | |
| vgg = tf.keras.applications.VGG19(include_top=False, weights="imagenet") | |
| layer_outputs = [vgg.get_layer(x).output for x in self.encoder_layers] | |
| self.vgg_model = tf.keras.Model(vgg.input, layer_outputs, name="VGG") | |
| self.mae = tf.keras.losses.MeanAbsoluteError() | |
| def call(self, y_true, y_pred): | |
| y_true = tf.keras.applications.vgg19.preprocess_input(127.5 * (y_true + 1)) | |
| y_pred = tf.keras.applications.vgg19.preprocess_input(127.5 * (y_pred + 1)) | |
| real_features = self.vgg_model(y_true) | |
| fake_features = self.vgg_model(y_pred) | |
| loss = 0 | |
| for i in range(len(real_features)): | |
| loss += self.weights[i] * self.mae(real_features[i], fake_features[i]) | |
| return loss | |
| class DiscriminatorLoss(tf.keras.losses.Loss): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.hinge_loss = tf.keras.losses.Hinge() | |
| def call(self, y, is_real): | |
| label = 1.0 if is_real else -1.0 | |
| return self.hinge_loss(label, y) | |
| class GauGAN(tf.keras.Model): | |
| def __init__( | |
| self, | |
| image_size, | |
| num_classes, | |
| batch_size, | |
| latent_dim, | |
| feature_loss_coeff=10, | |
| vgg_feature_loss_coeff=0.1, | |
| kl_divergence_loss_coeff=0.1, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.image_size = image_size | |
| self.latent_dim = latent_dim | |
| self.batch_size = batch_size | |
| self.num_classes = num_classes | |
| self.image_shape = (image_size, image_size, 3) | |
| self.mask_shape = (image_size, image_size, num_classes) | |
| self.feature_loss_coeff = feature_loss_coeff | |
| self.vgg_feature_loss_coeff = vgg_feature_loss_coeff | |
| self.kl_divergence_loss_coeff = kl_divergence_loss_coeff | |
| self.discriminator = build_discriminator(self.image_shape) | |
| self.generator = build_generator(self.mask_shape, latent_dim=latent_dim) | |
| self.encoder = build_encoder(self.image_shape, latent_dim=latent_dim) | |
| self.sampler = GaussianSampler(batch_size, latent_dim) | |
| self.patch_size, self.combined_model = self.build_combined_generator() | |
| self.disc_loss_tracker = tf.keras.metrics.Mean(name="disc_loss") | |
| self.gen_loss_tracker = tf.keras.metrics.Mean(name="gen_loss") | |
| self.feat_loss_tracker = tf.keras.metrics.Mean(name="feat_loss") | |
| self.vgg_loss_tracker = tf.keras.metrics.Mean(name="vgg_loss") | |
| self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss") | |
| def metrics(self): | |
| return [ | |
| self.disc_loss_tracker, | |
| self.gen_loss_tracker, | |
| self.feat_loss_tracker, | |
| self.vgg_loss_tracker, | |
| self.kl_loss_tracker, | |
| ] | |
| def build_combined_generator(self): | |
| # This method builds a model that takes as inputs the following: | |
| # latent vector, one-hot encoded segmentation label map, and | |
| # a segmentation map. It then (i) generates an image with the generator, | |
| # (ii) passes the generated images and segmentation map to the discriminator. | |
| # Finally, the model produces the following outputs: (a) discriminator outputs, | |
| # (b) generated image. | |
| # We will be using this model to simplify the implementation. | |
| self.discriminator.trainable = False | |
| mask_input = tf.keras.Input(shape=self.mask_shape, name="mask") | |
| image_input = tf.keras.Input(shape=self.image_shape, name="image") | |
| latent_input = tf.keras.Input(shape=(self.latent_dim), name="latent") | |
| generated_image = self.generator([latent_input, mask_input]) | |
| discriminator_output = self.discriminator([image_input, generated_image]) | |
| patch_size = discriminator_output[-1].shape[1] | |
| combined_model = tf.keras.Model( | |
| [latent_input, mask_input, image_input], | |
| [discriminator_output, generated_image], | |
| ) | |
| return patch_size, combined_model | |
| def compile(self, gen_lr=1e-4, disc_lr=4e-4, **kwargs): | |
| super().compile(**kwargs) | |
| self.generator_optimizer = tf.keras.optimizers.Adam( | |
| gen_lr, beta_1=0.0, beta_2=0.999 | |
| ) | |
| self.discriminator_optimizer = tf.keras.optimizers.Adam( | |
| disc_lr, beta_1=0.0, beta_2=0.999 | |
| ) | |
| self.discriminator_loss = DiscriminatorLoss() | |
| self.feature_matching_loss = FeatureMatchingLoss() | |
| self.vgg_loss = VGGFeatureMatchingLoss() | |
| def train_discriminator(self, latent_vector, segmentation_map, real_image, labels): | |
| fake_images = self.generator([latent_vector, labels]) | |
| with tf.GradientTape() as gradient_tape: | |
| pred_fake = self.discriminator([segmentation_map, fake_images])[-1] | |
| pred_real = self.discriminator([segmentation_map, real_image])[-1] | |
| loss_fake = self.discriminator_loss(pred_fake, False) | |
| loss_real = self.discriminator_loss(pred_real, True) | |
| total_loss = 0.5 * (loss_fake + loss_real) | |
| self.discriminator.trainable = True | |
| gradients = gradient_tape.gradient( | |
| total_loss, self.discriminator.trainable_variables | |
| ) | |
| self.discriminator_optimizer.apply_gradients( | |
| zip(gradients, self.discriminator.trainable_variables) | |
| ) | |
| return total_loss | |
| def train_generator( | |
| self, latent_vector, segmentation_map, labels, image, mean, variance | |
| ): | |
| # Generator learns through the signal provided by the discriminator. During | |
| # backpropagation, we only update the generator parameters. | |
| self.discriminator.trainable = False | |
| with tf.GradientTape() as tape: | |
| real_d_output = self.discriminator([segmentation_map, image]) | |
| fake_d_output, fake_image = self.combined_model( | |
| [latent_vector, labels, segmentation_map] | |
| ) | |
| pred = fake_d_output[-1] | |
| # Compute generator losses. | |
| g_loss = generator_loss(pred) | |
| kl_loss = self.kl_divergence_loss_coeff * kl_divergence_loss(mean, variance) | |
| vgg_loss = self.vgg_feature_loss_coeff * self.vgg_loss(image, fake_image) | |
| feature_loss = self.feature_loss_coeff * self.feature_matching_loss(real_d_output, fake_d_output) | |
| total_loss = g_loss + kl_loss + vgg_loss + feature_loss | |
| gradients = tape.gradient(total_loss, self.combined_model.trainable_variables) | |
| self.generator_optimizer.apply_gradients( | |
| zip(gradients, self.combined_model.trainable_variables) | |
| ) | |
| return total_loss, feature_loss, vgg_loss, kl_loss | |
| def train_step(self, data): | |
| segmentation_map, image, labels = data | |
| mean, variance = self.encoder(image) | |
| latent_vector = self.sampler([mean, variance]) | |
| discriminator_loss = self.train_discriminator( | |
| latent_vector, segmentation_map, image, labels | |
| ) | |
| (generator_loss, feature_loss, vgg_loss, kl_loss) = self.train_generator( | |
| latent_vector, segmentation_map, labels, image, mean, variance | |
| ) | |
| # Report progress. | |
| self.disc_loss_tracker.update_state(discriminator_loss) | |
| self.gen_loss_tracker.update_state(generator_loss) | |
| self.feat_loss_tracker.update_state(feature_loss) | |
| self.vgg_loss_tracker.update_state(vgg_loss) | |
| self.kl_loss_tracker.update_state(kl_loss) | |
| results = {m.name: m.result() for m in self.metrics} | |
| return results | |
| def test_step(self, data): | |
| segmentation_map, image, labels = data | |
| # Obtain the learned moments of the real image distribution. | |
| mean, variance = self.encoder(image) | |
| # Sample a latent from the distribution defined by the learned moments. | |
| latent_vector = self.sampler([mean, variance]) | |
| # Generate the fake images. | |
| fake_images = self.generator([latent_vector, labels]) | |
| # Calculate the losses. | |
| pred_fake = self.discriminator([segmentation_map, fake_images])[-1] | |
| pred_real = self.discriminator([segmentation_map, image])[-1] | |
| loss_fake = self.discriminator_loss(pred_fake, False) | |
| loss_real = self.discriminator_loss(pred_real, True) | |
| total_discriminator_loss = 0.5 * (loss_fake + loss_real) | |
| real_d_output = self.discriminator([segmentation_map, image]) | |
| fake_d_output, fake_image = self.combined_model( | |
| [latent_vector, labels, segmentation_map] | |
| ) | |
| pred = fake_d_output[-1] | |
| g_loss = generator_loss(pred) | |
| kl_loss = self.kl_divergence_loss_coeff * kl_divergence_loss(mean, variance) | |
| vgg_loss = self.vgg_feature_loss_coeff * self.vgg_loss(image, fake_image) | |
| feature_loss = self.feature_loss_coeff * self.feature_matching_loss( | |
| real_d_output, fake_d_output | |
| ) | |
| total_generator_loss = g_loss + kl_loss + vgg_loss + feature_loss | |
| # Report progress. | |
| self.disc_loss_tracker.update_state(total_discriminator_loss) | |
| self.gen_loss_tracker.update_state(total_generator_loss) | |
| self.feat_loss_tracker.update_state(feature_loss) | |
| self.vgg_loss_tracker.update_state(vgg_loss) | |
| self.kl_loss_tracker.update_state(kl_loss) | |
| results = {m.name: m.result() for m in self.metrics} | |
| return results | |
| def call(self, inputs): | |
| latent_vectors, labels = inputs | |
| return self.generator([latent_vectors, labels]) | |