【TensorFlow 3】mnist数据集:与Keras对比

时间:2022-09-05 13:54:09

在TF1.8之后Keras被当作为一个内置API:tf.keras.

并且之前的下载语句会报错。

1 mnist = input_data.read_data_sets('MNIST_data',one_hot=True)

下面给出Keras和TensorFlow两种方式的训练代码(附验证代码):

Keras:

import numpy as np
import matplotlib.pyplot as plt

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense 
from keras.optimizers import SGD

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(60000,784)
x_test = x_test.reshape(10000,784)

# 归一化
x_train = x_train/255 
x_test = x_test/255

y_train = keras.utils.to_categorical(y_train,10)
y_test = keras.utils.to_categorical(y_test,10)

model = Sequential()
model.add(Dense(512,activation='relu',input_shape=(784,)))
model.add(Dense(256,activation="relu"))
model.add(Dense(10,activation="softmax"))

# 显示网络结构
model.summary()

model.compile(optimizer=SGD(),loss='categorical_crossentropy',metrics=['accuracy'])
model.fit(x_train,y_train,batch_size=64,epochs=5,validation_data=(x_test,y_test))

score = model.evaluate(x_test,y_test)

# 输出 loss 和 accuracy
print("loss",score[0])
print("accu",score[1])

# 输入样本生成输出预测。
predictions = model.predict([x_test])

# 预测
print(np.argmax(predictions[23]))

# 查看图片 检查是否预测正确
plt.imshow(x_test[23])
plt.show()

TensorFlow:

代码来自TensorFlow官网

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)

model.evaluate(x_test, y_test)

#验证代码同上 略

另附Keras与tf.keras的区别(https://www.zhihu.com/question/313111229/answer/606660552)