Python 板


LINE

如題 這個程式訓練一些照片 最後把訓練的鑑別網路權重參數結果存在TESTgen/discriminator_weights.h5中 但後來要載入TESTgen/discriminator_weights.h5這個參數鑑別網路時卻不斷說discrimi nator_weights.h5 裡有問題 我打開discriminator_weights.h5中看起來是網路參數 跟float32浮點數格式 但要載入用來辨識其他照片時卻說無法載入HTF5格式 我用的是tensrflow GPU 求跪強者們開示 謝謝 import os import numpy as np import matplotlib.pyplot as plt from tensorflow.keras.datasets import mnist from tensorflow.keras.layers import Input, Dense, Reshape, Flatten from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU from tensorflow.keras.models import Sequential, Model from tensorflow.keras.optimizers import Adam #缺Keras HDF5 格式 # 設定圖像參數 img_rows = 28 img_cols = 28 channels = 1 # 設定生成器 def build_generator(): ? ? noise_shape = (100,) ? ? model = Sequential() ? ? model.add(Dense(256, input_shape=noise_shape)) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(BatchNormalization(momentum=0.8)) ? ? model.add(Dense(512)) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(BatchNormalization(momentum=0.8)) ? ? model.add(Dense(1024)) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(BatchNormalization(momentum=0.8)) ? ? model.add(Dense(img_rows * img_cols * channels, activation='tanh')) ? ? model.add(Reshape((img_rows, img_cols, channels))) ? ? model.summary() ? ? noise = Input(shape=noise_shape) ? ? img = model(noise) ? ? return Model(noise, img) # 設定鑑別器 def build_discriminator(): ? ? model = Sequential() ? ? model.add(Flatten(input_shape=(img_rows, img_cols, channels))) ? ? model.add(Dense((img_rows * img_cols * channels), input_shape=(img_rows, i mg_cols, channels))) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(Dense(int((img_rows * img_cols * channels) / 2))) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(Dense(1, activation='sigmoid')) ? ? model.summary() ? ? img = Input(shape=(img_rows, img_cols, channels)) ? ? validity = model(img) ? ? return Model(img, validity) # 設定生成器和對抗器 generator = build_generator() discriminator = build_discriminator() # 編譯鑑別器 discriminator.compile(loss='binary_crossentropy', ? ? ? ? ? ? ? ? ? ? ? optimizer=Adam(0.0002, 0.5), ? ? ? ? ? ? ? ? ? ? ? metrics=['accuracy']) # 建立結合模型 z = Input(shape=(100,)) img = generator(z) discriminator.trainable = False validity = discriminator(img) combined = Model(z, validity) combined.compile(loss='binary_crossentropy', ? ? ? ? ? ? ? ? 漑ptimizer=Adam(0.0002, 0.5)) # 載入並預處理MNIST資料集 (X_train, _), (_, _) = mnist.load_data() X_train = (X_train.astype(np.float32) - 127.5) / 127.5 X_train = np.expand_dims(X_train, axis=3) # 定義訓練參數 epochs = 3000 batch_size = 128 save_interval = 100 # 定義圖像標籤 valid = np.ones((batch_size, 1)) fake = np.zeros((batch_size, 1)) # 訓練生成器和鑑別器 for epoch in range(epochs): ? ? # 訓練鑑別器 ? ? idx = np.random.randint(0, X_train.shape[0], batch_size) ? ? imgs = X_train[idx] ? ? noise = np.random.normal(0, 1, (batch_size, 100)) ? ? gen_imgs = generator.predict(noise) ? ? d_loss_real = discriminator.train_on_batch(imgs, valid) ? ? d_loss_fake = discriminator.train_on_batch(gen_imgs, fake) ? ? d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) ? ? # 訓練生成器 ? ? noise = np.random.normal(0, 1, (batch_size, 100)) ? ? g_loss = combined.train_on_batch(noise, valid) ? ? # 顯示訓練進度 ? ? if epoch % save_interval == 0: ? ? ? ? print(f"Epoch {epoch}/{epochs}, D loss: {d_loss[0]}, acc.: {100 * d_lo ss[1]}, G loss: {g_loss}") ? ? ? ? # 顯示生成的圖像 ? ? ? ? r, c = 2, 2 ? ? ? ? noise = np.random.normal(0, 1, (r * c, 100)) ? ? ? ? gen_imgs = generator.predict(noise) ? ? ? ? fig, axs = plt.subplots(r, c) ? ? ? ? cnt = 0 ? ? ? ? for i in range(r): ? ? ? ? ? ? for j in range(c): ? ? ? ? ? ? ? ? axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray') ? ? ? ? ? ? ? ? axs[i, j].axis('off') ? ? ? ? ? ? ? ? cnt += 1 ? ? ? ? plt.show() ? ? # 將生成網路和鑑別器的參數保存到TESTgen資料夾中 os.makedirs("TESTgen", exist_ok=True) generator.save_weights("TESTgen/generator_weights.h5") discriminator.save_weights("TESTgen/discriminator_weights.h5", save_format="h5 ") with open("TESTgen.txt", "w") as f: ? ? f.write("Generator and discriminator parameters saved.") print("訓練完成並保存生成網路和鑑別器參數。") ? ? ? ? ? ? ? import os import numpy as np import matplotlib.pyplot as plt from tensorflow.keras.datasets import mnist from tensorflow.keras.layers import Input, Dense, Reshape, Flatten from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU from tensorflow.keras.models import Sequential, Model from tensorflow.keras.optimizers import Adam # 匯入所需的庫和模組 import os import numpy as np import matplotlib.pyplot as plt from tensorflow.keras.datasets import mnist from tensorflow.keras.layers import Input, Dense, Reshape, Flatten from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU from tensorflow.keras.models import Sequential, Model from tensorflow.keras.optimizers import Adam # 設定圖像參數 img_rows = 28 img_cols = 28 channels = 1 # 設定生成器 def build_generator(): ? ? noise_shape = (100,) ? ? model = Sequential() ? ? model.add(Dense(256, input_shape=noise_shape)) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(BatchNormalization(momentum=0.8)) ? ? model.add(Dense(512)) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(BatchNormalization(momentum=0.8)) ? ? model.add(Dense(1024)) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(BatchNormalization(momentum=0.8)) ? ? model.add(Dense(img_rows * img_cols * channels, activation='tanh')) ? ? model.add(Reshape((img_rows, img_cols, channels))) ? ? model.summary() ? ? noise = Input(shape=noise_shape) ? ? img = model(noise) ? ? return Model(noise, img) # 建立生成器模型 def build_generator(): ? ? noise_shape = (100,) ? ? model = Sequential() ? ? model.add(Dense(256, input_shape=noise_shape)) ?# 全連接層,輸入是噪音 ? ? model.add(LeakyReLU(alpha=0.2)) ?# LeakyReLU 激活函數 ? ? model.add(BatchNormalization(momentum=0.8)) ?# BatchNormalization 正規化 ? ? model.add(Dense(512)) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(BatchNormalization(momentum=0.8)) ? ? model.add(Dense(1024)) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(BatchNormalization(momentum=0.8)) ? ? model.add(Dense(img_rows * img_cols * channels, activation='tanh')) ?# 生 成器輸出,使用 tanh 激活函數 ? ? model.add(Reshape((img_rows, img_cols, channels))) ?# 重塑輸出形狀 ? ? model.summary() ? ? noise = Input(shape=noise_shape) ?# 噪音輸入 ? ? img = model(noise) ?# 使用模型生成圖像 ? ? return Model(noise, img) ?# 返回噪音和生成圖像模型 # 設定鑑別器 def build_discriminator(): ? ? model = Sequential() ? ? model.add(Flatten(input_shape=(img_rows, img_cols, channels))) ?# 將圖像展 平為一維 ? ? model.add(Dense((img_rows * img_cols * channels), input_shape=(img_rows, i mg_cols, channels))) ?# 全連接層 ? ? model.add(LeakyReLU(alpha=0.2)) ?# LeakyReLU 激活函數 ? ? model.add(Dense(int((img_rows * img_cols * channels) / 2))) ? ? model.add(LeakyReLU(alpha=0.2)) ? ? model.add(Dense(1, activation='sigmoid')) ?# 預測真假的輸出,使用 sigmoid 激活函數 ? ? model.summary() ? ? img = Input(shape=(img_rows, img_cols, channels)) ?# 圖像輸入 ? ? validity = model(img) ?# 使用模型判斷真假 ? ? return Model(img, validity) ?# 返回圖像和判斷真假模型 # 建立生成器和鑑別器 generator = build_generator() ?# 創建生成器模型 discriminator = build_discriminator() ?# 創建鑑別器模型 # 編譯鑑別器 discriminator.compile(loss='binary_crossentropy', ? ? ? ? ? ? ? ? ? ? ? optimizer=Adam(0.0002, 0.5), ? ? ? ? ? ? ? ? ? ? ? metrics=['accuracy']) # 建立結合模型 z = Input(shape=(100,)) img = generator(z) discriminator.trainable = False ?# 在結合模型中,鑑別器權重凍結 validity = discriminator(img) combined = Model(z, validity) ?# 創建結合模型,輸入噪音,輸出真假 combined.compile(loss='binary_crossentropy', ? ? ? ? ? ? ? ? 漑ptimizer=Adam(0.0002, 0.5)) # 載入並預處理MNIST資料集 (X_train, _), (_, _) = mnist.load_data() ?# 載入MNIST數據集 X_train = (X_train.astype(np.float32) - 127.5) / 127.5 ?# 正規化數據到-1到1之 間 X_train = np.expand_dims(X_train, axis=3) ?# 增加一個維度(通道) # 定義訓練參數 epochs = 3000 ?# 訓練迭代次數 batch_size = 128 ?# 批次大小 save_interval = 100 ?# 每隔多少個迭代保存模型 # 定義圖像標籤 valid = np.ones((batch_size, 1)) ?# 真實標籤 fake = np.zeros((batch_size, 1)) ?# 假標籤 # 訓練生成器和鑑別器 for epoch in range(epochs): ? ? # 訓練鑑別器 ? ? idx = np.random.randint(0, X_train.shape[0], batch_size) ? ? imgs = X_train[idx] ?# 隨機選取真實圖像 ? ? noise = np.random.normal(0, 1, (batch_size, 100)) ?# ? (X_tr X_tra 間 X_tra--



