• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

SRCNN分辨率Pytorch实现,代码逐行,附源码

武飞扬头像
Jin、yz
帮助1

目录

1.SRCNN介绍

训练过程

损失函数

个人对SRCNN训练过程的理解

2.实验常见问题和部分解读

1. torch.utils.data.dataloader中DataLoader函数的用法

2.SRCNN图像颜色空间转换原因以及方法?

3. model.parameters()与model.state_dict()的区别

4. .item()函数的用法?

5.最后的测试过程步骤?

6.argparse的使用以及定义

7.unsqueeze与squeeze的使用 

1.unsqueeze用法:在数组原来维度索引i之间增加一个维度

2.Squeeze用法:挤压掉tensor数据中维度特征数为1的维度

8.对Python之if __name__ == ‘__main__‘的理解。

9.自定义数据集步骤?

3.Code部分解读

model.py

dataset.py

prepare.py(制作自定义的训练和验证的h5格式的数据集)

train.py(训练SRCNN模型,得到最优参数)

utils.py(工具包)

test.py

4.实验结果展示



1.SRCNN介绍

超分辨率,就是把低分辨率(LR, Low Resolution)图片放大为高分辨率(HR, High Resolution)的过程。

学新通

图像特征提取层:通过CNN将图像Y 的特征提取出来存到向量中。用一层的CNN以及ReLU去将图像Y 变成一堆堆向量,即feature map。

学新通

非线性映射层:把提取到的特征进一步做非线性映射,加大网络深度,提高网络复杂性。

学新通

重建层:结合了前面得到的补丁来产生最终的高分辨率图像。

学新通

实验步骤

  1. 输入LR图像X,经双三次(bicubic)插值,被放大成目标尺寸(如放大至2倍、3倍、4倍),得到Y ,即低分辨率图像(Low-resolution image)
  2. 通过三层卷积网络拟合非线性映射
  3. 输出HR图像结果F ( Y ) 

注解:

  1. Y:输入图像经过预处理(双三次插值)得到的图像,我们仍将Y 当作是低分辨率图像,但它的size要比输入图像要大。
  2. F ( Y ) :网络最后输出的图像,我们的目标就是通过优化F(Y)和Ground-Truth之间的loss来学会这个函数F (⋅) 。
  3. X:高分辨率图像,即Ground-Truth,它和Y的size是相同的。
  4. 图像被转化为 YCbCr 色彩空间,尽管该网络只使用亮度通道(Y)。然后,网络的输出合并已插值的 CbCr 通道,输出最终彩色图像。我们选择这一步骤是因为我们感兴趣的不是颜色变化(存储在 CbCr 通道中的信息)而只是其亮度(Y 通道);根本原因在于相较于色差,人类视觉对亮度变化更为敏感。

训练过程

1.降低分辨率:

学新通

2.切割图片,补丁之间有重复

学新通

3.训练模型,学习低分辨率 → to→ 高分辨率的映射关系

学新通

学新通

损失函数

 损失哈数:MES(均方误差),选择MSE作为损失函数的一个重要原因是MSE的格式和我们图像失真评价指标PSNR很像

 F(Y;θ):得到的超分辨率图像          X:原高分辨率图像

学新通

激活函数:Relu

PSRN:峰值信噪比,是一种评价图像的客观标准,它具有局限性,一般是用于最大值信号和背景噪音之间的一个工程项目。

MSE与PSNR公式对比:

学新通

        学新通

 这里的MSE是原图像(语音)与处理图像(语音)之间均方误差。

SSIM(另外一种衡量结果的参数)

学新通

 个人对SRCNN训练过程的理解

        1.构建训练集,含有低分辨率图像和高分辨图像,其中图像需要将其从RGB图像转为YCBCR图像,并且对图像进行分割为小块进行存储,高分辨率图像为未下采样前的图像,低分辨率图像为下采样,上采样后的图像。

      2.构建SRCNN模型,即三层卷积模型,设置MES为损失函数,因为MES与评价图像客观指标PSNR计算相似,即最大化PSNR。设置其余常见的神经网络参数(学习率,Batch_size,num-epochs等)。

      3.训练模型SRCNN,即学习低分辨率图像到高分辨率图像的映射关系。根据不同参数的不同PSRN值,保留最大PSNR值对应的模型参数。

2.实验常见问题和部分解读

1. torch.utils.data.dataloader中DataLoader函数的用法

通过查阅资料,翻阅代码实例得到DataLoader()函数参数意义如下:

 1.dataset (Dataset) :决定数据从哪读取或者从何读取;

 2. batch_size (python:int, optional) : 每次处理的数据集大小(默认为1)

 3. shuffle (bool, optional) :每一个 epoch是否为乱序 (default: False);

 4. num_workers (python:int, optional) : 多少个进程读取数据(默认为0);

 5. pin_memory(bool, optional) : 如果为True会将数据放置到GPU上去(默认为false)

 6. drop_last (bool, optional) :当样本数不能被batchsize整除时,最后一批数据是否舍弃(default: False)

