Commit f6f02a66 authored by lucidrains's avatar lucidrains
Browse files

allow for specifying which layer of hubert to use

parent 5ff887c5
Loading
Loading
Loading
Loading
+6 −2
Original line number Diff line number Diff line
@@ -18,6 +18,9 @@ logging.root.setLevel(logging.ERROR)
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

class HubertWithKmeans(nn.Module):
    """
    checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
@@ -29,7 +32,8 @@ class HubertWithKmeans(nn.Module):
        checkpoint_path,
        kmeans_path,
        target_sample_hz = 16000,
        seq_len_multiple_of = None
        seq_len_multiple_of = None,
        output_layer = 9
    ):
        super().__init__()
        self.target_sample_hz = target_sample_hz
@@ -74,7 +78,7 @@ class HubertWithKmeans(nn.Module):
        if exists(self.seq_len_multiple_of):
            wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)

        embed = self.model(wav_input, features_only = True)
        embed = self.model(wav_input, features_only = True, output_layer = self.output_layer)
        embed, packed_shape = pack([embed['x']], '* d')

        codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy())
+1 −1
Original line number Diff line number Diff line
__version__ = '0.25.0'
__version__ = '0.25.1'