音效素材网提供各类素材,打造精品素材网站!

站内导航 站长工具 投稿中心 手机访问

音效素材

基于PyTorch实现一个简单的CNN图像分类器
日期:2021-09-08 14:31:26   来源:脚本之家

pytorch中文网:https://www.pytorchtutorial.com/
pytorch官方文档:https://pytorch.org/docs/stable/index.html

一. 加载数据

Pytorch的数据加载一般是用torch.utils.data.Dataset与torch.utils.data.Dataloader两个类联合进行。我们需要继承Dataset来定义自己的数据集类,然后在训练时用Dataloader加载自定义的数据集类。

1. 继承Dataset类并重写关键方法

pytorch的dataset类有两种:Map-style datasets和Iterable-style datasets。前者是我们常用的结构,而后者是当数据集难以(或不可能)进行随机读取时使用。在这里我们实现Map-style dataset。
继承torch.utils.data.Dataset后,需要重写的方法有:__len__与__getitem__方法,其中__len__方法需要返回所有数据的数量,而__getitem__则是要依照给出的数据索引获取对应的tensor类型的Sample,除了这两个方法以外,一般还需要实现__init__方法来初始化一些变量。话不多说,直接上代码。

'''
包括了各种数据集的读取处理,以及图像相关处理方法
'''
from torch.utils.data import Dataset
import torch
import os
import cv2
from Config import mycfg
import random
import numpy as np


class ImageClassifyDataset(Dataset):
    def __init__(self, imagedir, labelfile, classify_num, train=True):
    	'''
    	这里进行一些初始化操作。
    	'''
        self.imagedir = imagedir
        self.labelfile = labelfile
        self.classify_num = classify_num
        self.img_list = []
        # 读取标签
        with open(self.labelfile, 'r') as fp:
            lines = fp.readlines()
            for line in lines:
                filepath = os.path.join(self.imagedir, line.split(";")[0].replace('\\', '/'))
                label = line.split(";")[1].strip('\n')
                self.img_list.append((filepath, label))
        if not train:
            self.img_list = random.sample(self.img_list, 50)

    def __len__(self):
        return len(self.img_list)
        
    def __getitem__(self, item):
	    '''
	    这个函数是关键,通过item(索引)来取数据集中的数据,
	    一般来说在这里才将图像数据加载入内存,之前存的是图像的保存路径
	    '''
        _int_label = int(self.img_list[item][1])	# label直接用0,1,2,3,4...表示不同类别
        label = torch.tensor(_int_label,dtype=torch.long)
        img = self.ProcessImgResize(self.img_list[item][0])
        return img, label

    def ProcessImgResize(self, filename):
    	'''
    	对图像进行一些预处理
    	'''
        _img = cv2.imread(filename)
        _img = cv2.resize(_img, (mycfg.IMG_WIDTH, mycfg.IMG_HEIGHT), interpolation=cv2.INTER_CUBIC)
        _img = _img.transpose((2, 0, 1))
        _img = _img / 255
        _img = torch.from_numpy(_img)
        _img = _img.to(torch.float32)
        return _img

有一些的数据集类一般还会传入一个transforms函数来构造一个图像预处理序列,传入transforms函数的一个好处是作为参数传入的话可以对一些非本地数据集中的数据进行操作(比如直接通过torchvision获取的一些预存数据集CIFAR10等等),除此之外就是torchvision.transforms里面有一些预定义的图像操作函数,可以直接像拼积木一样拼成一个图像处理序列,很方便。我这里因为是用我自己下载到本地的数据集,而且比较简单就直接用自己的函数来操作了。

2. 使用Dataloader加载数据

实例化自定义的数据集类ImageClassifyDataset后,将其传给DataLoader作为参数,得到一个可遍历的数据加载器。可以通过参数batch_size控制批处理大小,shuffle控制是否乱序读取,num_workers控制用于读取数据的线程数量。

from torch.utils.data import DataLoader
from MyDataset import ImageClassifyDataset

dataset = ImageClassifyDataset(imagedir, labelfile, 10)
dataloader = DataLoader(dataset, batch_size=5, shuffle=True,num_workers=5)
for index, data in enumerate(dataloader):
	print(index)	# batch索引
	print(data)		# 一个batch的{img,label}

二. 模型设计

在这里只讨论深度学习模型的设计,pytorch中的网络结构是一层一层叠出来的,pytorch中预定义了许多可以通过参数控制的网络层结构,比如Linear、CNN、RNN、Transformer等等具体可以查阅官方文档中的torch.nn部分。
设计自己的模型结构需要继承torch.nn.Module这个类,然后实现其中的forward方法,一般在__init__中设定好网络模型的一些组件,然后在forward方法中依据输入输出顺序拼装组件。

'''
包括了各种模型、自定义的loss计算方法、optimizer
'''
import torch.nn as nn


