当前位置: 首页 > news >正文

做网站具体收费/360应用商店

做网站具体收费,360应用商店,网站没被百度收录,济宁网架公司在上一篇博客PyTorch学习之路(level1)——训练一个图像分类模型中介绍了如何用PyTorch训练一个图像分类模型,建议先看懂那篇博客后再看这篇博客。在那份代码中,采用torchvision.datasets.ImageFolder这个接口来读取图像数据&#…

在上一篇博客PyTorch学习之路(level1)——训练一个图像分类模型中介绍了如何用PyTorch训练一个图像分类模型,建议先看懂那篇博客后再看这篇博客。在那份代码中,采用torchvision.datasets.ImageFolder这个接口来读取图像数据,该接口默认你的训练数据是按照一个类别存放在一个文件夹下。但是有些情况下你的图像数据不是这样维护的,比如一个文件夹下面各个类别的图像数据都有,同时用一个对应的标签文件,比如txt文件来维护图像和标签的对应关系,在这种情况下就不能用torchvision.datasets.ImageFolder来读取数据了,需要自定义一个数据读取接口。另外这篇博客最后还顺带介绍如何保存模型和多GPU训练。

怎么做呢?

先来看看torchvision.datasets.ImageFolder这个类是怎么写的,主要代码如下,想详细了解的可以看:官方github代码。

看起来很复杂,其实非常简单。继承的类是torch.utils.data.Dataset,主要包含三个方法:初始化__init__,获取图像__getitem__,数据集数量 __len____init__方法中先通过find_classes函数得到分类的类别名(classes)和类别名与数字类别的映射关系字典(class_to_idx)。然后通过make_dataset函数得到imags,这个imags是一个列表,其中每个值是一个tuple,每个tuple包含两个元素:图像路径和标签。剩下的就是一些赋值操作了。在__getitem__方法中最重要的就是 img = self.loader(path)这行,表示数据读取,可以从__init__方法中看出self.loader采用的是default_loader,这个default_loader的核心就是用python的PIL库的Image模块来读取图像数据。

class ImageFolder(data.Dataset):"""A generic data loader where the images are arranged in this way: ::root/dog/xxx.pngroot/dog/xxy.pngroot/dog/xxz.pngroot/cat/123.pngroot/cat/nsdf3.pngroot/cat/asd932_.pngArgs:root (string): Root directory path.transform (callable, optional): A function/transform that  takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it.loader (callable, optional): A function to load an image given its path.Attributes:classes (list): List of the class names.class_to_idx (dict): Dict with items (class_name, class_index).imgs (list): List of (image path, class_index) tuples"""def __init__(self, root, transform=None, target_transform=None,loader=default_loader):classes, class_to_idx = find_classes(root)imgs = make_dataset(root, class_to_idx)if len(imgs) == 0:raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n""Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))self.root = rootself.imgs = imgsself.classes = classesself.class_to_idx = class_to_idxself.transform = transformself.target_transform = target_transformself.loader = loaderdef __getitem__(self, index):"""Args:index (int): IndexReturns:tuple: (image, target) where target is class_index of the target class."""path, target = self.imgs[index]img = self.loader(path)if self.transform is not None:img = self.transform(img)if self.target_transform is not None:target = self.target_transform(target)return img, targetdef __len__(self):return len(self.imgs)

稍微看下default_loader函数,该函数主要分两种情况调用两个函数,一般采用pil_loader函数。

def pil_loader(path):with open(path, 'rb') as f:with Image.open(f) as img:return img.convert('RGB')def accimage_loader(path):import accimagetry:return accimage.Image(path)except IOError:# Potentially a decoding problem, fall back to PIL.Imagereturn pil_loader(path)def default_loader(path):from torchvision import get_image_backendif get_image_backend() == 'accimage':return accimage_loader(path)else:return pil_loader(path)

看懂了ImageFolder这个类,就可以自定义一个你自己的数据读取接口了。

首先在PyTorch中和数据读取相关的类基本都要继承一个基类:torch.utils.data.Dataset。然后再改写其中的__init____len____getitem__等方法即可

下面假设img_path是你的图像文件夹,该文件夹下面放了所有图像数据(包括训练和测试),然后txt_path下面放了train.txt和val.txt两个文件,txt文件中每行都是图像路径,tab键,标签。所以下面代码的__init__方法中self.img_name和self.img_label的读取方式就跟你数据的存放方式有关,你可以根据你实际数据的维护方式做调整。__getitem__方法没有做太大改动,依然采用default_loader方法来读取图像。最后在Transform中将每张图像都封装成Tensor。

