博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
pytorch加载数据
阅读量:3938 次
发布时间:2019-05-23

本文共 2538 字,大约阅读时间需要 8 分钟。

数据集

Batch(每次使用一个)和随机梯度下降(每次使用一个样本)

如果我们使用Batch我们可能会遇见鞍点,使得我们的优化停滞不前。但是我们可以利用数组计算的优势来加速我们的运算时间。

如果我们使用单个数据的随机梯度下降,那么我们将可以避免鞍点,但是我们一个一个计算我们很难充分利用cpu和gpu并行处理数据的一个能力,使得运算时间边长。
所以我们为了解决这个问题应当引入mini_batch

mini_batch

这里有三个概念:

epoch:我们到底有多少数据
batch-size:分的那些小块每个是多大的
iteration:分成了多少个组
代码具体如下:

for epoch in range(training_epoch):    for i in range(total_batch)

DataLoader

首先数据集必须是可以索引的(可以用角标访问)、知道具体长度的。

里面有两个重要的参数bactch_size=这个是每个batch是多大的意思、还有一个是shuffle=true就是我们是否需要打乱的问题

具体代码

import numpy as npimport torch as torchfrom torch.utils.data import Datasetfrom torch.utils.data import DataLoader#这里我们注意这个Dataset是一个抽象类,是不可以被实例化的只可以被继承#继承之后我们要注意,我们需要实例化这些函数接口。class DiabetesDataset(Dataset):    def __init__(self,filepath):       xy = np.loadtxt(fillepath,delimiter=',',dtype=np.float32)       #想要搞明白这个代码要先搞明白shape的作用       # 这个shape实际上是产生一个元组{n,9}代表着n行9列,所以我们就能理解为什么这样可以获得长度了。       self.len=xy.shape[0]              #想要顺利的理解这个引入的范围首先我们先明确这个-1是       self.x_data=torch.from_numpy(xy[:,:-1])       self.y_data=torch.from_numpy(xy[:,[-1]])          #这里有两个实现方法:    # 1.全部在init的时候读进来,然后在这里输出    # 2.放在对应的文件里面,形成一个文件列表到时候再去取。    #这里主要是取决于数据集的大小的问题,例如图片数据集每个都很大。    #这里主要是为了节约我们有限的内存    def __getitem__(self,index):        return self.x_data[index],self.y_data[index]        #注意这里返回的是一个元组    def __len__(self):        return(self.len)dataset =DiabetesDataset('diabetes.csv.gz')trian_loader=DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=2)class LogisticRegressionModel(torch.nn.Module):    def __init__(self):        super(LogisticRegressionModel,self).__init__()        self.linear=torch.nn.Linear(1,1)        #因为sigmoid是没有任何参数的,所以我们不需要额外定义参数。        def forward(self,x):        y_pred=func.sigmoid(self.linear(x))        #这里其实就是在定义forward函数的时候增加了一部sigmoid的计算。        return y_predcriterion = torch.nn.BCELoss(size_average=True)model= LogisticRegressionModel()#这里的bce是交叉熵损失函数的问题#这里求均值的时候其实还会影响到学习率的选择问题,因为是不是求平均会影响梯度的大小opitmizer =torch.optim.SGD(model.parameters(),lr=0.01)#最后这个numberworkers是设置我们取数据的时候有几个并行的线程,# 这里我们要结合我们的电脑的具体情况看看我们这是是能给到多少。# #这里是为了处理C语言多线程的接口在windows中和Linux下不同的问题。if__name__=='__main__':    for epoch in range(100):        for i,data in enumerate(trian_loader,0):            inputs,labels=data            #这里我们注意一个事情这个data是一个元组所以传过来也是用一个元组,            # 另外这个DataLoader会自动帮我们把输出调整为一个Tensor,            # 都是向量运算所以下面的内容并没有太大区别。            y_pred=model(inputs)            loss=criterion(y_pred,labels)            opitmizer.zero_grad()            loss.backward()            opitmizer.step()#class module(torch.nn.Module):

转载地址:http://ykywi.baihongyu.com/

你可能感兴趣的文章
PHP中文件读写操作
查看>>
php开发常识b_01
查看>>
PHP单例模式
查看>>
PHP项目设计
查看>>
memcache的安装及管理
查看>>
git 传输
查看>>
创建新项目
查看>>
印刷工艺- 喷墨印刷
查看>>
印刷工艺流程
查看>>
印刷业ERP启蒙
查看>>
Java8 Lambda表达式使用集合(笔记)
查看>>
Java魔法师Unsafe
查看>>
spring cloud java.lang.NoClassDefFoundError: javax/servlet/http/HttpServletRequest
查看>>
Centos系统安装MySQL(整理)
查看>>
postgresql计算两点距离(经纬度地理位置)
查看>>
postgres多边形存储--解决 Points of LinearRing do not form a closed linestring
查看>>
postgresql+postgis空间数据库总结
查看>>
spring 之 Http Cache 和 Etag(转)
查看>>
基于Lucene查询原理分析Elasticsearch的性能(转)
查看>>
HttpClient请求外部服务器NoHttpResponseException
查看>>