简化后的模型定义

时间:2024-10-07 07:22:47
class Model(nn.Module):
    def __init__(self, input_size, hidden_layers=[64, 32], activation='relu', batch_size=32, learning_rate=0.001, epochs=10, patience=5, l1_regularization=0, l2_regularization=0, dropout_rate=0.3):
        super(Model, self).__init__()
        self.hidden_layers = nn.ModuleList()  # 定义隐藏层为一个列表,方便添加
        last_size = input_size  # 用于追踪上一个层的输出大小

        for hidden_units in hidden_layers:  # 迭代每一层的单元数
            self.hidden_layers.append(nn.Linear(last_size, hidden_units))
            last_size = hidden_units  # 更新当前层的输出作为下一层的输入

        self.output_layer = nn.Linear(last_size, 1)  # 输出层,输出为 1,适用于回归任务
        self.activation = nn.LeakyReLU() if activation == 'leaky_relu' else nn.ReLU()  # 激活函数
        self.dropout = nn.Dropout(dropout_rate)  # Dropout层
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.epochs = epochs
        self.patience = patience
        self.l1_regularization = l1_regularization
        self.l2_regularization = l2_regularization

    def forward(self, x):
        try:
            x = x.to(self.hidden_layers[0].weight.device)  # 确保输入数据在与模型同样的设备上
            for layer in self.hidden_layers:
                x = layer(x)
                x = self.activation(x)
                x = self.dropout(x)  # 使用dropout

            x = self.output_layer(x)  # 输出层不需要激活函数
            return x

        except Exception as e:
            print(f"Error during forward pass: {e}")
            return None

主要改进:

  1. hidden_layers 模块化:通过 nn.ModuleList() 将隐藏层用循环定义,使模型更灵活。可以轻松添加或修改隐藏层的数量和单元数。
  2. 减少不必要的参数:为一些参数提供默认值,例如 dropout_rate,减少构造函数的复杂性。
  3. 简化的 forward 函数:前向传播函数中遍历 hidden_layers,无需逐层手动书写,这让网络层数的修改变得更加简便。