Tensorflow2.0构造生成器Unet网络


试着用Tensorflow2.0实现Unet网络结构,遇到了一点问题:

Sequential模式不能设置跳跃连接,如下:

def make_generator_model():
model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(filters=64,kernel_size=4,strides=2,padding='same',use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.Conv2D(filters=128,kernel_size=4,strides=2,padding='same',use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.Conv2D(filters=256,kernel_size=4,strides=2,padding='same',use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.Conv2D(filters=512,kernel_size=4,strides=2,padding='same',use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.Conv2D(filters=512,kernel_size=4,strides=2,padding='same',use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.Conv2D(filters=512,kernel_size=4,strides=2,padding='same',use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.Conv2D(filters=512,kernel_size=4,strides=2,padding='same',use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.Conv2D(filters=512,kernel_size=4,strides=2,padding='same',use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())

model.add(tf.keras.layers.Conv2DTranspose(512,kernel_size=4,strides=2,padding='same',use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dropout(0.5))

model.add(tf.keras.layers.Conv2DTranspose(512,kernel_size=4,strides=2,padding='same',use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dropout(0.5))

model.add(tf.keras.layers.Conv2DTranspose(512,kernel_size=4,strides=2,padding='same',use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dropout(0.5))

model.add(tf.keras.layers.Conv2DTranspose(512,kernel_size=4,strides=2,padding='same',use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dropout(0.5))

model.add(tf.keras.layers.Conv2DTranspose(256,kernel_size=4,strides=2,padding='same',use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dropout(0.5))

model.add(tf.keras.layers.Conv2DTranspose(128,kernel_size=4,strides=2,padding='same',use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dropout(0.5))

model.add(tf.keras.layers.Conv2DTranspose(64,kernel_size=4,strides=2,padding='same',use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dropout(0.5))

model.add(tf.keras.layers.Conv2DTranspose(3,kernel_size=4,strides=2,padding='same',use_bias=False))
model.add(tf.keras.layers.BatchNormalization())

return model

使用函数式api,解决了不能跳跃连接的问题,上面的代码改写为:

gen_input = tf.keras.Input(shape=(256,256,3), name='train_img')    #输入
c1 = tf.keras.layers.Conv2D(filters=64,kernel_size=4,strides=2,padding='same',input_shape=[256,256,3])(gen_input)
b1 = batch_norm(c1)
#第一个卷积层,输出尺度[1,128,128,64]
c2 = tf.keras.layers.Conv2D(filters=128,kernel_size=4,strides=2,padding='same',use_bias=False)(lrelu(b1))
b2 = batch_norm(c2)
#第二个卷积层,输出尺度[1,64,64,256]
c3 = tf.keras.layers.Conv2D(filters=256,kernel_size=4,strides=2,padding='same',use_bias=False)(lrelu(b2))
b3 = batch_norm(c3)
#第三个卷积层,输出尺度[1,32,32,256]
c4 = tf.keras.layers.Conv2D(filters=512,kernel_size=4,strides=2,padding='same',use_bias=False)(lrelu(b3))
b4 = batch_norm(c4)
#第四个卷积层,输出尺度[1,16,16,512]
c5 = tf.keras.layers.Conv2D(filters=512,kernel_size=4,strides=2,padding='same',use_bias=False)(lrelu(b4))
b5 = batch_norm(c5)
#第五个卷积层,输出尺度[1,8,8,512]
c6 = tf.keras.layers.Conv2D(filters=512,kernel_size=4,strides=2,padding='same',use_bias=False)(lrelu(b5))
b6 = batch_norm(c6)
#第六个卷积层,输出尺度[1,4,4,512]
c7 = tf.keras.layers.Conv2D(filters=512,kernel_size=4,strides=2,padding='same',use_bias=False)(lrelu(b6))
b7 = batch_norm(c7)
#第七个卷积层,输出尺度[1,2,2,512]
c8 = tf.keras.layers.Conv2D(filters=512,kernel_size=4,strides=2,padding='same',use_bias=False)(lrelu(b7))
b8 = batch_norm(c8)
#第八个卷积层,输出尺度[1,1,1,512]

d1 = tf.keras.layers.Conv2DTranspose(512,kernel_size=4,strides=2,padding='same',use_bias=False)(b8)
d1 = tf.nn.dropout(d1, 0.5)
d1 = tf.concat([batch_norm(d1, name='g_bn_d1'), b7],3)#跳跃连接
#第一个反卷积层,输出尺度[1,2,2,512]
d2 = tf.keras.layers.Conv2DTranspose(512,kernel_size=4,strides=2,padding='same',use_bias=False)(tf.nn.relu(d1))
d2 = tf.nn.dropout(d2, 0.5)
d2 = tf.concat([batch_norm(d2, name='g_bn_d2'), b6],3)#跳跃连接
#第二个反卷积层,输出尺度[1,4,4,512]
d3 = tf.keras.layers.Conv2DTranspose(512,kernel_size=4,strides=2,padding='same',use_bias=False)(tf.nn.relu(d2))
d3 = tf.nn.dropout(d3, 0.5)
d3 = tf.concat([batch_norm(d3, name='g_bn_d3'), b5],3)#跳跃连接
#第三个反卷积层,输出尺度[1,8.8.512]
d4 = tf.keras.layers.Conv2DTranspose(512,kernel_size=4,strides=2,padding='same',use_bias=False)(tf.nn.relu(d3))
d4 = tf.concat([batch_norm(d4, name='g_bn_d4'), b4],3)#跳跃连接
#第四个反卷积层,输出尺度[1,16,16,512]
d5 = tf.keras.layers.Conv2DTranspose(256,kernel_size=4,strides=2,padding='same',use_bias=False)(tf.nn.relu(d4))
d5 = tf.concat([batch_norm(d5, name='g_bn_d5'), b3],3)#跳跃连接
#第五个反卷积层,输出尺度[1,32,32,256]
d6 = tf.keras.layers.Conv2DTranspose(128,kernel_size=4,strides=2,padding='same',use_bias=False)(tf.nn.relu(d5))
d6 = tf.concat([batch_norm(d6, name='g_bn_d6'), b2],3)#跳跃连接
#第六个反卷积层,输出尺度[1,64,64,128]
d7 = tf.keras.layers.Conv2DTranspose(64,kernel_size=4,strides=2,padding='same',use_bias=False)(tf.nn.relu(d6))
d7 = tf.concat([batch_norm(d7, name='g_bn_d7'), b1],3)#跳跃连接
#第七个反卷积层,输出尺度[1,128,128,64]
d8 = tf.keras.layers.Conv2DTranspose(3,kernel_size=4,strides=2,padding='same',use_bias=False)(tf.nn.relu(d7))
gen_out = tf.nn.tanh(d8)     #输出
#第八个反卷积层,输出尺度[1.256,256,3]
gen_model = tf.keras.Model(inputs=gen_input, outputs=gen_out, name='gen_model')

#batchnorm函数可以直接调用

def batch_norm(inp,name="batch_norm"):
batch_norm_fi = tf.keras.layers.BatchNormalization()(inp, training=True)
return batch_norm_fi

#定义lrelu激活函数

def lrelu(x, leak=0.2, name = "lrelu"):
return tf.maximum(x, leak*x)

 

引用自:https://blog.csdn.net/jiongnima/article/details/80209239


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM