除此之外,我们需要观察模型的效果怎么样,所以需要可视化一些数据,从而方便我们调整网络结构以及相应的参数,比如损失值,准确率等。
首先需要将模型输出的张量转化为4个字符的验证码。由于标签是用4个one hot向量拼接在一起的,所以需要按照one hot向量的长度将张量分割开。然后在每个切片中找到数值最大的对应的索引。数值最大的我们就认为是one hot中的1。最后,根据索引去字符集中找到对应的字符即可。
def tensor2captcha(tensor):
tot_len = tensor.shape[1]
assert CAPTCHA_LEN * CHAR_LEN == tot_len
chars = np.array(list(CHAR_SET))
captcha = []
for i in range(CAPTCHA_LEN):
slice = tensor[:, i * CHAR_LEN:(i + 1) * CHAR_LEN]
idx = torch.argmax(slice, dim=1).detach().cpu().numpy()
col = chars[idx][:, np.newaxis]
captcha.append(col)
captcha = list(np.hstack(captcha))
captcha = [''.join(row) for row in captcha]
return captcha
计算准确率就很简单了,预测值和标签一样就认为是预测正确。
def get_acc(label, pred):
assert len(label) == len(pred)
label, pred = np.array(label), np.array(pred)
return sum(label == pred) / len(label)
最后是绘制损失曲线和准确率曲线。
def draw_loss_curve(loss_per_epoch):
plt.figure()
plt.plot(list(range(1, len(loss_per_epoch) + 1)), loss_per_epoch)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.savefig('loss.png')
def draw_acc(train_acc, test_acc):
assert len(train_acc) == len(test_acc)
x = list(range(1, len(train_acc) + 1))
plt.clf()
plt.plot(x, train_acc, label='train_acc')
plt.plot(x, test_acc, label='test_acc')
plt.legend()
plt.title("acc goes by epoch")
plt.xlabel('eopch')
plt.ylabel('acc_value')
plt.savefig('acc.png')