Eg:shuffle(bool,optional)表示传入的参数类型为bool类型,并且该参数shuffle是可选参数。

2.SRCNN图像颜色空间转换原因以及方法?

       选择YCbCr的原因:因为我们感兴趣的不是颜色变化(存储在 CbCr 通道中的信息)而只是其亮度(Y 通道);根本原因在于相较于色差,人类视觉对亮度变化更为敏感。

Y only和YCbCr区别:

       ①Y only:基线方法,是一个单通道网络(c=1),只在亮度上进行了训练。对Cb、Cr通道采用双三次插值进行了扩展。②YCbCr:在YCbCr空间的三个通道上进行训练

       代码中三个转换函数:

       1. convert_rgb_to_y(img)

       2. convert_rgb_to_ycbcr(img)

       3. convert_ycbcr_to_rgb(img)

YCBCR:Y表示颜色的明亮度和浓度,也可叫灰度阶。(通过RGB转换YCBCR提取Y分量也可以得到灰度图像)

Cb:表示颜色的蓝色浓度偏移量即RGB输入信号蓝色部分与RGB信号亮度值之间的差异。

Cr:表示颜色的红色浓度偏移量即RGB输入信号红色部分与RGB信号亮度值之间的差异。

转换公式:

1、RGB转YCBCR

        Y=0.257*R 0.564*G 0.098*B 16

        Cb=-0.148*R-0.291*G 0.439*B 12

        Cr=0.439*R-0.368*G-0.071*B 128

 2、YCBCR转RGB

        R=1.164*(Y-16) 1.596*(Cr-128)

        G=1.164*(Y-16)-0.392*(Cb-128)-0.813*(Cr-128)

        B=1.164*(Y-16) 2.017*(Cb-128)

3. model.parameters()与model.state_dict()的区别

    区别:model.parameters()方法返回的是一个生成器generator,每一个元素是从开头到结尾的参数,parameters没有对应的key名称,是一个由纯参数组成的generator,而state_dict是一个字典,包含了一个key。

4. .item()函数的用法?

    t.item()将Tensor变量转换为python标量(int float等),其中t是一个Tensor变量,只能是标量,转换后dtype与Tensor的dtype一致。

5.最后的测试过程步骤?

    1.设置参数(训练好的权重,图片,放大倍数)

    2.创建SRCNN模型,给模型赋值最优参数

    3.对图像进行插值得到低分辨率图像

    4.对Lr低分辨率图像的y颜色空间进行训练

    5.计算PSNR值并输出

    6.将转换为图像并进行输出

6.argparse的使用以及定义

    argparse 模块是 Python 内置的用于命令项选项与参数解析的模块,argparse 模块可以让人轻松编写用户友好的命令行接口,能够帮助程序员为模型定义参数。

    定义步骤

  1. 导入argparse包 ——import argparse
  2. 创建一个命令行解析器对象 ——创建 ArgumentParser() 对象
  3. 给解析器添加命令行参数 ——调用add_argument() 方法添加参数
  4. 解析命令行的参数 ——使用 parse_args() 解析添加的参数

7.unsqueeze与squeeze的使用 

1.unsqueeze用法:在数组原来维度索引i之间增加一个维度

  1.  
    x = t.Tensor([[3, 4], [2, 7], [6, 9]]) # 3*2
  2.  
    y1 = x.unsqueeze(0) # 1*3*2
  3.  
    print(y1.size())
  4.  
    y2 = x.unsqueeze(1) # 3*1*2
  5.  
    print(y2.size())
  6.  
    y3 = x.unsqueeze(2) # 3*2*1
  7.  
    print(y3.size())

2.Squeeze用法:挤压掉tensor数据中维度特征数为1的维度

  1.  
    x = t.ones(1,1,2,3,1)
  2.  
    y1 = x.squeeze(0) # 1*2*3*1
  3.  
    print(y1.size())
  4.  
    y2 = x.squeeze(1) # 1*2*3*1
  5.  
    print(y2.size())
  6.  
    y3 = x.squeeze() # 2*3
  7.  
    print(y3.size())

8.对Python之if __name__ == ‘__main__‘的理解。

       该代码片段只在运行脚本时执行,在import到其他脚本中不会执行,把文件当做脚本直接执行的时候这个时候__name__的值是:main,而被其它文件引用的时候就是文件本身的名字。

