深度学习项目--基于LSTM的糖尿病预测探究(pytorch实现)-模型训练

时间:2025-01-25 16:20:13

1、创建训练集

def train(dataloader, model, loss_fn, opt):
    size = len(dataloader.dataset)
    num_batch = len(dataloader)
    
    train_acc, train_loss = 0.0, 0.0 
    
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        
        pred = model(X).view(-1, 2)
        loss = loss_fn(pred, y)
        
        # 梯度设置
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        train_loss += loss.item()
        # 求最大概率配对
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        
    train_acc /= size 
    train_loss /= num_batch
    
    return train_acc, train_loss 
        

2、创建测试集函数

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batch = len(dataloader)
    
    test_acc, test_loss = 0.0, 0.0 
    
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            
            pred = model(X).view(-1, 2)
            loss = loss_fn(pred, y)
            
            test_loss += loss.item()
            # 求最大概率配对
            test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
            
    test_acc /= size 
    test_loss /= num_batch 
    
    return test_acc, test_loss

3、设置超参数

learn_rate = 1e-4
opt = torch.optim.Adam(model.parameters(), lr=learn_rate)
loss_fn = nn.CrossEntropyLoss()