Keras-WGAN Critic和Generator的准确率为0

我正在尝试在Keras中实现WGAN。我正在使用David Foster的生成性深度学习书籍和this code作为参考。我写下了这段简单的代码。然而,每当我开始训练模型时,准确率始终为0,Critic和Discriminator的损失为~0。

无论他们为多少个时期训练,他们都会被困在这些数字上。我尝试了各种网络配置和不同的超参数,但结果似乎没有变化。谷歌也没能帮上多少忙。我无法确定这种行为的来源。

这是我写的代码。

from os.path import expanduser
import os
import struct as st

import numpy as np
import matplotlib.pyplot as plt

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import RMSprop
import keras.backend as K

def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true * y_pred)

class WGAN:

    def __init__(self):

        # Data Params
        self.genInput=100
        self.imChannels=1
        self.imShape = (28,28,1)

        # Build Models
        self.onBuildDiscriminator()
        self.onBuildGenerator()
        self.onBuildGAN()

        pass

    def onBuildGAN(self):

        if self.mGenerator is None or self.mDiscriminator is None: raise Exception('Generator Or Descriminator Uninitialized.')

        self.mDiscriminator.trainable=False

        self.mGAN=Sequential()
        self.mGAN.add(self.mGenerator)
        self.mGAN.add(self.mDiscriminator)

        ganOptimizer=RMSprop(lr=0.00005)
        self.mGAN.compile(loss=wasserstein_loss, optimizer=ganOptimizer, metrics=['accuracy'])

        print('GAN Model')
        self.mGAN.summary()
        pass

    def onBuildGenerator(self):

        self.mGenerator=Sequential()

        self.mGenerator.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.genInput))
        self.mGenerator.add(Reshape((7, 7, 128)))
        self.mGenerator.add(UpSampling2D())
        self.mGenerator.add(Conv2D(128, kernel_size=4, padding="same"))
        self.mGenerator.add(BatchNormalization(momentum=0.8))
        self.mGenerator.add(Activation("relu"))
        self.mGenerator.add(UpSampling2D())
        self.mGenerator.add(Conv2D(64, kernel_size=4, padding="same"))
        self.mGenerator.add(BatchNormalization(momentum=0.8))
        self.mGenerator.add(Activation("relu"))
        self.mGenerator.add(Conv2D(self.imChannels, kernel_size=4, padding="same"))
        self.mGenerator.add(Activation("tanh"))

        print('Generator Model')
        self.mGenerator.summary()
        pass

    def onBuildDiscriminator(self):

        self.mDiscriminator = Sequential()

        self.mDiscriminator.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.imShape, padding="same"))
        self.mDiscriminator.add(LeakyReLU(alpha=0.2))
        self.mDiscriminator.add(Dropout(0.25))
        self.mDiscriminator.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
        self.mDiscriminator.add(ZeroPadding2D(padding=((0,1),(0,1))))
        self.mDiscriminator.add(BatchNormalization(momentum=0.8))
        self.mDiscriminator.add(LeakyReLU(alpha=0.2))
        self.mDiscriminator.add(Dropout(0.25))
        self.mDiscriminator.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        self.mDiscriminator.add(BatchNormalization(momentum=0.8))
        self.mDiscriminator.add(LeakyReLU(alpha=0.2))
        self.mDiscriminator.add(Dropout(0.25))
        self.mDiscriminator.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
        self.mDiscriminator.add(BatchNormalization(momentum=0.8))
        self.mDiscriminator.add(LeakyReLU(alpha=0.2))
        self.mDiscriminator.add(Dropout(0.25))
        self.mDiscriminator.add(Flatten())
        self.mDiscriminator.add(Dense(1))

        disOptimizer=RMSprop(lr=0.00005)
        self.mDiscriminator.compile(loss=wasserstein_loss, optimizer=disOptimizer, metrics=['accuracy'])

        print('Discriminator Model')
        self.mDiscriminator.summary()

        pass

    def fit(self, trainData, nEpochs=1000, batchSize=64):

        lblForReal = -np.ones((batchSize, 1))
        lblForGene = np.ones((batchSize, 1))

        for ep in range(1, nEpochs+1):

            for __ in range(5):

                # Get Valid Images
                validImages = trainData[ np.random.randint(0, trainData.shape[0], batchSize) ]

                # Get Generated Images
                noiseForGene=np.random.normal(0, 1, size=(batchSize, self.genInput))
                geneImages=self.mGenerator.predict(noiseForGene)

                # Train Critic On Valid And Generated Images With Labels -1 And 1 Respectively
                disValidLoss=self.mDiscriminator.train_on_batch(validImages, lblForReal)
                disGeneLoss=self.mDiscriminator.train_on_batch(geneImages, lblForGene)

                # Perform Critic Weight Clipping
                for l in self.mDiscriminator.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, -0.01, 0.01) for w in weights]
                    l.set_weights(weights)

            # Train Generator Using Combined Model
            geneLoss=self.mGAN.train_on_batch(noiseForGene, lblForReal)

            print(' Epoch', ep, 'Critic Valid Loss,Acc', disValidLoss, 'Critic Generated Loss,Acc', disGeneLoss, 'Generator Loss,Acc', geneLoss)
        pass

    pass

