Commit 2d13bf13 authored by Phil Wang's avatar Phil Wang
Browse files

turn off masking in hubert kmeans forward, thanks to @maitycyrus

parent e2579052
Loading
Loading
Loading
Loading
+7 −1
Original line number Diff line number Diff line
@@ -79,7 +79,13 @@ class HubertWithKmeans(nn.Module):
        if exists(self.seq_len_multiple_of):
            wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)

        embed = self.model(wav_input, features_only = True, output_layer = self.output_layer)
        embed = self.model(
            wav_input,
            features_only = True,
            mask = False,  # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code
            output_layer = self.output_layer
        )

        embed, packed_shape = pack([embed['x']], '* d')

        codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy())
+1 −1
Original line number Diff line number Diff line
__version__ = '0.26.0'
__version__ = '0.26.1'