fplloss2
import torch from Utils.utils import * def getHighLowFre(image): f = torch.fft.fft2(image) # 计算频率 freqs = torch.fft.fftfreq(image.shape[-1]) # print(freqs) # 设定阈值,用于分离高频和低频信息 threshold = 0.1 # 创建掩码,用于分离高频和低频信息 mask = (freqs.abs() < threshold).float().to(args.device) # print(mask) # 应用掩码,分离高频和低频信息 low_freq = torch.fft.ifft2(f * mask) # print(low_freq) high_freq = image - low_freq # print(high_freq) return high_freq, low_freq class LossNetwork(torch.nn.Module): def __init__(self): super(LossNetwork, self).__init__() self.alpha_sp, self.gamma_sp = 1, 0.5 self.alpha_lp, self.gamma_lp = 1, 1 self.weig_func = lambda x, y, z: torch.exp((x - x.min()) / (x.max() - x.min()) * y) * z def forward(self, pred, gt): pred_high_freq, pred_low_freq = getHighLowFre(pred) gt_high_freq, gt_low_freq = getHighLowFre(gt) y_sp = torch.abs(gt_low_freq - pred_low_freq) w_y_sp = self.weig_func(y_sp, self.alpha_sp, self.gamma_sp).detach() y_lp = torch.abs(gt_high_freq) w_y_lp = self.weig_func(y_lp, self.alpha_lp, self.gamma_lp).detach() y_hat = gt - pred # loss = torch.mean(w_y_sp * w_y_lp * torch.abs(y_hat)) loss = torch.mean(w_y_lp * torch.abs(y_hat)) return loss