Loading audiolm_pytorch/soundstream.py +10 −5 Original line number Diff line number Diff line Loading @@ -4,6 +4,7 @@ from pathlib import Path from functools import partial, wraps from itertools import zip_longest from typing import Optional import torch from torch import nn, einsum Loading Loading @@ -438,7 +439,8 @@ class SoundStream(nn.Module): attn_xpos_scale_base = None, attn_dynamic_pos_bias = False, squeeze_excite = False, complex_stft_discr_logits_abs = True complex_stft_discr_logits_abs = True, stft_discriminator: Optional[nn.Module] = None # can pass in own stft discriminator ): super().__init__() Loading Loading @@ -522,6 +524,9 @@ class SoundStream(nn.Module): discr_rel_factors = [int(s1 / s2) for s1, s2 in zip(discr_multi_scales[:-1], discr_multi_scales[1:])] self.downsamples = nn.ModuleList([nn.Identity()] + [nn.AvgPool1d(2 * factor, stride = factor, padding = factor) for factor in discr_rel_factors]) self.stft_discriminator = stft_discriminator if not exists(self.stft_discriminator): self.stft_discriminator = ComplexSTFTDiscriminator( stft_normalized = stft_normalized, logits_abs = complex_stft_discr_logits_abs # whether to output as abs() or use view_as_real Loading audiolm_pytorch/version.py +1 −1 Original line number Diff line number Diff line __version__ = '0.28.0' __version__ = '0.28.1' Loading
audiolm_pytorch/soundstream.py +10 −5 Original line number Diff line number Diff line Loading @@ -4,6 +4,7 @@ from pathlib import Path from functools import partial, wraps from itertools import zip_longest from typing import Optional import torch from torch import nn, einsum Loading Loading @@ -438,7 +439,8 @@ class SoundStream(nn.Module): attn_xpos_scale_base = None, attn_dynamic_pos_bias = False, squeeze_excite = False, complex_stft_discr_logits_abs = True complex_stft_discr_logits_abs = True, stft_discriminator: Optional[nn.Module] = None # can pass in own stft discriminator ): super().__init__() Loading Loading @@ -522,6 +524,9 @@ class SoundStream(nn.Module): discr_rel_factors = [int(s1 / s2) for s1, s2 in zip(discr_multi_scales[:-1], discr_multi_scales[1:])] self.downsamples = nn.ModuleList([nn.Identity()] + [nn.AvgPool1d(2 * factor, stride = factor, padding = factor) for factor in discr_rel_factors]) self.stft_discriminator = stft_discriminator if not exists(self.stft_discriminator): self.stft_discriminator = ComplexSTFTDiscriminator( stft_normalized = stft_normalized, logits_abs = complex_stft_discr_logits_abs # whether to output as abs() or use view_as_real Loading
audiolm_pytorch/version.py +1 −1 Original line number Diff line number Diff line __version__ = '0.28.0' __version__ = '0.28.1'