【Hackathon 8th No.23】Improved Training of Wasserstein GANs 论文复现 by robinbg · Pull Request #1147 · PaddlePaddle/PaddleScience (original) (raw)

这里运行时会报错,Paddle 的 nn.Sequential 只能包含继承自 nn.Layer 的对象,不能包含 lambda 表达式,会触发 assert isinstance(layer, Layer) 错误,可以改为类似如下代码:

def __init__(self, noise_dim=100, output_channels=3):
      super(CIFAR10Generator, self).__init__()

      self.layers1 = nn.Sequential(
          nn.Linear(noise_dim, 512 * 4 * 4),
          nn.BatchNorm1D(512 * 4 * 4),
          nn.ReLU(),
      )
      self.layers2 = nn.Sequential(
          nn.Conv2DTranspose(512, 256, 4, 2, 1),
          nn.BatchNorm2D(256),
          nn.ReLU(),
          nn.Conv2DTranspose(256, 128, 4, 2, 1),
          nn.BatchNorm2D(128),
          nn.ReLU(),
          nn.Conv2DTranspose(128, output_channels, 4, 2, 1),
          nn.Tanh(),
      )

    def forward(self, x):
        x = self.layers1(x)
        x = x.reshape([-1, 512, 4, 4])
        x = self.layers2(x)
        return x