VGG16-CF-VGG11实验报告

时间:2024-04-14 22:52:38

说明:VGG16和CF-VGG11是论文《A 3D Fluorescence Classification and Component Prediction Method Based on VGG Convolutional Neural Network and PARAFAC Analysis Method》使用的两种主要模型。其对应代码仓库提供了实验使用的数据集、平行因子分析结果和CNN模型。论文和代码仓库是本文实验使用的基本材料。

目录

  • 论文摘要
  • 数据集信息
  • 环境配置
  • 分类实验(工作目录:代码仓库/VGG16)
    • 修改代码
    • 实验流程
    • 实验结果
    • train.py笔记
      • 数据流
  • 组分拟合实验(工作目录:代码仓库/CF-VGG11)
    • 修改代码
    • 实验流程
    • 实验结果
    • 笔记
      • train.py数据流分析
      • FVGG11.py
      • FAlexNet.py
      • SimpleCNN.py

论文摘要

  • 三维荧光的研究目前主要采用平行因子分析(PARAFAC)、荧光区域积分(FRI)和主成分分析(PCA)等方法。
  • 目前结合卷积神经网络(CNN)的研究也很多,但在CNN与三维荧光分析相结合的方法中,还没有一种方法被认为是最有效的
  • 本文在已有研究基础上,从实际环境中采集了一些样品进行三维荧光数据的测量,并从互联网中获得了一批公开数据集
  • 首先对数据进行预处理(包括PARAFAC分析和CNN数据集生成两步),然后提出了基于VGG16和VGG11卷积神经网络的三维荧光分类方法和分量拟合方法
  • 使用VGG16网络对三维荧光数据进行分类,训练准确率为99.6%(与PCA + SVM方法同样准确)。
  • 对于分量图拟合网络,我们综合比较了改进的LeNet网络、改进的AlexNet网络和改进的VGG11网络,PCA + SVM改进的VGG11网络。
  • 在改进的VGG11网络训练中,我们使用MSE损失函数和余弦相似度来判断模型的优劣,网络训练的MSE损失达到4.6×10−4,训练结果的余弦相似度达到0.99。(由此可见,)网络性能非常出色
  • 实验表明,CNN在三维荧光分析中具有很大的应用价值

数据集信息

以下表格中的Samples,Number,Train,Validate,Test,Total Samples after Expansion列来自论文的Table 3。

Samples Number Train Validate VGG16/main VGG11/train Test VGG11/test VGG16/test Total Samples after Expansion
FU 45 27 9 35 35 9 7 35 135
F 105 63 21 81 81 21 21 81 315
P 206 124 41 161 161 41 42 161 618
PU 60 36 12 45 45 12 12 45 180

论文中的数据扩充说明:在实际训练过程中,我们通过色域失真和镜像翻转来扩展图像。

表格分析:Train+Validate与2个网络的训练集大小相近,VGG16的测试集扩充了4倍,Total Samples after Expansion=3*Number。

环境配置

  • 安装CUDA 12.1
  • 安装cuDnn
  • 新建环境:conda create -n 3deem python=3.10
  • 安装torch-2.2.1+cu121-cp310-cp310-win_amd64.whl
  • 在虚拟环境中安装matplotlib,opencv
  • pip3 install torchvision --index-url https://download.pytorch.org/whl/cu121

分类实验(工作目录:代码仓库/VGG16)

修改代码

  1. 新增annotation_generator.py
import os
from utils.utils import get_classes

classes_path = 'model_data/cls_classes.txt' 
img_root_path = 'datasets/main/'
txt_path = 'model_data/cls_train.txt'

assert os.path.exists(img_root_path)

txt = open(txt_path,'w')
class_names, num_classes = get_classes(classes_path)
for tag_index in range(0,num_classes):
    class_name = class_names[tag_index]
    img_path = img_root_path + class_name + '/'
    files = os.listdir(img_path)
    for img_file in files:
        line = str(tag_index) + ';' + img_path + img_file + '\n'
        txt.write(li