Commit 914ba79a authored by Phil Wang's avatar Phil Wang
Browse files

fix for coarse transformer raw wave shapes

parent 241bae31
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -1411,8 +1411,8 @@ class CoarseTransformerWrapper(nn.Module):
            with torch.no_grad():
                self.codec.eval()
                _, indices, _ = self.codec(raw_wave_for_codec, return_encoded = True)
                batch = raw_wave.shape[0]
                num_timesteps = raw_wave.shape[1]
                batch = raw_wave_for_codec.shape[0]
                num_timesteps = raw_wave_for_codec.shape[1]
                num_frames = int(num_timesteps / self.codec.seq_len_multiple_of)
                assert indices.shape[0] == batch and indices.shape[1] == num_frames, \
                    f'Expected indices to have shape (batch, num_frames, num_coarse_quantizers + num_fine_quantizers), but got {indices.shape}'
+1 −1
Original line number Diff line number Diff line
__version__ = '0.26.5'
__version__ = '0.26.6'