作者psw (ICK)
看板Python
标题[问题] 生成对抗训练程式
时间Tue Aug 29 09:54:24 2023
如题
这个程式训练一些照片
最後把训练的监别网路权重参数结果存在TESTgen/discriminator_weights.h5中
但後来要载入TESTgen/discriminator_weights.h5这个参数监别网路时却不断说discrimi
nator_weights.h5
里有问题
我打开discriminator_weights.h5中看起来是网路参数
跟float32浮点数格式
但要载入用来辨识其他照片时却说无法载入HTF5格式
我用的是tensrflow GPU
求跪强者们开示
谢谢
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten
from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
#缺Keras HDF5 格式
# 设定图像参数
img_rows = 28
img_cols = 28
channels = 1
# 设定生成器
def build_generator():
? ? noise_shape = (100,)
? ? model = Sequential()
? ? model.add(Dense(256, input_shape=noise_shape))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(512))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(1024))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(img_rows * img_cols * channels, activation='tanh'))
? ? model.add(Reshape((img_rows, img_cols, channels)))
? ? model.summary()
? ? noise = Input(shape=noise_shape)
? ? img = model(noise)
? ? return Model(noise, img)
# 设定监别器
def build_discriminator():
? ? model = Sequential()
? ? model.add(Flatten(input_shape=(img_rows, img_cols, channels)))
? ? model.add(Dense((img_rows * img_cols * channels), input_shape=(img_rows, i
mg_cols, channels)))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(Dense(int((img_rows * img_cols * channels) / 2)))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(Dense(1, activation='sigmoid'))
? ? model.summary()
? ? img = Input(shape=(img_rows, img_cols, channels))
? ? validity = model(img)
? ? return Model(img, validity)
# 设定生成器和对抗器
generator = build_generator()
discriminator = build_discriminator()
# 编译监别器
discriminator.compile(loss='binary_crossentropy',
? ? ? ? ? ? ? ? ? ? ? optimizer=Adam(0.0002, 0.5),
? ? ? ? ? ? ? ? ? ? ? metrics=['accuracy'])
# 建立结合模型
z = Input(shape=(100,))
img = generator(z)
discriminator.trainable = False
validity = discriminator(img)
combined = Model(z, validity)
combined.compile(loss='binary_crossentropy',
? ? ? ? ? ? ? ? 漑ptimizer=Adam(0.0002, 0.5))
# 载入并预处理MNIST资料集
(X_train, _), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
# 定义训练参数
epochs = 3000
batch_size = 128
save_interval = 100
# 定义图像标签
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
# 训练生成器和监别器
for epoch in range(epochs):
? ? # 训练监别器
? ? idx = np.random.randint(0, X_train.shape[0], batch_size)
? ? imgs = X_train[idx]
? ? noise = np.random.normal(0, 1, (batch_size, 100))
? ? gen_imgs = generator.predict(noise)
? ? d_loss_real = discriminator.train_on_batch(imgs, valid)
? ? d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
? ? d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
? ? # 训练生成器
? ? noise = np.random.normal(0, 1, (batch_size, 100))
? ? g_loss = combined.train_on_batch(noise, valid)
? ? # 显示训练进度
? ? if epoch % save_interval == 0:
? ? ? ? print(f"Epoch {epoch}/{epochs}, D loss: {d_loss[0]}, acc.: {100 * d_lo
ss[1]}, G loss: {g_loss}")
? ? ? ? # 显示生成的图像
? ? ? ? r, c = 2, 2
? ? ? ? noise = np.random.normal(0, 1, (r * c, 100))
? ? ? ? gen_imgs = generator.predict(noise)
? ? ? ? fig, axs = plt.subplots(r, c)
? ? ? ? cnt = 0
? ? ? ? for i in range(r):
? ? ? ? ? ? for j in range(c):
? ? ? ? ? ? ? ? axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
? ? ? ? ? ? ? ? axs[i, j].axis('off')
? ? ? ? ? ? ? ? cnt += 1
? ? ? ? plt.show()
? ?
# 将生成网路和监别器的参数保存到TESTgen资料夹中
os.makedirs("TESTgen", exist_ok=True)
generator.save_weights("TESTgen/generator_weights.h5")
discriminator.save_weights("TESTgen/discriminator_weights.h5", save_format="h5
")
with open("TESTgen.txt", "w") as f:
? ? f.write("Generator and discriminator parameters saved.")
print("训练完成并保存生成网路和监别器参数。") ? ?
? ?
?
? ? import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten
from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
# 汇入所需的库和模组
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten
from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
# 设定图像参数
img_rows = 28
img_cols = 28
channels = 1
# 设定生成器
def build_generator():
? ? noise_shape = (100,)
? ? model = Sequential()
? ? model.add(Dense(256, input_shape=noise_shape))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(512))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(1024))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(img_rows * img_cols * channels, activation='tanh'))
? ? model.add(Reshape((img_rows, img_cols, channels)))
? ? model.summary()
? ? noise = Input(shape=noise_shape)
? ? img = model(noise)
? ? return Model(noise, img)
# 建立生成器模型
def build_generator():
? ? noise_shape = (100,)
? ? model = Sequential()
? ? model.add(Dense(256, input_shape=noise_shape)) ?# 全连接层,输入是噪音
? ? model.add(LeakyReLU(alpha=0.2)) ?# LeakyReLU 激活函数
? ? model.add(BatchNormalization(momentum=0.8)) ?# BatchNormalization 正规化
? ? model.add(Dense(512))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(1024))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(img_rows * img_cols * channels, activation='tanh')) ?# 生
成器输出,使用 tanh 激活函数
? ? model.add(Reshape((img_rows, img_cols, channels))) ?# 重塑输出形状
? ? model.summary()
? ? noise = Input(shape=noise_shape) ?# 噪音输入
? ? img = model(noise) ?# 使用模型生成图像
? ? return Model(noise, img) ?# 返回噪音和生成图像模型
# 设定监别器
def build_discriminator():
? ? model = Sequential()
? ? model.add(Flatten(input_shape=(img_rows, img_cols, channels))) ?# 将图像展
平为一维
? ? model.add(Dense((img_rows * img_cols * channels), input_shape=(img_rows, i
mg_cols, channels))) ?# 全连接层
? ? model.add(LeakyReLU(alpha=0.2)) ?# LeakyReLU 激活函数
? ? model.add(Dense(int((img_rows * img_cols * channels) / 2)))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(Dense(1, activation='sigmoid')) ?# 预测真假的输出,使用 sigmoid
激活函数
? ? model.summary()
? ? img = Input(shape=(img_rows, img_cols, channels)) ?# 图像输入
? ? validity = model(img) ?# 使用模型判断真假
? ? return Model(img, validity) ?# 返回图像和判断真假模型
# 建立生成器和监别器
generator = build_generator() ?# 创建生成器模型
discriminator = build_discriminator() ?# 创建监别器模型
# 编译监别器
discriminator.compile(loss='binary_crossentropy',
? ? ? ? ? ? ? ? ? ? ? optimizer=Adam(0.0002, 0.5),
? ? ? ? ? ? ? ? ? ? ? metrics=['accuracy'])
# 建立结合模型
z = Input(shape=(100,))
img = generator(z)
discriminator.trainable = False ?# 在结合模型中,监别器权重冻结
validity = discriminator(img)
combined = Model(z, validity) ?# 创建结合模型,输入噪音,输出真假
combined.compile(loss='binary_crossentropy',
? ? ? ? ? ? ? ? 漑ptimizer=Adam(0.0002, 0.5))
# 载入并预处理MNIST资料集
(X_train, _), (_, _) = mnist.load_data() ?# 载入MNIST数据集
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 ?# 正规化数据到-1到1之
间
X_train = np.expand_dims(X_train, axis=3) ?# 增加一个维度(通道)
# 定义训练参数
epochs = 3000 ?# 训练迭代次数
batch_size = 128 ?# 批次大小
save_interval = 100 ?# 每隔多少个迭代保存模型
# 定义图像标签
valid = np.ones((batch_size, 1)) ?# 真实标签
fake = np.zeros((batch_size, 1)) ?# 假标签
# 训练生成器和监别器
for epoch in range(epochs):
? ? # 训练监别器
? ? idx = np.random.randint(0, X_train.shape[0], batch_size)
? ? imgs = X_train[idx] ?# 随机选取真实图像
? ? noise = np.random.normal(0, 1, (batch_size, 100)) ?#
?
(X_tr
X_tra
间
X_tra--
※ 发信站: 批踢踢实业坊(ptt.cc), 来自: 114.140.112.83 (台湾)
※ 文章网址: https://webptt.com/cn.aspx?n=bbs/Python/M.1693274066.A.760.html
※ 编辑: psw (114.140.112.83 台湾), 08/29/2023 09:55:19
1F:嘘 lycantrope: 问GPT,不经大脑复制贴上,也没写你是怎麽载入h508/29 10:52
来了
#打开.h5了
import h5py
import numpy as np
from keras.models import load_model
from PIL import Image
from keras.preprocessing import image
import cv2
# 1. 安装必要的库:Keras, h5py, PIL, OpenCV
# 2. 加载模型
model_filename = 'TESTgen/discriminator_weights.h5'
model = load_model(model_filename)
# 3. 图像预处理
def preprocess_image(image_path):
? ? img = Image.open(image_path).convert('L') ?# 将图像转换为灰度
? ? img = img.resize((28, 28)) ?# 调整图像大小为模型所需大小
? ? img = image.img_to_array(img)
? ? img = np.expand_dims(img, axis=0)
? ? img /= 255.0 ?# 像素标准化
? ? return img
# 4. 进行预测
def predict_image(image_path, model):
? ? img = preprocess_image(image_path)
? ? prediction = model.predict(img)
? ? return prediction
# 图像路径
image_path = 'TESTgen/P_20230623_133745.jpg'
# 进行预测
prediction = predict_image(image_path, model)
# 输出预测结果
print("Prediction:", prediction)
# 关闭模型
model.close()
※ 编辑: psw (36.239.156.85 台湾), 08/29/2023 11:34:35
2F:→ tsoahans: save_weights对应load_weights model.save对load_model08/29 15:04
感谢大神
3F:→ lycantrope: 同楼上,从model= build_discriminator()产生model後08/29 15:18
4F:→ lycantrope: model.load_weights才对08/29 15:19
感谢大神
※ 编辑: psw (27.53.112.160 台湾), 08/29/2023 18:05:57
5F:→ chang1248w: 超热心ww 09/14 22:58