TFboys:使用Tensorflow搭建深层网络分类器

时间:2022-12-14 18:40:22

前言

根据官方文档整理而来的,主要是对Iris数据集进行分类。使用tf.contrib.learn.tf.contrib.learn快速搭建一个深层网络分类器,

步骤

  1. 导入csv数据
  2. 搭建网络分类器
  3. 训练网络
  4. 计算测试集正确率
  5. 对新样本进行分类

数据

Iris数据集包含150行数据,有三种不同的Iris品种分类。每一行数据给出了四个特征信息和一个分类信息。
现在已经将数据分为训练集和测试集

网络搭建

1. 首先,导入tensorflow 和 numpy

 
 
 
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import tensorflow as tf
  5. import numpy as np

2. 导入数据

 
 
 
  1. # 定义数据地址
  2. IRIS_TRAINING = "iris_training.csv"
  3. IRIS_TEST = "iris_test.csv"
  4. # 导入数据
  5. training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
  6. filename=IRIS_TRAINING,
  7. target_dtype=np.int,
  8. features_dtype=np.float32)
  9. test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
  10. filename=IRIS_TEST,
  11. target_dtype=np.int,
  12. features_dtype=np.float32)

load_csv_with_header() 有三个参数

  • filename, 数据地址
  • target_dtype, 目标值的numpy datatype(iris的目标值是0,1,2,所以是np.int)
  • features_dtype, 特征值的numpy datatype .

3. 搭建网络结构

 
 
 
  1. # 每行数据4个特征,都是real-value的
  2. feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
  3. # 3层DNN,3分类问题
  4. classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
  5. hidden_units=[10, 20, 10],
  6. n_classes=3,
  7. model_dir="iris_model")

参数解释

  • feature_columns 特征值
  • hidden_units=[10, 20, 10]. 3个隐藏层,包含的隐藏神经元依次是10, 20, 10
  • n_classes 类别个数
  • model_dir 模型保存地址

4. 训练数据

 
 
 
  1. classifier.fit(x=training_set.data, y=training_set.target, steps=2000)

steps 为训练次数

5. 计算准确率

 
 
 
  1. accuracy_score = classifier.evaluate(x=test_set.data, y=test_set.target)["accuracy"]
  2. print('Accuracy: {0:f}'.format(accuracy_score))

运行结果是

 
 
 
  1. Accuracy: 0.966667

6. 对新样本进行预测

 
 
 
  1. # Classify two new flower samples.
  2. new_samples = np.array(
  3. [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
  4. y = list(classifier.predict(new_samples, as_iterable=True))
  5. print('Predictions: {}'.format(str(y)))

运行结果为:

 
 
 
  1. Prediction: [1 2]

完整代码

 
 
 
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import tensorflow as tf
  5. import numpy as np
  6. IRIS_TRAINING = "iris_training.csv"
  7. IRIS_TEST = "iris_test.csv"
  8. training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
  9. filename=IRIS_TRAINING,
  10. target_dtype=np.int,
  11. features_dtype=np.float32)
  12. test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
  13. filename=IRIS_TEST,
  14. target_dtype=np.int,
  15. features_dtype=np.float32)
  16. feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
  17. classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
  18. hidden_units=[10, 20, 10],
  19. n_classes=3,
  20. model_dir="iris_model")
  21. classifier.fit(x=training_set.data,
  22. y=training_set.target,
  23. steps=2000)
  24. accuracy_score = classifier.evaluate(x=test_set.data,
  25. y=test_set.target)["accuracy"]
  26. print('Accuracy: {0:f}'.format(accuracy_score))
  27. new_samples = np.array(
  28. [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
  29. y = list(classifier.predict(new_samples, as_iterable=True))
  30. print('Predictions: {}'.format(str(y)))

参考


原文地址: http://www.datalearner.com/blog/1051488938031745