1 import torch 2 from torch import nn 3 from torch.nn import functional as F 4 from torch import optim 5 6 import torchvision 7 from matplotlib import pyplot as plt 8 9 # 小工具 10 11 def plot_curve(data): 12 fig = plt.figure() 13 plt.plot(range(len(data)),data,color='blue') 14 plt.legend(['value'],loc='upper right') 15 plt.xlabel('step') 16 plt.tlabel('value') 17 plt.show() 18 19 def plot_image(img,label,name): 20 fig = plt.figure() 21 for i in range(6): 22 plt.subplot(2,3,i+1) 23 plt,tight_layout() 24 plt.imshow(img[i][0]*0.3081+0.1307,cmap='gray',interpolation='none') 25 plt.title("{}:{}".format(name,label[i].item())) 26 plt.xticks([]) 27 plt.xticks([]) 28 29 plt.show() 30 31 def one_hot(label,depth = 10): 32 out = torch.zeros(label.size(0),depth) 33 idx = torch.LongTensor(label).view(-1,1) 34 out.scatter_(dim=1,index=idx,value=1) 35 return out 36 37 # 一次加载多少图片 38 batch_size = 512 39 # step1. load dataset 数据加载 40 train_loader = torch.utils.data.DataLoader( 41 torchvision.datasets.MINST('mnist_data',train=True,download=True, 42 transform=torchvision.transforms.Compose([ 43 torchvision.transfroms.ToTensor(), 44 45 torchvision.transfroms.Normalize( 46 (0.1307,),(0.3081,)) 47 ])), 48 batch_size=batch_size,shuffle=True) 49 test_loader = torch.utils.data.DataLoader( 50 torchvision.datasets.MINST('mnist_data/',train=False,download=True, 51 transform=torchvision.transforms.Compose([ 52 torchvision.transfroms.ToTensor(), 53 torchvision.transfroms.Normalize( 54 (0.1307,),(0.3081,)) 55 ])), 56 batch_size=batch_size,shuffle=False) 57 58 # 网络创建 59 class Net(nn.Module): 60 61 def __init__(self): 62 super(Net,self).__init__() 63 64 #xw+b 65 self.fc1 = nn.Linear(28*28,256) 66 self.fc2 = nn.Linear(256,64) 67 self.fc3 = nn.Linear(64,10) 68 69 def forward(self,x): 70 # x:[batch_size,1,28,28] 71 # h1 = relu(xw1+b1) 72 x = F.relu(self.fc1(x)) 73 # h1 = relu(h1w2+b2) 74 x = F.relu(self.fc2(x)) 75 # h3 = h2w3+b3 76 x = self.fc3(x) 77 78 return x 79 80 net = Net() 81 # [w1,b1,w2,b1,w3,b3] 82 optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9) 83 84 train_loss = [] 85 86 # 训练 87 for epoch in range(3): 88 89 for batch_idx,(x,y) in enumerate(train_loader): 90 91 # x: [b,1,28,28], y:[512] 92 # [b,1,28,28]-->[b,feature] 93 x = x.view(x.size(0),28*28) 94 # --> [b,10] 95 out = net(x) 96 # --> [b,10] 97 y_onehot = one_hot(y) 98 # loss = mse(out,y_onehot) 99 loss = F.mse_loss(out,y_onehot) 100 # 清零梯度 101 optimizer.zero_grad() 102 # 计算梯度 103 loss.backward() 104 #w' = w - lr*grad 更新梯度 105 optimizer.step() 106 107 train_loss.append(loss.item()) 108 109 if batch_idx % 10 == 0: 110 print(epoch,batch_idx,loss.item()) 111 112 plot_curve(train_loss) 113 114 # 得到一个比较好的 [w1,b1,w2,b1,w3,b3] 115 116 117 # 验证准确率 118 total_correct = 0 119 for x,y in test_loader" 120 x = x.view(x.size(0),28*28) 121 out = net(x) 122 # out: [b,10] --> pred: [b] 123 pred = out.argmax(dim = 1) 124 correct = pred.eq(y).sum().float().item() 125 total_correct += correct 126 127 total_num = len(test_loader.dataset) 128 acc = total_correct / total_num 129 print('test acc:',acc) 130 131 # 直观显示验证 132 x,y = next(iter(test_loader)) 133 out = net(x.view(x.size(0),28*28)) 134 pred = out.argmax(dim = 1) 135 plot_image(x,pred,'test') 136 137 138 139 140 141