欢迎来到资源库(www.zyku.net)

Python教程

当前位置:首页 > 网络编程 > Python教程 > batch_train

pytorch 6 batch_train 批训练操作

时间:2022-02-11|栏目:Python教程|点击:|我要投稿

看代码吧~

import torch
import torch.utils.data as Data
torch.manual_seed(1)    # reproducible
# BATCH_SIZE = 5  
BATCH_SIZE = 8      # 每次使用8个数据同时传入网路
x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)
y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=False,              # 设置不随机打乱数据 random shuffle for training
    num_workers=2,              # 使用两个进程提取数据,subprocesses for loading data
)
def show_batch():
    for epoch in range(3):   # 全部的数据使用3遍,train entire dataset 3 times
        for step, (batch_x, batch_y) in enumerate(loader):  # for each training step
            # train your data...
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())
if __name__ == '__main__':
    show_batch()

BATCH_SIZE = 8 , 所有数据利用三次

Epoch:  0 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  0 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  1 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  1 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  2 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  2 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]

补充:pytorch批训练bug

问题描述:

在进行pytorch神经网络批训练的时候,有时会出现报错 

TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'torch.autograd.variable.Variable'>

解决办法:

第一步:

检查(重点!!!!!):

train_dataset = Data.TensorDataset(train_x, train_y)

train_x,和train_y格式,要求是tensor类,我第一次出错就是因为传入的是variable

可以这样将数据变为tensor类:

train_x = torch.FloatTensor(train_x)

第二步:

train_loader = Data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

实例化一个DataLoader对象

第三步:

    for epoch in range(epochs):
        for step, (batch_x, batch_y) in enumerate(train_loader):
            batch_x, batch_y = Variable(batch_x), Variable(batch_y)

这样就可以批训练了

需要注意的是:train_loader输出的是tensor,在训练网络时,需要变成Variable

(资源库 www.zyku.net)

原文链接:https://www.cnblogs.com/yangzhaonan/p/10439839.html

上一篇:520必备!这些Python表白代码祝你脱单成功

栏    目:Python教程

下一篇:pytorch 如何把图像数据集进行划分成train,test和val

本文标题:pytorch 6 batch_train 批训练操作

本文地址:https://www.zyku.net/python/9815.html

关于我们 | 版权申明 | 寻求合作 |

重要申明:本站所有的文章、图片、评论等内容,均由网友发表或上传并维护或收集自网络,仅供个人学习交流使用,版权归原作者所有。

如有侵犯您的版权,请与我们联系,我们将在24小时内进行处理、任何非本站因素导致的法律后果,本站均不负任何责任。

联系QQ:95148658 | 邮箱:mb8#qq.com(#换成@)

苏ICP备2020066115号-1

本网站由提供CDN加速/云存储服务