fplloss2

helloWorld / 2023-07-29 / 原文

import torch
from Utils.utils import *


def getHighLowFre(image):
    f = torch.fft.fft2(image)
    # 计算频率
    freqs = torch.fft.fftfreq(image.shape[-1])
    # print(freqs)
    # 设定阈值,用于分离高频和低频信息
    threshold = 0.1
    # 创建掩码,用于分离高频和低频信息
    mask = (freqs.abs() < threshold).float().to(args.device)
    # print(mask)
    # 应用掩码,分离高频和低频信息
    low_freq = torch.fft.ifft2(f * mask)
    # print(low_freq)
    high_freq = image - low_freq
    # print(high_freq)
    return high_freq, low_freq


class LossNetwork(torch.nn.Module):
    def __init__(self):
        super(LossNetwork, self).__init__()
        self.alpha_sp, self.gamma_sp = 1, 0.5
        self.alpha_lp, self.gamma_lp = 1, 1
        self.weig_func = lambda x, y, z: torch.exp((x - x.min()) / (x.max() - x.min()) * y) * z

    def forward(self, pred, gt):
        pred_high_freq, pred_low_freq = getHighLowFre(pred)
        gt_high_freq, gt_low_freq = getHighLowFre(gt)

        y_sp = torch.abs(gt_low_freq - pred_low_freq)
        w_y_sp = self.weig_func(y_sp, self.alpha_sp, self.gamma_sp).detach()

        y_lp = torch.abs(gt_high_freq)
        w_y_lp = self.weig_func(y_lp, self.alpha_lp, self.gamma_lp).detach()

        y_hat = gt - pred
        # loss = torch.mean(w_y_sp * w_y_lp * torch.abs(y_hat))
        loss = torch.mean(w_y_lp * torch.abs(y_hat))

        return loss