Unverified Commit e67d4661 authored by eonglints's avatar eonglints Committed by GitHub
Browse files

Depth and heads options for initializing transformers

parent 06eecbdb
Loading
Loading
Loading
Loading
+13 −6
Original line number Diff line number Diff line
@@ -283,6 +283,7 @@ class Transformer(nn.Module):
        *,
        dim,
        depth,
        heads,
        dim_context = None,
        cross_attend = False,
        grad_shrink_alpha = 0.1,
@@ -293,12 +294,12 @@ class Transformer(nn.Module):

        self.layers = nn.ModuleList([])

        self.rel_pos_bias = RelativePositionBias()
        self.rel_pos_bias = RelativePositionBias(heads = heads)

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, causal = True, **kwargs),
                Attention(dim = dim, dim_context = dim_context, num_null_kv = 1, norm_context = True, **kwargs) if cross_attend else None,
                Attention(dim = dim, heads = heads, causal = True, **kwargs),
                Attention(dim = dim, heads = heads, dim_context = dim_context, num_null_kv = 1, norm_context = True, **kwargs) if cross_attend else None,
                FeedForward(dim = dim)
            ]))

@@ -337,6 +338,8 @@ class SemanticTransformer(nn.Module):
        self,
        *,
        dim,
        depth,
        heads,
        num_semantic_tokens,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
@@ -358,7 +361,7 @@ class SemanticTransformer(nn.Module):
        self.eos_id = num_semantic_tokens
        self.pad_id = pad_id

        self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs)
        self.transformer = Transformer(dim = dim, depth = depth, heads = heads, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs)
        self.to_logits = nn.Linear(dim, num_semantic_tokens + 1)

    @property
@@ -428,6 +431,8 @@ class CoarseTransformer(nn.Module):
        codebook_size,
        num_coarse_quantizers,
        dim,
        depth,
        heads,
        num_semantic_tokens,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
@@ -450,7 +455,7 @@ class CoarseTransformer(nn.Module):
        codebook_size_with_eos = codebook_size + 1
        self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size_with_eos, dim)

        self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs)
        self.transformer = Transformer(dim = dim, depth = depth, heads = heads, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs)

        self.codebook_size = codebook_size
        self.num_coarse_quantizers = num_coarse_quantizers
@@ -574,6 +579,8 @@ class FineTransformer(nn.Module):
        num_fine_quantizers,
        codebook_size,
        dim,
        depth,
        heads,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        cond_drop_prob = 0.5,
@@ -597,7 +604,7 @@ class FineTransformer(nn.Module):

        self.eos_id = codebook_size

        self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs)
        self.transformer = Transformer(dim = dim, depth = depth, heads = heads, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs)

        self.codebook_size = codebook_size
        self.num_coarse_quantizers = num_coarse_quantizers