作者clairehuei (不是clairehuei 是桂)
看板DataScience
标题[问题] keras custom layer multiple output请教
时间Sat Jul 28 13:58:48 2018
大家好, 最近在尝试写一个keras的attention layer
为了能够视觉化, 所以我把 def call (self, x, mask=None):
的 return 值,改为list 内容类似 [outputs, 权重] 这样
以上在train的时候,都很顺利, 视觉化效果也能够透过权重来呈现
但是, 当我把train好的model save下来, 然後 再次 load_model()时
会报一个神奇的错误:
File
"C:\Users\Allen\Anaconda3\envs\tensorflow\lib\site-packages\keras\engine\topology.py",
line 719, in _add_inbound_node
output_tensors[i]._keras_shape = output_shapes[i]
IndexError: list index out of range
我在load_model的时候,有加入custom_objects的参数设定,把我的attention layer引入
使用的keras 版本为 2.1.5
python版本为 3.5
想请教版上大家有没有遇过类似的情形,有相关的解法吗QQ?
--
※ 发信站: 批踢踢实业坊(ptt.cc), 来自: 220.130.131.58
※ 文章网址: https://webptt.com/cn.aspx?n=bbs/DataScience/M.1532757530.A.AAA.html
1F:推 goldflower: 你的custom layer直接呼叫get_weights不行吗@@? 07/28 15:54
2F:→ clairehuei: 可是我需要的是layer的outputs 不是weight @@ 07/28 16:02
刚才异想天开试了一个方法, 我想说,既然load_model之後 他只会丢出一个output
那我就在call() 里面,自己把 output 跟 权重的array append起来 再return
(之後再自己切割取要的部分)
因为过程当中 我有用到reshape, 实验结果更惨, 一样在train的时候都正常
load_model() 的时候, 又报错, 说 call() 里面的 'output'物件 是tensor
没有 .reshape() 方法 囧rz...
※ 编辑: clairehuei (220.130.131.58), 07/28/2018 16:09:09
3F:推 goldflower: 因为我猜问题应该在回传的weight没连到任意的output 07/28 16:40
4F:→ goldflower: 因此存model时会有问题 所以我是指你call的weights就 07/28 16:40
5F:→ goldflower: 用get weights来拿然後回传计算完的output就好 不过 07/28 16:40
6F:→ goldflower: 你说要output就好那不就可拿掉吗 是不是我搞错什麽QQ 07/28 16:40
7F:推 germun: 留weight就好 model直接重建 07/29 04:28
8F:→ clairehuei: 留weights, model重建, 再载入weights 的确可行@.@ 07/29 21:58
9F:→ clairehuei: 感谢各位大大指点 <(_"_)> 07/29 21:58