1.mnist数据
tensorflow提供一个input_data.py文件,专门用于下载mnist数据:
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
执行完成后,会在当前目录下新建一个文件夹MNIST_data, 下载的数据将放入这个文件夹内。下载的四个文件为:
nput_data文件会调用一个maybe_download函数,确保数据下载成功。这个函数还会判断数据是否已经下载,如果已经下载好了,就不再重复下载。
下载下来的数据集被分三个子集:5.5W行的训练数据集(mnist.train
),5千行的验证数据集(mnist.validation)和1W行的测试数据集(mnist.test
)。因为每张图片为28x28的黑白图片,所以每行为784维的向量。
每个子集都由两部分组成:图片部分(images)和标签部分(labels), 我们可以用下面的代码来查看 :
print mnist.train.images.shape
print mnist.train.labels.shape
print mnist.validation.images.shape
print mnist.validation.labels.shape
print mnist.test.images.shape
print mnist.test.labels.shape
如果想查看具体数值,可以将这些数据提取为变量来查看,如:
train_data = mnist.train.images
train_label = mnist.train.labels
val_data = mnist.validation.images
val_label = mnist.validation.labels
test_data = mnist.test.images
test_label = mnist.test.labels
print train_data,train_label
print val_data,val_label
print test_data,test_label
2.CSV数据
除了mnist手写字体图片数据,tf还提供了几个csv的数据供大家练习,存放路径为:
/root/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/data/text_train.csv
如果要将这些数据读出来,可用代码:
import tensorflow.contrib.learn.python.learn.datasets.base as base
iris_data,iris_label=base.load_iris()
house_data,house_label=base.load_boston()
前者为iris鸢尾花卉数据集,后者为波士顿房价数据。
3.cifar10数据
tf提供了cifar10数据的下载和读取的函数,我们直接调用就可以了。执行下列代码:
import tensorflow.models.image.cifar10.cifar10 as cifar10
cifar10.maybe_download_and_extract()
images, labels = cifar10.distorted_inputs()
print images
print labels
就可以将cifar10下载并读取出来。