1、创建训练集
def train(dataloader, model, loss_fn, opt):
size = len(dataloader.dataset)
num_batch = len(dataloader)
train_acc, train_loss = 0.0, 0.0
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X).view(-1, 2)
loss = loss_fn(pred, y)
opt.zero_grad()
loss.backward()
opt.step()
train_loss += loss.item()
train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
train_acc /= size
train_loss /= num_batch
return train_acc, train_loss
2、创建测试集函数
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batch = len(dataloader)
test_acc, test_loss = 0.0, 0.0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X).view(-1, 2)
loss = loss_fn(pred, y)
test_loss += loss.item()
test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
test_acc /= size
test_loss /= num_batch
return test_acc, test_loss
3、设置超参数
learn_rate = 1e-4
opt = torch.optim.Adam(model.parameters(), lr=learn_rate)
loss_fn = nn.CrossEntropyLoss()