Commit af321ce9 authored by dmMaze's avatar dmMaze
Browse files

remove fp16 for lama due to overflow

parent 3fa1ef97
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -421,7 +421,6 @@ class LamaLarge(LamaInpainterMPE):
        'precision': {
            'type': 'selector',
            'options': [
                'fp16', 
                'fp32',
                'bf16'
            ], 
+0 −51
Original line number Diff line number Diff line
@@ -69,57 +69,6 @@ class FourierUnit(nn.Module):
        self.ffc3d = ffc3d
        self.fft_norm = fft_norm

    # def forward(self, x):
    #     batch = x.shape[0]
    #     input_dtype = x.dtype

    #     if self.spatial_scale_factor is not None:
    #         orig_size = x.shape[-2:]
    #         x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False)

    #     # (batch, c, h, w/2+1, 2)
    #     fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
    #     # x: torch.float16

    #     if input_dtype != torch.float32:
    #         x = x.type(torch.float32)
    #     ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
    #     ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
    #     ffted = ffted.permute(0, 1, 4, 2, 3).contiguous()  # (batch, c, 2, h, w/2+1)
    #     ffted = ffted.view((batch, -1,) + ffted.size()[3:])

    #     if self.spectral_pos_encoding:
    #         height, width = ffted.shape[-2:]
    #         coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
    #         coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
    #         ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)

    #     if self.use_se:
    #         ffted = self.se(ffted)

    #     if ffted.dtype != input_dtype:
    #         ffted = ffted.type(input_dtype)
    #     ffted = self.conv_layer(ffted)  # (batch, c*2, h, w/2+1)
    #     ffted = self.relu(self.bn(ffted))

    #     ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
    #         0, 1, 3, 4, 2).contiguous()  # (batch,c, t, h, w/2+1, 2)
        
    #     if input_dtype != torch.float32:
    #         ffted = ffted.type(torch.float32)
    #     ffted = torch.complex(ffted[..., 0], ffted[..., 1])

    #     ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
    #     output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)

    #     if output.dtype != input_dtype:
    #         output = output.type(input_dtype)

    #     if self.spatial_scale_factor is not None:
    #         output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)

    #     return output

    def forward(self, x):
        batch = x.shape[0]