J3学习打卡

时间:2024-11-04 11:15:45
  • ???? 本文为????365天深度学习训练营 中的学习记录博客
  • ???? 原作者:K同学啊

DensNet模型

import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models, initializers

class DenseLayer(layers.Layer):
    def __init__(self, growth_rate, drop_rate=0.1):
        super(DenseLayer, self).__init__()

        self.drop_rate = drop_rate
        self.norm1 = layers.BatchNormalization()
        self.relu1 = layers.ReLU()
        self.conv1 = layers.Conv2D(filters=4 * growth_rate, kernel_size=1, use_bias=False)
        self.norm2 = layers.BatchNormalization()
        self.relu2 = layers.ReLU()
        self.conv2 = layers.Conv2D(filters=growth_rate, kernel_size=3, padding='same', use_bias=False)

        # Create Dropout layer once, not in call
        if self.drop_rate > 0:
            self.dropout = layers.Dropout(self.drop_rate)

    def call(self, inputs, training=False):
        x = self.conv1(self.relu1(self.norm1(inputs, training=training)))
        x = self.conv2(self.relu2(self.norm2(x, training=training)))

        if self.drop_rate > 0:
            x = self.dropout(x, training=training)  # Use the predefined dropout layer

        return layers.Concatenate()([inputs, x])


class DenseBlock(tf.keras.layers.Layer):
    def __init__(self, num_layers, growth_rate, drop_rate):
        super(DenseBlock, self).__init__()
        self.num_layers = num_layers
        self.layers_list = [DenseLayer(growth_rate, drop_rate) for _ in range(num_layers)]

    def call(self, x, training=False):
        for layer in self.layers_list:
            x = layer(x, training=training)
        return x


class TransitionLayer(tf.keras.layers.Layer):
    def __init__(self, num_output_features):
        super(TransitionLayer, self).__init__()
        self.norm = layers.BatchNormalization()
        self.relu = layers.ReLU()
        self.conv = layers.Conv2D(num_output_features, kernel_size=1, strides=1, use_bias=False)
        self.pool = layers.AveragePooling2D(pool_size=2, strides=2)

    def call(self, x, training=False):
        x = self.conv(self.relu(self.norm(x, training=training)))
        x = self.pool(x)
        return x


class DenseNet(tf.keras.Model):
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,
                 bn_size=8, compression_rate=0.5, drop_rate=0.1, num_classes=4):
        super(DenseNet, self).__init__()

        # Initial Conv Layer
        self.features = models.Sequential([
            layers.Conv2D(num_init_features, kernel_size=7, strides=2, padding='same', use_bias=False),
            layers.BatchNormalization(),
            layers.ReLU(),
            layers.MaxPooling2D(pool_size=3, strides=2, padding='same'),
        ])

        # DenseBlocks and Transitions
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            self.features.add(DenseBlock(num_layers, growth_rate, drop_rate))
            num_features += num_layers * growth_rate
            if i != len(block_config) - 1:
                out_features = int(num_features * compression_rate)
                self.features.add(TransitionLayer(out_features))
                num_features = out_features

        # Final Batch Norm and ReLU
        self.features.add(layers.BatchNormalization())
        self.features.add(layers.ReLU())

        # Classification Layer
        self.classifier = layers.Dense(num_classes, kernel_initializer=initializers.he_normal())

    def call(self, x, training=False):
        x = self.features(x, training=training)  # Pass the 'training' argument
        x = layers.GlobalAveragePooling2D()(x)
        x = self.classifier(x)
        return x

# 设置数据路径
data_dir = r"C:\Users\11054\Desktop\kLearning\J1_learning\bird_photos"

batch_size = 8
img_height = 224
img_width = 224


train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)
plt.figure(figsize=(10, 5))  # 图形的宽为10高为5
plt.suptitle("微信公众号:K同学啊")

for images, labels in train_ds.take(1):
    for i in range(8):

        ax = plt.subplot(2, 4, i + 1)

        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])

        plt.axis("off")

    plt.imshow(images[1].numpy().astype("uint8"))

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
# 添加自定义层
model = DenseNet(num_classes=4)

# 编译模型
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',  # 使用稀疏分类交叉熵
              metrics=['accuracy'])

# 训练模型
epochs = 10
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)

# 绘制训练和验证过程中的损失及准确率
plt.figure(figsize=(12, 4))

