Loading audiolm_pytorch/soundstream.py +61 −32 Original line number Diff line number Diff line Loading @@ -154,8 +154,7 @@ class ComplexSTFTDiscriminator(nn.Module): input_channels = 1, n_fft = 1024, hop_length = 256, win_length = 1024, **kwargs win_length = 1024 ): super().__init__() self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3, dtype = torch.complex64) Loading Loading @@ -216,16 +215,48 @@ class ComplexSTFTDiscriminator(nn.Module): return complex_logits, intermediates # non-complex stft discriminator # simulated complex stft discriminator class ComplexConv2d(nn.Module): def __init__( self, *args, **kwargs ): super().__init__() self.conv_real = nn.Conv2d(*args, **kwargs) self.conv_imag = nn.Conv2d(*args, **kwargs) def forward(self, x): real, imag = x.unbind(dim = 1) new_real = self.conv_real(real) - self.conv_imag(imag) new_imag = self.conv_real(imag) + self.conv_imag(real) return torch.stack((new_real, new_imag), dim = 1) class ComplexModReLU(nn.Module): def __init__(self): super().__init__() self.b = nn.Parameter(torch.tensor(0.)) def forward(self, x): real, imag = x.unbind(dim = 1) x_abs = (real ** 2 + imag ** 2).clamp(min = 1e-5).sqrt() x_angle = torch.atan(imag / real.clamp(min = 1e-5)) new_real = F.relu(x_abs + self.b) new_imag = x_angle.exp() return torch.stack((new_real, new_imag), dim = 1) def STFTResidualUnit(chan_in, chan_out, strides): kernel_sizes = tuple(map(lambda t: t + 2, strides)) paddings = tuple(map(lambda t: t // 2, kernel_sizes)) return nn.Sequential( nn.Conv2d(chan_in, chan_in, 3, padding = 1), leaky_relu(), nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings) ComplexConv2d(chan_in, chan_in, 3, padding = 1), ComplexModReLU(), ComplexConv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings) ) class STFTDiscriminator(nn.Module): Loading @@ -238,23 +269,11 @@ class STFTDiscriminator(nn.Module): input_channels = 1, n_fft = 1024, hop_length = 256, win_length = 1024, stft_normalized = True win_length = 1024 ): super().__init__() self.stft = T.Spectrogram( n_fft = n_fft, hop_length = hop_length, win_length = win_length, window_fn = torch.hann_window, normalized = stft_normalized, center = False, pad_mode = None, power = None ) input_channels *= 2 self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3) self.init_conv = ComplexConv2d(input_channels, channels, 7, padding = 3) layer_channels = tuple(map(lambda mult: mult * channels, chan_mults)) layer_channels = (channels, *layer_channels) Loading @@ -267,23 +286,33 @@ class STFTDiscriminator(nn.Module): for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs): self.layers.append(STFTResidualUnit(chan_in, chan_out, layer_stride)) self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1)) # todo: remove hardcoded 16 self.final_conv = ComplexConv2d(layer_channels[-1], 1, (16, 1)) # todo: remove hardcoded 16 # stft settings self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length def forward(self, x, return_intermediates = False): x = rearrange(x, 'b 1 n -> b n') """ einstein notation b - batch n - sequence length x - dimension of 2 that holds the real and imaginary values """ ''' reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows: x = rearrange(x, 'b 1 n -> b n') The STFT-based discriminator is illustrated in Figure 4 and operates on a single scale, computing the STFT with a window length of W = 1024 samples and a hop length of H = 256 samples ''' x = torch.stft( x, self.n_fft, hop_length = self.hop_length, win_length = self.win_length, return_complex = False # going to simulate complex using dimension 1 ) x = self.stft(x) x = rearrange(x, 'b ... -> b 1 ...') x = torch.cat((x.real, x.imag), dim = 1) x = rearrange(x, 'b ... x -> b x 1 ...') intermediates = [] Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.7.3', version = '0.7.4', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/soundstream.py +61 −32 Original line number Diff line number Diff line Loading @@ -154,8 +154,7 @@ class ComplexSTFTDiscriminator(nn.Module): input_channels = 1, n_fft = 1024, hop_length = 256, win_length = 1024, **kwargs win_length = 1024 ): super().__init__() self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3, dtype = torch.complex64) Loading Loading @@ -216,16 +215,48 @@ class ComplexSTFTDiscriminator(nn.Module): return complex_logits, intermediates # non-complex stft discriminator # simulated complex stft discriminator class ComplexConv2d(nn.Module): def __init__( self, *args, **kwargs ): super().__init__() self.conv_real = nn.Conv2d(*args, **kwargs) self.conv_imag = nn.Conv2d(*args, **kwargs) def forward(self, x): real, imag = x.unbind(dim = 1) new_real = self.conv_real(real) - self.conv_imag(imag) new_imag = self.conv_real(imag) + self.conv_imag(real) return torch.stack((new_real, new_imag), dim = 1) class ComplexModReLU(nn.Module): def __init__(self): super().__init__() self.b = nn.Parameter(torch.tensor(0.)) def forward(self, x): real, imag = x.unbind(dim = 1) x_abs = (real ** 2 + imag ** 2).clamp(min = 1e-5).sqrt() x_angle = torch.atan(imag / real.clamp(min = 1e-5)) new_real = F.relu(x_abs + self.b) new_imag = x_angle.exp() return torch.stack((new_real, new_imag), dim = 1) def STFTResidualUnit(chan_in, chan_out, strides): kernel_sizes = tuple(map(lambda t: t + 2, strides)) paddings = tuple(map(lambda t: t // 2, kernel_sizes)) return nn.Sequential( nn.Conv2d(chan_in, chan_in, 3, padding = 1), leaky_relu(), nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings) ComplexConv2d(chan_in, chan_in, 3, padding = 1), ComplexModReLU(), ComplexConv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings) ) class STFTDiscriminator(nn.Module): Loading @@ -238,23 +269,11 @@ class STFTDiscriminator(nn.Module): input_channels = 1, n_fft = 1024, hop_length = 256, win_length = 1024, stft_normalized = True win_length = 1024 ): super().__init__() self.stft = T.Spectrogram( n_fft = n_fft, hop_length = hop_length, win_length = win_length, window_fn = torch.hann_window, normalized = stft_normalized, center = False, pad_mode = None, power = None ) input_channels *= 2 self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3) self.init_conv = ComplexConv2d(input_channels, channels, 7, padding = 3) layer_channels = tuple(map(lambda mult: mult * channels, chan_mults)) layer_channels = (channels, *layer_channels) Loading @@ -267,23 +286,33 @@ class STFTDiscriminator(nn.Module): for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs): self.layers.append(STFTResidualUnit(chan_in, chan_out, layer_stride)) self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1)) # todo: remove hardcoded 16 self.final_conv = ComplexConv2d(layer_channels[-1], 1, (16, 1)) # todo: remove hardcoded 16 # stft settings self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length def forward(self, x, return_intermediates = False): x = rearrange(x, 'b 1 n -> b n') """ einstein notation b - batch n - sequence length x - dimension of 2 that holds the real and imaginary values """ ''' reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows: x = rearrange(x, 'b 1 n -> b n') The STFT-based discriminator is illustrated in Figure 4 and operates on a single scale, computing the STFT with a window length of W = 1024 samples and a hop length of H = 256 samples ''' x = torch.stft( x, self.n_fft, hop_length = self.hop_length, win_length = self.win_length, return_complex = False # going to simulate complex using dimension 1 ) x = self.stft(x) x = rearrange(x, 'b ... -> b 1 ...') x = torch.cat((x.real, x.imag), dim = 1) x = rearrange(x, 'b ... x -> b x 1 ...') intermediates = [] Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.7.3', version = '0.7.4', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading