分类器平均准确率计算:
1
2
3
4
5
6
7
8
9
10
11
12
|
correct = torch.zeros( 1 ).squeeze().cuda()
total = torch.zeros( 1 ).squeeze().cuda()
for i, (images, labels) in enumerate (train_loader):
images = Variable(images.cuda())
labels = Variable(labels.cuda())
output = model(images)
prediction = torch.argmax(output, 1 )
correct + = (prediction = = labels). sum (). float ()
total + = len (labels)
acc_str = 'Accuracy: %f' % ((correct / total).cpu().detach().data.numpy())
|
分类器各个子类准确率计算:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
correct = list ( 0. for i in range (args.class_num))
total = list ( 0. for i in range (args.class_num))
for i, (images, labels) in enumerate (train_loader):
images = Variable(images.cuda())
labels = Variable(labels.cuda())
output = model(images)
prediction = torch.argmax(output, 1 )
res = prediction = = labels
for label_idx in range ( len (labels)):
label_single = label[label_idx]
correct[label_single] + = res[label_idx].item()
total[label_single] + = 1
acc_str = 'Accuracy: %f' % ( sum (correct) / sum (total))
for acc_idx in range ( len (train_class_correct)):
try :
acc = correct[acc_idx] / total[acc_idx]
except :
acc = 0
finally :
acc_str + = '\tclassID:%d\tacc:%f\t' % (acc_idx + 1 , acc)
|
以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/u014657795/article/details/86419197