class Simple_CNN(nn.Module):
    def __init__(self, class_num):
        super(Simple_CNN, self).__init__()
        self.class_num = class_num
        self.conv1 = nn.Sequential(
            nn.Conv2d(		# input: 3,400,600
                in_channels=3,
                out_channels=8,
                kernel_size=5,
                stride=1,
                padding=2
            ),
            nn.Conv2d(
                in_channels=8,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=2
            ),
            nn.AvgPool2d(2),  # 16,400,600 --> 16,200,300
            nn.BatchNorm2d(16),
            nn.LeakyReLU(),
            nn.Conv2d(
                in_channels=16,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=2
            ),
            nn.Conv2d(
                in_channels=16,
                out_channels=8,
                kernel_size=5,
                stride=1,
                padding=2
            ),
            nn.AvgPool2d(2),  # 8,200,300 --> 8,100,150
            nn.BatchNorm2d(8),
            nn.LeakyReLU(),
            nn.Conv2d(
                in_channels=8,
                out_channels=8,
                kernel_size=3,
                stride=1,
                padding=1
            ),
            nn.Conv2d(
                in_channels=8,
                out_channels=1,
                kernel_size=3,
                stride=1,
                padding=1
            ),
            nn.AvgPool2d(2),  # 1,100,150 --> 1,50,75
            nn.BatchNorm2d(1),
            nn.LeakyReLU()
        )
        self.line = nn.Sequential(
            nn.Linear(
                in_features=50 * 75,
                out_features=self.class_num
            ),
            nn.Softmax()
        )

    def forward(self, x):
        x = self.conv1(x)
        x = x.view(-1, 50 * 75)
        y = self.line(x)
        return y

上面我定义的模型中包括卷积组件conv1和全连接组件line,卷积组件中包括了一些卷积层,一般是按照{卷积层、池化层、激活函数}的顺序拼接,其中我还在激活函数之前添加了一个BatchNorm2d层对上层的输出进行正则化以免传入激活函数的值过小(梯度消失)或过大(梯度爆炸)。
在拼接组件时,由于我全连接层的输入是一个一维向量,所以需要将卷积组件中最后的50 × 75 50\times 7550×75大小的矩阵展平成一维的再传入全连接层(x.view(-1,50*75))

三. 训练

实例化模型后,网络模型的训练需要定义损失函数与优化器,损失函数定义了网络输出与标签的差距,依据不同的任务需要定义不同的合适的损失函数,而优化器则定义了神经网络中的参数如何基于损失来更新,目前神经网络最常用的优化器就是SGD(随机梯度下降算法) 及其变种。
在我这个简单的分类器模型中,直接用的多分类任务最常用的损失函数CrossEntropyLoss()以及优化器SGD。

self.cnnmodel = Simple_CNN(mycfg.CLASS_NUM)
self.criterion = nn.CrossEntropyLoss()	# 交叉熵,标签应该是0,1,2,3...的形式而不是独热的
self.optimizer = optim.SGD(self.cnnmodel.parameters(), lr=mycfg.LEARNING_RATE, momentum=0.9)

训练过程其实很简单,使用dataloader依照batch读出数据后,将input放入网络模型中计算得到网络的输出,然后基于标签通过损失函数计算Loss,并将Loss反向传播回神经网络(在此之前需要清理上一次循环时的梯度),最后通过优化器更新权重。训练部分代码如下:

for each_epoch in range(mycfg.MAX_EPOCH):
            running_loss = 0.0
            self.cnnmodel.train()
            for index, data in enumerate(self.dataloader):
                inputs, labels = data
                outputs = self.cnnmodel(inputs)
                loss = self.criterion(outputs, labels)

                self.optimizer.zero_grad()	# 清理上一次循环的梯度
                loss.backward()	# 反向传播
                self.optimizer.step()	# 更新参数
                running_loss += loss.item()
                if index % 200 == 199:
                    print("[{}] loss: {:.4f}".format(each_epoch, running_loss/200))
                    running_loss = 0.0
            # 保存每一轮的模型
            model_name = 'classify-{}-{}.pth'.format(each_epoch,round(all_loss/all_index,3))
            torch.save(self.cnnmodel,model_name)	# 保存全部模型

四. 测试

测试和训练的步骤差不多,也就是读取模型后通过dataloader获取数据然后将其输入网络获得输出,但是不需要进行反向传播的等操作了。比较值得注意的可能就是准确率计算方面有一些小技巧。

acc = 0.0
count = 0
self.cnnmodel = torch.load('mymodel.pth')
self.cnnmodel.eval()
for index, data in enumerate(dataloader_eval):
	inputs, labels = data   # 5,3,400,600  5,10
	count += len(labels)
	outputs = cnnmodel(inputs)
	_,predict = torch.max(outputs, 1)
	acc += (labels == predict).sum().item()
print("[{}] accurancy: {:.4f}".format(each_epoch, acc / count))

