Tensorflow 网络基本结构

时间:2022-12-12 09:56:26

网络容器
通过Keras提供的网络容器Sequential将多个网络层封装成一个大网络模型,只需要调用网络模型的实例一次即可完成数据从第一层到最末层的顺序传播运算。
(1)Sequential容器封装为一个网络:

import tensorflow as tf
from tensorflow.keras import  layers,Sequential
model = Sequential([#封装一个网络
    layers.Dense(3,activation=None),#全连接层,不使用激活函数
    layers.ReLU(),#激活函数层
    layers.Dense(2,activation=None),
    layers.ReLU()
])
x = tf.random.normal([4,3])
out = model(x)
out
'''
<tf.Tensor: shape=(4, 2), dtype=float32, numpy=
array([[0.02129143, 0.        ],
       [0.00915629, 0.        ],
       [0.        , 0.9278006 ],
       [1.190076  , 0.        ]], dtype=float3
'''

(2)Sequential容器通过add()方法:

layers_num = 3#堆叠三层
model = Sequential([])#创建空间的网络容器
for _ in range(layers_num):
    model.add(layers.Dense(3))#添加全连接层
    model.add(layers.ReLU())#添加激活层
model.build(input_shape=(4,4))#创建网络参数
model.summary()#打印出网络结构和参数量
'''
Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_9 (Dense)              (4, 3)                    15        
_________________________________________________________________
re_lu_9 (ReLU)               (4, 3)                    0         
_________________________________________________________________
dense_10 (Dense)             (4, 3)                    12        
_________________________________________________________________
re_lu_10 (ReLU)              (4, 3)                    0         
_________________________________________________________________
dense_11 (Dense)             (4, 3)                    12        
_________________________________________________________________
re_lu_11 (ReLU)              (4, 3)                    0         
=================================================================
Total params: 39
Trainable params: 39
Non-trainable params: 0
_________________________________________________________________
'''

(3)所有层的待优化张量列表和全部张量列表:

for p in model.trainable_variables:
    print(p.name,p.shape)
'''
dense_9/kernel:0 (4, 3)
dense_9/bias:0 (3,)
dense_10/kernel:0 (3, 3)
dense_10/bias:0 (3,)
dense_11/kernel:0 (3, 3)
dense_11/bias:0 (3,)
'''

配置模型
Layer类是网络层的母类,定义了网络层的一些常见功能,如添加权值、管理权值列表等。
Model类是网络的母类,除了具有Layer类的功能,还添加了保存模型、加载模型、训练与测试模型等便捷功能。Sequential也是Model的子类,因此具有Model类的所有功能。

model = Sequential([#封装一个网络
    layers.Dense(256,activation='relu'),
    layers.Dense(128,activation='relu'),
    layers.Dense(56,activation='relu'),
    layers.Dense(28,activation='relu'),
    layers.Dense(10),
])
model.build(input_shape=(4,28,28))
model.summary()
#compile()函数中指定的优化器、损失函数等参数也是自行训练时需要设置的参数
from tensorflow.keras import  optimizers,losses
from tensorflow import  metrics
model.compile(optimizer=optimizers.Adam(lr=0.01),
             loss = losses.CategoricalCrossentropy(from_logits=True),
             metrics=['accuracy'])#设置测量指标为准确率

模型训练
通过fit()函数送入待训练的数据集和验证用的数据集,这一步称为模型训练

#epochs参数指定训练迭代的Epoch数量;validation_data参数指定用于验证(测试)的数据集和验证的频率validation_freq

history = model.fit(train_x,epochs=10,validation_data=train_val,validation_freq=5)

fit()函数会返回训练过程的数据记录history,其中history.history为字典对象,包含了训练过程中的loss、测量指标等记录项。

history.history()

模型测试
过model.predict(x)方法即可完成模型的预测。

#模型预测
model.predict(x_test)
#测试模型性能
model.evaluate(x_test)

保存模型
(1)张量方式
网络的状态主要体现在网络的结构以及网络层内部张量数据上,因此在拥有网络结构源文件的条件下,直接保存网络张量参数到文件系统上是最轻量级的一种方式。

model.save_weights('weights.ckpt')#保存模型的所有张量数据

需要时,先创建好网络对象,然后调用网络对象的load_weights(path)方法即可将指定的模型文件中保存的张量数值写入到当前网络参数中去。

model.load_weights('weights.ckpt')

这种保存与加载网络的方式最为轻量级,文件中保存的仅是张量参数的数值,并没有其他额外的结构参数。但是它需要使用相同的网络结构才能够正确恢复网络状态,因此一般在拥有网络源文件的情况下使用。

(2)网络方式
不需要网络源文件,仅需要模型参数文件即可恢复网络模型的方法。通过Model.save(path)函数可以将模型的结构以及模型的参数保存到path文件上,在不需要网络源文件的条件下,通过keras.models.load_model(path)即可恢复网络结构和网络参数。

#保存模型结构与模型参数到文件
model.save('model.h5')
#从文件恢复网络结构与网络参数
model = tf.keras.model.load_model('model.h5')

model.h5文件除了保存了模型参数外,还应保存了网络结构信息,不需要提前创建模型即可直接从文件中恢复网络model对象。

(3)SavedModel方式
通过tf.saved_model.save(network,path)即可将模型以SavedModel方式保存到path目录中。

#保存模型结构与模型参数到文件
tf.saved_model.save(model,'model_savedmodel')

用户无须关心文件的保存格式,只需要通过tf.saved_model.load函数即可恢复模型对象。

#从文件恢复网络结构与网络参数
model = tf.saved_model.load('model_savedmodel')