Python 神经网络项目常用语法-2. 类相关

时间:2024-11-19 07:21:51

1. 类的定义及初始化

__init__ 是 Python 中的类初始化方法,也叫构造方法。它用于在类的对象被创建时初始化对象的状态(即设置对象的属性)

__init__ 方法会在类的实例化时自动调用,并且在对象创建后执行

class ClassName:
    def __init__(self, parameters):
        # 初始化属性或执行其他必要的操作
        self.attribute = value
        # 其他代码
  • __init__(self, parameters)__init__ 方法接受至少一个参数,通常是 self,它表示类的实例对象。parameters 是该方法接受的其他参数,用于在初始化时传递值
  • self.attribute:self 表示当前对象实例,可以通过它访问类的属性和方法
  • value 是初始化时为属性赋的值,可以是常量、变量或通过其他逻辑生成的值。

示例 1:基础初始化方法

class Person:
    def __init__(self, name, age):
        # 初始化时将名字和年龄赋值给对象的属性
        self.name = name
        self.age = age

    def introduce(self):
        print(f"Hello, my name is {self.name} and I am {self.age} years old.")

在这个例子中,Person 类的 __init__ 方法接受两个其他参数:name 和 age,并将它们赋值给实例对象的属性 self.nameself.age

# 创建对象时,初始化属性
person1 = Person("Alice", 30)
person2 = Person("Bob", 25)

# 调用方法
person1.introduce()  # 输出: Hello, my name is Alice and I am 30 years old.
person2.introduce()  # 输出: Hello, my name is Bob and I am 25 years old.

示例 2:使用默认参数

class Car:
    def __init__(self, make, model, year=2020):
        # 初始化时,year 如果未传递将默认为 2020
        self.make = make
        self.model = model
        self.year = year

    def display_info(self):
        print(f"{self.year} {self.make} {self.model}")

在这个例子中,year 参数具有默认值 2020。如果在创建 Car 实例时未传递 year 参数,它会自动使用默认值 2020。

# 创建时传递所有参数
car1 = Car("Toyota", "Camry", 2021)

# 创建时只传递 make 和 model,year 会使用默认值 2020
car2 = Car("Honda", "Civic")

# 输出汽车信息
car1.display_info()  # 输出: 2021 Toyota Camry
car2.display_info()  # 输出: 2020 Honda Civic

2. 类的实例化及函数调用

# train_script.py

# Unet3D_with_Conv3D、GaussianDiffusion、Trainer 是从模块中导入的类
from model.video_diffusion_pytorch.video_diffusion_pytorch_conv3d import Unet3D_with_Conv3D
from diffusion.diffusion_2d_jellyfish import GaussianDiffusion, Trainer

if __name__ == "__main__":
	# 解析命令行参数并将其存储在 FLAGS 对象中
	FLAGS = parser.parse_args()

	# 创建 Unet3D_with_Conv3D 模型实例
	model = Unet3D_with_Conv3D(
		dim = 64,  # 设置模型的基础维度大小
		out_dim = 1 if FLAGS.only_vis_pressure else 3,  # 根据命令行参数 only_vis_pressure 决定输出维度
		dim_mults = (1, 2, 4),  # 传递一个元组作为参数,用于指定每个网络层维度的倍数
		channels=5 if FLAGS.only_vis_pressure else 7  # 根据命令行参数 only_vis_pressure 决定通道数
		)
        
	# 创建 GaussianDiffusion 实例
	diffusion = GaussianDiffusion(
		model,
        image_size = 64,
        frames=FLAGS.frames,
        cond_steps=FLAGS.cond_steps,
        timesteps = 1000,           # 设置扩散步骤数
        sampling_timesteps = 250,   # 采样步骤数
        loss_type = 'l2',           # 设置损失函数类型:L1 or L2
        objective = "pred_noise",
        device =device              # 模型运行的设备(CPU/GPU)
    )
    
	# 创建 Trainer 类的实例,该类用于管理模型的训练
	trainer = Trainer(
        diffusion,
        FLAGS.dataset,
        FLAGS.dataset_path,
        FLAGS.frames,
        FLAGS.traj_len,
        FLAGS.ts,
        FLAGS.log_path,
        train_batch_size = FLAGS.batch_size,  # 训练的批次大小
        train_lr = 1e-3,                  # 学习率
        train_num_steps = 400000,         # 总训练步数
        gradient_accumulate_every = 1,    # 指定进行梯度累积的次数
        ema_decay = 0.995,                # 用于模型参数的指数移动平均值的衰减因子
        save_and_sample_every = 4000,     # 每 4000 步保存模型和进行采样
        results_path = results_path,
        amp = False,                      # 是否使用混合精度训练
        calculate_fid = False,            # 训练过程中是否计算 fid
        is_testdata = FLAGS.is_testdata,
        only_vis_pressure = FLAGS.only_vis_pressure,
        model_type = FLAGS.model_type
    	)
    
	trainer.train()  # 调用 Trainer 类的 train 方法,启动模型的训练过程

这段代码展示了如何定义和使用深度学习模型的训练流程,包括模型定义、模型实例化、训练参数设置,以及如何通过面向对象编程实现模块化if __name__ == "__main__" 使得这段代码在直接运行脚本时会执行训练逻辑,而在导入时不会执行,从而提高了代码的复用性和模块化水平。

  • if __name__ == "__main__":它是模块和脚本的运行入口。该语句下的代码仅在该脚本作为主程序运行时才会被执行

    • __name__ 变量:每个 Python 模块都有一个内置属性 __name__,其值决定了模块是被导入还是直接运行
    • __main__:当一个 Python 文件被直接运行时,__name__ 的值会被设置为 __main__
    • 导入时的行为:如果该模块被其他脚本导入__name__ 的值是该模块的文件名(不带路径和 .py 扩展名)。
  • model = ClassName(arguments)创建类的实例,通过类构造函数 ClassName 初始化对象 model

  • out_dim = 1 if FLAGS.only_vis_pressure else 3:使用条件表达式(类似三元运算符)来设置输出维度,如果 FLAGS.only_vis_pressure 为真,out_dim 为 1,否则为 3。

示例:

# main_script.py
if __name__ == "__main__":
    print("This will only run when main_script.py is executed directly.")

运行结果:

  • 如果运行 python main_script.py,将输出 This will only run when main_script.py is executed directly.
  • 如果 main_script.py 被其他脚本导入,如 import main_script,这行代码不会被执行。