Python 板


LINE

如题 这个程式训练一些照片 最後把训练的监别网路权重参数结果存在TESTgen/discriminator_weights.h5中 但後来要载入TESTgen/discriminator_weights.h5这个参数监别网路时却不断说discrimi nator_weights.h5 里有问题 我打开discriminator_weights.h5中看起来是网路参数 跟float32浮点数格式 但要载入用来辨识其他照片时却说无法载入HTF5格式 我用的是tensrflow GPU 求跪强者们开示 谢谢 import os import numpy as np import matplotlib.pyplot as plt from tensorflow.keras.datasets import mnist from tensorflow.keras.layers import Input, Dense, Reshape, Flatten from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU from tensorflow.keras.models import Sequential, Model from tensorflow.keras.optimizers import Adam #缺Keras HDF5 格式 # 设定图像参数 img_rows = 28 img_cols = 28 channels = 1 # 设定生成器 def build_generator(): ? ? noise_shape = (100,) ? ? model = Sequential() ? ? model.add(Dense(256, input_shape=noise_shape)) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(BatchNormalization(momentum=0.8)) ? ? model.add(Dense(512)) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(BatchNormalization(momentum=0.8)) ? ? model.add(Dense(1024)) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(BatchNormalization(momentum=0.8)) ? ? model.add(Dense(img_rows * img_cols * channels, activation='tanh')) ? ? model.add(Reshape((img_rows, img_cols, channels))) ? ? model.summary() ? ? noise = Input(shape=noise_shape) ? ? img = model(noise) ? ? return Model(noise, img) # 设定监别器 def build_discriminator(): ? ? model = Sequential() ? ? model.add(Flatten(input_shape=(img_rows, img_cols, channels))) ? ? model.add(Dense((img_rows * img_cols * channels), input_shape=(img_rows, i mg_cols, channels))) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(Dense(int((img_rows * img_cols * channels) / 2))) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(Dense(1, activation='sigmoid')) ? ? model.summary() ? ? img = Input(shape=(img_rows, img_cols, channels)) ? ? validity = model(img) ? ? return Model(img, validity) # 设定生成器和对抗器 generator = build_generator() discriminator = build_discriminator() # 编译监别器 discriminator.compile(loss='binary_crossentropy', ? ? ? ? ? ? ? ? ? ? ? optimizer=Adam(0.0002, 0.5), ? ? ? ? ? ? ? ? ? ? ? metrics=['accuracy']) # 建立结合模型 z = Input(shape=(100,)) img = generator(z) discriminator.trainable = False validity = discriminator(img) combined = Model(z, validity) combined.compile(loss='binary_crossentropy', ? ? ? ? ? ? ? ? 漑ptimizer=Adam(0.0002, 0.5)) # 载入并预处理MNIST资料集 (X_train, _), (_, _) = mnist.load_data() X_train = (X_train.astype(np.float32) - 127.5) / 127.5 X_train = np.expand_dims(X_train, axis=3) # 定义训练参数 epochs = 3000 batch_size = 128 save_interval = 100 # 定义图像标签 valid = np.ones((batch_size, 1)) fake = np.zeros((batch_size, 1)) # 训练生成器和监别器 for epoch in range(epochs): ? ? # 训练监别器 ? ? idx = np.random.randint(0, X_train.shape[0], batch_size) ? ? imgs = X_train[idx] ? ? noise = np.random.normal(0, 1, (batch_size, 100)) ? ? gen_imgs = generator.predict(noise) ? ? d_loss_real = discriminator.train_on_batch(imgs, valid) ? ? d_loss_fake = discriminator.train_on_batch(gen_imgs, fake) ? ? d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) ? ? # 训练生成器 ? ? noise = np.random.normal(0, 1, (batch_size, 100)) ? ? g_loss = combined.train_on_batch(noise, valid) ? ? # 显示训练进度 ? ? if epoch % save_interval == 0: ? ? ? ? print(f"Epoch {epoch}/{epochs}, D loss: {d_loss[0]}, acc.: {100 * d_lo ss[1]}, G loss: {g_loss}") ? ? ? ? # 显示生成的图像 ? ? ? ? r, c = 2, 2 ? ? ? ? noise = np.random.normal(0, 1, (r * c, 100)) ? ? ? ? gen_imgs = generator.predict(noise) ? ? ? ? fig, axs = plt.subplots(r, c) ? ? ? ? cnt = 0 ? ? ? ? for i in range(r): ? ? ? ? ? ? for j in range(c): ? ? ? ? ? ? ? ? axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray') ? ? ? ? ? ? ? ? axs[i, j].axis('off') ? ? ? ? ? ? ? ? cnt += 1 ? ? ? ? plt.show() ? ? # 将生成网路和监别器的参数保存到TESTgen资料夹中 os.makedirs("TESTgen", exist_ok=True) generator.save_weights("TESTgen/generator_weights.h5") discriminator.save_weights("TESTgen/discriminator_weights.h5", save_format="h5 ") with open("TESTgen.txt", "w") as f: ? ? f.write("Generator and discriminator parameters saved.") print("训练完成并保存生成网路和监别器参数。") ? ? ? ? ? ? ? import os import numpy as np import matplotlib.pyplot as plt from tensorflow.keras.datasets import mnist from tensorflow.keras.layers import Input, Dense, Reshape, Flatten from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU from tensorflow.keras.models import Sequential, Model from tensorflow.keras.optimizers import Adam # 汇入所需的库和模组 import os import numpy as np import matplotlib.pyplot as plt from tensorflow.keras.datasets import mnist from tensorflow.keras.layers import Input, Dense, Reshape, Flatten from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU from tensorflow.keras.models import Sequential, Model from tensorflow.keras.optimizers import Adam # 设定图像参数 img_rows = 28 img_cols = 28 channels = 1 # 设定生成器 def build_generator(): ? ? noise_shape = (100,) ? ? model = Sequential() ? ? model.add(Dense(256, input_shape=noise_shape)) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(BatchNormalization(momentum=0.8)) ? ? model.add(Dense(512)) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(BatchNormalization(momentum=0.8)) ? ? model.add(Dense(1024)) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(BatchNormalization(momentum=0.8)) ? ? model.add(Dense(img_rows * img_cols * channels, activation='tanh')) ? ? model.add(Reshape((img_rows, img_cols, channels))) ? ? model.summary() ? ? noise = Input(shape=noise_shape) ? ? img = model(noise) ? ? return Model(noise, img) # 建立生成器模型 def build_generator(): ? ? noise_shape = (100,) ? ? model = Sequential() ? ? model.add(Dense(256, input_shape=noise_shape)) ?# 全连接层,输入是噪音 ? ? model.add(LeakyReLU(alpha=0.2)) ?# LeakyReLU 激活函数 ? ? model.add(BatchNormalization(momentum=0.8)) ?# BatchNormalization 正规化 ? ? model.add(Dense(512)) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(BatchNormalization(momentum=0.8)) ? ? model.add(Dense(1024)) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(BatchNormalization(momentum=0.8)) ? ? model.add(Dense(img_rows * img_cols * channels, activation='tanh')) ?# 生 成器输出,使用 tanh 激活函数 ? ? model.add(Reshape((img_rows, img_cols, channels))) ?# 重塑输出形状 ? ? model.summary() ? ? noise = Input(shape=noise_shape) ?# 噪音输入 ? ? img = model(noise) ?# 使用模型生成图像 ? ? return Model(noise, img) ?# 返回噪音和生成图像模型 # 设定监别器 def build_discriminator(): ? ? model = Sequential() ? ? model.add(Flatten(input_shape=(img_rows, img_cols, channels))) ?# 将图像展 平为一维 ? ? model.add(Dense((img_rows * img_cols * channels), input_shape=(img_rows, i mg_cols, channels))) ?# 全连接层 ? ? model.add(LeakyReLU(alpha=0.2)) ?# LeakyReLU 激活函数 ? ? model.add(Dense(int((img_rows * img_cols * channels) / 2))) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(Dense(1, activation='sigmoid')) ?# 预测真假的输出,使用 sigmoid 激活函数 ? ? model.summary() ? ? img = Input(shape=(img_rows, img_cols, channels)) ?# 图像输入 ? ? validity = model(img) ?# 使用模型判断真假 ? ? return Model(img, validity) ?# 返回图像和判断真假模型 # 建立生成器和监别器 generator = build_generator() ?# 创建生成器模型 discriminator = build_discriminator() ?# 创建监别器模型 # 编译监别器 discriminator.compile(loss='binary_crossentropy', ? ? ? ? ? ? ? ? ? ? ? optimizer=Adam(0.0002, 0.5), ? ? ? ? ? ? ? ? ? ? ? metrics=['accuracy']) # 建立结合模型 z = Input(shape=(100,)) img = generator(z) discriminator.trainable = False ?# 在结合模型中,监别器权重冻结 validity = discriminator(img) combined = Model(z, validity) ?# 创建结合模型,输入噪音,输出真假 combined.compile(loss='binary_crossentropy', ? ? ? ? ? ? ? ? 漑ptimizer=Adam(0.0002, 0.5)) # 载入并预处理MNIST资料集 (X_train, _), (_, _) = mnist.load_data() ?# 载入MNIST数据集 X_train = (X_train.astype(np.float32) - 127.5) / 127.5 ?# 正规化数据到-1到1之 间 X_train = np.expand_dims(X_train, axis=3) ?# 增加一个维度(通道) # 定义训练参数 epochs = 3000 ?# 训练迭代次数 batch_size = 128 ?# 批次大小 save_interval = 100 ?# 每隔多少个迭代保存模型 # 定义图像标签 valid = np.ones((batch_size, 1)) ?# 真实标签 fake = np.zeros((batch_size, 1)) ?# 假标签 # 训练生成器和监别器 for epoch in range(epochs): ? ? # 训练监别器 ? ? idx = np.random.randint(0, X_train.shape[0], batch_size) ? ? imgs = X_train[idx] ?# 随机选取真实图像 ? ? noise = np.random.normal(0, 1, (batch_size, 100)) ?# ? (X_tr X_tra 间 X_tra--