9.自定义数据集步骤?

      训练数据集可手动生成,设放大倍数为scale,考虑到原始数据未必会被scale整除,所以要重新规划一下图像尺寸,通过双三次插值设置图像大小,然后将其保存为h5文件进行保存,训练数据集的生成分为三步:

  1. 读取图像文件夹所在目录
  2. 将所有图像转为RGB图像
  3. 将原始图像通过双三次插值重设尺寸,使之可被scale整除,作为高分辨图像数据HR
  4. 将HR通过双三次插值压缩scale倍,为低分辨图像的原始数据
  5. 将低分辨图像通过双三次插值放大scale倍,与HR图像维度相等,作为低分辨图像数据LR
  6. 将低分辨率图像,高分辨率图像转为YCBCR图像,对y通道进行训练。
  7. 提取高分辨率,低分辨率图像补丁,用来训练低分辨率图像刀高分辨率图像之间的映射关系。

最后,可通过h5py将训练数据分块并打包,同理可以按照上述操作生成测试集文件。

3.Code部分解读

学新通

model.py

  1.  
    from torch import nn
  2.  
     
  3.  
    class SRCNN(nn.Module): #搭建SRCNN 3层卷积模型,Conve2d(输入层数,输出层数,卷积核大小,步长,填充层)
  4.  
    def __init__(self, num_channels=1):
  5.  
    super(SRCNN, self).__init__()
  6.  
    self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
  7.  
    self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
  8.  
    self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
  9.  
    self.relu = nn.ReLU(inplace=True)
  10.  
     
  11.  
    def forward(self, x):
  12.  
    x = self.relu(self.conv1(x))
  13.  
    x = self.relu(self.conv2(x))
  14.  
    x = self.conv3(x)
  15.  
    return x
学新通

dataset.py

h5py文件格式

学新通

  1.  
    import h5py # 一个h5py文件是 “dataset” 和 “group” 二合一的容器。
  2.  
    import numpy as np
  3.  
    from torch.utils.data import Dataset
  4.  
     
  5.  
    '''为这些数据创建一个读取类,以便torch中的DataLoader调用,而DataLoader中的内容则是Dataset,
  6.  
    所以新建的读取类需要继承Dataset,并实现其__getitem__和__len__这两个成员方法。
  7.  
    '''
  8.  
     
  9.  
    class TrainDataset(Dataset): # 构建训练数据集,通过np.expand_dims将h5文件中的lr(低分辨率图像)和hr(高分辨率图像)组合为训练集
  10.  
    def __init__(self, h5_file):
  11.  
    super(TrainDataset, self).__init__()
  12.  
    self.h5_file = h5_file
  13.  
     
  14.  
    def __getitem__(self, idx): #通过np.expand_dims方法得到组合的新数据
  15.  
    with h5py.File(self.h5_file, 'r') as f:
  16.  
    return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)
  17.  
     
  18.  
    def __len__(self): #得到数据大小
  19.  
    with h5py.File(self.h5_file, 'r') as f:
  20.  
    return len(f['lr'])
  21.  
     
  22.  
    # 与TrainDataset类似
  23.  
    class EvalDataset(Dataset): # 构建测试数据集,通过np.expand_dims将h5文件中的lr(低分辨率图像)和hr(高分辨率图像)组合为验证集
  24.  
    def __init__(self, h5_file):
  25.  
    super(EvalDataset, self).__init__()
  26.  
    self.h5_file = h5_file
  27.  
     
  28.  
    def __getitem__(self, idx):
  29.  
    with h5py.File(self.h5_file, 'r') as f:
  30.  
    return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0)
  31.  
     
  32.  
    def __len__(self):
  33.  
    with h5py.File(self.h5_file, 'r') as f:
  34.  
    return len(f['lr'])
学新通