※ 發信站: 批踢踢實業坊(ptt.cc), 來自: 114.140.112.83 (臺灣)
※ 文章網址: https://webptt.com/m.aspx?n=bbs/Python/M.1693274066.A.760.html ※ 編輯: psw (114.140.112.83 臺灣), 08/29/2023 09:55:19
1F:噓 lycantrope: 問GPT,不經大腦複製貼上,也沒寫你是怎麼載入h508/29 10:52
來了 #打開.h5了 import h5py import numpy as np from keras.models import load_model from PIL import Image from keras.preprocessing import image import cv2 # 1. 安装必要的库:Keras, h5py, PIL, OpenCV # 2. 加载模型 model_filename = 'TESTgen/discriminator_weights.h5' model = load_model(model_filename) # 3. 图像预处理 def preprocess_image(image_path): ? ? img = Image.open(image_path).convert('L') ?# 将图像转换为灰度 ? ? img = img.resize((28, 28)) ?# 调整图像大小为模型所需大小 ? ? img = image.img_to_array(img) ? ? img = np.expand_dims(img, axis=0) ? ? img /= 255.0 ?# 像素标准化 ? ? return img # 4. 进行预测 def predict_image(image_path, model): ? ? img = preprocess_image(image_path) ? ? prediction = model.predict(img) ? ? return prediction # 图像路径 image_path = 'TESTgen/P_20230623_133745.jpg' # 进行预测 prediction = predict_image(image_path, model) # 输出预测结果 print("Prediction:", prediction) # 关闭模型 model.close() ※ 編輯: psw (36.239.156.85 臺灣), 08/29/2023 11:34:35
2F:→ tsoahans: save_weights對應load_weights model.save對load_model08/29 15:04
感謝大神
3F:→ lycantrope: 同樓上,從model= build_discriminator()產生model後08/29 15:18
4F:→ lycantrope: model.load_weights才對08/29 15:19
感謝大神 ※ 編輯: psw (27.53.112.160 臺灣), 08/29/2023 18:05:57
5F:→ chang1248w: 超熱心ww 09/14 22:58







