diff --git a/models/networks.py b/models/networks.py index 7519e80..3c29522 100644 --- a/models/networks.py +++ b/models/networks.py @@ -1413,12 +1413,11 @@ class MLPDiscriminator(nn.Module): self.dropout = nn.Dropout(dropout) def forward(self, x): - x = self.linear1(x) - x = self.activation(x) + features = self.linear1(x) # 中间特征,即 D_real 或 D_fake + x = self.activation(features) x = self.dropout(x) - x = self.linear2(x) - return self.dropout(x) - + scores = self.linear2(x) # 最终分数,即 real_scores 或 fake_scores + return scores, features class NLayerDiscriminator(nn.Module): """Defines a PatchGAN discriminator"""