Commit 1c5038f7 authored by Leon Wu's avatar Leon Wu
Browse files

Fix

parent 20d60d18
Loading
Loading
Loading
Loading
+6 −1
Original line number Diff line number Diff line
@@ -53,7 +53,12 @@ class EncodecWrapper(nn.Module):
        # Extract discrete codes from EnCodec
        with torch.no_grad():
            encoded_frames = self.model.encode(wav)
        codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=1)  # [batch, timesteps, num_quantizers]
        # encoded_frames is a list of (frame, scale) tuples. Scale is a scalar but we don't use it. Frame is a tensor
        # of shape [batch, num_quantizers, num_samples_per_frame]. We want to concatenate the frames to get all the
        # timesteps concatenated.
        codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=1)  # [batch, num_quantizers, timesteps]
        # transformer code that uses codec expects codes to be [batch, timesteps, num_quantizers]
        codes = rearrange(codes, 'b q n -> b n q')  # result: [batch, timesteps, num_quantizers]
        # in original soundstream, is x, indices, commit_loss. But we only use indices in eval mode, so just keep that.
        return None, codes, None