本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
|
#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03'
import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data
train_data = torchvision.datasets.MNIST(
'./mnist' , train = True , transform = torchvision.transforms.ToTensor(), download = True
)
test_data = torchvision.datasets.MNIST(
'./mnist' , train = False , transform = torchvision.transforms.ToTensor()
)
print ( "train_data:" , train_data.train_data.size())
print ( "train_labels:" , train_data.train_labels.size())
print ( "test_data:" , test_data.test_data.size())
train_loader = Data.DataLoader(dataset = train_data, batch_size = 64 , shuffle = True )
test_loader = Data.DataLoader(dataset = test_data, batch_size = 64 )
class Net(torch.nn.Module):
def __init__( self ):
super (Net, self ).__init__()
self .conv1 = torch.nn.Sequential(
torch.nn.Conv2d( 1 , 32 , 3 , 1 , 1 ),
torch.nn.ReLU(),
torch.nn.MaxPool2d( 2 ))
self .conv2 = torch.nn.Sequential(
torch.nn.Conv2d( 32 , 64 , 3 , 1 , 1 ),
torch.nn.ReLU(),
torch.nn.MaxPool2d( 2 )
)
self .conv3 = torch.nn.Sequential(
torch.nn.Conv2d( 64 , 64 , 3 , 1 , 1 ),
torch.nn.ReLU(),
torch.nn.MaxPool2d( 2 )
)
self .dense = torch.nn.Sequential(
torch.nn.Linear( 64 * 3 * 3 , 128 ),
torch.nn.ReLU(),
torch.nn.Linear( 128 , 10 )
)
def forward( self , x):
conv1_out = self .conv1(x)
conv2_out = self .conv2(conv1_out)
conv3_out = self .conv3(conv2_out)
res = conv3_out.view(conv3_out.size( 0 ), - 1 )
out = self .dense(res)
return out
model = Net()
print (model)
optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()
for epoch in range ( 10 ):
print ( 'epoch {}' . format (epoch + 1 ))
# training-----------------------------
train_loss = 0.
train_acc = 0.
for batch_x, batch_y in train_loader:
batch_x, batch_y = Variable(batch_x), Variable(batch_y)
out = model(batch_x)
loss = loss_func(out, batch_y)
train_loss + = loss.data[ 0 ]
pred = torch. max (out, 1 )[ 1 ]
train_correct = (pred = = batch_y). sum ()
train_acc + = train_correct.data[ 0 ]
optimizer.zero_grad()
loss.backward()
optimizer.step()
print ( 'Train Loss: {:.6f}, Acc: {:.6f}' . format (train_loss / ( len (
train_data)), train_acc / ( len (train_data))))
# evaluation--------------------------------
model. eval ()
eval_loss = 0.
eval_acc = 0.
for batch_x, batch_y in test_loader:
batch_x, batch_y = Variable(batch_x, volatile = True ), Variable(batch_y, volatile = True )
out = model(batch_x)
loss = loss_func(out, batch_y)
eval_loss + = loss.data[ 0 ]
pred = torch. max (out, 1 )[ 1 ]
num_correct = (pred = = batch_y). sum ()
eval_acc + = num_correct.data[ 0 ]
print ( 'Test Loss: {:.6f}, Acc: {:.6f}' . format (eval_loss / ( len (
test_data)), eval_acc / ( len (test_data))))
|
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持服务器之家。
原文链接:http://www.cnblogs.com/denny402/p/7506523.html