GAN原理手写数据集生成


GAN原理介绍

  • GAN 来源于博弈论中的零和博弈,博弈双方,分别为生成模型与判别模型。
  • 生成模型G捕捉样本数据的分布,用服从某一分布例如正太,高斯分布的噪声z来生成一个类似真实训练数据的样本,追求的效果是越像真实越好。
  • 判别模型是一个二分类器,判别样本来自于训练数据还是真实数据的概率。如果来自于真实样本输出大概率,如果来自于训练数据,输出小概率。

实例demo

  • 以造小狗的假图片为例。首先生成小狗图片的模型,称之为generator,还有一个判断小狗图片是否是真假的判别模型 discrimator。
  • 首先输入一个的噪声,然后送入生成器,生成器的生成假图
  • 把真图与假图。进行拼接,然后打上标签,真图标签是1,假图标签是0,送入鉴别器,鉴别器输出属于真实样本与训练样本的概率。

实际GAN(手写数据集为例)

数据预处理

导入函数库

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt

划分数据集

(train_image,train_labels),_=keras.datasets.mnist.load_data()

数据类型转换于归一化[-1,1]

train_images=2*tf.cast(train_image,tf.float32)/255.-1

expand_dims 设置通道,-1 加一维

train_images=2*tf.cast(train_image,tf.float32)/255.-1
# expand_dims 设置通道,-1 加一维
train_images=tf.expand_dims(train_images,-1)
train_images.shape

常用参数设置与数据集生成

Batch_Size=256
# 每回使用256
Buffer_Size=60000 #乱序范围
# 构建demo使用的数据集
dataset=tf.data.Dataset.from_tensor_slices(train_images).shuffle(Buffer_Size).batch(Batch_Size)

GAN模型的生成器与鉴别器构建

def generator_model():
  # 第一层
  model=tf.keras.Sequential()
  # 100->256
  # 第一层
  model.add(layers.Dense(256,input_shape=(100,),use_bias=False))
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())
  # 第二层
  #256->512
  model.add(layers.Dense(512,use_bias=False))
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())
  # 第三层
  #512->28*28
  model.add(layers.Dense(28*28,use_bias=False,activation="tanh"))
  model.add(layers.BatchNormalization())
  model.add(layers.Reshape([28,28,1]))
  
  return model

定义判别器

def discriminator_model():
  model=tf.keras.Sequential()
  # 第一层
  model.add(layers.Flatten())
  model.add(layers.Dense(512,use_bias=False))
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())
  # 第二层
  model.add(layers.Dense(512, use_bias=False))
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())
  model.add(layers.Dense(1))
  # 输出一个值
  return model

设置损失函数

定义优化器

generator_opt=tf.keras.optimizers.Adam(0.0001)
discriminator_opt=tf.keras.optimizers.Adam(0.0001)
cross_entropy=keras.losses.BinaryCrossentropy(from_logits=True)

计算判别器损失

def discriminator_loss(real_out,fake_out):
  real_loss=cross_entropy(tf.ones_like(real_out),real_out)
  fake_loss=cross_entropy(tf.zeros_like(fake_out),fake_out)
  return real_loss+fake_loss

计算生成器损失

def generator_loss(fake_out):
    fake_loss = cross_entropy(tf.ones_like(fake_out), fake_out)
    return fake_loss

定义训练step

Epochs=100
input_dim=100
num_exp_to_generate=16
# 生成16*100
seed=tf.random.normal([num_exp_to_generate,input_dim])
# 定义训练步骤
generator=generator_model()
discriminator=discriminator_model()
def train_step(images):
  noise=tf.random.normal([Batch_Size,input_dim])
  with tf.GradientTape() as gen_tape,tf.GradientTape() as dis_tape:
    real_out=discriminator(images)
    gen_img=generator(noise)
    fake_out=discriminator(gen_img)
    dis_loss=discriminator_loss(real_out,fake_out)
    gen_loss=generator_loss(fake_out) 
    # 梯度下降参数计算
    gen_gard=gen_tape.gradient(gen_loss,generator.trainable_variables)
    dis_gard=dis_tape.gradient(dis_loss,discriminator.trainable_variables)
    # 进行参数更新,并反传
    discriminator_opt.apply_gradients(zip(dis_gard,discriminator.trainable_variables))
    generator_opt.apply_gradients(zip(gen_gard, generator.trainable_variables))
