Commit 15391183 authored by Phil Wang's avatar Phil Wang
Browse files

add an extra assert to protect against empty audiofiles

parent 9d992724
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -66,6 +66,8 @@ class SoundDataset(Dataset):

        data, sample_hz = torchaudio.load(file)

        assert data.numel() > 0, f'one of your audio file ({file}) is empty. please remove it from your folder'

        num_outputs = len(self.target_sample_hz)
        data = cast_tuple(data, num_outputs)

+9 −4
Original line number Diff line number Diff line
@@ -267,6 +267,7 @@ class SoundStream(nn.Module):
        feature_loss_weight = 100,
        quantize_dropout_cutoff_index = 1,
        target_sample_hz = 24000,
        use_local_attn = True,
        attn_window_size = 128,
        attn_dim_head = 64,
        attn_heads = 8
@@ -301,7 +302,7 @@ class SoundStream(nn.Module):
            causal = True
        )

        self.encoder_attn = LocalMHA(**attn_kwargs)
        self.encoder_attn = LocalMHA(**attn_kwargs) if use_local_attn else None

        self.rq = ResidualVQ(
            dim = codebook_dim,
@@ -315,7 +316,7 @@ class SoundStream(nn.Module):
            quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
        )

        self.decoder_attn = LocalMHA(**attn_kwargs)
        self.decoder_attn = LocalMHA(**attn_kwargs) if use_local_attn else None

        decoder_blocks = []

@@ -390,11 +391,15 @@ class SoundStream(nn.Module):
        x = self.encoder(x)

        x = rearrange(x, 'b c n -> b n c')

        if exists(self.encoder_attn):
            x = self.encoder_attn(x) + x

        x, indices, commit_loss = self.rq(x)

        if exists(self.decoder_attn):
            x = self.decoder_attn(x) + x

        x = rearrange(x, 'b n c -> b c n')

        if return_encoded:
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.4.7',
  version = '0.4.8',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',