pytorch 6 batch_train 批训练操作
看代码吧~
import torch
import torch.utils.data as Data
torch.manual_seed(1) # reproducible
# BATCH_SIZE = 5
BATCH_SIZE = 8 # 每次使用8个数据同时传入网路
x = torch.linspace(1, 10, 10) # this is x data (torch tensor)
y = torch.linspace(10, 1, 10) # this is y data (torch tensor)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
dataset=torch_dataset, # torch TensorDataset format
batch_size=BATCH_SIZE, # mini batch size
shuffle=False, # 设置不随机打乱数据 random shuffle for training
num_workers=2, # 使用两个进程提取数据,subprocesses for loading data
)
def show_batch():
for epoch in range(3): # 全部的数据使用3遍,train entire dataset 3 times
for step, (batch_x, batch_y) in enumerate(loader): # for each training step
# train your data...
print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
batch_x.numpy(), '| batch y: ', batch_y.numpy())
if __name__ == '__main__':
show_batch()
BATCH_SIZE = 8 , 所有数据利用三次
Epoch: 0 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]
Epoch: 0 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]
Epoch: 1 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]
Epoch: 1 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]
Epoch: 2 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]
Epoch: 2 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]
补充:pytorch批训练bug
问题描述:
在进行pytorch神经网络批训练的时候,有时会出现报错
TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'torch.autograd.variable.Variable'>
解决办法:
第一步:
检查(重点!!!!!):
train_dataset = Data.TensorDataset(train_x, train_y)
train_x,和train_y格式,要求是tensor类,我第一次出错就是因为传入的是variable
可以这样将数据变为tensor类:
train_x = torch.FloatTensor(train_x)
第二步:
train_loader = Data.DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True
)
实例化一个DataLoader对象
第三步:
for epoch in range(epochs):
for step, (batch_x, batch_y) in enumerate(train_loader):
batch_x, batch_y = Variable(batch_x), Variable(batch_y)
这样就可以批训练了
需要注意的是:train_loader输出的是tensor,在训练网络时,需要变成Variable
(资源库 www.zyku.net)
原文链接:https://www.cnblogs.com/yangzhaonan/p/10439839.html
栏 目:Python教程
下一篇:pytorch 如何把图像数据集进行划分成train,test和val
本文标题:pytorch 6 batch_train 批训练操作
本文地址:https://www.zyku.net/python/9815.html
您可能感兴趣的文章
- 02-08python munch库的使用解析
- 05-07华为mate40pro皮套操作开启教程
- 07-05Linux ulimit命令
- 10-11realmegtneo2如何关闭夜间自动更新
- 02-18一加9pro开启LHDC流式传输教程
- 09-23B站动态删除教程介绍
- 12-28超田智趣+-超田智趣+应用软件功能介绍
- 01-16电子日程表-电子日程表应用软件功能介
- 03-07荣耀50pro设置动态壁纸教程
- 02-27帝国CMS调用包含指定关键词文章列表的
- 03-04oppoencoair设置游戏模式教程
- 01-13FLOW冥想-FLOW冥想应用软件功能介绍
- 12-16一加7pro专业拍照模式在哪里
- 02-26帝国CMS用PHP代码实现灵动标签的技巧
- 05-11FastAdmin 在 IIS 环境下伪静态如何配
- 01-13中医执业助理题库-中医执业助理题库应
- 03-20csv导入mysql中文乱码等问题解决方法
- 01-11全众云物业-全众云物业应用软件功能介
- 12-26编程绘画-编程绘画应用软件功能介绍
- 01-12嘉御健康-嘉御健康应用软件功能介绍
最近更新
阅读排行
猜你喜欢
- 02-09python 命令行传参方法总结
- 01-10云小团-云小团应用软件功能介绍
- 03-03小米11青春版开启隐藏应用教程
- 01-12Anna Card安娜请柬-Anna Card安娜请柬
- 02-20制作网页中设计段落缩进的方法
- 10-09公考雷达如何进行职位匹配
- 02-10Pytorch GPU内存占用很高,但是利用率
- 01-28小米note9全面屏手势关闭方法
- 11-25小米手机应用智能省电功能在哪
- 02-24小米10s开启与关闭桌面透明壁纸教程