Commit 1905a608 authored by Phil Wang's avatar Phil Wang
Browse files

allow for sampling multiple pieces of music and selecting top match with audio clip mulan

parent 22fd63bb
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -133,7 +133,7 @@ musiclm = MusicLM(
    mulan_embed_quantizer = mulan_embed_quantizer
)

music = musiclm(['the crystalline sounds of the piano in a ballroom']) # torch.Tensor
music = musiclm('the crystalline sounds of the piano in a ballroom', num_samples = 4) # sample 4 and pick the top match with mulan
```

## Todo
+48 −6
Original line number Diff line number Diff line
@@ -23,6 +23,9 @@ from beartype import beartype
def exists(val):
    return val is not None

def first(it):
    return it[0]

def default(val, d):
    return val if exists(val) else d

@@ -243,6 +246,8 @@ class AudioSpectrogramTransformer(nn.Module):
        attn_dropout = 0.,
        ff_mult = 4,
        ff_dropout = 0.,
        accept_spec = False,
        accept_spec_time_first = True,
        spec_n_fft = 128,
        spec_power = 2,
        spec_win_length = 24,
@@ -268,6 +273,9 @@ class AudioSpectrogramTransformer(nn.Module):
            nn.LayerNorm(dim)
        )

        self.accept_spec = accept_spec
        self.accept_spec_time_first = accept_spec_time_first

        self.spec = Spectrogram(
            n_fft = spec_n_fft,
            power = spec_power,
@@ -321,7 +329,12 @@ class AudioSpectrogramTransformer(nn.Module):
        force_no_patch_dropout = False
    ):
        batch, device = x.shape[0], x.device
        assert (self.accept_spec and x.ndim == 3) or (not self.accept_spec and x.ndim == 2)

        if self.accept_spec and self.accept_spec_time_first:
            x = rearrange(x, 'b t f -> b f t')

        if not self.accept_spec:
            x = self.spec(x)

        if self.training:
@@ -525,18 +538,26 @@ class MuLaN(nn.Module):
        wavs,
        texts = None,
        raw_texts: Optional[List[str]] = None,
        return_similarities = False
        return_latents = False,
        return_similarities = False,
        return_pairwise_similarities = False
    ):
        batch, device = wavs.shape[0], wavs.device

        audio_latents = self.get_audio_latents(wavs)
        text_latents = self.get_text_latents(texts, raw_texts = raw_texts)

        if return_latents:
            return audio_latents, text_latents

        if return_similarities:
            return einsum('i d, i d -> i', audio_latents, text_latents)

        cosine_sim = einsum('i d, j d -> i j', audio_latents, text_latents)

        assert cosine_sim.shape[0] == cosine_sim.shape[1], 'batch sizes for audio and text are not equal'

        if return_similarities:
        if return_pairwise_similarities:
            return cosine_sim

        cosine_sim = cosine_sim * self.temperature.exp()
@@ -661,13 +682,34 @@ class MusicLM(nn.Module):
    @torch.no_grad()
    def forward(
        self,
        raw_texts: List[str],
        text: str,
        num_samples = 1,
        **audio_lm_kwargs
    ):
        self.eval()

        texts = tokenizer.tokenize(raw_texts).to(self.device)
        texts = tokenizer.tokenize([text]).to(self.device)

        text_embeds = self.mulan_embed_quantizer(texts = texts)

        return self.audio_lm(text_embeds = text_embeds, **audio_lm_kwargs)
        # unable to deal with variable lengthed audio for now

        samples = []

        for _ in range(num_samples):
            music = self.audio_lm(text_embeds = text_embeds, **audio_lm_kwargs)
            samples.append(music)

        # if one sample, just return it

        if num_samples == 1:
            return first(samples)

        mulan = self.mulan_embed_quantizer.mulan

        # get the one with the highest similarity score, of all the samples

        sims = torch.cat([mulan(texts = texts, wavs = music, return_similarities = True) for music in samples], dim = 0)
        top_matching_index = sims.topk(1, dim = 0).indices.item()

        return samples[top_matching_index]
+2 −2
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'musiclm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.26',
  version = '0.0.28',
  license='MIT',
  description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
  author = 'Phil Wang',
@@ -20,7 +20,7 @@ setup(
  ],
  install_requires=[
    'accelerate',
    'audiolm-pytorch>=0.10.4',
    'audiolm-pytorch>=0.17.0',
    'beartype',
    'einops>=0.6',
    'lion-pytorch',