Loading audiolm_pytorch/soundstream.py +2 −116 Original line number Diff line number Diff line Loading @@ -28,28 +28,11 @@ def exists(val): def default(val, d): return val if exists(val) else d # decorators def auto_handle_complex(fn): @wraps(fn) def inner(*args): if args[0].dtype not in (torch.complex64, torch.complex32): return fn(*args) real_args = tuple(arg.real for arg in args) imag_args = tuple(arg.imag for arg in args) return (fn(*real_args) + fn(*imag_args)) * 0.5 return inner # gan losses @auto_handle_complex def hinge_discr_loss(fake, real): return (F.relu(1 + fake) + F.relu(1 - real)).mean() @auto_handle_complex def hinge_gen_loss(fake): return -fake.mean() Loading Loading @@ -120,101 +103,6 @@ class MultiScaleDiscriminator(nn.Module): return out, intermediates # complex stft discriminator class ModReLU(nn.Module): """ https://arxiv.org/abs/1705.09792 https://github.com/pytorch/pytorch/issues/47052#issuecomment-718948801 """ def __init__(self): super().__init__() self.b = nn.Parameter(torch.tensor(0.)) def forward(self, x): return F.relu(torch.abs(x) + self.b) * torch.exp(1.j * torch.angle(x)) def ComplexSTFTResidualUnit(chan_in, chan_out, strides): kernel_sizes = tuple(map(lambda t: t + 2, strides)) paddings = tuple(map(lambda t: t // 2, kernel_sizes)) return nn.Sequential( nn.Conv2d(chan_in, chan_in, 3, padding = 1, dtype = torch.complex64), ModReLU(), nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings, dtype = torch.complex64) ) class ComplexSTFTDiscriminator(nn.Module): def __init__( self, *, channels = 32, strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)), chan_mults = (1, 2, 4, 4, 8, 8), input_channels = 1, n_fft = 1024, hop_length = 256, win_length = 1024 ): super().__init__() self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3, dtype = torch.complex64) layer_channels = tuple(map(lambda mult: mult * channels, chan_mults)) layer_channels = (channels, *layer_channels) layer_channels_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:])) curr_channels = channels self.layers = nn.ModuleList([]) for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs): self.layers.append(ComplexSTFTResidualUnit(chan_in, chan_out, layer_stride)) self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1), dtype = torch.complex64) # todo: remove hardcoded 16 # stft settings self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length def forward(self, x, return_intermediates = False): x = rearrange(x, 'b 1 n -> b n') ''' reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows: The STFT-based discriminator is illustrated in Figure 4 and operates on a single scale, computing the STFT with a window length of W = 1024 samples and a hop length of H = 256 samples ''' x = torch.stft( x, self.n_fft, hop_length = self.hop_length, win_length = self.win_length, return_complex = True ) x = rearrange(x, 'b ... -> b 1 ...') intermediates = [] x = self.init_conv(x) intermediates.append(x) for layer in self.layers: x = layer(x) intermediates.append(x) complex_logits = self.final_conv(x) if not return_intermediates: return complex_logits return complex_logits, intermediates # simulated complex stft discriminator class ComplexConv2d(nn.Module): Loading Loading @@ -461,8 +349,7 @@ class SoundStream(nn.Module): mhesa_dim_head = 32, attn_window_size = 128, attn_dim_head = 64, attn_heads = 8, use_complex_stft_discriminator = False attn_heads = 8 ): super().__init__() self.target_sample_hz = target_sample_hz # for resampling on the fly Loading Loading @@ -536,8 +423,7 @@ class SoundStream(nn.Module): self.discr_multi_scales = discr_multi_scales self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))]) stft_klass = ComplexSTFTDiscriminator if use_complex_stft_discriminator else STFTDiscriminator self.stft_discriminator = stft_klass() self.stft_discriminator = STFTDiscriminator() # loss weights Loading Loading
audiolm_pytorch/soundstream.py +2 −116 Original line number Diff line number Diff line Loading @@ -28,28 +28,11 @@ def exists(val): def default(val, d): return val if exists(val) else d # decorators def auto_handle_complex(fn): @wraps(fn) def inner(*args): if args[0].dtype not in (torch.complex64, torch.complex32): return fn(*args) real_args = tuple(arg.real for arg in args) imag_args = tuple(arg.imag for arg in args) return (fn(*real_args) + fn(*imag_args)) * 0.5 return inner # gan losses @auto_handle_complex def hinge_discr_loss(fake, real): return (F.relu(1 + fake) + F.relu(1 - real)).mean() @auto_handle_complex def hinge_gen_loss(fake): return -fake.mean() Loading Loading @@ -120,101 +103,6 @@ class MultiScaleDiscriminator(nn.Module): return out, intermediates # complex stft discriminator class ModReLU(nn.Module): """ https://arxiv.org/abs/1705.09792 https://github.com/pytorch/pytorch/issues/47052#issuecomment-718948801 """ def __init__(self): super().__init__() self.b = nn.Parameter(torch.tensor(0.)) def forward(self, x): return F.relu(torch.abs(x) + self.b) * torch.exp(1.j * torch.angle(x)) def ComplexSTFTResidualUnit(chan_in, chan_out, strides): kernel_sizes = tuple(map(lambda t: t + 2, strides)) paddings = tuple(map(lambda t: t // 2, kernel_sizes)) return nn.Sequential( nn.Conv2d(chan_in, chan_in, 3, padding = 1, dtype = torch.complex64), ModReLU(), nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings, dtype = torch.complex64) ) class ComplexSTFTDiscriminator(nn.Module): def __init__( self, *, channels = 32, strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)), chan_mults = (1, 2, 4, 4, 8, 8), input_channels = 1, n_fft = 1024, hop_length = 256, win_length = 1024 ): super().__init__() self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3, dtype = torch.complex64) layer_channels = tuple(map(lambda mult: mult * channels, chan_mults)) layer_channels = (channels, *layer_channels) layer_channels_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:])) curr_channels = channels self.layers = nn.ModuleList([]) for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs): self.layers.append(ComplexSTFTResidualUnit(chan_in, chan_out, layer_stride)) self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1), dtype = torch.complex64) # todo: remove hardcoded 16 # stft settings self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length def forward(self, x, return_intermediates = False): x = rearrange(x, 'b 1 n -> b n') ''' reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows: The STFT-based discriminator is illustrated in Figure 4 and operates on a single scale, computing the STFT with a window length of W = 1024 samples and a hop length of H = 256 samples ''' x = torch.stft( x, self.n_fft, hop_length = self.hop_length, win_length = self.win_length, return_complex = True ) x = rearrange(x, 'b ... -> b 1 ...') intermediates = [] x = self.init_conv(x) intermediates.append(x) for layer in self.layers: x = layer(x) intermediates.append(x) complex_logits = self.final_conv(x) if not return_intermediates: return complex_logits return complex_logits, intermediates # simulated complex stft discriminator class ComplexConv2d(nn.Module): Loading Loading @@ -461,8 +349,7 @@ class SoundStream(nn.Module): mhesa_dim_head = 32, attn_window_size = 128, attn_dim_head = 64, attn_heads = 8, use_complex_stft_discriminator = False attn_heads = 8 ): super().__init__() self.target_sample_hz = target_sample_hz # for resampling on the fly Loading Loading @@ -536,8 +423,7 @@ class SoundStream(nn.Module): self.discr_multi_scales = discr_multi_scales self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))]) stft_klass = ComplexSTFTDiscriminator if use_complex_stft_discriminator else STFTDiscriminator self.stft_discriminator = stft_klass() self.stft_discriminator = STFTDiscriminator() # loss weights Loading