2. PyTorch 环境下的简单残差网络
- class ResnetBlock(nn.Module):
- def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
- super(ResnetBlock, self).__init__()
- selfself.conv_block = self.build_conv_block(...)
-
- def build_conv_block(self, ...):
- conv_block = []
-
- conv_block += [nn.Conv2d(...),
- norm_layer(...),
- nn.ReLU()]
- if use_dropout:
- conv_block += [nn.Dropout(...)]
-
- conv_block += [nn.Conv2d(...),
- norm_layer(...)]
-
- return nn.Sequential(*conv_block)
-
- def forward(self, x):
- out = x + self.conv_block(x)
- return ou
在这里,ResNet 模块的跳跃连接直接在前向传导过程中实现了,PyTorch 允许在前向传导过程中进行动态操作。
3. PyTorch 环境下的带多个输出的网络
对于有多个输出的网络(例如使用一个预训练好的 VGG 网络构建感知损失),我们使用以下模式:
- class Vgg19(torch.nn.Module):
- def __init__(self, requires_grad=False):
- super(Vgg19, self).__init__()
- vgg_pretrained_features = models.vgg19(pretrained=True).features
- self.slice1 = torch.nn.Sequential()
- self.slice2 = torch.nn.Sequential()
- self.slice3 = torch.nn.Sequential()
-
- for x in range(7):
- self.slice1.add_module(str(x), vgg_pretrained_features[x])
- for x in range(7, 21):
- self.slice2.add_module(str(x), vgg_pretrained_features[x])
- for x in range(21, 30):
- self.slice3.add_module(str(x), vgg_pretrained_features[x])
- if not requires_grad:
- for param in self.parameters():
- param.requires_grad = False
-
- def forward(self, x):
- h_relu1 = self.slice1(x)
- h_relu2 = self.slice2(h_relu1)
- h_relu3 = self.slice3(h_relu2)
- out = [h_relu1, h_relu2, h_relu3]
- return out
(编辑:晋中站长网)
【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!
|