class customData(Dataset):def __init__(self, img_path, txt_path, dataset = '', data_transforms=None, loader = default_loader):with open(txt_path) as input_file:lines = input_file.readlines()self.img_name = [os.path.join(img_path, line.strip().split('\t')[0]) for line in lines]self.img_label = [int(line.strip().split('\t')[-1]) for line in lines]self.data_transforms = data_transformsself.dataset = datasetself.loader = loaderdef __len__(self):return len(self.img_name)def __getitem__(self, item):img_name = self.img_name[item]label = self.img_label[item]img = self.loader(img_name)if self.data_transforms is not None:try:img = self.data_transforms[self.dataset](img)except:print("Cannot transform image: {}".format(img_name))return img, label

定义好了数据读取接口后,怎么用呢?

在代码中可以这样调用。

 image_datasets = {x: customData(img_path='/ImagePath',txt_path=('/TxtFile/' + x + '.txt'),data_transforms=data_transforms,dataset=x) for x in ['train', 'val']}

这样返回的image_datasets就和用torchvision.datasets.ImageFolder类返回的数据类型一样,有点狸猫换太子的感觉,这就是在第一篇博客中说的写代码类似搭积木的感觉。

有了image_datasets,然后依然用torch.utils.data.DataLoader类来做进一步封装,将这个batch的图像数据和标签都分别封装成Tensor。

 dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],batch_size=batch_size,shuffle=True) for x in ['train', 'val']}

另外,每次迭代生成的模型要怎么保存呢?非常简单,那就是用torch.save。输入就是你的模型和要保存的路径及模型名称,如果这个output文件夹没有,可以手动新建一个或者在代码里面新建。

torch.save(model, 'output/resnet_epoch{}.pkl'.format(epoch))

最后,关于多GPU的使用,PyTorch支持多GPU训练模型,假设你的网络是model,那么只需要下面一行代码(调用 torch.nn.DataParallel接口)就可以让后续的模型训练在0和1两块GPU上训练,加快训练速度。

 model = torch.nn.DataParallel(model, device_ids=[0,1])

完整代码请移步:Github

http://www.lbrq.cn/news/1450567.html

相关文章:

  • 无锡网站建设推广服务/seo公司哪家好用
  • 建网站需成本多少钱/竞价推广托管多少钱
  • 学网站开发首先学哪些基础/收录批量查询
  • win7电脑做网站/查指数
  • 杭州建网站/seo优化师
  • 设计相关网站/站长之家新网址
  • 个人网站申请空间/西安高端模板建站
  • 上海个体户注册代办/吉林seo外包
  • 百度做网站续费费用/会员制营销方案
  • 闲鱼网站建设费用/成都网站快速排名优化
  • 简易做海报网站/免费测试seo
  • 做淘宝网站代理/郑州seo优化
  • 网站建设经验典型/微信营销平台哪个好
  • 网站注册系统/爱站网站排行榜
  • 做库房推广哪个网站好/百度数据开放平台
  • 房地产做网站的意义/网站排名优化课程
  • 建筑网格布/网站推广的优化
  • 怎么做二级网站域名/腾讯搜索引擎入口
  • wordpress 重复内容/北京百度seo排名
  • 徐汇科技网站建设/seo优化教程自学网
  • 香港人做evus在哪个网站/网络营销推广计划
  • 建设购物网站/线下推广宣传方式有哪些
  • 张家港普通网站建设/域名服务器ip查询网站
  • wordpress 写文章页面/seo建站
  • 老域名怎么做新网站/现在百度推广有用吗
  • 如何制作单页网站/站外推广方式
  • 怎么做微商的微网站/全网网络营销
  • 目前网站开发有什么缺点/免费推广app平台有哪些
  • 上海房产做哪个网站好/创建软件平台该怎么做
  • 威海有名的做网站/免费制作网站
  • 【Linux】特效爆满的Vim的配置方法 and make/Makefile原理
  • Linux驱动25 --- RkMedia音频API使用增加 USB 音视频设备
  • GPT-5的诞生之痛:AI帝国的现实危机
  • 前端1.0
  • tc 介绍
  • 【工程化】tree-shaking 的作用以及配置