prepare.py(制作自定义的训练和验证的h5格式的数据集)

  1.  
    import argparse
  2.  
    import glob
  3.  
    import h5py
  4.  
    import numpy as np
  5.  
    import PIL.Image as pil_image
  6.  
    from utils import convert_rgb_to_y
  7.  
     
  8.  
    '''
  9.  
    训练数据集可手动生成,设放大倍数为scale,考虑到原始数据未必会被scale整除,所以要重新规划一下图像尺寸,所以训练数据集的生成分为三步:
  10.  
    1.将原始图像通过双三次插值重设尺寸,使之可被scale整除,作为高分辨图像数据HR
  11.  
    2.将HR通过双三次插值压缩scale倍,为低分辨图像的原始数据
  12.  
    3.将低分辨图像通过双三次插值放大scale倍,与HR图像维度相等,作为低分辨图像数据LR
  13.  
    最后,可通过h5py将训练数据分块并打包
  14.  
    '''
  15.  
    # 生成训练集
  16.  
    def train(args):
  17.  
     
  18.  
    """
  19.  
    def是python的关键字,用来定义函数。这里通过def定义名为train的函数,函数的参数为args,args这个参数通过外部命令行传入output
  20.  
    的路径,通过h5py.File()方法的w模式--创建文件自己自写,已经存在的文件会被覆盖,文件的路径是通过args.output_path来传入
  21.  
    """
  22.  
    h5_file = h5py.File(args.output_path, 'w')
  23.  
    # #用于存储低分辨率和高分辨率的patch
  24.  
    lr_patches = []
  25.  
    hr_patches = []
  26.  
     
  27.  
    for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))):
  28.  
    '''
  29.  
    这部分代码的目的就是搜索指定文件夹下的文件并排序,for这一句包含了几个知识点:
  30.  
    1.{}.format():-->格式化输出函数,从args.images_dir路径中格式化输出路径
  31.  
    2.glob.glob():-->返回所有匹配的文件路径列表,将1得到的路径中的所有文件返回
  32.  
    3.sorted():-->排序,将2得到的所有文件按照某种顺序返回,,默认是升序
  33.  
    4.for x in *: -->循换输出
  34.  
    '''
  35.  
    #将照片转换为RGB通道
  36.  
    hr = pil_image.open(image_path).convert('RGB')
  37.  
    '''
  38.  
    1. *.open(): 是PIL图像库的函数,用来从image_path中加载图像
  39.  
    2. *.convert(): 是PIL图像库的函数, 用来转换图像的模式
  40.  
    '''
  41.  
    #取放大倍数的倍数, width, height为可被scale整除的训练数据尺寸
  42.  
    hr_width = (hr.width // args.scale) * args.scale
  43.  
    hr_height = (hr.height // args.scale) * args.scale
  44.  
    #图像大小调整,得到高分辨率图像Hr
  45.  
    hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
  46.  
    #低分辨率图像缩小
  47.  
    lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
  48.  
    #低分辨率图像放大,得到低分辨率图像Lr
  49.  
    lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
  50.  
    #转换为浮点并取ycrcb中的y通道
  51.  
    hr = np.array(hr).astype(np.float32)
  52.  
    lr = np.array(lr).astype(np.float32)
  53.  
    hr = convert_rgb_to_y(hr)
  54.  
    lr = convert_rgb_to_y(lr)
  55.  
    '''
  56.  
    np.array():将列表list或元组tuple转换为ndarray数组
  57.  
    astype():转换数组的数据类型
  58.  
    convert_rgb_to_y():将图像从RGB格式转换为Y通道格式的图片
  59.  
    假设原始输入图像为(321,481,3)-->依次为高,宽,通道数
  60.  
    1.先把图像转为可放缩的scale大小的图片,之后hr的图像尺寸为(320,480,3)
  61.  
    2.对hr图像进行双三次上采样放大操作
  62.  
    3.将hr//scale进行双三次上采样放大操作之后×scale得到lr
  63.  
    4.接着进行通道数转换和类型转换
  64.  
    '''
  65.  
    # 将数据分割
  66.  
    for i in range(0, lr.shape[0] - args.patch_size 1, args.stride):
  67.  
    for j in range(0, lr.shape[1] - args.patch_size 1, args.stride):
  68.  
    '''
  69.  
    图像的shape是宽度、高度和通道数,shape[0]是指图像的高度=320;shape[1]是图像的宽度=480; shape[2]是指图像的通道数
  70.  
    '''
  71.  
    lr_patches.append(lr[i:i args.patch_size, j:j args.patch_size])
  72.  
    hr_patches.append(hr[i:i args.patch_size, j:j args.patch_size])
  73.  
     
  74.  
    lr_patches = np.array(lr_patches)
  75.  
    hr_patches = np.array(hr_patches)
  76.  
    #创建数据集,把得到的数据转化为数组类型
  77.  
    h5_file.create_dataset('lr', data=lr_patches)
  78.  
    h5_file.create_dataset('hr', data=hr_patches)
  79.  
    h5_file.close()
  80.  
     
  81.  
    #下同,生成测试集
  82.  
    def eval(args):
  83.  
    h5_file = h5py.File(args.output_path, 'w')
  84.  
     
  85.  
    lr_group = h5_file.create_group('lr')
  86.  
    hr_group = h5_file.create_group('hr')
  87.  
     
  88.  
    for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))):
  89.  
    hr = pil_image.open(image_path).convert('RGB')
  90.  
    hr_width = (hr.width // args.scale) * args.scale
  91.  
    hr_height = (hr.height // args.scale) * args.scale
  92.  
    hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
  93.  
    lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
  94.  
    lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
  95.  
    hr = np.array(hr).astype(np.float32)
  96.  
    lr = np.array(lr).astype(np.float32)
  97.  
    hr = convert_rgb_to_y(hr)
  98.  
    lr = convert_rgb_to_y(lr)
  99.  
     
  100.  
    lr_group.create_dataset(str(i), data=lr)
  101.  
    hr_group.create_dataset(str(i), data=hr)
  102.  
     
  103.  
    h5_file.close()
  104.  
     
  105.  
     
  106.  
    if __name__ == '__main__':
  107.  
    parser = argparse.ArgumentParser()
  108.  
    parser.add_argument('--images-dir', type=str, required=True)
  109.  
    parser.add_argument('--output-path', type=str, required=True)
  110.  
    parser.add_argument('--patch-size', type=int, default=32)
  111.  
    parser.add_argument('--stride', type=int, default=14)
  112.  
    parser.add_argument('--scale', type=int, default=4)
  113.  
    parser.add_argument('--eval', action='store_true') #store_flase就是存储一个bool值true,也就是说在该参数在被激活时它会输出store存储的值true。
  114.  
    args = parser.parse_args()
  115.  
     
  116.  
    #决定使用哪个函数来生成h5文件,因为有俩个不同的函数train和eval生成对应的h5文件。
  117.  
    if not args.eval:
  118.  
    train(args)
  119.  
    else:
  120.  
    eval(args)
学新通

train.py(训练SRCNN模型,得到最优参数)

  1.  
    import argparse
  2.  
    import os
  3.  
    import copy
  4.  
     
  5.  
    import numpy as np
  6.  
    from torch import Tensor
  7.  
    import torch
  8.  
    from torch import nn
  9.  
    import torch.optim as optim
  10.  
     
  11.  
    # gpu加速库
  12.  
    import torch.backends.cudnn as cudnn
  13.  
     
  14.  
    from torch.utils.data.dataloader import DataLoader
  15.  
     
  16.  
    # 进度条
  17.  
    from tqdm import tqdm
  18.  
     
  19.  
    from models import SRCNN
  20.  
    from datasets import TrainDataset, EvalDataset
  21.  
    from utils import AverageMeter, calc_psnr
  22.  
     
  23.  
    ##需要修改的参数
  24.  
    # epoch.pth
  25.  
    # losslog
  26.  
    # psnrlog
  27.  
    # best.pth
  28.  
     
  29.  
    '''
  30.  
    python train.py --train-file "path_to_train_file" \
  31.  
    --eval-file "path_to_eval_file" \
  32.  
    --outputs-dir "path_to_outputs_file" \
  33.  
    --scale 3 \
  34.  
    --lr 1e-4 \
  35.  
    --batch-size 16 \
  36.  
    --num-epochs 400 \
  37.  
    --num-workers 0 \
  38.  
    --seed 123
  39.  
    '''
  40.  
    if __name__ == '__main__':
  41.  
     
  42.  
    # 初始参数设定
  43.  
    parser = argparse.ArgumentParser() # argparse是python用于解析命令行参数和选项的标准模块
  44.  
    parser.add_argument('--train-file', type=str, required=True,) # 训练 h5文件目录
  45.  
    parser.add_argument('--eval-file', type=str, required=True) # 测试 h5文件目录
  46.  
    parser.add_argument('--outputs-dir', type=str, required=True) #模型 .pth保存目录
  47.  
    parser.add_argument('--scale', type=int, default=3) # 放大倍数
  48.  
    parser.add_argument('--lr', type=float, default=1e-4) #学习率
  49.  
    parser.add_argument('--batch-size', type=int, default=16) # 一次处理的图片大小
  50.  
    parser.add_argument('--num-workers', type=int, default=0) # 线程数
  51.  
    parser.add_argument('--num-epochs', type=int, default=400) #训练次数
  52.  
    parser.add_argument('--seed', type=int, default=123) # 随机种子
  53.  
    args = parser.parse_args()
  54.  
     
  55.  
    # 输出放入固定文件夹里
  56.  
    args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))
  57.  
    # 没有该文件夹就新建一个文件夹
  58.  
    if not os.path.exists(args.outputs_dir):
  59.  
    os.makedirs(args.outputs_dir)
  60.  
     
  61.  
    # benckmark模式,加速计算,但寻找最优配置,计算的前馈结果会有差异
  62.  
    cudnn.benchmark = True
  63.  
     
  64.  
    # gpu或者cpu模式,取决于当前cpu是否可用
  65.  
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  66.  
     
  67.  
    # 每次程序运行生成的随机数固定
  68.  
    torch.manual_seed(args.seed)
  69.  
     
  70.  
    # 构建SRCNN模型,并且放到device上训练
  71.  
    model = SRCNN().to(device)
  72.  
     
  73.  
    # 恢复训练,从之前结束的那个地方开始
  74.  
    # model.load_state_dict(torch.load('outputs/x3/epoch_173.pth'))
  75.  
     
  76.  
    # 设置损失函数为MSE
  77.  
    criterion = nn.MSELoss()
  78.  
     
  79.  
    # 优化函数Adam,lr代表学习率,
  80.  
    optimizer = optim.Adam([
  81.  
    {'params': model.conv1.parameters()},
  82.  
    {'params': model.conv2.parameters()},
  83.  
    {'params': model.conv3.parameters(), 'lr': args.lr * 0.1}
  84.  
    ], lr=args.lr)
  85.  
     
  86.  
    # 预处理训练集
  87.  
    train_dataset = TrainDataset(args.train_file)
  88.  
    train_dataloader = DataLoader(
  89.  
    # 数据
  90.  
    dataset=train_dataset,
  91.  
    # 分块
  92.  
    batch_size=args.batch_size,
  93.  
    # 数据集数据洗牌,打乱后取batch
  94.  
    shuffle=True,
  95.  
    # 工作进程,像是虚拟存储器中的页表机制
  96.  
    num_workers=args.num_workers,
  97.  
    # 锁页内存,不换出内存,生成的Tensor数据是属于内存中的锁页内存区
  98.  
    pin_memory=True,
  99.  
    # 不取余,丢弃不足batchSize大小的图像
  100.  
    drop_last=True)
  101.  
    # 预处理验证集
  102.  
    eval_dataset = EvalDataset(args.eval_file)
  103.  
    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)
  104.  
     
  105.  
    # 拷贝权重
  106.  
    best_weights = copy.deepcopy(model.state_dict())
  107.  
    best_epoch = 0
  108.  
    best_psnr = 0.0
  109.  
     
  110.  
    # 画图用
  111.  
    lossLog = []
  112.  
    psnrLog = []
  113.  
     
  114.  
    # 恢复训练
  115.  
    # for epoch in range(args.num_epochs):
  116.  
    for epoch in range(1, args.num_epochs 1):
  117.  
    # for epoch in range(174, 400):
  118.  
    # 模型训练入口
  119.  
    model.train()
  120.  
     
  121.  
    # 变量更新,计算epoch平均损失
  122.  
    epoch_losses = AverageMeter()
  123.  
     
  124.  
    # 进度条,就是不要不足batchsize的部分
  125.  
    with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t:
  126.  
    # t.set_description('epoch:{}/{}'.format(epoch, args.num_epochs - 1))
  127.  
    t.set_description('epoch:{}/{}'.format(epoch, args.num_epochs))
  128.  
     
  129.  
    # 每个batch计算一次
  130.  
    for data in train_dataloader:
  131.  
    # 对应datastes.py中的__getItem__,分别为lr,hr图像
  132.  
    inputs, labels = data
  133.  
     
  134.  
    inputs = inputs.to(device)
  135.  
    labels = labels.to(device)
  136.  
    # 送入模型训练
  137.  
    preds = model(inputs)
  138.  
     
  139.  
    # 获得损失
  140.  
    loss = criterion(preds, labels)
  141.  
     
  142.  
    # 显示损失值与长度
  143.  
    epoch_losses.update(loss.item(), len(inputs))
  144.  
     
  145.  
    # 梯度清零
  146.  
    optimizer.zero_grad()
  147.  
     
  148.  
    # 反向传播
  149.  
    loss.backward()
  150.  
     
  151.  
    # 更新参数
  152.  
    optimizer.step()
  153.  
     
  154.  
    # 进度条更新
  155.  
    t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
  156.  
    t.update(len(inputs))
  157.  
    # 记录lossLog 方面画图
  158.  
    lossLog.append(np.array(epoch_losses.avg))
  159.  
    # 可以在前面加上路径
  160.  
    np.savetxt("lossLog.txt", lossLog)
  161.  
     
  162.  
    # 保存模型
  163.  
    torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))
  164.  
     
  165.  
    # 是否更新当前最好参数
  166.  
    model.eval()
  167.  
    epoch_psnr = AverageMeter()
  168.  
     
  169.  
    for data in eval_dataloader:
  170.  
    inputs, labels = data
  171.  
     
  172.  
    inputs = inputs.to(device)
  173.  
    labels = labels.to(device)
  174.  
     
  175.  
    # 验证不用求导
  176.  
    with torch.no_grad():
  177.  
    preds = model(inputs).clamp(0.0, 1.0)
  178.  
     
  179.  
    epoch_psnr.update(calc_psnr(preds, labels), len(inputs))
  180.  
     
  181.  
    print('eval psnr: {:.2f}'.format(epoch_psnr.avg))
  182.  
     
  183.  
    # 记录psnr
  184.  
    psnrLog.append(Tensor.cpu(epoch_psnr.avg))
  185.  
    np.savetxt('psnrLog.txt', psnrLog)
  186.  
    # 找到更好的权重参数,更新
  187.  
    if epoch_psnr.avg > best_psnr:
  188.  
    best_epoch = epoch
  189.  
    best_psnr = epoch_psnr.avg
  190.  
    best_weights = copy.deepcopy(model.state_dict())
  191.  
     
  192.  
    print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
  193.  
     
  194.  
    torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))
  195.  
     
  196.  
    print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
  197.  
     
  198.  
    torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))
