Easy-Classification-手写数字识别

科技网编2023-08-16 11:392140

1.背景

Easy-Classification是一个应用于分类任务的深度学习框架,它集成了众多成熟的分类神经网络模型,可帮助使用者简单快速的构建分类训练任务。

案例源代码:https://github.com/wuya11/easy-classification
Easy-Classification框架介绍:https://www.cnblogs.com/wlandwl/p/deep_learn_class.html

本例基于Easy-Classification框架,快速搭建一个手写数字识别训练任务。项目整体目录如下:

  • 任务输入:一系列手写数字图片,其中每张图片都是28x28的像素矩阵。
  • 任务输出:经过了大小归一化和居中处理,输出对应的0~9的数字标签。

2.数字识别

  1. MNIST数据集是深度学习领域标准,易用的成熟数据集。数据集下载地址:https://github.com/zalandoresearch/fashion-mnist。

手写数字由6万个训练样本和1万个测试样本组成,每个样本都是一张28*28像素的灰度手写数字图片。
data包含三个元素的列表:train_set、val_set、 test_set,包括50 000条训练样本、10000条验证样本、10000条测试样本。每个样本包含手写数字图片和对应的标签。

  • train_set(训练集):用于确定模型参数。
  • val_set(验证集):用于调节模型超参数(如多个网络结构、正则化权重的最优选择)。(验证是否过拟合)
  • test_set(测试集):用于估计应用效果(没有在模型中应用过的数据,更贴近模型在真实场景应用的效果)。

train_set包含两个元素的列表:train_images、train_labels。

  • train_images:[50000, 784]的二维列表,包含50000张图片。每张图片用一个长度为784的向量表示,内容是28*28尺寸的像素灰度值(黑白图片)。
  • train_labels:[50000, ]的列表,表示这些图片对应的分类标签,即0~9之间的一个数字。

下载相关的压刷文件并放在data目录下:

2.2 生成训练数据

在项目根目录下新建data目录用于放置训练集,测试集,验证集数据。执行项目中scripts/make_fashionmnist.py脚本,解压文件,最终得到神经网络的训练数据,参考如图:


说明:

  1. 标签生成为目录,每个目录里面为具体的数字图片。比如8目录的图片均是手写数字为8的图片。
  1. 每个图像解析后,size为28*28。(若后续模型的入参需求为224*224,可以在此处调整图像大小。

2.3 编写训练脚本

训练过程需编写配置文件,自定义DateSet数据加载类,训练过程脚本类。详情请参考对应目录下实现源码。

自定义DateSe部分核心代码说明:

"""
编写自定义Dataset类时,初始化参数需定义为source_img, cfg。否则数据加载通用模块,data_load_service.py模块会报错。

source_img :传入的图像地址信息集合

cfg:传入的配置类信息,针对不同的任务,可能生成的label模式不同,可基于配置类指定label的加载模式,最终为训练的图像初始化label (用户自定义实现)

本例为 简单首先数字10分类(0-9):基于文件夹名称(路径)做label
"""

class TrainDataset(Dataset):
    """
    构建一个 加载原始图片的dataSet对象

    此函数可加载 训练集数据,基于文件夹名称获取label

    若 验证集逻辑与训练集逻辑一样,验证集可使用TrainDataset,不同,则需自定义一个,参考如下EvalDataset
    """

def __init__(self, source_img, cfg):
        self.source_img = source_img
        self.cfg = cfg
        self.transform = createTransform(cfg, TrainImgDeal)
        self.label_dict = getLabels(cfg['train_path'], source_img)

    def __getitem__(self, index):
        img = cv2.imread(self.source_img[index])
        if self.transform is not None:
            img = self.transform(img)
        target = self.label_dict[self.source_img[index]]
        return img, target, self.source_img[index]

    def __len__(self):
        return len(self.source_img)
        

class TrainImgDeal:
    def __init__(self, cfg):
        img_size = cfg['target_img_size']
        self.h = img_size[0]
        self.w = img_size[1]

    def __call__(self, img):
        # 本次图像处理,大小调整
        img = cv2.resize(img, (self.h, self.w))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(img)
        return img

def getLabels(label_path, source_img):
    cate_dirs = os.listdir(label_path)
    cate_dirs.sort()
    label_dict = {}
    for i, img_path in enumerate(source_img):
        img_dirs = img_path.replace(label_path, '')
        img_dirs = img_dirs.split(os.sep)[:2]
        img_dir = img_dirs[0] if img_dirs[0] else img_dirs[1]
        y = cate_dirs.index(img_dir)
        label_dict[img_path] = y
    return label_dict
            
def createTransform(cfg, img_deal):
    """
    将图像转换为张量对象
:param cfg: 配置文件对象
:param img_deal: 图像增强处理类(如大小,旋转,高斯函数等处理)
:return: 返回张量信息
    """
# 不同的网络模型,适配获取归一性值
    my_normalize = NormalizeAdapter.getNormalize(cfg['model_name'])
    transform = transforms.Compose([
        img_deal(cfg),
        transforms.ToTensor(),
        my_normalize,
    ])
    return transform

2.4 训练结果展示

训练结果会输出到out目录,输出信息包括acc,loss的过程图,最优训练权重文件。本次训练基于预训练权重文件,输出结果如下:

2.5 预测应用

编写预测类脚本,在配置文件中,配置model_path(为训练好的权重文件路径如:

'model_path':'output/mobilenetv3_e20_0.9800.pth'

),加载预测数据,模型预测后将结果输出到csv文件中。预测代码参考如下:

def predict(cfg):
    initConfig(cfg)
    model = ModelService(cfg)
    data = DataLoadService(cfg)
    test_loader = data.getPredictDataloader(PredictDataset)
    runner = MnistRunnerService(cfg, model)
    modelLoad(model, cfg['model_path'])
    res_dict = runner.predict(test_loader)
    print(len(res_dict))
    # to csv
    res_df = pd.DataFrame.from_dict(res_dict, orient='index', columns=['label'])
    res_df = res_df.reset_index().rename(columns={'index': 'image_id'})
    res_df.to_csv(os.path.join(cfg['save_dir'], 'pre.csv'),
                  index=False, header=True)

if __name__ == '__main__':
    predict(cfg)

模型预测结果输出如下:

随机抽查参与预测的四个图像信息如下:

最终输出预测结果如下:编号为10,1015图片预测分类与实际情况一样。编号为1084,1121的图片预测结果与实际结果不一样。

喜欢请赞赏一下啦^_^

微信赞赏

支付宝赞赏

评论区