# 绘制训练后的图
def genrate_plot_image(gen_model,test_noise):
    pre_images=gen_model(test_noise,training=False)
    fig=plt.figure(figsize=(8,6))
    for i in range(pre_images.shape[0]):
        plt.subplot(4,4,i+1)
        plt.imshow((pre_images[i,:,:,0]+1)/2*255.)
        plt.axis('off')
    plt.show()

epoch设置

def train(dataset,epochs):
for epoch in range(epochs):
print(epoch)
for image_batch in dataset:
train_step(image_batch)
print(epoch)
if epoch%10==0:
print(epoch)
genrate_plot_image(generator,seed)

main函数

if __name__ == '__main__':
    train(dataset,500)
0
0
0

output_34_1

1
1
2
2
3
3
4
4
5
5
6
6
7
7
8
8
9
9
10
10
10

output_34_3

11
11
12
12
13
13
14
14
15
15
16
16
17
17
18
18
19
19
20
20
20

output_34_5

21
21
22
22
23
23
24
24
25
25
26
26
27
27
28
28
29
29
30
30
30

output_34_7

31
31
32
32
33
33
34
34
35
35
36
36
37
37
38
38
39
39
40
40
40

output_34_9

41
41
42
42
43
43
44
44
45
45
46
46
47
47
48
48
49
49
50
50
50

output_34_11

51
51
52
52
53
53
54
54
55
55
56
56
57
57
58
58
59
59
60
60
60

output_34_13

61
61
62
62
63
63
64
64
65
65
66
66
67
67
68
68
69
69
70
70
70

output_34_15

71
71
72
72
73
73
74
74
75
75
76
76
77
77
78
78
79
79
80
80
80

output_34_17

81
81
82
82
83
83
84
84
85
85
86
86
87
87
88
88
89
89
90
90
90

output_34_19

91
91
92
92
93
93
94
94
95
95
96
96
97
97
98
98
99
99
100
100
100

output_34_21

101
101
102
102
103
103
104
104
105
105
106
106
107
107
108
108
109
109
110
110
110

output_34_23

111
111
112
112
113
113
114
114
115
115
116
116
117
117
118
118
119
119
120
120
120

output_34_25

121
121
122
122
123
123
124
124
125
125
126
126
127
127
128
128
129
129
130
130
130

output_34_27

131
131
132
132
133
133
134
134
135
135
136
136
137
137
138
138
139
139
140
140
140

output_34_29

141
141
142
142
143
143
144
144
145
145
146
146
147
147
148
148
149
149
150
150
150

output_34_31

151
151
152
152
153
153
154
154
155
155
156
156
157
157
158
158
159
159
160
160
160

output_34_33

161
161
162
162
163
163
164
164
165
165
166
166
167
167
168
168
169
169
170
170
170

output_34_35

171
171
172
172
173
173
174
174
175
175
176
176
177
177
178
178
179
179
180
180
180

output_34_37

181
181
182
182
183
183
184
184
185
185
186
186
187
187
188
188
189
189
190
190
190

output_34_39

191
191
192
192
193
193
194
194
195
195
196
196
197
197
198
198
199
199
200
200
200

output_34_41

201
201
202
202
203
203
204
204
205
205
206
206
207
207
208
208
209
209
210
210
210

output_34_43

211
211
212
212
213
213
214
214
215
215
216
216
217
217
218
218
219
219
220
220
220

output_34_45

