示例:如何用pytorh写一个代码

时间:2022-10-07 19:52:38
  1. 库的导入,如:
    import torch
    from torch import nn
    from torch.utils.data import Dataset, DataLoader
    等等

  2. 数据集的定义——最终要有一个dataloader
    1)如果是从接口调用,那么可以直接按接口文档调用。
    2)如果是自定义数据集和加载,那么必须要有__init__、getitem、__len__三个类函数,如下:
    class TrainSet(Dataset):
    def init(self, X, Y):
    # 定义好数据体本身
    self.X, self.Y = X, Y

    def getitem(self, index):
    # 定义得到第index元素的值的方法
    return self.X[index], self.Y[index]

    def len(self):
    # 定义数据集长度
    return len(self.X)

  3. 模型的定义
    1)如果是从接口调用,那么可以直接按接口文档调用。
    2)如果是自定义模型,那么必须要有__init__、forward两个类函数,如下:
    class NeuralNetwork(nn.Module):

    定义模型的层

    def init(self):
    super(NeuralNetwork, self).init()

    定义前向传播函数

    def forward(self, x):

    return y

  4. 优化模型参数,示例:
    def train_loop(dataloader, model, loss_fn, optimizer, device):
    for batch, (X, y) in enumerate(dataloader):
    X = X.to(device)
    y = y.to(device)
    # 前向传播,计算预测值
    pred = model(X)
    # 计算损失
    loss = loss_fn(pred, y)
    # 反向传播,优化参数
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  5. 测试模型性能,示例:
    def test_loop(dataloader, model, loss_fn, device):
    test_loss, correct = 0, 0

    with torch.no_grad():
    for X, y in dataloader:
    X = X.to(device)
    y = y.to(device)
    # 前向传播,计算预测值
    pred = model(X)
    # 计算损失
    test_loss += loss_fn(pred, y).item()
    # 计算准确率
    correct += (pred.argmax(1) == y).type(torch.float).sum().item()

if name == ‘main’:
6. 得到设备(cuda是gpu)
device = “cuda” if torch.cuda.is_available() else “cpu”
print(f"Using {device} device")
7. 定义模型
model = NeuralNetwork().to(device)
8. 设置超参数
learning_rate = 1e-6
batch_size = 8
epochs = 5
9. 定义损失函数
loss_fn = nn.CrossEntropyLoss()
10. 定义优化器(可能还可以定义学习率策略器lr_scheduler)
optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate)
11. 训练模型
for t in range(epochs):
train_loop(train_dataloader, model, loss_fn, optimizer, device)
test_loop(test_dataloader, model, loss_fn, device)
12. 保存模型
torch.save(model.state_dict(), ‘model_weights.pth’)