学新通

utils.py(工具包)

  1.  
    import torch
  2.  
    import numpy as np
  3.  
     
  4.  
    """
  5.  
    只操作y通道
  6.  
    因为我们感兴趣的不是颜色变化(存储在 CbCr 通道中的信息)而只是其亮度(Y 通道);
  7.  
    根本原因在于相较于色差,人类视觉对亮度变化更为敏感。
  8.  
    """
  9.  
    def convert_rgb_to_y(img):
  10.  
    if type(img) == np.ndarray:
  11.  
    return 16. (64.738 * img[:, :, 0] 129.057 * img[:, :, 1] 25.064 * img[:, :, 2]) / 256.
  12.  
    elif type(img) == torch.Tensor:
  13.  
    if len(img.shape) == 4:
  14.  
    img = img.squeeze(0)
  15.  
    return 16. (64.738 * img[0, :, :] 129.057 * img[1, :, :] 25.064 * img[2, :, :]) / 256.
  16.  
    else:
  17.  
    raise Exception('Unknown Type', type(img))
  18.  
     
  19.  
    """
  20.  
    RGB转YCBCR
  21.  
    Y=0.257*R 0.564*G 0.098*B 16
  22.  
    Cb=-0.148*R-0.291*G 0.439*B 128
  23.  
    Cr=0.439*R-0.368*G-0.071*B 128
  24.  
    """
  25.  
    def convert_rgb_to_ycbcr(img):
  26.  
    if type(img) == np.ndarray:
  27.  
    y = 16. (64.738 * img[:, :, 0] 129.057 * img[:, :, 1] 25.064 * img[:, :, 2]) / 256.
  28.  
    cb = 128. (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] 112.439 * img[:, :, 2]) / 256.
  29.  
    cr = 128. (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256.
  30.  
    return np.array([y, cb, cr]).transpose([1, 2, 0])
  31.  
    elif type(img) == torch.Tensor:
  32.  
    if len(img.shape) == 4:
  33.  
    img = img.squeeze(0)
  34.  
    y = 16. (64.738 * img[0, :, :] 129.057 * img[1, :, :] 25.064 * img[2, :, :]) / 256.
  35.  
    cb = 128. (-37.945 * img[0, :, :] - 74.494 * img[1, :, :] 112.439 * img[2, :, :]) / 256.
  36.  
    cr = 128. (112.439 * img[0, :, :] - 94.154 * img[1, :, :] - 18.285 * img[2, :, :]) / 256.
  37.  
    return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
  38.  
    else:
  39.  
    raise Exception('Unknown Type', type(img))
  40.  
     
  41.  
    """
  42.  
    YCBCR转RGB
  43.  
    R=1.164*(Y-16) 1.596*(Cr-128)
  44.  
    G=1.164*(Y-16)-0.392*(Cb-128)-0.813*(Cr-128)
  45.  
    B=1.164*(Y-16) 2.017*(Cb-128)
  46.  
    """
  47.  
    def convert_ycbcr_to_rgb(img):
  48.  
    if type(img) == np.ndarray:
  49.  
    r = 298.082 * img[:, :, 0] / 256. 408.583 * img[:, :, 2] / 256. - 222.921
  50.  
    g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. 135.576
  51.  
    b = 298.082 * img[:, :, 0] / 256. 516.412 * img[:, :, 1] / 256. - 276.836
  52.  
    return np.array([r, g, b]).transpose([1, 2, 0])
  53.  
    elif type(img) == torch.Tensor:
  54.  
    if len(img.shape) == 4:
  55.  
    img = img.squeeze(0)
  56.  
    r = 298.082 * img[0, :, :] / 256. 408.583 * img[2, :, :] / 256. - 222.921
  57.  
    g = 298.082 * img[0, :, :] / 256. - 100.291 * img[1, :, :] / 256. - 208.120 * img[2, :, :] / 256. 135.576
  58.  
    b = 298.082 * img[0, :, :] / 256. 516.412 * img[1, :, :] / 256. - 276.836
  59.  
    return torch.cat([r, g, b], 0).permute(1, 2, 0)
  60.  
    else:
  61.  
    raise Exception('Unknown Type', type(img))
  62.  
     
  63.  
    # PSNR 计算
  64.  
    def calc_psnr(img1, img2):
  65.  
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))
  66.  
     
  67.  
    # 计算 平均数,求和,长度
  68.  
    class AverageMeter(object):
  69.  
    def __init__(self):
  70.  
    self.reset()
  71.  
     
  72.  
    def reset(self):
  73.  
    self.val = 0
  74.  
    self.avg = 0
  75.  
    self.sum = 0
  76.  
    self.count = 0
  77.  
     
  78.  
    def update(self, val, n=1):
  79.  
    self.val = val
  80.  
    self.sum = val * n
  81.  
    self.count = n
  82.  
    self.avg = self.sum / self.count
