DCGAN入門
- 前言
- DCGAN介紹
- 所需環境
- 代碼解刨
- 訓練集獲取
- 所需引數構造
- 前期準備作業代碼撰寫
- 日志輸出
- 訓練節點保存
- 訓練節點讀取
- 生成優化器
- 權重初始化
- 影像資料集讀取
- 運行額外引數
- 核心代碼
- 生成器G(x)
- 判別器D(x)
- 主函式
- 結果展示
- 學以致用
前言
根據之前的兩片入門級別的GAN文章,相信各位對GAN有一絲絲了解,
知道對抗網路究竟是干什么的就能讀懂這篇文章了=·=
DCGAN介紹
DCGAN的英文全名為:Deep Convolution Generative Adversarial Networks
顧名思義,DCGAN主要由兩部分組成,即:
- 生成模型 G
- 判別模型 D
其作業的基本原理很簡單,以圖片生成任務為例來說明,生成模型的作用是根據網路輸入的隨機噪聲 z ,來生成一張圖片 G(z) ;而判別模型的作用則是判別網路輸入的圖片 x 是否是"真實"的,即 D(x) ,這里的"真實"意味著輸入的圖片不是由生成模型生成,而是真實存在的,
簡單畫個示例圖吧:

在DCGAN的訓練程序中,生成模型的訓練目標是使得生成的圖片可以很好地欺騙判別模型,使得判別模型認為生成模型生成的圖片是"真實"的;而判別模型的訓練目標則是盡量地正確區分生成模型生成的圖片和真實存在的圖片,于是,這種訓練方式就很自然地產生了生成模型和判別模型之間的"博弈",
在理想情況下,我們希望DCGAN訓練好之后,生成模型 G 生成的圖片是可以以假亂真的,即 D(G(z)) = 0.5 ,
具體思路是,生成器是將一個噪點生成一副假圖片,然后將假圖片傳給判別器進行判斷,如果判別器判斷為真,則代碼生成器性能很好,而判別器是從真實圖片中學習模型,對生成的假圖片進行判斷,如果判斷出來為假則代碼判別器性能很好,
所需環境
- Python 3.7
- torch >= 1.0.0
- torchvision
- argparse
- pillow
代碼解刨
訓練集獲取
本文資料集來自kaggle的tagged-anime-illustrations作為訓練使用,
共包含51222個64×64的動漫頭像,
作者已經為你們打包到專案中供你們使用,
所需引數構造
我們會將引數放到一個py檔案中,方便其他代碼參考一些全域引數,
介紹代碼的時候我會講解全域引數的作用,這里我們先忽略引數意義,
# 潛在空間的維度
NUM_LATENT_DIMS = 100
# 批次大小
BATCH_SIZE = 128
# 圖片尺寸
IMAGE_SIZE = (64, 64)
# 圖片規范化資訊
IMAGE_NORM_INFO = {'means': [0.5, 0.5, 0.5], 'stds': [0.5, 0.5, 0.5]}
# 訓練批次的數量
NUM_EPOCHS = 500
# 保存檢查點之間的間隔
SAVE_INTERVAL = 5
# 圖片路徑
ROOTDIR = os.path.join(os.getcwd(), 'images/*')
# 檢查點保存位置
BACKUP_DIR = os.path.join(os.getcwd(), 'checkpoints')
# 日志保存位置
LOGFILEPATH = {'train': os.path.join(BACKUP_DIR, 'train.log'), 'test': os.path.join(BACKUP_DIR, 'test.log')}
# 優化器配置引數
OPTIMIZER_CFG = {'generator': {'type': 'adam', 'adam': {'lr': 1e-4, 'betas': [0.5, 0.999]}},
'discriminator': {'type': 'adam', 'adam': {'lr': 1e-4, 'betas': [0.5, 0.999]}}}
前期準備作業代碼撰寫
由于是個長時間訓練的深度學習,準備作業不能缺少,在這里主要介紹以下幾點方面:
- 日志輸出
- 訓練節點保存
- 訓練節點讀取
- 生成優化器
- 權重能否正常初始化
- 影像資料集由torch讀取
- 運行額外引數填寫
日志輸出
使用的是Python3自帶的 logging 模塊處理日志,
日志格式為:當前時間 + level等級 + message內容
'''log function.'''
class Logger():
def __init__(self, logfilepath, **kwargs):
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[logging.FileHandler(logfilepath),
logging.StreamHandler()])
@staticmethod
def log(level, message):
logging.log(level, message)
@staticmethod
def debug(message):
Logger.log(logging.DEBUG, message)
@staticmethod
def info(message):
Logger.log(logging.INFO, message)
@staticmethod
def warning(message):
Logger.log(logging.WARNING, message)
@staticmethod
def error(message):
Logger.log(logging.ERROR, message)
訓練節點保存
torch.save模塊可以提供模型的保存,
使用這種方法,將會保存模型的引數和結構資訊,
引數一為模型的字典格式特征,引數二為保存的位置路徑,
'''save checkpoints'''
def saveCheckpoints(state_dict, savepath, logger_handle):
logger_handle.info('Saving state_dict in %s...' % savepath)
torch.save(state_dict, savepath)
return True
訓練節點讀取
torch.load模塊可以提供模型的讀取,引數為保存的位置路徑
該讀取為測驗時需要讀取模型,當運行代碼為測驗時,我們必須提供此引數,
'''load checkpoints'''
def loadCheckpoints(checkpointspath, logger_handle):
logger_handle.info('Loading checkpoints from %s...' % checkpointspath)
if torch.cuda.is_available():checkpoints = torch.load(checkpointspath)
else:checkpoints = torch.load(checkpointspath, map_location='cpu')
return checkpoints
生成優化器
torch.optim.Adam()利用系統自帶Adam優化器更新引數,
引數如下:
params (iterable)– 待優化引數的iterable或者是定義了引陣列的dictlr(float, 可選) – 學習率(默認:1e-3),同樣也稱為學習率或步長因子,它控制了權重的更新比率,較大的值在學習率更新前會有更快的初始學習,而較小的值會令訓練收斂到更好的性能,betas(Tuple[float,float], 可選) – 用于計算梯度以及梯度平方的運行平均值的系數(默認:0.9,0.999)eps(float, 可選) – 為了增加數值計算的穩定性而加到分母里的項(默認:1e-8),該引數是非常小的數,其為了防止在實作中除以零,weight_decay(float, 可選) – 權重衰減(L2懲罰)(默認: 0)
'''build optimizer'''
def buildOptimizer(params, cfg):
if cfg['type'] == 'adam':
optimizer = torch.optim.Adam(params, lr=cfg['adam']['lr'], betas=(cfg['adam']['betas'][0], cfg['adam']['betas'][1]))
else:
raise ValueError('Unsupport type %s in buildOptimizer...' % cfg['type'])
return optimizer
權重初始化
首先用self.__class__將實體變數指向類,然后再去呼叫__name__類屬性
兩種情況分別討論:
Conv類中,使w引數服從正態分布,BatchNorm2d類中,首先將w引數服從正態分布,其次將b引數初始化為常數,
torch.nn.init.normal_(tensor, mean=0, std=1)服從正態分布,滿足~N(mean,std)
torch.nn.init.constant_(tensor, val)初始化為常數,初始化整個矩陣為val
'''normal initialization'''
def weightsNormalInit(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
影像資料集讀取
該ImageDataset類繼承torch.utils.data.Dataset,
傳進來的引數一共有三個,
rootdir為影像資料集的位置,需要斷言此引數的最后一個字串為*,資料集不能是單個必須是個整體,imagesize為影像資料集的尺寸大小,可被Resize到相應的尺寸方便處理,img_norm_info為影像資料集的平均值和標準差,方便Normalize進行歸一化處理,
__getitem__魔法為在整個類運行時,出現單方面映射則會呼叫此方法,在此魔法中將讀取每一張圖片給torch傳輸資料做特征處理后回傳給主變數,方便接下來處理,
preprocess函式中用到了以下函式,一一介紹:
torchvision.transforms.Compose()作用是可以將影像預處理操作連起來,torchvision.transforms.Resize()作用是把給定的圖片resize到給定的尺寸,torchvision.transforms.ToTensor()作用是將一個PIL影像轉換為tensor,即,(H × W × C)范圍在[0,255]的PIL影像 轉換為 (CHW)范圍在[0,1]的torch.tensor,torchvision.transforms.Normalize()作用是均值和標準差對影像做歸一化處理,
'''load images'''
class ImageDataset(Dataset):
def __init__(self, rootdir, imagesize, img_norm_info, **kwargs):
assert rootdir.endswith('*')
self.rootdir = rootdir
self.imagesize = imagesize
self.img_norm_info = img_norm_info
self.imagepaths = glob.glob(rootdir)
'''get item'''
def __getitem__(self, index):
image = Image.open(self.imagepaths[index])
return ImageDataset.preprocess(image, self.imagesize, self.img_norm_info)
'''calculate length'''
def __len__(self):
return len(self.imagepaths)
'''preprocess image'''
@staticmethod
def preprocess(image, imagesize, img_norm_info):
means_norm, stds_norm = img_norm_info.get('means'), img_norm_info.get('stds')
transform = torchvision.transforms.Compose([torchvision.transforms.Resize(imagesize),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=means_norm, std=stds_norm)])
return transform(image)
運行額外引數
主要讓代碼知道你運行代碼的需求,究竟是訓練還是測驗,
如果是測驗的話你的檢查點位置又在哪里,
'''parse arguments in command line'''
def parseArgs():
parser = argparse.ArgumentParser(description='use wcgan to generate anime avatar')
parser.add_argument('--mode', dest='mode', help='train or test', default='train', type=str)
parser.add_argument('--checkpointspath', dest='checkpointspath', help='the path of checkpoints', type=str)
args = parser.parse_args()
return args
基礎作業大致已經做完了,接下來就是核心代碼撰寫階段了,
核心代碼
核心代碼分為以下三個階段:
- 生成器G(x)的撰寫
- 判別器D(x)的撰寫
- 主函式main.py的撰寫
生成器G(x)
生成模型 G(x) 由幾個轉置卷積/卷積構成,
nn.Sequential()的作用:一個有序的容器,神經網路模塊將按照在傳入構造器的順序依次被添加到計算圖中執行,同時以神經網路模塊為元素的有序字典也可以作為傳入引數,nn.ConvTranspose2d()的作用:進行反卷積操作,nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1)- 引數
in_channels作用:輸入維度, - 引數
out_channels作用:輸出維度, - 引數
kernel_size作用:卷積核大小, - 引數
stride作用:步長大小, - 引數
padding作用:輸入的每一條邊補充0的層數,高寬都增加2*padding, - 引數
output_padding作用:輸出邊補充0的層數,高寬都增加padding, - 引數
groups作用:從輸入通道到輸出通道的阻塞連接數,
- BatchNormalization的目的是使我們的Batch feature map滿足均值為0,方差為1的分布規律,
nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)- 引數
num_features作用:一般輸入引數為height*width,即為其中特征的數量, - 引數
eps作用:分母中添加的一個值,目的是為了計算的穩定性,避免分母為0, - 引數
momentum作用:一個用于運行程序中均值和方差的一個估計引數, - 引數
affine作用:當設為true時,會給定可以學習的系數矩陣gamma和beta,
ReLU是將所有的負值都設為零,Leaky ReLU是給所有負值賦予一個非零斜率,

