作者wang19980531 (中立评论员)
看板DataScience
标题[问题] batch_size图片大小不一
时间Fri Apr 17 23:34:46 2020
作业系统:win10
问题类别:batch size, gradient, triplet loss
使用工具:python, pytorch
问题描述:
当把图片整批丢入模型时,
要求必须整批图片的size必须相同,
这我能理解 它是[batch size, h, l, w]形式传入
网上查到两种解决方法:
1. resize()成统一尺寸,但小弟的图片不巧图片大小差异很大(有旋转..等),长宽有的
甚至是转置。担心会不会resize後失真太多?
2. batch_size设为1,但这存在一个问题,我的模型设定的损失函数是计算triplet loss,
如果batch_size=1 没办法丢到loss function做计算,到下个batch模型重算时,上次计算
出来的gradient是不是就消失了?(我实验是模型根本找不到梯度)所以想请教大家有没有
保留梯度的方法?
谢谢
--
※ 发信站: 批踢踢实业坊(ptt.cc), 来自: 117.19.228.161 (台湾)
※ 文章网址: https://webptt.com/cn.aspx?n=bbs/DataScience/M.1587137688.A.EC0.html
※ 编辑: wang19980531 (117.19.228.161 台湾), 04/17/2020 23:38:38
1F:→ followwar: pytorch是 b,c,h,w04/18 00:02
2F:→ followwar: 1. 你可以考虑用crop04/18 00:02
3F:→ followwar: 2.你没计算loss没有backward 不会有gradient04/18 00:03
我的架构大概是
for each图片:
y=model(x)
append(y_,y)
loss=triplet_loss(y_,labels)
opt.zero_grad()
loss.backward()
opt.step()
原本将图片这批丢到model(就是append先,回圈外才做model)是可行的
但只有一张图片的y也没有办法算loss跟gradient;
append在一起为何会算不出来(loss降不了)
4F:→ yoyololicon: 做gradient accumulation ㄅ04/18 14:03
5F:推 sxy67230: 1.resize可调整插值方法,试试看是否还有严重失真,或是04/18 14:49
6F:→ sxy67230: crop,或是直接在Convolution layer跟Fully connected l04/18 14:49
7F:→ sxy67230: ayer之间塞入一些特殊池化层,因为尺度固定是来自於FCL04/18 14:49
8F:→ sxy67230: 层便於计算而采用的。04/18 14:49
我没有fully connected layer喔 但是一个batch丢入model算必须要[n,c,l,w] l*w的图片
大小要一致 不然只能一张一张丢 可是如上 会发生找不出梯度的问题?
9F:推 sxy67230: 2. 你直接印出损失值看看是否是损失函数出来的值是否有04/18 14:51
10F:→ sxy67230: 问题,或是你没有更新参数自然没有梯度。04/18 14:51
11F:推 world4jason: gradient accumulation遇到BN会GG 我最近也在思考这04/18 19:57
12F:→ world4jason: 个问题04/18 19:57
13F:推 illegalplan: Padding, crop一起做 就当作data augmentation04/19 14:54
padding似乎不失为一个好方法 谢谢提议
※ 编辑: wang19980531 (117.19.228.161 台湾), 04/20/2020 07:30:48
感谢大家的意见,抱歉这麽晚才做回应!
※ 编辑: wang19980531 (117.19.228.161 台湾), 04/20/2020 07:34:46
14F:→ followwar: 你的triplet loss的用法对吗? 应该是04/20 11:53
15F:→ followwar: (anchor, positive, negative)04/20 11:54
我用的是别人写好的hard都triplet loss
应该是online的就是只要给一串label和向量他会自己去算loss
※ 编辑: wang19980531 (61.221.242.34 台湾), 04/20/2020 12:10:01
16F:推 sxy67230: 如果你是用online triplet记得要去检查一下target是要 04/20 14:32
17F:→ sxy67230: 有重复类别,不然输出矩阵计算後会变成nan。另外还有tri 04/20 14:32
18F:→ sxy67230: plets loss计算是采用欧式距离,可能要注意压缩出来的 04/20 14:32
19F:→ sxy67230: 结果 04/20 14:32
20F:推 janus7799: 长边resize到固定长度,短边padding 04/21 21:01
21F:→ diabolica: padding试试看 04/22 21:46