※ 发信站: 批踢踢实业坊(ptt.cc), 来自: 114.140.112.83 (台湾)
※ 文章网址: https://webptt.com/cn.aspx?n=bbs/Python/M.1693274066.A.760.html ※ 编辑: psw (114.140.112.83 台湾), 08/29/2023 09:55:19
1F:嘘 lycantrope: 问GPT,不经大脑复制贴上,也没写你是怎麽载入h508/29 10:52
来了 #打开.h5了 import h5py import numpy as np from keras.models import load_model from PIL import Image from keras.preprocessing import image import cv2 # 1. 安装必要的库:Keras, h5py, PIL, OpenCV # 2. 加载模型 model_filename = 'TESTgen/discriminator_weights.h5' model = load_model(model_filename) # 3. 图像预处理 def preprocess_image(image_path): ? ? img = Image.open(image_path).convert('L') ?# 将图像转换为灰度 ? ? img = img.resize((28, 28)) ?# 调整图像大小为模型所需大小 ? ? img = image.img_to_array(img) ? ? img = np.expand_dims(img, axis=0) ? ? img /= 255.0 ?# 像素标准化 ? ? return img # 4. 进行预测 def predict_image(image_path, model): ? ? img = preprocess_image(image_path) ? ? prediction = model.predict(img) ? ? return prediction # 图像路径 image_path = 'TESTgen/P_20230623_133745.jpg' # 进行预测 prediction = predict_image(image_path, model) # 输出预测结果 print("Prediction:", prediction) # 关闭模型 model.close() ※ 编辑: psw (36.239.156.85 台湾), 08/29/2023 11:34:35
2F:→ tsoahans: save_weights对应load_weights model.save对load_model08/29 15:04
感谢大神
3F:→ lycantrope: 同楼上,从model= build_discriminator()产生model後08/29 15:18
4F:→ lycantrope: model.load_weights才对08/29 15:19
感谢大神 ※ 编辑: psw (27.53.112.160 台湾), 08/29/2023 18:05:57
5F:→ chang1248w: 超热心ww 09/14 22:58