最后的激活函式用nn.Tanh()以保證輸出的圖片像素取值范圍為[-1, 1],原因是我們訓練集中的真實圖片在輸入判別模型之前也會先歸一化到[-1, 1],(訓練GAN的話圖片一般都是歸一化到[-1, 1]的)
'''generator'''
class Generator(nn.Module):
def __init__(self, cfg, **kwargs):
super(Generator, self).__init__()
assert cfg.IMAGE_SIZE[0] == cfg.IMAGE_SIZE[1] and cfg.IMAGE_SIZE[0] == 64
self.cfg = cfg
self.conv1 = nn.Sequential(nn.ConvTranspose2d(in_channels=cfg.NUM_LATENT_DIMS, out_channels=64*8, kernel_size=4, stride=1, padding=0, bias=False),
nn.BatchNorm2d(64*8),
nn.LeakyReLU(0.2, inplace=True))
self.conv2 = nn.Sequential(nn.ConvTranspose2d(in_channels=64*8, out_channels=64*4, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64*4),
nn.LeakyReLU(0.2, inplace=True))
self.conv3 = nn.Sequential(nn.ConvTranspose2d(in_channels=64*4, out_channels=64*2, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64*2),
nn.LeakyReLU(0.2, inplace=True))
self.conv4 = nn.Sequential(nn.ConvTranspose2d(in_channels=64*2, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True))
self.conv5 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True))
self.conv6 = nn.Sequential(nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh())
def forward(self, x):
batch_size = x.size(0)
x = x.view(batch_size, -1, 1, 1)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = self.conv6(x)
return x
判別器D(x)
判別器前置代碼與生成器類似,請讀者自行理解,
最后的激活函式用nn.Sigmoid(),以預測每張圖是真實圖片的概率,
'''discriminator'''
class Discriminator(nn.Module):
def __init__(self, cfg, **kwargs):
super(Discriminator, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True))
self.conv2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64*2, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64*2),
nn.LeakyReLU(0.2, inplace=True))
self.conv3 = nn.Sequential(nn.Conv2d(in_channels=64*2, out_channels=64*4, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64*4),
nn.LeakyReLU(0.2, inplace=True))
self.conv4 = nn.Sequential(nn.Conv2d(in_channels=64*4, out_channels=64*8, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64*8),
nn.LeakyReLU(0.2, inplace=True))
self.conv5 = nn.Sequential(nn.Conv2d(in_channels=64*8, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False),
nn.Sigmoid())
def forward(self, x):
batch_size = x.size(0)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
return x.view(batch_size, -1)
主函式
最最最重要的主函式來了,上面的大風大浪都經歷過來了就沒什么可擔心的了,
雖說主函式并不是特別難,但是主函式擁有著撰寫深度學習中所有的基本方法,
為了防止介紹出錯,我將每一行代碼的作用寫在了下方代碼體中
'''main function'''
def main():
# 決議引數
args = parseArgs()
assert args.mode in ['train', 'test']
if args.mode == 'test': assert os.path.isfile(args.checkpointspath)
# 一些必要的準備作業
checkDir(cfg.BACKUP_DIR)
logger_handle = Logger(cfg.LOGFILEPATH.get(args.mode))
start_epoch = 1
end_epoch = cfg.NUM_EPOCHS + 1
use_cuda = torch.cuda.is_available() # 檢測電腦是否支持CUDA
# 定義資料集
dataset = ImageDataset(rootdir=cfg.ROOTDIR, imagesize=cfg.IMAGE_SIZE, img_norm_info=cfg.IMAGE_NORM_INFO)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=cfg.BATCH_SIZE, shuffle=True)
# 定義損失函式
loss_func = nn.BCELoss()
if use_cuda: loss_func = loss_func.cuda()
# 定義模型
net_g = Generator(cfg)
net_d = Discriminator(cfg)
if use_cuda:
net_g = net_g.cuda()
net_d = net_d.cuda()
# 定義優化器
optimizer_g = buildOptimizer(net_g.parameters(), cfg.OPTIMIZER_CFG['generator'])
optimizer_d = buildOptimizer(net_d.parameters(), cfg.OPTIMIZER_CFG['discriminator'])
# 加載檢查點
if args.checkpointspath:
checkpoints = loadCheckpoints(args.checkpointspath, logger_handle)
net_d.load_state_dict(checkpoints['net_d'])
net_g.load_state_dict(checkpoints['net_g'])
optimizer_g.load_state_dict(checkpoints['optimizer_g'])
optimizer_d.load_state_dict(checkpoints['optimizer_d'])
start_epoch = checkpoints['epoch'] + 1
else:
net_d.apply(weightsNormalInit)
net_g.apply(weightsNormalInit)
# 定義浮點張量
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
# 訓練模型
if args.mode == 'train':
for epoch in range(start_epoch, end_epoch):
logger_handle.info('Start epoch %s...' % epoch)
for batch_idx, imgs in enumerate(dataloader):
imgs = imgs.type(FloatTensor)
z = torch.randn(imgs.size(0), cfg.NUM_LATENT_DIMS, 1, 1).type(FloatTensor)
imgs_g = net_g(z)
# 訓練生成器
optimizer_g.zero_grad()
labels = FloatTensor(imgs_g.size(0), 1).fill_(1.0)
loss_g = loss_func(net_d(imgs_g), labels)
loss_g.backward()
optimizer_g.step()
# 訓練判別器
optimizer_d.zero_grad()
labels = FloatTensor(imgs_g.size(0), 1).fill_(1.0)
loss_real = loss_func(net_d(imgs), labels)
labels = FloatTensor(imgs_g.size(0), 1).fill_(0.0)
loss_fake = loss_func(net_d(imgs_g.detach()), labels)
loss_d = loss_real + loss_fake
loss_d.backward()
optimizer_d.step()
# 輸出資訊
logger_handle.info('Epoch %s/%s, Batch %s/%s, Loss_G %f, Loss_D %f' % (epoch, cfg.NUM_EPOCHS, batch_idx+1, len(dataloader), loss_g.item(), loss_d.item()))
# 保存檢查點
if epoch % cfg.SAVE_INTERVAL == 0 or epoch == cfg.NUM_EPOCHS:
state_dict = {
'epoch': epoch,
'net_d': net_d.state_dict(),
'net_g': net_g.state_dict(),
'optimizer_g': optimizer_g.state_dict(),
'optimizer_d': optimizer_d.state_dict()
}
savepath = os.path.join(cfg.BACKUP_DIR, 'epoch_%s.pth' % epoch)
saveCheckpoints(state_dict, savepath, logger_handle)
save_image(imgs_g.data[:25], os.path.join(cfg.BACKUP_DIR, 'images_epoch_%s.png' % epoch), nrow=5, normalize=True)
# 測驗模型
else:
z = torch.randn(cfg.BATCH_SIZE, cfg.NUM_LATENT_DIMS, 1, 1).type(FloatTensor)
net_g.eval()
imgs_g = net_g(z)
save_image(imgs_g.data[:25], 'images.png', nrow=5, normalize=True)
結果展示
下圖為訓練一百批次后生成的影像,看起來還行趴,

學以致用
真慶幸你們能學到最后,也不知道你們掌握了多少,
真的說深度學習零基礎接受對抗網路是有點難,但我感覺我盡力了,
這篇文章就是想帶你們感受一下深度學習的美妙之處,
也希望各位能學業有成,頭發不禿,謝謝各位觀看,
轉載請註明出處,本文鏈接:https://www.uj5u.com/houduan/246564.html
標籤:python
