Unverified Commit 7a84f5ee authored by eonglints's avatar eonglints Committed by GitHub
Browse files

Merge pull request #2 from eonglints/dropout-and-get_tokens-hookup

Attention and feedforward dropout
parents 3ff361d5 d83457f0
Loading
Loading
Loading
Loading
+25 −8
Original line number Diff line number Diff line
@@ -185,12 +185,13 @@ class GEGLU(nn.Module):
        x, gate = x.chunk(2, dim = -1)
        return F.gelu(gate) * x

def FeedForward(dim, mult = 4):
def FeedForward(dim, mult = 4, dropout = 0.1):
    inner_dim = int(dim * 2 * mult / 3)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim * 2, bias = False),
        GEGLU(),
        nn.Dropout(dropout),
        nn.Linear(inner_dim, dim, bias = False)
    )

@@ -205,7 +206,8 @@ class Attention(nn.Module):
        dim_context = None,
        heads = 8,
        norm_context = False,
        num_null_kv = 0
        num_null_kv = 0,
        dropout = 0.1
    ):
        super().__init__()
        self.heads = heads
@@ -223,7 +225,10 @@ class Attention(nn.Module):

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim_context, dim_head * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            nn.Dropout(dropout)
        )

    def forward(
        self,
@@ -286,6 +291,8 @@ class Transformer(nn.Module):
        heads,
        dim_context = None,
        cross_attend = False,
        attn_dropout = 0.1,
        ff_dropout = 0.1,
        grad_shrink_alpha = 0.1,
        **kwargs
    ):
@@ -298,9 +305,9 @@ class Transformer(nn.Module):

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, heads = heads, causal = True, **kwargs),
                Attention(dim = dim, heads = heads, dim_context = dim_context, num_null_kv = 1, norm_context = True, **kwargs) if cross_attend else None,
                FeedForward(dim = dim)
                Attention(dim = dim, heads = heads, dropout = attn_dropout, causal = True, **kwargs),
                Attention(dim = dim, heads = heads, dropout = attn_dropout, dim_context = dim_context, num_null_kv = 1, norm_context = True, **kwargs) if cross_attend else None,
                FeedForward(dim = dim, dropout = ff_dropout)
            ]))

        self.norm = nn.LayerNorm(dim)
@@ -341,6 +348,8 @@ class SemanticTransformer(nn.Module):
        depth,
        heads,
        num_semantic_tokens,
        attn_dropout = 0.1,
        ff_dropout = 0.1,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        cond_drop_prob = 0.5,
@@ -361,7 +370,15 @@ class SemanticTransformer(nn.Module):
        self.eos_id = num_semantic_tokens
        self.pad_id = pad_id

        self.transformer = Transformer(dim = dim, depth = depth, heads = heads, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs)
        self.transformer = Transformer(dim = dim,
                                       depth = depth,
                                       heads = heads,
                                       attn_dropout = attn_dropout,
                                       ff_dropout = ff_dropout,
                                       dim_context = get_encoded_dim(t5_name),
                                       cross_attend = has_condition,
                                       grad_shrink_alpha = grad_shrink_alpha,
                                       **kwargs)
        self.to_logits = nn.Linear(dim, num_semantic_tokens + 1)

    @property
@@ -414,7 +431,7 @@ class SemanticTransformer(nn.Module):
        if return_loss:
            labels, ids = ids.clone(), ids[:, :-1]

        tokens = self.semantic_embedding(ids)
        tokens = get_embeds(self.semantic_embedding, ids)

        start_tokens = repeat(self.start_token, 'd -> b 1 d', b = ids.shape[0])