like.gif 您可能会有兴趣的文章
icon.png[问题/行为] 猫晚上进房间会不会有憋尿问题
icon.pngRe: [闲聊] 选了错误的女孩成为魔法少女 XDDDDDDDDDD
icon.png[正妹] 瑞典 一张
icon.png[心得] EMS高领长版毛衣.墨小楼MC1002
icon.png[分享] 丹龙隔热纸GE55+33+22
icon.png[问题] 清洗洗衣机
icon.png[寻物] 窗台下的空间
icon.png[闲聊] 双极の女神1 木魔爵
icon.png[售车] 新竹 1997 march 1297cc 白色 四门
icon.png[讨论] 能从照片感受到摄影者心情吗
icon.png[狂贺] 贺贺贺贺 贺!岛村卯月!总选举NO.1
icon.png[难过] 羡慕白皮肤的女生
icon.png阅读文章
icon.png[黑特]
icon.png[问题] SBK S1安装於安全帽位置
icon.png[分享] 旧woo100绝版开箱!!
icon.pngRe: [无言] 关於小包卫生纸
icon.png[开箱] E5-2683V3 RX480Strix 快睿C1 简单测试
icon.png[心得] 苍の海贼龙 地狱 执行者16PT
icon.png[售车] 1999年Virage iO 1.8EXi
icon.png[心得] 挑战33 LV10 狮子座pt solo
icon.png[闲聊] 手把手教你不被桶之新手主购教学
icon.png[分享] Civic Type R 量产版官方照无预警流出
icon.png[售车] Golf 4 2.0 银色 自排
icon.png[出售] Graco提篮汽座(有底座)2000元诚可议
icon.png[问题] 请问补牙材质掉了还能再补吗?(台中半年内
icon.png[问题] 44th 单曲 生写竟然都给重复的啊啊!
icon.png[心得] 华南红卡/icash 核卡
icon.png[问题] 拔牙矫正这样正常吗
icon.png[赠送] 老莫高业 初业 102年版
icon.png[情报] 三大行动支付 本季掀战火
icon.png[宝宝] 博客来Amos水蜡笔5/1特价五折
icon.pngRe: [心得] 新鲜人一些面试分享
icon.png[心得] 苍の海贼龙 地狱 麒麟25PT
icon.pngRe: [闲聊] (君の名は。雷慎入) 君名二创漫画翻译
icon.pngRe: [闲聊] OGN中场影片:失踪人口局 (英文字幕)
icon.png[问题] 台湾大哥大4G讯号差
icon.png[出售] [全国]全新千寻侘草LED灯, 水草

请输入看板名称,例如:BabyMother站内搜寻

TOP