# 绘制训练和验证损失
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# 绘制训练和验证准确率
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.show()
Found 565 files belonging to 4 classes.
Using 452 files for training.
Found 565 files belonging to 4 classes.
Using 113 files for validation.
['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
(8, 224, 224, 3)
(8,)
Epoch 1/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m159s[0m 2s/step - accuracy: 0.4129 - loss: 3.3313 - val_accuracy: 0.2655 - val_loss: 11.8390
Epoch 2/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m111s[0m 2s/step - accuracy: 0.5344 - loss: 2.9348 - val_accuracy: 0.2655 - val_loss: 11.3703
Epoch 3/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m94s[0m 2s/step - accuracy: 0.3674 - loss: 2.6228 - val_accuracy: 0.2655 - val_loss: 11.6901
Epoch 4/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m94s[0m 2s/step - accuracy: 0.4173 - loss: 2.0029 - val_accuracy: 0.2743 - val_loss: 8.5733
Epoch 5/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m116s[0m 2s/step - accuracy: 0.5138 - loss: 1.6268 - val_accuracy: 0.2743 - val_loss: 10.3662
Epoch 6/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m134s[0m 2s/step - accuracy: 0.3255 - loss: 2.6131 - val_accuracy: 0.2655 - val_loss: 8.9590
Epoch 7/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m142s[0m 2s/step - accuracy: 0.3427 - loss: 3.4939 - val_accuracy: 0.3805 - val_loss: 7.2092
Epoch 8/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m132s[0m 2s/step - accuracy: 0.3013 - loss: 3.0391 - val_accuracy: 0.2832 - val_loss: 1.3802
Epoch 9/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m135s[0m 2s/step - accuracy: 0.3722 - loss: 2.9314 - val_accuracy: 0.2478 - val_loss: 10.2079
Epoch 10/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m140s[0m 2s/step - accuracy: 0.3580 - loss: 1.3863 - val_accuracy: 0.1770 - val_loss: 5.2081


C:\Users\11054\.conda\envs\tf39\lib\site-packages\IPython\core\pylabtools.py:152: UserWarning: Glyph 24494 (\N{CJK UNIFIED IDEOGRAPH-5FAE}) missing from font(s) DejaVu Sans.
  fig.canvas.print_figure(bytes_io, **kw)
C:\Users\11054\.conda\envs\tf39\lib\site-packages\IPython\core\pylabtools.py:152: UserWarning: Glyph 20449 (\N{CJK UNIFIED IDEOGRAPH-4FE1}) missing from font(s) DejaVu Sans.
  fig.canvas.print_figure(bytes_io, **kw)
C:\Users\11054\.conda\envs\tf39\lib\site-packages\IPython\core\pylabtools.py:152: UserWarning: Glyph 20844 (\N{CJK UNIFIED IDEOGRAPH-516C}) missing from font(s) DejaVu Sans.
  fig.canvas.print_figure(bytes_io, **kw)
C:\Users\11054\.conda\envs\tf39\lib\site-packages\IPython\core\pylabtools.py:152: UserWarning: Glyph 20247 (\N{CJK UNIFIED IDEOGRAPH-4F17}) missing from font(s) DejaVu Sans.
  fig.canvas.print_figure(bytes_io, **kw)
C:\Users\11054\.conda\envs\tf39\lib\site-packages\IPython\core\pylabtools.py:152: UserWarning: Glyph 21495 (\N{CJK UNIFIED IDEOGRAPH-53F7}) missing from font(s) DejaVu Sans.
  fig.canvas.print_figure(bytes_io, **kw)
C:\Users\11054\.conda\envs\tf39\lib\site-packages\IPython\core\pylabtools.py:152: UserWarning: Glyph 65306 (\N{FULLWIDTH COLON}) missing from font(s) DejaVu Sans.
  fig.canvas.print_figure(bytes_io, **kw)
C:\Users\11054\.conda\envs\tf39\lib\site-packages\IPython\core\pylabtools.py:152: UserWarning: Glyph 21516 (\N{CJK UNIFIED IDEOGRAPH-540C}) missing from font(s) DejaVu Sans.
  fig.canvas.print_figure(bytes_io, **kw)
C:\Users\11054\.conda\envs\tf39\lib\site-packages\IPython\core\pylabtools.py:152: UserWarning: Glyph 23398 (\N{CJK UNIFIED IDEOGRAPH-5B66}) missing from font(s) DejaVu Sans.
  fig.canvas.print_figure(bytes_io, **kw)
C:\Users\11054\.conda\envs\tf39\lib\site-packages\IPython\core\pylabtools.py:152: UserWarning: Glyph 21834 (\N{CJK UNIFIED IDEOGRAPH-554A}) missing from font(s) DejaVu Sans.
  fig.canvas.print_figure(bytes_io, **kw)

在这里插入图片描述
在这里插入图片描述

个人总结

  1. 完成了torch到tensorflow的代码转换,但是tensorflow运行显示loss较大,模型运行效果较差
  2. DenseNet采用了密集连接(Dense Connection),每一层的输入是前面所有层的输出的拼接(concatenation),而不是求和。这种设计使得每一层都能直接访问前面所有层的特征图,从而增强了特征的重用,并且减少了梯度消失问题