pytorch实现unet网络的方法-创新互联
创新互联www.cdcxhl.cn八线动态BGP香港云服务器提供商,新人活动买多久送多久,划算不套路!
成都创新互联公司-专业网站定制、快速模板网站建设、高性价比定兴网站开发、企业建站全套包干低至880元,成熟完善的模板库,直接使用。一站式定兴网站制作公司更省心,省钱,快速模板网站建设找我们,业务覆盖定兴地区。费用合理售后完善,10余年实体公司更值得信赖。这期内容当中小编将会给大家带来有关pytorch实现unet网络的方法,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。
设计神经网络的一般步骤:
1. 设计框架
2. 设计骨干网络
Unet网络设计的步骤:
1. 设计Unet网络工厂模式
2. 设计编解码结构
3. 设计卷积模块
4. unet实例模块
Unet网络最重要的特征:
1. 编解码结构。
2. 解码结构,比FCN更加完善,采用连接方式。
3. 本质是一个框架,编码部分可以使用很多图像分类网络。
示例代码:
import torch import torch.nn as nn class Unet(nn.Module): #初始化参数:Encoder,Decoder,bridge #bridge默认值为无,如果有参数传入,则用该参数替换None def __init__(self,Encoder,Decoder,bridge = None): super(Unet,self).__init__() self.encoder = Encoder(encoder_blocks) self.decoder = Decoder(decoder_blocks) self.bridge = bridge def forward(self,x): res = self.encoder(x) out,skip = res[0],res[1,:] if bridge is not None: out = bridge(out) out = self.decoder(out,skip) return out #设计编码模块 class Encoder(nn.Module): def __init__(self,blocks): super(Encoder,self).__init__() #assert:断言函数,避免出现参数错误 assert len(blocks) > 0 #nn.Modulelist():模型列表,所有的参数可以纳入网络,但是没有forward函数 self.blocks = nn.Modulelist(blocks) def forward(self,x): skip = [] for i in range(len(self.blocks) - 1): x = self.blocks[i](x) skip.append(x) res = [self.block[i+1](x)] #列表之间可以通过+号拼接 res += skip return res #设计Decoder模块 class Decoder(nn.Module): def __init__(self,blocks): super(Decoder, self).__init__() assert len(blocks) > 0 self.blocks = nn.Modulelist(blocks) def ceter_crop(self,skips,x): _,_,height1,width2 = skips.shape() _,_,height2,width3 = x.shape() #对图像进行剪切处理,拼接的时候保持对应size参数一致 ht,wt = min(height1,height2),min(width2,width3) dh2 = (height1 - height2)//2 if height1 > height2 else 0 dw1 = (width2 - width3)//2 if width2 > width3 else 0 dh3 = (height2 - height1)//2 if height2 > height1 else 0 dw2 = (width3 - width2)//2 if width3 > width2 else 0 return skips[:,:,dh2:(dh2 + ht),dw1:(dw1 + wt)],\ x[:,:,dh3:(dh3 + ht),dw2 : (dw2 + wt)] def forward(self, skips,x,reverse_skips = True): assert len(skips) == len(blocks) - 1 if reverse_skips is True: skips = skips[: : -1] x = self.blocks[0](x) for i in range(1, len(self.blocks)): skip = skips[i-1] x = torch.cat(skip,x,1) x = self.blocks[i](x) return x #定义了一个卷积block def unet_convs(in_channels,out_channels,padding = 0): #nn.Sequential:与Modulelist相比,包含了forward函数 return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernal_size = 3, padding = padding, bias = False), nn.BatchNorm2d(outchannels), nn.ReLU(inplace = True), nn.Conv2d(in_channels, out_channels, kernal_size=3, padding=padding, bias=False), nn.BatchNorm2d(outchannels), nn.ReLU(inplace=True), ) #实例化Unet模型 def unet(in_channels,out_channels): encoder_blocks = [unet_convs(in_channels, 64),\ nn.Sequential(nn.Maxpool2d(kernal_size = 2, stride = 2, ceil_mode = True),\ unet_convs(64,128)), \ nn.Sequential(nn.Maxpool2d(kernal_size=2, stride=2, ceil_mode=True), \ unet_convs(128, 256)), nn.Sequential(nn.Maxpool2d(kernal_size=2, stride=2, ceil_mode=True), \ unet_convs(256, 512)), ] bridge = nn.Sequential(unet_convs(512, 1024)) decoder_blocks = [nn.conTranpose2d(1024, 512), \ nn.Sequential(unet_convs(1024, 512), nn.conTranpose2d(512, 256)),\ nn.Sequential(unet_convs(512, 256), nn.conTranpose2d(256, 128)), \ nn.Sequential(unet_convs(512, 256), nn.conTranpose2d(256, 128)), \ nn.Sequential(unet_convs(256, 128), nn.conTranpose2d(128, 64)) ] return Unet(encoder_blocks,decoder_blocks,bridge)
分享文章:pytorch实现unet网络的方法-创新互联
标题来源:http://abwzjs.com/article/pepjj.html