like.gif 您可能會有興趣的文章
icon.png[問題/行為] 貓晚上進房間會不會有憋尿問題
icon.pngRe: [閒聊] 選了錯誤的女孩成為魔法少女 XDDDDDDDDDD
icon.png[正妹] 瑞典 一張
icon.png[心得] EMS高領長版毛衣.墨小樓MC1002
icon.png[分享] 丹龍隔熱紙GE55+33+22
icon.png[問題] 清洗洗衣機
icon.png[尋物] 窗台下的空間
icon.png[閒聊] 双極の女神1 木魔爵
icon.png[售車] 新竹 1997 march 1297cc 白色 四門
icon.png[討論] 能從照片感受到攝影者心情嗎
icon.png[狂賀] 賀賀賀賀 賀!島村卯月!總選舉NO.1
icon.png[難過] 羨慕白皮膚的女生
icon.png閱讀文章
icon.png[黑特]
icon.png[問題] SBK S1安裝於安全帽位置
icon.png[分享] 舊woo100絕版開箱!!
icon.pngRe: [無言] 關於小包衛生紙
icon.png[開箱] E5-2683V3 RX480Strix 快睿C1 簡單測試
icon.png[心得] 蒼の海賊龍 地獄 執行者16PT
icon.png[售車] 1999年Virage iO 1.8EXi
icon.png[心得] 挑戰33 LV10 獅子座pt solo
icon.png[閒聊] 手把手教你不被桶之新手主購教學
icon.png[分享] Civic Type R 量產版官方照無預警流出
icon.png[售車] Golf 4 2.0 銀色 自排
icon.png[出售] Graco提籃汽座(有底座)2000元誠可議
icon.png[問題] 請問補牙材質掉了還能再補嗎?(台中半年內
icon.png[問題] 44th 單曲 生寫竟然都給重複的啊啊!
icon.png[心得] 華南紅卡/icash 核卡
icon.png[問題] 拔牙矯正這樣正常嗎
icon.png[贈送] 老莫高業 初業 102年版
icon.png[情報] 三大行動支付 本季掀戰火
icon.png[寶寶] 博客來Amos水蠟筆5/1特價五折
icon.pngRe: [心得] 新鮮人一些面試分享
icon.png[心得] 蒼の海賊龍 地獄 麒麟25PT
icon.pngRe: [閒聊] (君の名は。雷慎入) 君名二創漫畫翻譯
icon.pngRe: [閒聊] OGN中場影片:失蹤人口局 (英文字幕)
icon.png[問題] 台灣大哥大4G訊號差
icon.png[出售] [全國]全新千尋侘草LED燈, 水草

請輸入看板名稱,例如:Boy-Girl站內搜尋

TOP