swin-transformer详解及代码复现_apodxxx的博客-程序员秘密

技术标签: Transformer  深度学习  pytorch  人工智能  torch  paddle  

1. swin-transformer网络结构








实际上,我们在进行代码复现时应该是下图,接下来我们根据下面的图片进行分段实现



2. Patch Partition & Patch Embedding

首先将图片输入到Patch Partition模块中进行分块,即每4x4相邻的像素为一个Patch,然后在channel方向展平(flatten)。假设输入的是RGB三通道图片,那么每个patch就有4x4=16个像素,然后每个像素有R、G、B三个值所以展平后是16x3=48,所以通过Patch Partition后图像shape由 [H, W, 3]变成了 [H/4, W/4, 48]。然后在通过Linear Embeding层对每个像素的channel数据做线性变换,由48变成C,即图像shape再由 [H/4, W/4, 48]变成了 [H/4, W/4, C]。其实在源码中Patch Partition和Linear Embeding就是直接通过一个卷积层实现的,和之前Vision Transformer中讲的 Embedding层结构一模一样。

import paddle
import paddle.nn as nn
class PatchEmbedding(nn.Layer):
    def __init__(self,patch_size=4,embed_dim=96):
        super().__init__()
        self.patch_embed = nn.Conv2D(3,out_channels=96,kernel_size=4,stride=4)
        self.norm = nn.LayerNorm(embed_dim)
    def forward(self,x):
        x = self.patch_embed(x) #[B,embed_dim,h,w]
        x = x.flatten(2)    #[B,embed_dim,h*w]
        x = x.transpose([0,2,1])
        x = self.norm(x)   
        return x

3. Patch Merging

前面有说,在每个Stage中首先要通过一个Patch Merging层进行下采样(Stage1除外)。如下图所示,假设输入Patch Merging的是一个4x4大小的单通道特征图(feature map),Patch Merging会将每个2x2的相邻像素划分为一个patch,然后将每个patch中相同位置(同一颜色)像素给拼在一起就得到了4个feature map。接着将这四个feature map在深度方向进行concat拼接,然后在通过一个LayerNorm层。最后通过一个全连接层在feature map的深度方向做线性变化,将feature map的深度由C变成C/2。通过这个简单的例子可以看出,通过Patch Merging层后,feature map的高和宽会减半,深度会翻倍

class PatchMerging(nn.Layer):
    def __init__(self,resolution,dim):
        super().__init__()
        self.resolution = resolution
        self.dim = dim
        self.reduction = nn.Linear(4*dim,2*dim)
        self.norm = nn.LayerNorm(4*dim)
        
    def forward(self,x):
        h ,w = self.resolution
        b,_,c = x.shape
        x = x.reshape([b,h,w,c])
        x0 = x[:,0::2,0::2,:]
        x1 = x[:,0::2,1::2,:]
        x2 = x[:,1::2,0::2,:]
        x3 = x[:,1::2,1::2,:]
        x = paddle.concat([x0,x1,x2,x3],axis=-1)
        x = x.reshape([b,-1,4*c])
        x = self.norm(x)
        x = self.reduction(x)
        return x

PS:演示一下 x[:,0::2,0::2,:]等的作用

4. W-MSA(Windows Multi-head Self-Attention)和SW-MSA(Shifted Windows Multi-head Self-Attentio)

之所以引用Windows Multi-head Self-Attention(W-MSA)模块是为了减少计算量,采用W-MSA模块时,只会在每个窗口内进行自注意力计算,所以窗口与窗口之间是无法进行信息传递的,为了解决这个问题,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模块。

# 将layer分成若干个windows,然后在每个windows内attention计算
def windows_partition(x , window_size):
    B , H , W , C = x.shape
    x = x.reshape([B,H//window_size,window_size,W//window_size,window_size,C])
    # [B,H//window_size,W//window_size,window_size,window_size,C]
    x.transpose([0,1,3,2,4,5])
    x.reshape([-1,window_size,window_size,C])
    # [B*H//window_size*w//window_size,window_size,window_size,c]
    return x

#将若干个windows合并为一个layer。
def window_reverse(window, window_size , H , W ):
    B = window.shape[0]//((H//window_size)*(W//window_size))
    x = window.reshape([B,H//window_size,W//window_size,window_size,window_size,-1])
    x = x.transpose([0,1,3,2,4,5])
    x = x.reshape([B,H,W,-1])
    return x




接下来,在每个window中做self attention,就是在不关注mask的情况下,attention与transformer中的self attention没啥区别。

class window_attention(nn.Layer):
    def __init__(self,dim,window_size,num_heads):
        super().__init__()
        self.dim = dim
        self.dim_head = dim//num_heads
        self.num_heads = num_heads
        self.scale = self.dim_head**-0.5
        self.softmax = nn.Softmax(-1)
        self.qkv = nn.Linear(dim,int(dim*3))
        self.proj = nn.Linear(dim,dim)
    
    def transpose_multi_head(self,x):
        new_shape = x.shape[:-1]+[self.num_heads,self.dim_head]
        x = x.reshape(new_shape)
        # [B,num_patches,num_heads,dim_head]
        x = x.transpose([0,2,1,3])
         # [B,num_heads,num_patches,dim_head]
        return x
    def forward(self,x,mask=None):
        B,N,C = x.shape
        qkv = self.qkv(x).chunk(3,-1)
        q,k,v = map(self.transpose_multi_head,qkv)
        q = q*self.scale
        attn = paddle.matmul(q,k,transpose_y=True)
        
        # attn = self.softmax(attn)
        if mask is None:
            attn = self.softmax(attn)
        else:
            attn = attn.reshape([B//mask.shape[0],mask.shape[0],self.num_heads,mask.shape[1],mask.shape[1 ]])
            attn = attn+mask.unsqueeze(1).unsqueeze(0)
            attn = attn.reshape([-1,self.num_heads,mask.shape[1],mask.shape[1]])
            attn = self.softmax(attn)
        attn = paddle.matmul(attn,v)
        # [B,num_heads,num_patches,dim_head]
        attn = attn.transpose([0,2,1,3])
        #[B,num_patches,num_heas,dim_head]
        attn = attn.reshape([B,N,C])
        out = self.proj(attn)
        return out 

至于SW-MSA(Shifted Windows Multi-head Self-Attentio),具体的是如何实现的,可以详见博客,我在此处针对我所认为的难点,写了一些demo方便理解。

paddle.roll()

关于paddle.roll(同torch.roll),下面的图片中,ba 分别在第0轴和第1轴,下移两次,然后b再同样的操作便能达到a

如何生成generate mask

关于self.register_buffer与attention mask

        if self.shift_size > 0:
            H, W = self.resolution
            img_mask = paddle.zeros((1, H, W, 1))
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
            mask_windows = windows_partition(img_mask, self.window_size)
            mask_windows = mask_windows.reshape((-1, self.window_size * self.window_size))
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = paddle.where(attn_mask != 0,
                                     paddle.ones_like(attn_mask) * float(-100.0),
                                     attn_mask)
            attn_mask = paddle.where(attn_mask == 0,
                                     paddle.zeros_like(attn_mask),
                                     attn_mask)
        else:
            attn_mask = None
            
        self.register_buffer("attn_mask", attn_mask)

一般情况下,是将网络中的参数保存成orderedDict形式的,这里的参数其实包含两种,一种是模型中各种module含的参数,即nn.Parameter,我们当然可以在网络中定义其他的nn.Parameter参数,另一种就是buffer,前者每次optim.step会得到更新,而不会更新后者
接下来就是分成若干个window,展平(flatten),展平后,自己乘自己,最后得到attention mask。(上上图有展示)

class Identity(nn.Layer):
    def __init__(self):
        super().__init__()
    def forward(self,x):
        return x
class Mlp(nn.Layer):
    def __init__(self,embed_dim,mlp_ratio=4.0,dropout=0.):
        super().__init__()
        w_att_1,b_att_1 = self.init_weight()
        w_att_2,b_att_2 = self.init_weight()
        self.fc1 = nn.Linear(embed_dim,int(embed_dim*mlp_ratio),weight_attr=w_att_1,bias_attr=b_att_1)
        self.fc2 = nn.Linear(int(embed_dim*mlp_ratio),embed_dim,weight_attr=w_att_2,bias_attr=b_att_2)
        self.dropout = nn.Dropout(dropout)
        self.act = nn.GELU()
    def init_weight(self):
        weight_attr = paddle.ParamAttr(initializer=nn.initializer.TruncatedNormal(std=0.2))
        bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(.0))
        return  weight_attr,bias_attr
    def forward(self,x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

4. swin block

所有的模块在写完后,我们便需要将每个模块串联起来生成swin block。除了需要判断是 W-MSA和SW-MSA,其他的和transformer中的encoder没区别。在patch embedding后,将patch分成若干个window,在各个window中分别做W-MSA或SW-MSA,残差连接,然后再mlp,再进行残差连接。

class SwinBlock(nn.Layer):
    def __init__(self,dim,input_resolution,num_heads,window_size,shift_size):
        super().__init__()
        self.dim = dim
        self.resolution = input_resolution
        self.window_size = window_size
        self.att_norm = nn.LayerNorm(dim)
        self.attn = window_attention(dim=dim,window_size=window_size, num_heads=num_heads)
        self.mlp = Mlp(dim)
        self.shift_size = shift_size
        self.mlp_norm = nn.LayerNorm(dim)
        if self.shift_size > 0:
            H, W = self.resolution
            img_mask = paddle.zeros((1, H, W, 1))
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
            mask_windows = windows_partition(img_mask, self.window_size)
            mask_windows = mask_windows.reshape((-1, self.window_size * self.window_size))
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = paddle.where(attn_mask != 0,
                                     paddle.ones_like(attn_mask) * float(-100.0),
                                     attn_mask)
            attn_mask = paddle.where(attn_mask == 0,
                                     paddle.zeros_like(attn_mask),
                                     attn_mask)
        else:
            attn_mask = None
        self.register_buffer("attn_mask", attn_mask)

    def forward(self,x):

        H,W = self.resolution
        B,N,C = x.shape
        h = x
        x = self.att_norm(x)
        x = x.reshape([B,H,W,C])
        if self.shift_size >0 :
            shift_x = paddle.roll(x,shifts=(-self.shift_size,-self.shift_size),axis=(1,2))
        else:
            shift_x = x
        x_windows = windows_partition(shift_x,self.window_size)
        x_windows = x_windows.reshape([-1,self.window_size*self.window_size,C])
        attn_windows = self.attn(x_windows,mask = self.attn_mask)
        attn_windows = attn_windows.reshape([-1,self.window_size,self.window_size,C])
        shifted_x = window_reverse(attn_windows,self.window_size,H,W)
        if self.shift_size>0:
            x = paddle.roll(shifted_x,shifts=(-self.shift_size,-self.shift_size),axis=(1,2))
        else:
            x = shifted_x          
        x = x.reshape([B,-1,C])
        x = h+x
        h = x
        x = self.mlp_norm(x)
        x = self.mlp(x)
        x = h+x
        return x

5. 接下来我们将所有的模块串联起来生成一个stage

stage由若干个Swin Transformer Block和一个Patch Merging生成。

class SwinTransformerStage(nn.Layer):
    def __init__(self,dim,input_resolution,depth,num_heads,window_size,patch_merging= None):
        super().__init__()
        self.blocks = nn.LayerList()
        for i in range(depth):
            # print(i)
            self.blocks.append(SwinBlock(dim = dim,input_resolution=input_resolution,num_heads=num_heads,window_size=window_size,\
                        shift_size=0 if (i % 2 == 0) else window_size//2))
        if patch_merging is None:
            self.patch_merging = Identity()
        else:
            self.patch_merging = patch_merging(input_resolution,dim)
    def forward(self,x):
        for block in self.blocks:
            x = block(x)
        x = self.patch_merging(x)
        return x
class SwinTransformerStage(nn.Layer):
    def __init__(self,dim,input_resolution,depth,num_heads,window_size,patch_merging= None):
        super().__init__()
        self.blocks = nn.LayerList()
        for i in range(depth):
            # print(i)
            self.blocks.append(SwinBlock(dim = dim,input_resolution=input_resolution,num_heads=num_heads,window_size=window_size,\
                        shift_size=0 if (i % 2 == 0) else window_size//2))
        if patch_merging is None:
            self.patch_merging = Identity()
        else:
            self.patch_merging = patch_merging(input_resolution,dim)
    def forward(self,x):
        for block in self.blocks:
            x = block(x)
        x = self.patch_merging(x)
        return x

class Swin(nn.Layer):
    def __init__(self, 
                 image_size=224,
                 patch_size=4,
                 in_channels=3,
                 embed_dim=96,
                 window_size=7,
                 num_heads=[3,6,12,24],
                 depths = [2,2,62],
                 num_classes=1000):
        super().__init__()
        self.num_classes = num_classes
        self.depths = depths
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.num_stages = len(depths)
        self.num_features = int(self.embed_dim * 2 ** (self.num_stages - 1))
        self.patch_resolution = [image_size//patch_size,image_size//patch_size]
        self.patch_embedding = PatchEmbedding(patch_size=patch_size,embed_dim=embed_dim)
        self.stages = nn.LayerList()
        for idx,(depth,num_heads) in enumerate(zip(self.depths,num_heads)):

            stage = SwinTransformerStage(dim=int(self.embed_dim*2**idx),
                                        input_resolution=(self.patch_resolution[0]//(2**idx),
                                                          self.patch_resolution[0]//(2**idx)),
                                        depth=depth,
                                        num_heads=num_heads,
                                        window_size=window_size,
                                        patch_merging=PatchMerging if (idx < self.num_stages-1 ) else None )
            self.stages.append(stage)
        self.norm = nn.LayerNorm(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1D(1)
        self.fc = nn.Linear(self.num_features,self.num_classes)
    def forward(self,x):
        x = self.patch_embedding(x)
        for stage in self.stages:
            x = stage(x)
        x = self.norm(x)
        x = x.transpose([0,2,1])
        x = self.avgpool(x)
        x = x.flatten(1)
        x = self.fc(x)
        return x       

6. 输出网络

    model = Swin()
    print(model)
    out = model(t)
    print(out.shape)
Swin(
  (patch_embedding): PatchEmbedding(
    (patch_embed): Conv2D(3, 96, kernel_size=[4, 4], stride=[4, 4], data_format=NCHW)
    (norm): LayerNorm(normalized_shape=[96], epsilon=1e-05)
  )
  (stages): LayerList(
    (0): SwinTransformerStage(
      (blocks): LayerList(
        (0): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[96], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=96, out_features=288, dtype=float32)
            (proj): Linear(in_features=96, out_features=96, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=96, out_features=384, dtype=float32)
            (fc2): Linear(in_features=384, out_features=96, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[96], epsilon=1e-05)
        )
        (1): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[96], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=96, out_features=288, dtype=float32)
            (proj): Linear(in_features=96, out_features=96, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=96, out_features=384, dtype=float32)
            (fc2): Linear(in_features=384, out_features=96, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[96], epsilon=1e-05)
        )
      )
      (patch_merging): PatchMerging(
        (reduction): Linear(in_features=384, out_features=192, dtype=float32)
        (norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
      )
    )
    (1): SwinTransformerStage(
      (blocks): LayerList(
        (0): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[192], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=192, out_features=576, dtype=float32)
            (proj): Linear(in_features=192, out_features=192, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=192, out_features=768, dtype=float32)
            (fc2): Linear(in_features=768, out_features=192, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[192], epsilon=1e-05)
        )
        (1): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[192], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=192, out_features=576, dtype=float32)
            (proj): Linear(in_features=192, out_features=192, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=192, out_features=768, dtype=float32)
            (fc2): Linear(in_features=768, out_features=192, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[192], epsilon=1e-05)
        )
      )
      (patch_merging): PatchMerging(
        (reduction): Linear(in_features=768, out_features=384, dtype=float32)
        (norm): LayerNorm(normalized_shape=[768], epsilon=1e-05)
      )
    )
    (2): SwinTransformerStage(
      (blocks): LayerList(
        (0): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (1): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (2): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (3): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (4): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (5): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (6): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (7): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (8): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (9): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (10): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (11): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (12): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (13): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (14): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (15): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (16): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (17): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (18): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (19): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (20): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (21): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (22): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (23): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (24): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (25): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (26): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (27): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (28): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (29): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (30): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (31): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (32): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (33): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (34): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (35): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (36): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (37): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (38): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (39): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (40): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (41): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (42): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (43): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (44): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (45): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (46): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (47): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (48): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (49): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (50): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (51): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (52): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (53): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (54): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (55): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (56): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (57): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (58): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (59): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (60): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
        (61): SwinBlock(
          (att_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
          (attn): window_attention(
            (softmax): Softmax(axis=-1)
            (qkv): Linear(in_features=384, out_features=1152, dtype=float32)
            (proj): Linear(in_features=384, out_features=384, dtype=float32)
          )
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, dtype=float32)
            (fc2): Linear(in_features=1536, out_features=384, dtype=float32)
            (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
            (act): GELU(approximate=False)
          )
          (mlp_norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
        )
      )
      (patch_merging): Identity()
    )
  )
  (norm): LayerNorm(normalized_shape=[384], epsilon=1e-05)
  (avgpool): AdaptiveAvgPool1D(output_size=1)
  (fc): Linear(in_features=384, out_features=1000, dtype=float32)
)



---------------------------------------------------------------------------

NameError                                 Traceback (most recent call last)

/tmp/ipykernel_790/2976751405.py in <module>
      1 model = Swin()
      2 print(model)
----> 3 out = model(t)
      4 print(out.shape)


NameError: name 't' is not defined

7. 关于Relative Position Bias

可以参考这里
或者视频

8. 参考

代码参考

视频参考

博客参考

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/apodx/article/details/123941720

智能推荐

如何提高学生线上学习率的有效方法——以高一数学基础模块复习教学为例_wuwl150的博客-程序员秘密

如何提高学生线上学习率的有效方法一场突如其然的新冠肺炎来势汹汹,席卷全国,举国上下,万众一心积极抗战疫情。为了加强新型冠状病毒感染的肺炎疫情防控工作,有效减少人员聚集,阻断疫情传播,更好地保障人民群众生命安全及身体健康,我校积极贯彻教育局的文件政策,延迟开学。为了保障学生在家期间不落下功课,我校开展了“停课不停学”的网上教学活动。对于网络学习这种师生不见面的教学方式,教师不好监控,一些学科学生又不感兴趣像数学学科内容抽象、动手计算多的学科,学生会出现缺席、溜号等不参与课堂的情况。为了更好地让全体学生

Rasa原文--处理业务原文_海人001的博客-程序员秘密

目录Handling Business LogicStep-by-step Guide on Using Forms to Handle Business Logic#1. Defining the form#2. Updating the configuration#3. Creating rules4. Updating the NLU training data#Defining the responses#Summary#Handling Business Log

重磅推荐:一个基于 Vue 的 (大转盘/九宫格) 抽奖插件_萌眼牛牛 Lah的博客-程序员秘密

大家好,我是章鱼猫。今天给大家推荐的这开源项目是前端开发者比较喜欢的。现在作为前端开发者,经常开发一个商城网页,开发一些活动网页,那么就必须离不开需要大转盘和九宫格的抽奖功能。今天给大家推荐的这个开源项目就是专门针对抽奖的一个抽奖插件。这个开源项目就是:vue-luck-draw,它是一个基于 vue2 /vue3 的(大转盘抽奖 / 九宫格抽奖)插件;A lucky draw plug-in b...

Windows变量路径_lijianbiao0的博客-程序员秘密

如果使用VS开发的话,可以在需要查看变量路径的时候,也就是一些宏定义的时候,看这样一个按钮:点开就会发现在自己的电脑上,各个宏实际的路径值是多少了:Windows变量路径路径名称实际路径%SystemDrive%操作系统所在的分区号。如 C:%SystemRoot%操作系统根目录。如 C:\WINDOWS%windir%操作系统根目录。如 C:\WI...

服务器之间的命令和数据传输的通信方式_上游服务器 向下游服务器传输数据_answer3lin的博客-程序员秘密

服务器之间的通信通常我们交互除了P2P等协议,大多数都是基于C/S架构的通信场景,比如FTP, HTTP, DNS等。但是再射一一些安全协议方案的时候通常包括多方服务器和用户。此时应该如何通信那?比如传递命令和传输密钥。(1)Socket一般情况下比如我们设计一个后端服务,包括多个服务器:数据库服务器,web服务器,文件服务器、缓存服务器等的通信,一般是通过socket来设计专门的通信...

Helmholtz方程在柱坐标系下的变量分离及Bessel方程的导出 | 特殊函数(二) |偏微分方程(二十四)_Sany 何灿的博客-程序员秘密

在圆柱坐标曲面所围的区域上求解时,应采用柱坐标系(r,θ,z)(r,\theta,z)(r,θ,z),此时Δ3=1r∂∂r(r∂∂r)+1r2∂2∂θ2+∂2∂z2\Delta_3=\frac{1}{r}\frac{\partial}{\partial r}(r\frac{\partial}{\partial r})+\frac{1}{r^2}\frac{\partial ^2}{\partial \theta^2}+\frac{\partial^2}{\partial z^2}Δ3​=r1​∂r∂​

随便推点

解决二叉树的编程问题_疆~的博客-程序员秘密

目录(一)二叉树定义(二)二叉树的相关术语(三)二叉树的主要性质二叉树的存储结构1.顺序存储结构2.二叉链式存储结构用链式存储结构表示二叉树(Binary Tree)代码实现:3.三叉链表存储结构二叉树的遍历方法及递归实现注意:如果中序遍历和后序遍历序列相同,则该树只有左子树没有右子树。如果中序遍历和先序遍历序列相同,则该树只有右子树没有左子树。1.先...

浅谈32位和64位操作系统与内存的关系_u010182839的博客-程序员秘密

比如说有这样一个场景,你在书房里面看书你代表–支持32位的CPU和64位的CPU 书桌代表–内存 书架代表–硬盘你从书架取出32本书放在书桌上面,书桌只能同时铺开放32本书,你可以同时看这32本书获取信息,看完以后,你将这32本书放到书架,然后从书架再房32本书继续上面的操作;过了一段时间你觉得看书速度比较慢,你又买了一个可以同时铺开放64本书的大书桌,然后你通过神秘的训练让你可以同时看64本书

Spring Boot的应用启动与关闭_springboot关闭heapdump_frankliu01的博客-程序员秘密

1. Spring Boot应用打包spring Boot应用可以打成jar包,其中内嵌tomcat,因此可以直接启动使用。但是在Spring Boot应用启动之前,首先需要进行打包,本文讲述的是Maven工程的打包,打包需要的前提条件(pom.xml文件中的内容)是:...jar... org.springframework.boot spring-

误差与有效数字_有效数字和绝对误差限的关系_Onwaier的博客-程序员秘密

文章目录绝对误差相对误差有效数字定理1定理2绝对误差绝对误差:e=x∗−xe = x^* - xe=x∗−x,其中xxx为近似值,x∗x^*x∗为精确值。∣e∣|e|∣e∣的上限记为ϵ\epsilonϵ,称为绝对误差限,记为x=x∗±ϵx=x^* \pm \epsilonx=x∗±ϵ相对误差相对误差:er=ex∗e_r = \frac{e}{x^* }er​=x∗e​x的相对误差上限为...

Python import urllib.parse ImportError: No module named parse_StudyQuant的博客-程序员秘密

import urllib.parse ImportError: No module named parse错误原因:出现这个错误,是因为我使用的Python版本是2.7,根据Python 2.x urlparse模块文档,urlparse模块在Python 3中重命名为urllib.parse所以模块在Python 2.7下你应该使用urlparsepython3 和...

关于RPG MAKER的一些资源_泠璃的博客-程序员秘密

RPGMaker基础教程 yanfly.moe—–专门做RM插件的一个博客 JS Plugin Releases (RMMV)——RMMV插件论坛

推荐文章

热门文章

相关标签