Commit b6eeb93f authored by Phil Wang's avatar Phil Wang
Browse files

add some asserts

parent 1b15751a
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -1343,6 +1343,10 @@ class FineTransformerWrapper(nn.Module):

        self.num_fine_quantizers = transformer.num_fine_quantizers
        self.num_coarse_quantizers = transformer.num_coarse_quantizers

        if exists(soundstream):
            assert (self.num_fine_quantizers + self.num_coarse_quantizers) == soundstream.num_quantizers, 'number of fine and coarse quantizers on fine transformer must add up to total number of quantizers on soundstream'

        self.eos_id = transformer.eos_id

        assert self.num_coarse_quantizers > 0
@@ -1563,6 +1567,7 @@ class AudioLM(nn.Module):
        assert semantic_transformer.num_semantic_tokens == coarse_transformer.num_semantic_tokens
        assert coarse_transformer.codebook_size == fine_transformer.codebook_size
        assert coarse_transformer.num_coarse_quantizers == fine_transformer.num_coarse_quantizers
        assert (fine_transformer.num_coarse_quantizers + fine_transformer.num_fine_quantizers) == soundstream.num_quantizers

        self.semantic_has_condition = semantic_transformer.has_condition
        self.coarse_has_condition = coarse_transformer.has_condition
+2 −0
Original line number Diff line number Diff line
@@ -424,6 +424,8 @@ class SoundStream(nn.Module):

        self.encoder_attn = nn.Sequential(*[LocalTransformerBlock(**attn_kwargs) for _ in range(attn_depth)]) if use_local_attn else None

        self.num_quantizers = rq_num_quantizers

        self.rq = ResidualVQ(
            dim = codebook_dim,
            num_quantizers = rq_num_quantizers,
+1 −1
Original line number Diff line number Diff line
__version__ = '0.18.1'
__version__ = '0.18.2'