@@ -142,7 +153,8 @@ class LeSigTransformer(nn.Module):
depths=cast_tuple(depth,stages)
layer_heads=cast_tuple(heads,stages)
assertall(map(lambdat:len(t)==stages,(dims,depths,layer_heads))),'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'
@@ -126,8 +136,10 @@ class Transformer(nn.Module):
x=ff(x)+x
returnx
classLeViT(nn.Module):
__model_name__='levit'
__dims__=2
def__init__(
self,
@@ -149,7 +161,8 @@ class LeViT(nn.Module):
depths=cast_tuple(depth,stages)
layer_heads=cast_tuple(heads,stages)
assertall(map(lambdat:len(t)==stages,(dims,depths,layer_heads))),'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'