Loading audiolm_pytorch/soundstream.py +27 −129 Original line number Diff line number Diff line Loading @@ -130,14 +130,35 @@ class ModReLU(nn.Module): def forward(self, x): return F.relu(torch.abs(x) + self.b) * torch.exp(1.j * torch.angle(x)) class ComplexConv2d(nn.Module): def __init__( self, dim, dim_out, kernel_size, stride = 1, padding = 0 ): super().__init__() conv = nn.Conv2d(dim, dim_out, kernel_size, dtype = torch.complex64) self.weight = nn.Parameter(torch.view_as_real(conv.weight)) self.bias = nn.Parameter(torch.view_as_real(conv.bias)) self.stride = stride self.padding = padding def forward(self, x): weight, bias = map(torch.view_as_complex, (self.weight, self.bias)) return F.conv2d(x, weight, bias, stride = self.stride, padding = self.padding) def ComplexSTFTResidualUnit(chan_in, chan_out, strides): kernel_sizes = tuple(map(lambda t: t + 2, strides)) paddings = tuple(map(lambda t: t // 2, kernel_sizes)) return nn.Sequential( nn.Conv2d(chan_in, chan_in, 3, padding = 1, dtype = torch.complex64), ComplexConv2d(chan_in, chan_in, 3, padding = 1), ModReLU(), nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings, dtype = torch.complex64) ComplexConv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings) ) class ComplexSTFTDiscriminator(nn.Module): Loading @@ -153,7 +174,7 @@ class ComplexSTFTDiscriminator(nn.Module): win_length = 1024 ): super().__init__() self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3, dtype = torch.complex64) self.init_conv = ComplexConv2d(input_channels, channels, 7, padding = 3) layer_channels = tuple(map(lambda mult: mult * channels, chan_mults)) layer_channels = (channels, *layer_channels) Loading @@ -166,7 +187,7 @@ class ComplexSTFTDiscriminator(nn.Module): for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs): self.layers.append(ComplexSTFTResidualUnit(chan_in, chan_out, layer_stride)) self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1), dtype = torch.complex64) # todo: remove hardcoded 16 self.final_conv = ComplexConv2d(layer_channels[-1], 1, (16, 1)) # todo: remove hardcoded 16 # stft settings Loading Loading @@ -213,127 +234,6 @@ class ComplexSTFTDiscriminator(nn.Module): return complex_logits_abs, intermediates # simulated complex stft discriminator class ComplexConv2d(nn.Module): def __init__( self, *args, **kwargs ): super().__init__() self.conv_real = nn.Conv2d(*args, **kwargs) self.conv_imag = nn.Conv2d(*args, **kwargs) def forward(self, x): real, imag = x.unbind(dim = 1) new_real = self.conv_real(real) - self.conv_imag(imag) new_imag = self.conv_real(imag) + self.conv_imag(real) return torch.stack((new_real, new_imag), dim = 1) def complex_abs(t, dim = 1, eps = 1e-8): real, imag = t.unbind(dim = 1) return (real ** 2 + imag ** 2).clamp(min = eps).sqrt() class ComplexModReLU(nn.Module): def __init__(self): super().__init__() self.b = nn.Parameter(torch.tensor(0.)) def forward(self, x): x_abs = complex_abs(x) real, imag = x.unbind(dim = 1) x_angle = torch.atan2(imag, real) new_real = F.relu(x_abs + self.b) new_imag = x_angle.exp() return torch.stack((new_real, new_imag), dim = 1) def STFTResidualUnit(chan_in, chan_out, strides): kernel_sizes = tuple(map(lambda t: t + 2, strides)) paddings = tuple(map(lambda t: t // 2, kernel_sizes)) return nn.Sequential( ComplexConv2d(chan_in, chan_in, 3, padding = 1), ComplexModReLU(), ComplexConv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings) ) class STFTDiscriminator(nn.Module): def __init__( self, *, channels = 32, strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)), chan_mults = (1, 2, 4, 4, 8, 8), input_channels = 1, n_fft = 1024, hop_length = 256, win_length = 1024 ): super().__init__() self.init_conv = ComplexConv2d(input_channels, channels, 7, padding = 3) layer_channels = tuple(map(lambda mult: mult * channels, chan_mults)) layer_channels = (channels, *layer_channels) layer_channels_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:])) curr_channels = channels self.layers = nn.ModuleList([]) for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs): self.layers.append(STFTResidualUnit(chan_in, chan_out, layer_stride)) self.final_conv = ComplexConv2d(layer_channels[-1], 1, (16, 1)) # todo: remove hardcoded 16 # stft settings self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length def forward(self, x, return_intermediates = False): """ einstein notation b - batch n - sequence length x - dimension of 2 that holds the real and imaginary values """ x = rearrange(x, 'b 1 n -> b n') x = torch.stft( x, self.n_fft, hop_length = self.hop_length, win_length = self.win_length, return_complex = False # going to simulate complex using dimension 1 ) x = rearrange(x, 'b ... x -> b x 1 ...') intermediates = [] x = self.init_conv(x) intermediates.append(x) for layer in self.layers: x = layer(x) intermediates.append(x) logits = self.final_conv(x) logits_abs = complex_abs(logits) if not return_intermediates: return logits_abs return logits_abs, intermediates # learned EMA blocks class MultiHeadEMABlock(nn.Module): Loading Loading @@ -470,8 +370,7 @@ class SoundStream(nn.Module): attn_window_size = 128, attn_dim_head = 64, attn_heads = 8, attn_depth = 1, use_complex_stft_discriminator = True attn_depth = 1 ): super().__init__() self.target_sample_hz = target_sample_hz # for resampling on the fly Loading Loading @@ -545,8 +444,7 @@ class SoundStream(nn.Module): self.discr_multi_scales = discr_multi_scales self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))]) stft_discr_klass = ComplexSTFTDiscriminator if use_complex_stft_discriminator else STFTDiscriminator self.stft_discriminator = stft_discr_klass() self.stft_discriminator = ComplexSTFTDiscriminator() # multi spectral reconstruction Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.11.9', version = '0.11.11', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/soundstream.py +27 −129 Original line number Diff line number Diff line Loading @@ -130,14 +130,35 @@ class ModReLU(nn.Module): def forward(self, x): return F.relu(torch.abs(x) + self.b) * torch.exp(1.j * torch.angle(x)) class ComplexConv2d(nn.Module): def __init__( self, dim, dim_out, kernel_size, stride = 1, padding = 0 ): super().__init__() conv = nn.Conv2d(dim, dim_out, kernel_size, dtype = torch.complex64) self.weight = nn.Parameter(torch.view_as_real(conv.weight)) self.bias = nn.Parameter(torch.view_as_real(conv.bias)) self.stride = stride self.padding = padding def forward(self, x): weight, bias = map(torch.view_as_complex, (self.weight, self.bias)) return F.conv2d(x, weight, bias, stride = self.stride, padding = self.padding) def ComplexSTFTResidualUnit(chan_in, chan_out, strides): kernel_sizes = tuple(map(lambda t: t + 2, strides)) paddings = tuple(map(lambda t: t // 2, kernel_sizes)) return nn.Sequential( nn.Conv2d(chan_in, chan_in, 3, padding = 1, dtype = torch.complex64), ComplexConv2d(chan_in, chan_in, 3, padding = 1), ModReLU(), nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings, dtype = torch.complex64) ComplexConv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings) ) class ComplexSTFTDiscriminator(nn.Module): Loading @@ -153,7 +174,7 @@ class ComplexSTFTDiscriminator(nn.Module): win_length = 1024 ): super().__init__() self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3, dtype = torch.complex64) self.init_conv = ComplexConv2d(input_channels, channels, 7, padding = 3) layer_channels = tuple(map(lambda mult: mult * channels, chan_mults)) layer_channels = (channels, *layer_channels) Loading @@ -166,7 +187,7 @@ class ComplexSTFTDiscriminator(nn.Module): for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs): self.layers.append(ComplexSTFTResidualUnit(chan_in, chan_out, layer_stride)) self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1), dtype = torch.complex64) # todo: remove hardcoded 16 self.final_conv = ComplexConv2d(layer_channels[-1], 1, (16, 1)) # todo: remove hardcoded 16 # stft settings Loading Loading @@ -213,127 +234,6 @@ class ComplexSTFTDiscriminator(nn.Module): return complex_logits_abs, intermediates # simulated complex stft discriminator class ComplexConv2d(nn.Module): def __init__( self, *args, **kwargs ): super().__init__() self.conv_real = nn.Conv2d(*args, **kwargs) self.conv_imag = nn.Conv2d(*args, **kwargs) def forward(self, x): real, imag = x.unbind(dim = 1) new_real = self.conv_real(real) - self.conv_imag(imag) new_imag = self.conv_real(imag) + self.conv_imag(real) return torch.stack((new_real, new_imag), dim = 1) def complex_abs(t, dim = 1, eps = 1e-8): real, imag = t.unbind(dim = 1) return (real ** 2 + imag ** 2).clamp(min = eps).sqrt() class ComplexModReLU(nn.Module): def __init__(self): super().__init__() self.b = nn.Parameter(torch.tensor(0.)) def forward(self, x): x_abs = complex_abs(x) real, imag = x.unbind(dim = 1) x_angle = torch.atan2(imag, real) new_real = F.relu(x_abs + self.b) new_imag = x_angle.exp() return torch.stack((new_real, new_imag), dim = 1) def STFTResidualUnit(chan_in, chan_out, strides): kernel_sizes = tuple(map(lambda t: t + 2, strides)) paddings = tuple(map(lambda t: t // 2, kernel_sizes)) return nn.Sequential( ComplexConv2d(chan_in, chan_in, 3, padding = 1), ComplexModReLU(), ComplexConv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings) ) class STFTDiscriminator(nn.Module): def __init__( self, *, channels = 32, strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)), chan_mults = (1, 2, 4, 4, 8, 8), input_channels = 1, n_fft = 1024, hop_length = 256, win_length = 1024 ): super().__init__() self.init_conv = ComplexConv2d(input_channels, channels, 7, padding = 3) layer_channels = tuple(map(lambda mult: mult * channels, chan_mults)) layer_channels = (channels, *layer_channels) layer_channels_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:])) curr_channels = channels self.layers = nn.ModuleList([]) for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs): self.layers.append(STFTResidualUnit(chan_in, chan_out, layer_stride)) self.final_conv = ComplexConv2d(layer_channels[-1], 1, (16, 1)) # todo: remove hardcoded 16 # stft settings self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length def forward(self, x, return_intermediates = False): """ einstein notation b - batch n - sequence length x - dimension of 2 that holds the real and imaginary values """ x = rearrange(x, 'b 1 n -> b n') x = torch.stft( x, self.n_fft, hop_length = self.hop_length, win_length = self.win_length, return_complex = False # going to simulate complex using dimension 1 ) x = rearrange(x, 'b ... x -> b x 1 ...') intermediates = [] x = self.init_conv(x) intermediates.append(x) for layer in self.layers: x = layer(x) intermediates.append(x) logits = self.final_conv(x) logits_abs = complex_abs(logits) if not return_intermediates: return logits_abs return logits_abs, intermediates # learned EMA blocks class MultiHeadEMABlock(nn.Module): Loading Loading @@ -470,8 +370,7 @@ class SoundStream(nn.Module): attn_window_size = 128, attn_dim_head = 64, attn_heads = 8, attn_depth = 1, use_complex_stft_discriminator = True attn_depth = 1 ): super().__init__() self.target_sample_hz = target_sample_hz # for resampling on the fly Loading Loading @@ -545,8 +444,7 @@ class SoundStream(nn.Module): self.discr_multi_scales = discr_multi_scales self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))]) stft_discr_klass = ComplexSTFTDiscriminator if use_complex_stft_discriminator else STFTDiscriminator self.stft_discriminator = stft_discr_klass() self.stft_discriminator = ComplexSTFTDiscriminator() # multi spectral reconstruction Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.11.9', version = '0.11.11', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading