Keras-WGAN Critic和Generator的准确率为0
我正在尝试在Keras中实现WGAN。我正在使用David Foster的生成性深度学习书籍和this code作为参考。我写下了这段简单的代码。然而,每当我开始训练模型时,准确率始终为0,Critic和Discriminator的损失为~0。
无论他们为多少个时期训练,他们都会被困在这些数字上。我尝试了各种网络配置和不同的超参数,但结果似乎没有变化。谷歌也没能帮上多少忙。我无法确定这种行为的来源。
这是我写的代码。
from os.path import expanduser
import os
import struct as st
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import RMSprop
import keras.backend as K
def wasserstein_loss(y_true, y_pred):
return K.mean(y_true * y_pred)
class WGAN:
def __init__(self):
# Data Params
self.genInput=100
self.imChannels=1
self.imShape = (28,28,1)
# Build Models
self.onBuildDiscriminator()
self.onBuildGenerator()
self.onBuildGAN()
pass
def onBuildGAN(self):
if self.mGenerator is None or self.mDiscriminator is None: raise Exception('Generator Or Descriminator Uninitialized.')
self.mDiscriminator.trainable=False
self.mGAN=Sequential()
self.mGAN.add(self.mGenerator)
self.mGAN.add(self.mDiscriminator)
ganOptimizer=RMSprop(lr=0.00005)
self.mGAN.compile(loss=wasserstein_loss, optimizer=ganOptimizer, metrics=['accuracy'])
print('GAN Model')
self.mGAN.summary()
pass
def onBuildGenerator(self):
self.mGenerator=Sequential()
self.mGenerator.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.genInput))
self.mGenerator.add(Reshape((7, 7, 128)))
self.mGenerator.add(UpSampling2D())
self.mGenerator.add(Conv2D(128, kernel_size=4, padding="same"))
self.mGenerator.add(BatchNormalization(momentum=0.8))
self.mGenerator.add(Activation("relu"))
self.mGenerator.add(UpSampling2D())
self.mGenerator.add(Conv2D(64, kernel_size=4, padding="same"))
self.mGenerator.add(BatchNormalization(momentum=0.8))
self.mGenerator.add(Activation("relu"))
self.mGenerator.add(Conv2D(self.imChannels, kernel_size=4, padding="same"))
self.mGenerator.add(Activation("tanh"))
print('Generator Model')
self.mGenerator.summary()
pass
def onBuildDiscriminator(self):
self.mDiscriminator = Sequential()
self.mDiscriminator.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.imShape, padding="same"))
self.mDiscriminator.add(LeakyReLU(alpha=0.2))
self.mDiscriminator.add(Dropout(0.25))
self.mDiscriminator.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
self.mDiscriminator.add(ZeroPadding2D(padding=((0,1),(0,1))))
self.mDiscriminator.add(BatchNormalization(momentum=0.8))
self.mDiscriminator.add(LeakyReLU(alpha=0.2))
self.mDiscriminator.add(Dropout(0.25))
self.mDiscriminator.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
self.mDiscriminator.add(BatchNormalization(momentum=0.8))
self.mDiscriminator.add(LeakyReLU(alpha=0.2))
self.mDiscriminator.add(Dropout(0.25))
self.mDiscriminator.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
self.mDiscriminator.add(BatchNormalization(momentum=0.8))
self.mDiscriminator.add(LeakyReLU(alpha=0.2))
self.mDiscriminator.add(Dropout(0.25))
self.mDiscriminator.add(Flatten())
self.mDiscriminator.add(Dense(1))
disOptimizer=RMSprop(lr=0.00005)
self.mDiscriminator.compile(loss=wasserstein_loss, optimizer=disOptimizer, metrics=['accuracy'])
print('Discriminator Model')
self.mDiscriminator.summary()
pass
def fit(self, trainData, nEpochs=1000, batchSize=64):
lblForReal = -np.ones((batchSize, 1))
lblForGene = np.ones((batchSize, 1))
for ep in range(1, nEpochs+1):
for __ in range(5):
# Get Valid Images
validImages = trainData[ np.random.randint(0, trainData.shape[0], batchSize) ]
# Get Generated Images
noiseForGene=np.random.normal(0, 1, size=(batchSize, self.genInput))
geneImages=self.mGenerator.predict(noiseForGene)
# Train Critic On Valid And Generated Images With Labels -1 And 1 Respectively
disValidLoss=self.mDiscriminator.train_on_batch(validImages, lblForReal)
disGeneLoss=self.mDiscriminator.train_on_batch(geneImages, lblForGene)
# Perform Critic Weight Clipping
for l in self.mDiscriminator.layers:
weights = l.get_weights()
weights = [np.clip(w, -0.01, 0.01) for w in weights]
l.set_weights(weights)
# Train Generator Using Combined Model
geneLoss=self.mGAN.train_on_batch(noiseForGene, lblForReal)
print(' Epoch', ep, 'Critic Valid Loss,Acc', disValidLoss, 'Critic Generated Loss,Acc', disGeneLoss, 'Generator Loss,Acc', geneLoss)
pass
pass
if __name__ == '__main__':
(trainData, __), (__, __) = mnist.load_data()
trainData = (trainData.astype(np.float32)/127.5) - 1
trainData = np.expand_dims(trainData, axis=3)
WGan = WGAN()
WGan.fit(trainData)
对于我尝试的所有配置,我得到的输出与以下内容非常相似。
Epoch 1 Critic Valid Loss,Acc [-0.00016362152, 0.0] Critic Generated Loss,Acc [0.0003417502, 0.0] Generator Loss,Acc [-0.00016735379, 0.0]
Epoch 2 Critic Valid Loss,Acc [-0.0001719332, 0.0] Critic Generated Loss,Acc [0.0003365979, 0.0] Generator Loss,Acc [-0.00017250411, 0.0]
Epoch 3 Critic Valid Loss,Acc [-0.00017473527, 0.0] Critic Generated Loss,Acc [0.00032945914, 0.0] Generator Loss,Acc [-0.00017612436, 0.0]
Epoch 4 Critic Valid Loss,Acc [-0.00017181305, 0.0] Critic Generated Loss,Acc [0.0003266656, 0.0] Generator Loss,Acc [-0.00016987178, 0.0]
Epoch 5 Critic Valid Loss,Acc [-0.0001683443, 0.0] Critic Generated Loss,Acc [0.00032702673, 0.0] Generator Loss,Acc [-0.00016638976, 0.0]
Epoch 6 Critic Valid Loss,Acc [-0.00017005506, 0.0] Critic Generated Loss,Acc [0.00032805002, 0.0] Generator Loss,Acc [-0.00017040147, 0.0]
Epoch 7 Critic Valid Loss,Acc [-0.00017353195, 0.0] Critic Generated Loss,Acc [0.00033711304, 0.0] Generator Loss,Acc [-0.00017537423, 0.0]
Epoch 8 Critic Valid Loss,Acc [-0.00017059325, 0.0] Critic Generated Loss,Acc [0.0003263024, 0.0] Generator Loss,Acc [-0.00016974319, 0.0]
Epoch 9 Critic Valid Loss,Acc [-0.00017530039, 0.0] Critic Generated Loss,Acc [0.00032463064, 0.0] Generator Loss,Acc [-0.00017845634, 0.0]
Epoch 10 Critic Valid Loss,Acc [-0.00017530067, 0.0] Critic Generated Loss,Acc [0.00033131015, 0.0] Generator Loss,Acc [-0.00017526663, 0.0]
转载请注明出处:http://www.xgclsm.com/article/20230526/2401658.html