卷積生成對抗網絡(DCGAN)—生成手寫数字

5{icon} {views}

深度卷積生成對抗網絡(DCGAN)

—- 生成 MNIST 手寫圖片

1、基本原理

生成對抗網絡(GAN)由2個重要的部分構成:

  • 生成器(Generator):通過機器生成數據(大部分情況下是圖像),目的是“騙過”判別器
  • 判別器(Discriminator):判斷這張圖像是真實的還是機器生成的,目的是找出生成器做的“假數據”

訓練過程

  • 1、固定判別器,讓生成器不斷生成假數據,給判別器判別,開始生成器很弱,但是隨着不斷的訓練,生成器不斷提升,最終騙過判別器。此時判別器判斷假數據的概率為50%
  • 2、固定生成器,訓練判別器。判別器經過訓練,提高鑒別能力,最終能準確判斷雖有的假圖片
  • 3、循環上兩個階段,最終生成器和判別器都越來越強。然後就可以使用生成器來生成我們想要的圖片了

2、相關數學原理

  • 判別器在這裡是一種分類器,用於區分樣本的真偽,因此我們常常使用交叉熵(cross entropy)來進行判別分佈的相似性

\[H(p, q) := -\sum_i p_i \log q_i \]

公式中 \(p_i\)\(q_i\) 為真實的樣本分佈和生成器的生成分佈

假定 \(y_1\) 為正確樣本分佈,那麼對應的( \(1-y_1\) )就是生成樣本的分佈。\(D\) 表示判別器,則 \(D(x_1)\) 表示判別樣本為正確的概率, \(1-D(x_1)\) 則對應着判別為錯誤樣本的概率。則有如下式子(這裏僅僅是對當前情況下的交叉熵損失的具體化)。

\[H((x_i, y_i)_{i=1}^N, D) = – \sum_{i=1}^N y_i\log D(x_i) – \sum_{i=1}^N(1-y_i)\log (1 – D(x_i)) \]

對於GAN中的樣本點 \(x_i\) ,對應於兩個出處,要麼來自於真實樣本,要麼來自於生成器生成的樣本 $\tilde{x} – G(z) $ ( 這裏的 \(z\) 是服從於投到生成器中噪聲的分佈)。

對於來自於真實的樣本,我們要判別為正確的分佈 \(y_i\) 。來自於生成的樣本我們要判別其為錯誤分佈( \(1-y_i\) )。將上面式子進一步使用概率分佈的期望形式寫出(為了表達無限的樣本情況,相當於無限樣本求和情況),並且讓 \(y_i\) 為 1/2 且使用 \(G(z)\) 表示生成樣本可以得到如下公式:

\[H \left( (x_i, y_i)_{i=1}^\infty, D \right) = -\frac{1}{2}E_{x-p_{data}}\left[ \log D(x) \right] – \frac{1}{2}E_z\left[ \log (1-D(G(z))) \right] \\\ GAN損失函數期望形式 \]

對於論文中的公式

\[min_G max_D V(D, G) = E_{x-p_{data}(x)}\left[ \log D(x) \right] + E_{z-p_z(z)}\left[ \log (1-D(G(z))) \right] \\\ GAN損失函數的 min max表達 \]

其實是與上面公式一樣的,下面做解釋

  • 這裏的 \(V(D, G)\) 相當於表示真實樣本和生成樣本的差異程度。
  • \(max_D V(D, G)\) 的意思是固定生成器 \(G\), 盡可能地讓判別器能夠最大化地判別出樣本來自於真實數據還是生成的數據。
  • 再將後面的 $L = max_D V(D, G) $ 看成整體,對於 \(min_G L\)這裡是在固定判別器\(D\)的條件下得到生成器 \(G\),這個 \(G\) 要求能夠最小化真實樣本與生成樣本的差異。
  • 通過上述 \(min\) \(max\) 的博弈過程,理想情況下會收斂於生成分佈擬合於真實分佈。

3、卷積對抗生成網絡

卷積對抗生成網絡(DCGAN)是在GAN的基礎上加入了CNN,主要是改進了網絡結構,在訓練過程中狀態穩定,並且可以有效實現高質量圖片的生成以及相關的生成模型應用。DCGAN的生成器網絡結構如下圖:

DCGAN的改進:

  • 使用步長卷積代替上採樣層,卷積在提取圖像特徵上具有很好的作用,並且使用卷積代替全連接層
  • 生成器G和判別器D中幾乎每一層都使用batchnorm層,將特徵層的輸出歸一化到一起,加速了訓練,提升了訓練的穩定性。
  • 在判別器中使用leakrelu激活函數,而不是RELU,防止梯度稀疏,生成器中仍然採用relu,但是輸出層採用tanh。

4、DCGAN代碼實現