我这里采用的是保存全部模型并加载全部模型的方法,这种方法的好处是在使用模型时可以完全将其看作一个黑盒,但是在模型比较大时这种方法会很费事。此时可以采用只保存参数不保存网络结构的方法,在每一次使用模型时需要读取参数赋值给已经实例化的模型:

torch.save(cnnmodel.state_dict(), "my_resnet.pth")
cnnmodel = Simple_CNN()
cnnmodel.load_state_dict(torch.load("my_resnet.pth"))

结语

至此整个流程就说完了,是一个小白级的图像分类任务流程,因为前段时间一直在做android方面的事,所以有点生疏了,就写了这篇博客记录一下,之后应该还会写一下seq2seq以及image caption任务方面的模型构造与训练过程,完整代码之后也会统一放到github上给大家做参考。

以上就是基于PyTorch实现一个简单的CNN图像分类器的详细内容,更多关于PyTorch实现CNN图像分类器的资料请关注其它相关文章!

    您感兴趣的教程

    在docker中安装mysql详解

    本篇文章主要介绍了在docker中安装mysql详解,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编...

    详解 安装 docker mysql

    win10中文输入法仅在桌面显示怎么办?

    win10中文输入法仅在桌面显示怎么办?

    win10系统使用搜狗,QQ输入法只有在显示桌面的时候才出来,在使用其他程序输入框里面却只能输入字母数字,win10中...

    win10 中文输入法

    一分钟掌握linux系统目录结构

    这篇文章主要介绍了linux系统目录结构,通过结构图和多张表格了解linux系统目录结构,感兴趣的小伙伴们可以参考一...

    结构 目录 系统 linux

    PHP程序员玩转Linux系列 Linux和Windows安装

    这篇文章主要为大家详细介绍了PHP程序员玩转Linux系列文章,Linux和Windows安装nginx教程,具有一定的参考价值,感兴趣...

    玩转 程序员 安装 系列 PHP

    win10怎么安装杜比音效Doby V4.1 win10安装杜

    第四代杜比®家庭影院®技术包含了一整套协同工作的技术,让PC 发出清晰的环绕声同时第四代杜比家庭影院技术...

    win10杜比音效

    纯CSS实现iOS风格打开关闭选择框功能

    这篇文章主要介绍了纯CSS实现iOS风格打开关闭选择框,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作...

    css ios c

    Win7如何给C盘扩容 Win7系统电脑C盘扩容的办法

    Win7如何给C盘扩容 Win7系统电脑C盘扩容的

    Win7给电脑C盘扩容的办法大家知道吗?当系统分区C盘空间不足时,就需要给它扩容了,如果不管,C盘没有足够的空间...

    Win7 C盘 扩容

    百度推广竞品词的投放策略

    SEM是基于关键词搜索的营销活动。作为推广人员,我们所做的工作,就是打理成千上万的关键词,关注它们的质量度...

    百度推广 竞品词

    Visual Studio Code(vscode) git的使用教程

    这篇文章主要介绍了详解Visual Studio Code(vscode) git的使用,小编觉得挺不错的,现在分享给大家,也给大家做个参考。...

    教程 Studio Visual Code git

    七牛云储存创始人分享七牛的创立故事与

    这篇文章主要介绍了七牛云储存创始人分享七牛的创立故事与对Go语言的应用,七牛选用Go语言这门新兴的编程语言进行...

    七牛 Go语言

    Win10预览版Mobile 10547即将发布 9月19日上午

    微软副总裁Gabriel Aul的Twitter透露了 Win10 Mobile预览版10536即将发布,他表示该版本已进入内部慢速版阶段,发布时间目...

    Win10 预览版

    HTML标签meta总结,HTML5 head meta 属性整理

    移动前端开发中添加一些webkit专属的HTML5头部标签,帮助浏览器更好解析HTML代码,更好地将移动web前端页面表现出来...

    移动端html5模拟长按事件的实现方法

    这篇文章主要介绍了移动端html5模拟长按事件的实现方法的相关资料,小编觉得挺不错的,现在分享给大家,也给大家...

    移动端 html5 长按

    HTML常用meta大全(推荐)

    这篇文章主要介绍了HTML常用meta大全(推荐),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参...

    cdr怎么把图片转换成位图? cdr图片转换为位图的教程

    cdr怎么把图片转换成位图? cdr图片转换为

    cdr怎么把图片转换成位图?cdr中插入的图片想要转换成位图,该怎么转换呢?下面我们就来看看cdr图片转换为位图的...

    cdr 图片 位图

    win10系统怎么录屏?win10系统自带录屏详细教程

    win10系统怎么录屏?win10系统自带录屏详细

    当我们是使用win10系统的时候,想要录制电脑上的画面,这时候有人会想到下个第三方软件,其实可以用电脑上的自带...

    win10 系统自带录屏 详细教程

    + 更多教程 +
    ASP编程JSP编程PHP编程.NET编程python编程