学新通

test.py

  1.  
    import argparse
  2.  
     
  3.  
    import torch
  4.  
    import torch.backends.cudnn as cudnn
  5.  
    import numpy as np
  6.  
    import PIL.Image as pil_image
  7.  
     
  8.  
    from models import SRCNN
  9.  
    from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calc_psnr
  10.  
     
  11.  
     
  12.  
    if __name__ == '__main__':
  13.  
    # 设置权重参数目录,处理图像目录,放大倍数
  14.  
    parser = argparse.ArgumentParser()
  15.  
    parser.add_argument('--weights-file', default='outputs/x3/best.pth', type=str)
  16.  
    parser.add_argument('--image-file', default='img/butterfly_GT.bmp', type=str)
  17.  
    parser.add_argument('--scale', type=int, default=3)
  18.  
    args = parser.parse_args()
  19.  
    # Benchmark模式会提升计算速度
  20.  
    cudnn.benchmark = True
  21.  
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  22.  
     
  23.  
    model = SRCNN().to(device) # 新建一个模型
  24.  
     
  25.  
    state_dict = model.state_dict() # 通过 model.state_dict()得到模型有哪些 parameters and persistent buffers
  26.  
    # torch.load('tensors.pth', map_location=lambda storage, loc: storage) 使用函数将所有张量加载到CPU(适用在GPU训练的模型在CPU上加载)
  27.  
    for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items(): # 载入最好的模型参数
  28.  
    if n in state_dict.keys():
  29.  
    state_dict[n].copy_(p)
  30.  
    else:
  31.  
    raise KeyError(n)
  32.  
     
  33.  
    model.eval() # 切换为测试模式 ,取消dropout
  34.  
     
  35.  
    image = pil_image.open(args.image_file).convert('RGB') # 将图片转为RGB类型
  36.  
     
  37.  
    # 经过一个插值操作,首先将原始图片重设尺寸,使之可以被放大倍数scale整除
  38.  
    # 得到低分辨率图像Lr,即三次插值后的图像,同时保存输出
  39.  
    image_width = (image.width // args.scale) * args.scale
  40.  
    image_height = (image.height // args.scale) * args.scale
  41.  
    image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
  42.  
    image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC)
  43.  
    image = image.resize((image.width * args.scale, image.height * args.scale), resample=pil_image.BICUBIC)
  44.  
    image.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))
  45.  
    # 将图像转化为数组类型,同时图像转为ycbcr类型
  46.  
    image = np.array(image).astype(np.float32)
  47.  
    ycbcr = convert_rgb_to_ycbcr(image)
  48.  
    # 得到 ycbcr中的 y 通道
  49.  
    y = ycbcr[..., 0]
  50.  
    y /= 255. # 归一化处理
  51.  
    y = torch.from_numpy(y).to(device) #把数组转换成张量,且二者共享内存,对张量进行修改比如重新赋值,那么原始数组也会相应发生改变,并且将参数放到device上
  52.  
    y = y.unsqueeze(0).unsqueeze(0) # 增加两个维度
  53.  
    # 令reqires_grad自动设为False,关闭自动求导
  54.  
    # clamp将inputs归一化为0到1区间
  55.  
    with torch.no_grad():
  56.  
    preds = model(y).clamp(0.0, 1.0)
  57.  
     
  58.  
    psnr = calc_psnr(y, preds) # 计算y通道的psnr值
  59.  
    print('PSNR: {:.2f}'.format(psnr)) # 格式化输出PSNR值
  60.  
     
  61.  
    # 1.mul函数类似矩阵.*,即每个元素×255
  62.  
    # 2. *.cpu().numpy() 将数据的处理设备从其他设备(如gpu拿到cpu上),不会改变变量类型,转换后仍然是Tensor变量,同时将Tensor转化为ndarray
  63.  
    # 3. *.squeeze(0).squeeze(0)数据的维度进行压缩
  64.  
    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0) #得到的是经过模型处理,取值在[0,255]的y通道图像
  65.  
     
  66.  
    # 将img的数据格式由(channels,imagesize,imagesize)转化为(imagesize,imagesize,channels),进行格式的转换后方可进行显示。
  67.  
    output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
  68.  
     
  69.  
    output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8) # 将图像格式从ycbcr转为rgb,限制取值范围[0,255],同时矩阵元素类型为uint8类型
  70.  
    output = pil_image.fromarray(output) # array转换成image,即将矩阵转为图像
  71.  
    output.save(args.image_file.replace('.', '_srcnn_x{}.'.format(args.scale))) # 对图像进行保存
学新通

4.实验结果展示

学新通 学新通 学新通

        original                                          bicubic_x3                          SRCNN_x3

SRCNN:PSNR: 27.61

学新通 学新通 学新通

              original                                    bicubic_x3                                SRCNN_x3

SRCNN:PSNR: 29.17

GitHub项目地址传送门:SRCNN_Pytorch

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhhkjjhj
系列文章
更多 icon
同类精品
更多 icon
继续加载