221
221
222
222
223
223
224
224
225
225
226
226
227
227
228
228
229
229
230
230
230

output_34_47

231
231
232
232
233
233
234
234
235
235
236
236
237
237
238
238
239
239
240
240
240

output_34_49

241
241
242
242
243
243
244
244
245
245
246
246
247
247
248
248
249
249
250
250
250

output_34_51

251
251
252
252
253
253
254
254
255
255
256
256
257
257
258
258
259
259
260
260
260

output_34_53

261
261
262
262
263
263
264
264
265
265
266
266
267
267
268
268
269
269
270
270
270

output_34_55

271
271
272
272
273
273
274
274
275
275
276
276
277
277
278
278
279
279
280
280
280

output_34_57

281
281
282
282
283
283
284
284
285
285
286
286
287
287
288
288
289
289
290
290
290

output_34_59

291
291
292
292
293
293
294
294
295
295
296
296
297
297
298
298
299
299
300
300
300

output_34_61

301
301
302
302
303
303
304
304
305
305
306
306
307
307
308
308
309
309
310
310
310

output_34_63

311
311
312
312
313
313
314
314
315
315
316
316
317
317
318
318
319
319
320
320
320

output_34_65

321
321
322
322
323
323
324
324
325
325
326
326
327
327
328
328
329
329
330
330
330

output_34_67

331
331
332
332
333
333
334
334
335
335
336
336
337
337
338
338
339
339
340
340
340

output_34_69

341
341
342
342
343
343
344
344
345
345
346
346
347
347
348
348
349
349
350
350
350

output_34_71

351
351
352
352
353
353
354
354
355
355
356
356
357
357
358
358
359
359
360
360
360

output_34_73

361
361
362
362
363
363
364
364
365
365
366
366
367
367
368
368
369
369
370
370
370

output_34_75

371
371
372
372
373
373
374
374
375
375
376
376
377
377
378
378
379
379
380
380
380

output_34_77

381
381
382
382
383
383
384
384
385
385
386
386
387
387
388
388
389
389
390
390
390

output_34_79

391
391
392
392
393
393
394
394
395
395
396
396
397
397
398
398
399
399
400
400
400

output_34_81

401
401
402
402
403
403
404
404
405
405
406
406
407
407
408
408
409
409
410
410
410

output_34_83

411
411
412
412
413
413
414
414
415
415
416
416
417
417
418
418
419
419
420
420
420

output_34_85

421
421
422
422
423
423
424
424
425
425
426
426
427
427
428
428
429
429
430
430
430

output_34_87

431
431
432
432
433
433
434
434
435
435
436
436
437
437
438
438
439
439
440
440
440

output_34_89

441
441
442
442
443
443
444
444
445
445
446
446
447
447
448
448
449
449
450
450
450

output_34_91

451
451
452
452
453
453
454
454
455
455
456
456
457
457
458
458
459
459
460
460
460

output_34_93

461
461
462
462
463
463
464
464
465
465
466
466
467
467
468
468
469
469
470
470
470

output_34_95

471
471
472
472
473
473
474
474
475
475
476
476
477
477
478
478
479
479
480
480
480

output_34_97

481
481
482
482
483
483
484
484
485
485
486
486
487
487
488
488
489
489
490
490
490

output_34_99

491
491
492
492
493
493
494
494
495
495
496
496
497
497
498
498
499
499

生成器描述

generator.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 256)               25600     
_________________________________________________________________
batch_normalization (BatchNo (None, 256)               1024      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 256)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 512)               131072    
_________________________________________________________________
batch_normalization_1 (Batch (None, 512)               2048      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 784)               401408    
_________________________________________________________________
batch_normalization_2 (Batch (None, 784)               3136      
_________________________________________________________________
reshape (Reshape)            (None, 28, 28, 1)         0         
=================================================================
Total params: 564,288
Trainable params: 561,184
Non-trainable params: 3,104
_________________________________________________________________


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM