Код: Выделить всё
class MelSpectrogramLoss(torch.nn.Module):
...
def forward(self, real, fake):
# real: torch(B, 1, T) , fake: torch(B, 1, T)
real_mels = self.mel_spectrogram(real)
fake_mels = self.mel_spectrogram(fake)
real_logmels = torch.log(real_mels.clamp(min=1e-5).pow(1)) / self.log_base
fake_logmels = torch.log(fake_mels.clamp(min=1e-5).pow(1)) / self.log_base
loss = torch.nn.functional.l1_loss(real_logmels, fake_logmels)
return loss