使用神经网络训练iris数据集 —— Python数据工程No.7

时间:2024-03-25 16:31:18

数据介绍:

iris数据集的数据有4个属性,分别为:花萼长、花萼宽、花瓣长、花瓣宽
这些数据是对三种鸢尾花——狗尾鸢尾、杂色鸢尾、弗吉尼亚鸢尾——采样生成的。
部分数据如下图所示:
数据特征示例:
使用神经网络训练iris数据集 —— Python数据工程No.7数据标签示例:
使用神经网络训练iris数据集 —— Python数据工程No.7
因此我们输入的数据特征为14矩阵,输出的数据标签为13矩阵分别记为X,Y。由此,我们可以搭建BP神经网络如下图所示:
使用神经网络训练iris数据集 —— Python数据工程No.7此时X = [x0, x1, x2, x3],Y = [y0, y1, y2]。
w = [[w00, w01, w02], [w10, w11, w12], [w20, w21, w22],[w30, w31, w32, w33]]
b = [b0, b1, b2]
数学关系为:X * w + b = Y
我们需要用训练数据训练出权重矩阵w和偏置矩阵b是神经网络取得效果好的拟合能力,在该神经网络运用于测试集时具有效果好的泛化能力。

训练步骤:

  1. 准备数据,包括数据集读入、数据集乱序,把训练集和测试集中的数据配成输入特征和标签对,生成 train 和 test —— 永不相见的训练集和测试集;
  2. 搭建网络,定义神经网络中的所有可训练参数;
  3. 优化这些可训练的参数,利用嵌套循环在 with 结构中求得损失函数 loss对每个可训练参数的偏导数,更改这些可训练参数,为了查看效果,程序中可以加入每遍历一次数据集显示当前准确率,还可以画出准确率 acc 和损失函数 loss的变化曲线图。

参考代码参见 CSDN下载
主要参考资源为清华大学课程