mnist作为最基础的图片数据集,在以后的cnn,rnn任务中都会用到
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
|
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
#数据集存放地址,采用0-1编码
mnist = input_data.read_data_sets( 'F:/mnist/data/' ,one_hot = True )
print (mnist.train.num_examples)
print (mnist.test.num_examples)
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels
#打印相关信息
print ( type (trainimg))
print (trainimg.shape,)
print (trainlabel.shape,)
print (testimg.shape,)
print (testlabel.shape,)
nsample = 5
randidx = np.random.randint(trainimg.shape[ 0 ],size = nsample)
#输出几张数字的图
for i in randidx:
curr_img = np.reshape(trainimg[i,:],( 28 , 28 ))
curr_label = np.argmax(trainlabel[i,:])
plt.matshow(curr_img,cmap = plt.get_cmap( 'gray' ))
plt.title(" "+str(i)+" th Training Data "+" label is " + str (curr_label))
print (" "+str(i)+" th Training Data "+" label is " + str (curr_label))
plt.show()
|
程序运行结果如下:
1
2
3
4
5
6
7
8
9
10
11
12
|
Extracting F: / mnist / data / train - images - idx3 - ubyte.gz
Extracting F: / mnist / data / train - labels - idx1 - ubyte.gz
Extracting F: / mnist / data / t10k - images - idx3 - ubyte.gz
Extracting F: / mnist / data / t10k - labels - idx1 - ubyte.gz
55000
10000
< class 'numpy.ndarray' >
( 55000 , 784 )
( 55000 , 10 )
( 10000 , 784 )
( 10000 , 10 )
52636th
|
输出的图片如下:
Training Datalabel is9
下面还有四张其他的类似图片
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/Missayaaa/article/details/80056103