Loading README.md +2 −1 Original line number Diff line number Diff line Loading @@ -190,7 +190,8 @@ semantic_transformer = SemanticTransformer( num_semantic_tokens = 500, dim = 1024, depth = 6, has_condition = True # this will have to be set to True has_condition = True, # this will have to be set to True cond_as_self_attn_prefix = True # whether to condition as prefix to self attention, instead of cross attention, as was done in 'VALL-E' paper ).cuda() # mock text video dataset (as an example) Loading audiolm_pytorch/audiolm_pytorch.py +78 −9 Original line number Diff line number Diff line Loading @@ -274,28 +274,58 @@ class Attention(nn.Module): x, context = None, mask = None, attn_bias = None attn_bias = None, prefix_context = None, prefix_context_mask = None ): b = x.shape[0] b, n, _, device = *x.shape, x.device if exists(context): context = self.context_norm(context) kv_input = default(context, x) # take care of prefix-based self attention conditioning # make sure to either concat the to the self attention mask or lengthen it accordingly if exists(prefix_context): kv_input = torch.cat((prefix_context, kv_input), dim = -2) prefix_seq_len = prefix_context.shape[-2] if not exists(mask): mask = torch.ones((b, n), device = device, dtype = torch.bool) if exists(prefix_context_mask): mask = torch.cat((prefix_context_mask, mask), dim = -1) else: mask = F.pad(mask, (prefix_seq_len, 0), value = True) if exists(attn_bias): attn_bias = F.pad(attn_bias, (prefix_seq_len, 0), value = 0.) # prenorm x = self.norm(x) # project for queries, keys, values q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1) # null key / values if self.num_null_kv > 0: null_k, null_v = repeat(self.null_kv, 'kv n d -> kv b n d', b = b).unbind(dim = 0) k = torch.cat((null_k, k), dim = -2) v = torch.cat((null_v, v), dim = -2) # split for multi-headed attention q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) q = q * self.scale # similarities sim = einsum('b h i d, b j d -> b h i j', q, k) if exists(attn_bias): Loading @@ -312,11 +342,17 @@ class Attention(nn.Module): causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1) sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) # attention attn = sim.softmax(dim = -1) attn = self.attn_dropout(attn) # aggregate out = einsum('b h i j, b j d -> b h i d', attn, v) # merge heads out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) Loading @@ -334,9 +370,13 @@ class Transformer(nn.Module): attn_dropout = 0., ff_dropout = 0., grad_shrink_alpha = 0.1, cond_as_self_attn_prefix = False, **kwargs ): super().__init__() assert not (cross_attend and cond_as_self_attn_prefix) self.cond_as_self_attn_prefix = cond_as_self_attn_prefix self.grad_shrink = partial(grad_shrink, alpha = grad_shrink_alpha) self.layers = nn.ModuleList([]) Loading @@ -359,14 +399,23 @@ class Transformer(nn.Module): context = None, context_mask = None ): assert not (self.cond_as_self_attn_prefix and not exists(context)) n, device = x.shape[1], x.device x = self.grad_shrink(x) # from cogview paper, adopted by GLM 130B LLM, decreases likelihood of attention net instability rel_pos_bias = self.rel_pos_bias(n, n, device = device) self_attn_kwargs = dict() if self.cond_as_self_attn_prefix: self_attn_kwargs = dict( prefix_context = context, prefix_context_mask = context_mask ) for attn, cross_attn, ff in self.layers: x = attn(x, attn_bias = rel_pos_bias, mask = self_attn_mask) + x x = attn(x, attn_bias = rel_pos_bias, mask = self_attn_mask, **self_attn_kwargs) + x if exists(cross_attn): assert exists(context) Loading @@ -392,6 +441,7 @@ class SemanticTransformer(nn.Module): ff_dropout = 0., t5_name = DEFAULT_T5_NAME, has_condition = False, cond_as_self_attn_prefix = False, cond_drop_prob = 0.5, grad_shrink_alpha = 0.1, pad_id = -1, Loading @@ -410,14 +460,17 @@ class SemanticTransformer(nn.Module): self.eos_id = num_semantic_tokens self.pad_id = pad_id text_dim = get_encoded_dim(t5_name) self.proj_text_embed = nn.Linear(text_dim, dim, bias = False) if text_dim != dim else nn.Identity() self.transformer = Transformer( dim = dim, depth = depth, heads = heads, attn_dropout = attn_dropout, ff_dropout = ff_dropout, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, cross_attend = has_condition and not cond_as_self_attn_prefix, cond_as_self_attn_prefix = cond_as_self_attn_prefix, grad_shrink_alpha = grad_shrink_alpha, **kwargs ) Loading Loading @@ -466,6 +519,9 @@ class SemanticTransformer(nn.Module): text_embeds = self.embed_text(text, output_device = device) text_mask = torch.any(text_embeds != 0, dim = -1) if exists(text_embeds): text_embeds = self.proj_text_embed(text_embeds) cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) if exists(text_mask) and cond_drop_prob > 0: Loading Loading @@ -502,6 +558,7 @@ class CoarseTransformer(nn.Module): ff_dropout = 0., t5_name = DEFAULT_T5_NAME, has_condition = False, cond_as_self_attn_prefix = False, cond_drop_prob = 0.5, grad_shrink_alpha = 0.1, **kwargs Loading @@ -523,14 +580,17 @@ class CoarseTransformer(nn.Module): codebook_size_with_eos = codebook_size + 1 self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size_with_eos, dim) text_dim = get_encoded_dim(t5_name) self.proj_text_embed = nn.Linear(text_dim, dim, bias = False) if text_dim != dim else nn.Identity() self.transformer = Transformer( dim = dim, depth = depth, heads = heads, attn_dropout = attn_dropout, ff_dropout = ff_dropout, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, cross_attend = has_condition and not cond_as_self_attn_prefix, cond_as_self_attn_prefix = cond_as_self_attn_prefix, grad_shrink_alpha = grad_shrink_alpha, **kwargs ) Loading Loading @@ -589,6 +649,8 @@ class CoarseTransformer(nn.Module): if exists(text_embeds): text_mask = torch.any(text_embeds != 0, dim = -1) text_embeds = self.proj_text_embed(text_embeds) cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) if exists(text_mask) and cond_drop_prob > 0: Loading Loading @@ -663,6 +725,7 @@ class FineTransformer(nn.Module): ff_dropout = 0., t5_name = DEFAULT_T5_NAME, has_condition = False, cond_as_self_attn_prefix = False, cond_drop_prob = 0.5, grad_shrink_alpha = 0.1, **kwargs Loading @@ -684,14 +747,17 @@ class FineTransformer(nn.Module): self.eos_id = codebook_size text_dim = get_encoded_dim(t5_name) self.proj_text_embed = nn.Linear(text_dim, dim, bias = False) if text_dim != dim else nn.Identity() self.transformer = Transformer( dim = dim, depth = depth, heads = heads, attn_dropout = attn_dropout, ff_dropout = ff_dropout, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, cross_attend = has_condition and not cond_as_self_attn_prefix, cond_as_self_attn_prefix = cond_as_self_attn_prefix, grad_shrink_alpha = grad_shrink_alpha, **kwargs ) Loading Loading @@ -747,6 +813,9 @@ class FineTransformer(nn.Module): text_embeds = self.embed_text(text, output_device = device) text_mask = torch.any(text_embeds != 0, dim = -1) if exists(text_embeds): text_embeds = self.proj_text_embed(text_embeds) cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) if exists(text_mask) and cond_drop_prob > 0: Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.3.6', version = '0.4.0', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
README.md +2 −1 Original line number Diff line number Diff line Loading @@ -190,7 +190,8 @@ semantic_transformer = SemanticTransformer( num_semantic_tokens = 500, dim = 1024, depth = 6, has_condition = True # this will have to be set to True has_condition = True, # this will have to be set to True cond_as_self_attn_prefix = True # whether to condition as prefix to self attention, instead of cross attention, as was done in 'VALL-E' paper ).cuda() # mock text video dataset (as an example) Loading
audiolm_pytorch/audiolm_pytorch.py +78 −9 Original line number Diff line number Diff line Loading @@ -274,28 +274,58 @@ class Attention(nn.Module): x, context = None, mask = None, attn_bias = None attn_bias = None, prefix_context = None, prefix_context_mask = None ): b = x.shape[0] b, n, _, device = *x.shape, x.device if exists(context): context = self.context_norm(context) kv_input = default(context, x) # take care of prefix-based self attention conditioning # make sure to either concat the to the self attention mask or lengthen it accordingly if exists(prefix_context): kv_input = torch.cat((prefix_context, kv_input), dim = -2) prefix_seq_len = prefix_context.shape[-2] if not exists(mask): mask = torch.ones((b, n), device = device, dtype = torch.bool) if exists(prefix_context_mask): mask = torch.cat((prefix_context_mask, mask), dim = -1) else: mask = F.pad(mask, (prefix_seq_len, 0), value = True) if exists(attn_bias): attn_bias = F.pad(attn_bias, (prefix_seq_len, 0), value = 0.) # prenorm x = self.norm(x) # project for queries, keys, values q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1) # null key / values if self.num_null_kv > 0: null_k, null_v = repeat(self.null_kv, 'kv n d -> kv b n d', b = b).unbind(dim = 0) k = torch.cat((null_k, k), dim = -2) v = torch.cat((null_v, v), dim = -2) # split for multi-headed attention q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) q = q * self.scale # similarities sim = einsum('b h i d, b j d -> b h i j', q, k) if exists(attn_bias): Loading @@ -312,11 +342,17 @@ class Attention(nn.Module): causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1) sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) # attention attn = sim.softmax(dim = -1) attn = self.attn_dropout(attn) # aggregate out = einsum('b h i j, b j d -> b h i d', attn, v) # merge heads out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) Loading @@ -334,9 +370,13 @@ class Transformer(nn.Module): attn_dropout = 0., ff_dropout = 0., grad_shrink_alpha = 0.1, cond_as_self_attn_prefix = False, **kwargs ): super().__init__() assert not (cross_attend and cond_as_self_attn_prefix) self.cond_as_self_attn_prefix = cond_as_self_attn_prefix self.grad_shrink = partial(grad_shrink, alpha = grad_shrink_alpha) self.layers = nn.ModuleList([]) Loading @@ -359,14 +399,23 @@ class Transformer(nn.Module): context = None, context_mask = None ): assert not (self.cond_as_self_attn_prefix and not exists(context)) n, device = x.shape[1], x.device x = self.grad_shrink(x) # from cogview paper, adopted by GLM 130B LLM, decreases likelihood of attention net instability rel_pos_bias = self.rel_pos_bias(n, n, device = device) self_attn_kwargs = dict() if self.cond_as_self_attn_prefix: self_attn_kwargs = dict( prefix_context = context, prefix_context_mask = context_mask ) for attn, cross_attn, ff in self.layers: x = attn(x, attn_bias = rel_pos_bias, mask = self_attn_mask) + x x = attn(x, attn_bias = rel_pos_bias, mask = self_attn_mask, **self_attn_kwargs) + x if exists(cross_attn): assert exists(context) Loading @@ -392,6 +441,7 @@ class SemanticTransformer(nn.Module): ff_dropout = 0., t5_name = DEFAULT_T5_NAME, has_condition = False, cond_as_self_attn_prefix = False, cond_drop_prob = 0.5, grad_shrink_alpha = 0.1, pad_id = -1, Loading @@ -410,14 +460,17 @@ class SemanticTransformer(nn.Module): self.eos_id = num_semantic_tokens self.pad_id = pad_id text_dim = get_encoded_dim(t5_name) self.proj_text_embed = nn.Linear(text_dim, dim, bias = False) if text_dim != dim else nn.Identity() self.transformer = Transformer( dim = dim, depth = depth, heads = heads, attn_dropout = attn_dropout, ff_dropout = ff_dropout, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, cross_attend = has_condition and not cond_as_self_attn_prefix, cond_as_self_attn_prefix = cond_as_self_attn_prefix, grad_shrink_alpha = grad_shrink_alpha, **kwargs ) Loading Loading @@ -466,6 +519,9 @@ class SemanticTransformer(nn.Module): text_embeds = self.embed_text(text, output_device = device) text_mask = torch.any(text_embeds != 0, dim = -1) if exists(text_embeds): text_embeds = self.proj_text_embed(text_embeds) cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) if exists(text_mask) and cond_drop_prob > 0: Loading Loading @@ -502,6 +558,7 @@ class CoarseTransformer(nn.Module): ff_dropout = 0., t5_name = DEFAULT_T5_NAME, has_condition = False, cond_as_self_attn_prefix = False, cond_drop_prob = 0.5, grad_shrink_alpha = 0.1, **kwargs Loading @@ -523,14 +580,17 @@ class CoarseTransformer(nn.Module): codebook_size_with_eos = codebook_size + 1 self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size_with_eos, dim) text_dim = get_encoded_dim(t5_name) self.proj_text_embed = nn.Linear(text_dim, dim, bias = False) if text_dim != dim else nn.Identity() self.transformer = Transformer( dim = dim, depth = depth, heads = heads, attn_dropout = attn_dropout, ff_dropout = ff_dropout, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, cross_attend = has_condition and not cond_as_self_attn_prefix, cond_as_self_attn_prefix = cond_as_self_attn_prefix, grad_shrink_alpha = grad_shrink_alpha, **kwargs ) Loading Loading @@ -589,6 +649,8 @@ class CoarseTransformer(nn.Module): if exists(text_embeds): text_mask = torch.any(text_embeds != 0, dim = -1) text_embeds = self.proj_text_embed(text_embeds) cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) if exists(text_mask) and cond_drop_prob > 0: Loading Loading @@ -663,6 +725,7 @@ class FineTransformer(nn.Module): ff_dropout = 0., t5_name = DEFAULT_T5_NAME, has_condition = False, cond_as_self_attn_prefix = False, cond_drop_prob = 0.5, grad_shrink_alpha = 0.1, **kwargs Loading @@ -684,14 +747,17 @@ class FineTransformer(nn.Module): self.eos_id = codebook_size text_dim = get_encoded_dim(t5_name) self.proj_text_embed = nn.Linear(text_dim, dim, bias = False) if text_dim != dim else nn.Identity() self.transformer = Transformer( dim = dim, depth = depth, heads = heads, attn_dropout = attn_dropout, ff_dropout = ff_dropout, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, cross_attend = has_condition and not cond_as_self_attn_prefix, cond_as_self_attn_prefix = cond_as_self_attn_prefix, grad_shrink_alpha = grad_shrink_alpha, **kwargs ) Loading Loading @@ -747,6 +813,9 @@ class FineTransformer(nn.Module): text_embeds = self.embed_text(text, output_device = device) text_mask = torch.any(text_embeds != 0, dim = -1) if exists(text_embeds): text_embeds = self.proj_text_embed(text_embeds) cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) if exists(text_mask) and cond_drop_prob > 0: Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.3.6', version = '0.4.0', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading