我就废话不多说了,直接上代码吧!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
|
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed( 1 )
np.random.seed( 1 )
BATCH_SIZE = 64
LR_G = 0.0001
LR_D = 0.0001
N_IDEAS = 5
ART_COMPONENTS = 15
PAINT_POINTS = np.vstack([np.linspace( - 1 , 1 ,ART_COMPONENTS) for _ in range (BATCH_SIZE)])
def artist_works():
a = np.random.uniform( 1 , 2 ,size = BATCH_SIZE)[:,np.newaxis]
paintings = a * np.power(PAINT_POINTS, 2 ) + (a - 1 )
paintings = torch.from_numpy(paintings). float ()
return Variable(paintings)
G = nn.Sequential(
nn.Linear(N_IDEAS, 128 ),
nn.ReLU(),
nn.Linear( 128 ,ART_COMPONENTS),
)
D = nn.Sequential(
nn.Linear(ART_COMPONENTS, 128 ),
nn.ReLU(),
nn.Linear( 128 , 1 ),
nn.Sigmoid(),
)
opt_D = torch.optim.Adam(D.parameters(),lr = LR_D)
opt_G = torch.optim.Adam(G.parameters(),lr = LR_G)
plt.ion()
for step in range ( 10000 ):
artist_paintings = artist_works()
G_ideas = Variable(torch.randn(BATCH_SIZE,N_IDEAS))
G_paintings = G(G_ideas)
prob_artist0 = D(artist_paintings)
prob_artist1 = D(G_paintings)
D_loss = - torch.mean(torch.log(prob_artist0) + torch.log( 1 - prob_artist1))
G_loss = torch.mean(torch.log( 1 - prob_artist1))
opt_D.zero_grad()
D_loss.backward(retain_variables = True )
opt_D.step()
opt_G.zero_grad()
G_loss.backward()
opt_G.step()
if step % 50 = = 0 :
plt.cla()
plt.plot(PAINT_POINTS[ 0 ],G_paintings.data.numpy()[ 0 ],c = '#4ad631' ,lw = 3 ,label = 'Generated painting' ,)
plt.plot(PAINT_POINTS[ 0 ], 2 * np.power(PAINT_POINTS[ 0 ], 2 ) + 1 ,c = '#74BCFF' ,lw = 3 ,label = 'upper bound' ,)
plt.plot(PAINT_POINTS[ 0 ], 1 * np.power(PAINT_POINTS[ 0 ], 2 ) + 0 ,c = '#FF9359' ,lw = 3 ,label = 'lower bound' ,)
plt.text( - . 5 , 2.3 , 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict = { 'size' : 15 })
plt.text( - . 5 , 2 , 'D score= %.2f (-1.38 for G to converge)' % - D_loss.data.numpy(), fontdict = { 'size' : 15 })
plt.ylim(( 0 , 3 ))
plt.legend(loc = 'upper right' , fontsize = 12 )
plt.draw()
plt.pause( 0.01 )
plt.ioff()
plt.show()
|
以上这篇pytorch GAN生成对抗网络实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/pureszgd/article/details/75095332