【Hackathon 8th No.23】Improved Training of Wasserstein GANs 论文复现 by robinbg · Pull Request #1147 · PaddlePaddle/PaddleScience (original) (raw)
这里运行时会报错,Paddle 的 nn.Sequential 只能包含继承自 nn.Layer 的对象,不能包含 lambda 表达式,会触发 assert isinstance(layer, Layer) 错误,可以改为类似如下代码:
def __init__(self, noise_dim=100, output_channels=3):
super(CIFAR10Generator, self).__init__()
self.layers1 = nn.Sequential(
nn.Linear(noise_dim, 512 * 4 * 4),
nn.BatchNorm1D(512 * 4 * 4),
nn.ReLU(),
)
self.layers2 = nn.Sequential(
nn.Conv2DTranspose(512, 256, 4, 2, 1),
nn.BatchNorm2D(256),
nn.ReLU(),
nn.Conv2DTranspose(256, 128, 4, 2, 1),
nn.BatchNorm2D(128),
nn.ReLU(),
nn.Conv2DTranspose(128, output_channels, 4, 2, 1),
nn.Tanh(),
)
def forward(self, x):
x = self.layers1(x)
x = x.reshape([-1, 512, 4, 4])
x = self.layers2(x)
return x