shenduimport numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import optimizers, losses, layers, Sequential, Model
class DCGAN():
    '''
    實現深度對抗神經網絡
    生成 MNIST 手寫数字圖片
    輸入的噪聲為服從正態分佈均值為 0 方差為 1 的分佈, shape:(None, 100)
    生成器(G)輸入 噪聲, 輸出為 (None, 28, 28, 1)的圖片
    分類器(D)輸入為 (None, 28, 28, 1)的圖片,輸出圖片的分類真假
    '''
    def __init__(self):
        self.img_rows = 28 
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        optimizer = optimizers.Adam(0.0002)

        # 構建編譯分類器
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy', 
            optimizer=optimizer,
            metrics=['accuracy'])

        # 構建編譯生成器
        self.generator = self.build_generator()
        self.generator.compile(loss='binary_crossentropy', optimizer=optimizer)

        # 生成器輸入為噪音,生成圖片
        z = layers.Input(shape=(100,))
        img = self.generator(z)

        # 對於整個對抗網絡模型只優化生成器的參數
        self.discriminator.trainable = False

        # 用生成的圖片輸入分類器判斷
        valid = self.discriminator(img)

        # 對於整個對抗網絡 輸入噪音 => 生成圖片 => 決定圖片是否有效
        self.combined = Model(z, valid)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

        
    def build_generator(self):
        '''
        構建生成器
        '''
        noise_shape = (100,)
        
        model = tf.keras.Sequential()
        
        # 添加全連接層
        model.add(layers.Dense(7*7*256, use_bias=False, input_shape=noise_shape))
        # 添加 BatchNormalization 層,對數據進行歸一化
        model.add(layers.BatchNormalization())
        model.add(layers.LeakyReLU())

        model.add(layers.Reshape((7, 7, 256)))
        
        # 添加逆卷積層,卷積核大小為 5X5,數量 128, 步長為 1
        model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
        assert model.output_shape == (None, 7, 7, 128)
        model.add(layers.BatchNormalization())
        model.add(layers.LeakyReLU())
        
        # 添加逆卷積層,卷積核大小為 5X5,數量 64, 步長為 2
        model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
        assert model.output_shape == (None, 14, 14, 64)
        model.add(layers.BatchNormalization())
        model.add(layers.LeakyReLU())
        
        # 添加逆卷積層,卷積核大小為 5X5,數量 1, 步長為 2
        model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
        assert model.output_shape == (None, 28, 28, 1)
        
        model.summary()
        noise = layers.Input(shape=noise_shape)
        img = model(noise)
        
        # 返回 Model 對象,輸入為 噪聲, 輸出為 圖像
        return keras.Model(noise, img)

    
    def build_discriminator(self):
        '''
        構建分類器
        '''
        img_shape = (self.img_rows, self.img_cols, self.channels)
        
        model = tf.keras.Sequential()
        
        model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                         input_shape=img_shape))
        model.add(layers.LeakyReLU())
        # 添加 Dropout 層,減少參數數量
        model.add(layers.Dropout(0.3))

        model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
        model.add(layers.LeakyReLU())
        model.add(layers.Dropout(0.3))
        # 把數據鋪平
        model.add(layers.Flatten())
        model.add(layers.Dense(1))
        
        model.summary()
        
        img = layers.Input(shape=img_shape)
        validity = model(img)
        
        return keras.Model(img, validity)

    
    def train(self, epochs, batch_size=128, save_interval=50):
        '''
        網絡訓練
        '''
        # 加載 數據集
        (X_train, _), (_, _) = keras.datasets.mnist.load_data()

        # 把數據縮放到 [-1, 1]
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        # 添加通道維度
        X_train = np.expand_dims(X_train, axis=3)
        half_batch = int(batch_size / 2)

        for epoch in range(epochs):

            # ---------------------
            #  訓練分類器
            # ---------------------

            # 隨機的選擇一半的 batch 數量圖片
            idx = np.random.randint(0, X_train.shape[0], half_batch)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (half_batch, 100))

            # 生成一半 batch 數量的 圖片
            gen_imgs = self.generator.predict(noise)

            # 分類器損失
            d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)


            # ---------------------
            #  訓練生成器
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, 100))

            # The generator wants the discriminator to label the generated samples
            # as valid (ones)
            # 對於生成器,希望分類器把更多的圖片判為 有效 (用 1 表示)
            valid_y = np.array([1] * batch_size)

            # 訓練生成器
            g_loss = self.combined.train_on_batch(noise, valid_y)

            # 打印訓練進度
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # 每個 save_interval 周期保存一張圖片
            if epoch % save_interval == 0:
                self.save_imgs(epoch)

    def save_imgs(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        gen_imgs = self.generator.predict(noise)

        # 把圖片數據縮放到 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("dcgan/images/mnist_%d.png" % epoch)
        plt.close()

if __name__ == '__main__':
    dcgan = DCGAN()
    dcgan.train(epochs=10000, batch_size=32, save_interval=200)

網絡參數信息

5、訓練結果

下面是循環了 10000 次 epoch 后,從開始每隔 2000 個 epoch 生成器生成的圖片

  • 可以看到,剛開始全部都是噪聲,隨着訓練的進行,圖片逐漸清晰

  • 生成的圖片還是不太清晰,一方面的原因是我訓練的 epoch 周期太少,因為自己電腦性能問題,太耗時間,所以訓練的epoch 周期少,如果有條件后提高訓練周期應該會好很多。另一方面或許因為我構建的網絡還有不合理之,後期還需要改進。

本站聲明:網站內容來源於博客園,如有侵權,請聯繫我們,我們將及時處理

【其他文章推薦】

※超省錢租車方案

※別再煩惱如何寫文案,掌握八大原則!

※回頭車貨運收費標準

※教你寫出一流的銷售文案?

※產品缺大量曝光嗎?你需要的是一流包裝設計!

※廣告預算用在刀口上,台北網頁設計公司幫您達到更多曝光效益

網頁設計最專業,超強功能平台可客製化