DCGAN增強圖片數據集


DCGAN增強圖片數據集

1.Dependencies

2.DCGAN

步驟:

  • 將圖片數集放在/Anime_GAN/DCGAN/faces

  • 進行如下的命令:

  $ cd Anime_GAN/DCGAN/
$ python main.py --help # 查看默認參數信息,根據需求可進行修改

執行完上述命令會產生相應的一張圖片(存儲位 置:/Anime_GAN/DCGAN/saved/img/xx.png)

  • 調用SegmentePictures.py進行圖片的切割

$ cd DCGAN/saved
$ python SegmentePictures.py   
# encoding:utf-8
from PIL import Image
import sys
import math
import argparse

def fill_image(image):
    """
    將圖片填充為正方形
    :param image:
    :return:
    """
    width, height = image.size
    #選取長和寬中較大值作為新圖片的
    new_image_length = width if width > height else height
    #生成新圖片[白底]
    new_image = Image.new(image.mode, (new_image_length, new_image_length), color='white')
    #將之前的圖粘貼在新圖上,居中
    if width > height:#原圖寬大於高,則填充圖片的豎直維度
        #(x,y)二元組表示粘貼上圖相對下圖的起始位置
        new_image.paste(image, (0, int((new_image_length - height) / 2)))
    else:
        new_image.paste(image,(int((new_image_length - width) / 2),0))

    return new_image


def cut_image(image,cut_num):
    """
    切圖
    :param image:
    :return:
    """
    flag_value = int(math.sqrt(cut_num))
    width, height = image.size
    item_width = int(width / flag_value)
    box_list = []
    for i in range(0,flag_value):
        for j in range(0,flag_value):
            box = (j*item_width,i*item_width,(j+1)*item_width,(i+1)*item_width)
            box_list.append(box)
    image_list = [image.crop(box) for box in box_list]

    return image_list


def save_images(image_list):
    """
    保存
    :param image_list:
    :return:
    """
    index = 1
    for image in image_list:
        image.save('./img_add/'+str(index) + '.png', 'PNG')
        index += 1

def main():
    parse = argparse.ArgumentParser()

    parse.add_argument("--lr", type=float, default=0.0001,
                       help="learning rate of generate and discriminator")
    parse.add_argument("--beta1", type=float, default=0.5,
                       help="adam optimizer parameter")
    parse.add_argument("--batch_size", type=int, default=81,
                       help="number of dataset in every train or test iteration")
    parse.add_argument("--epochs", type=int, default=0,
                       help="number of training epochs")
    parse.add_argument("--loaders", type=int, default=4,
                       help="number of parallel data loading processing")
    parse.add_argument("--size_per_dataset", type=int, default=30000,
                       help="number of training data")


    args = parse.parse_args()

    file_path = "./img/"+args.epochs+".png"   # 圖片路徑
    image = Image.open(file_path)
    image = fill_image(image)
    image_list = cut_image(image,batch_size)
    save_images(image_list)

if __name__ == '__main__':
    main()

需要注意的是:下面的命令中batch_size的數一定要一致

$ python main.py --batch_size=xx 

$ python SegmentePictures.py --batch_size=xx 

3.遇到的問題

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 370 and 667 in dimension 2 at /pytorch/aten/src/TH/generic/THTensor.cpp:711

  • 錯誤分析:使用DataLoader加載圖像,這些圖像中的一些具有3個通道(彩色圖像),而其他圖像可能具有單個通道(BW圖像),由於dim1的尺寸不同,因此無法將它們連接成批次。 嘗試將img = img.convert(‘RGB’)添加到數據集中的getitem中。

  • 將圖片的通道進行統一

    from PIL import Image
    import matplotlib.pyplot as plt
    import os
    ​
    ​
    def GetAllFiles(dir):
        files_ = []
        list = os.listdir(dir)
        for i in range(0, len(list)):
            path = os.path.join(dir, list[i])
            if os.path.isdir(path):
                files_.extend(GetAllFiles(path))
            if os.path.isfile(path):
                files_.append(path)
        return files_
    ​
    def ConvertRGB():
        """
        將圖片轉換為RGB格式
        :return:
        """
        files_ = GetAllFiles(file_path)
        for id,item in enumerate(files_):
            img=Image.open(item)
            gray=img.convert('RGB')
            plt.imshow(gray)
            plt.axis('off')
            save_path = "./save_img"+"\\"+str(id)+".jpg"
            plt.savefig(save_path)
            # plt.show()
    if __name__ == "__main__":
        file_path = "your path"
        ConvertRGB()

     

     

參考鏈接:https://github.com/FangYang970206/Anime_GAN/blob/master/README.md


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM