Tensorflow直接读取CSV文件

时间:2021-10-16 10:00:14

Tensorflow直接读取CSV文件

整理一下tensorflow读取csv文件的基本流程,主要是官方文档中的例子的记录。
tensorflow读取csv文件相对pandas要复杂一下,基本过程如下:

  1. 产生文件名列表,这里可以一次性用pipline读取一系列csv文件。
  2. 建立阅读器,读取原始数据。
  3. 解析读出的原始数据,转化成数值数据或指定格式的数据。
  4. 开启多线程协调器,启动输入管道。
  5. 读取完毕,停止请求。

    选取iris数据集测试,iris2.csv是完全一样的一个文件,主要为了验证多文件读取的功能。有4个实数型属性和一个字符串性label。

import tensorflow as tf
import numpy as np

filename_queue = tf.train.string_input_producer(["./data/iris.csv", "./data/iris2.csv"])

reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# key返回的是读取文件和行数信息 b'./data/iris.csv:146'
# value是按行读取到的原始字符串,送到下面的decoder去解析

record_defaults = [[1.0], [1.0], [1.0], [1.0], ["Null"]] # 这里的数据类型决定了读取的数据类型,而且必须是list形式
col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults) # 解析出的每一个属性都是rank为0的标量
features = tf.stack([col1, col2, col3, col4])

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(200):
        example, label = sess.run([features, col5])
        print (example,col5)

    coord.request_stop()
    coord.join(threads)

这里有几个点需要注意,代码注释中也已经写明。

  1. 文件名列表可以自己生成,也可以手写。
  2. key是文件信息和当前读取的行数,value是原始字符串。
  3. defualt有两个作用,一个是指定当前列的数据类型,一个是替补空值。
  4. 按行读取后读出的每个值都是rank为0的标量。
  5. stack函数可以堆叠,组成一个新的tensor,默认axis=0,因此就会横向把前面的标量串成一个rank为1的tensor。