Commit 48380d5d authored by Phil Wang's avatar Phil Wang
Browse files

force input to complex conv to have the same dtype as the weights, in the...

force input to complex conv to have the same dtype as the weights, in the hacked class to make sure complex valued network can do distributed training
parent ce717d99
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -191,6 +191,8 @@ class ComplexConv2d(nn.Module):

    def forward(self, x):
        weight, bias = map(torch.view_as_complex, (self.weight, self.bias))

        x = x.to(weight.dtype)
        return F.conv2d(x, weight, bias, stride = self.stride, padding = self.padding)

def ComplexSTFTResidualUnit(chan_in, chan_out, strides):
@@ -269,6 +271,7 @@ class ComplexSTFTDiscriminator(nn.Module):
        intermediates = []

        x = self.init_conv(x)

        intermediates.append(x)

        for layer in self.layers:
+1 −1
Original line number Diff line number Diff line
__version__ = '0.23.6'
__version__ = '0.23.7'