if __name__ == '__main__':
    (trainData, __), (__, __) = mnist.load_data()
    trainData = (trainData.astype(np.float32)/127.5) - 1
    trainData = np.expand_dims(trainData, axis=3)

    WGan = WGAN()
    WGan.fit(trainData)

对于我尝试的所有配置,我得到的输出与以下内容非常相似。

 Epoch 1 Critic Valid Loss,Acc [-0.00016362152, 0.0] Critic Generated Loss,Acc [0.0003417502, 0.0] Generator Loss,Acc [-0.00016735379, 0.0]
 Epoch 2 Critic Valid Loss,Acc [-0.0001719332, 0.0] Critic Generated Loss,Acc [0.0003365979, 0.0] Generator Loss,Acc [-0.00017250411, 0.0]
 Epoch 3 Critic Valid Loss,Acc [-0.00017473527, 0.0] Critic Generated Loss,Acc [0.00032945914, 0.0] Generator Loss,Acc [-0.00017612436, 0.0]
 Epoch 4 Critic Valid Loss,Acc [-0.00017181305, 0.0] Critic Generated Loss,Acc [0.0003266656, 0.0] Generator Loss,Acc [-0.00016987178, 0.0]
 Epoch 5 Critic Valid Loss,Acc [-0.0001683443, 0.0] Critic Generated Loss,Acc [0.00032702673, 0.0] Generator Loss,Acc [-0.00016638976, 0.0]
 Epoch 6 Critic Valid Loss,Acc [-0.00017005506, 0.0] Critic Generated Loss,Acc [0.00032805002, 0.0] Generator Loss,Acc [-0.00017040147, 0.0]
 Epoch 7 Critic Valid Loss,Acc [-0.00017353195, 0.0] Critic Generated Loss,Acc [0.00033711304, 0.0] Generator Loss,Acc [-0.00017537423, 0.0]
 Epoch 8 Critic Valid Loss,Acc [-0.00017059325, 0.0] Critic Generated Loss,Acc [0.0003263024, 0.0] Generator Loss,Acc [-0.00016974319, 0.0]
 Epoch 9 Critic Valid Loss,Acc [-0.00017530039, 0.0] Critic Generated Loss,Acc [0.00032463064, 0.0] Generator Loss,Acc [-0.00017845634, 0.0]
 Epoch 10 Critic Valid Loss,Acc [-0.00017530067, 0.0] Critic Generated Loss,Acc [0.00033131015, 0.0] Generator Loss,Acc [-0.00017526663, 0.0]

转载请注明出处:http://www.xgclsm.com/article/20230526/2401658.html