FedR代码学习文档
main.py
参数设置,进入主函数
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# parser.add_argument('--data_path', default='Fed_data/WN18RR-Fed3.pkl', type=str)
parser.add_argument('--data_path', default='Fed_data/DDB14-Fed3.pkl', type=str)
parser.add_argument('--name', default='wn18rr_fed3_fed_TransE', type=str)
parser.add_argument('--state_dir', '-state_dir', default='./state', type=str)
parser.add_argument('--log_dir', '-log_dir', default='./log', type=str)
parser.add_argument('--tb_log_dir', '-tb_log_dir', default='./tb_log', type=str)
parser.add_argument('--run_mode', default='FedR', choices=['FedE', 'Single', 'test_pretrain'])
parser.add_argument('--num_multi', default=3, type=int)
parser.add_argument('--model', default='TransE', choices=['TransE', 'RotatE', 'DistMult', 'ComplEx'])
# one task hyperparam
parser.add_argument('--one_client_idx', default=0, type=int)
parser.add_argument('--max_epoch', default=10000, type=int)
parser.add_argument('--log_per_epoch', default=1, type=int)
parser.add_argument('--check_per_epoch', default=10, type=int)
parser.add_argument('--batch_size', default=512, type=int)
parser.add_argument('--test_batch_size', default=16, type=int)
parser.add_argument('--num_neg', default=256, type=int)
parser.add_argument('--lr', default=0.001, type=int)
# for FedE
parser.add_argument('--num_client', default=3, type=int)
parser.add_argument('--max_round', default=10000, type=int)
parser.add_argument('--local_epoch', default=3, type=int)
parser.add_argument('--fraction', default=1, type=float)
parser.add_argument('--log_per_round', default=1, type=int)
parser.add_argument('--check_per_round', default=5, type=int)
parser.add_argument('--early_stop_patience', default=5, type=int)
parser.add_argument('--gamma', default=10.0, type=float)
parser.add_argument('--epsilon', default=2.0, type=float)
parser.add_argument('--hidden_dim', default=128, type=int)
parser.add_argument('--gpu', default='0', type=str)
parser.add_argument('--num_cpu', default=10, type=int)
parser.add_argument('--adversarial_temperature', default=1.0, type=float)
# parser.add_argument('--negative_adversarial_sampling', default=True, type=bool)
parser.add_argument('--seed', default=12345, type=int)
args = parser.parse_args()
args_str = json.dumps(vars(args))
args.gpu = torch.device('cuda:' + args.gpu)
# args.gpu = torch.device(("cuda:" + args.gpu) if torch.cuda.is_available() else "cpu")
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
init_dir(args)
writer = SummaryWriter(os.path.join(args.tb_log_dir, args.name))
args.writer = writer
init_logger(args)
logging.info(args_str)
if args.run_mode == 'FedR':
all_data = pickle.load(open(args.data_path, 'rb'))
learner = FedR(args, all_data)
learner.train()
elif args.run_mode == 'Single':
all_data = pickle.load(open(args.data_path, 'rb'))
data = all_data[args.one_client_idx]
learner = KGERunner(args, data)
learner.train()
数据导入
.pkl形式的数据 (通过csv的代码可以进行转换)
这里的数据分给三个客户端,每个客户端当中又有train,valid,test
- edge_index:是一个二维数组,表示第i个三元组的起始节点和终止节点
- edge_type:表示第i个三元组的relation
- edge_index_ori:
- edge_type_ori:
,train,test,valid
0,"{'edge_index': array([[3515, 3614, 3299, ..., 246, 3912, 2853],
[ 961, 2501, 703, ..., 211, 1904, 442]], dtype=int64), 'edge_type': array([1, 9, 1, ..., 7, 1, 1], dtype=int64), 'edge_index_ori': array([[1796, 4767, 3939, ..., 345, 3215, 4054],
[1787, 3036, 950, ..., 341, 3204, 537]], dtype=int64), 'edge_type_ori': array([2, 8, 2, ..., 7, 2, 2], dtype=int64)}","{'edge_index': array([[ 392, 822, 331, ..., 1207, 247, 902],
[ 199, 261, 175, ..., 802, 195, 540]], dtype=int64), 'edge_type': array([1, 1, 1, ..., 1, 2, 1], dtype=int64), 'edge_index_ori': array([[ 424, 655, 373, ..., 1531, 364, 1047],
[ 416, 530, 366, ..., 1530, 360, 527]], dtype=int64), 'edge_type_ori': array([2, 2, 2, ..., 2, 3, 2], dtype=int64)}","{'edge_index': array([[ 358, 1204, 2395, ..., 210, 1139, 371],
[2581, 813, 1564, ..., 211, 583, 288]], dtype=int64), 'edge_type': array([1, 1, 6, ..., 8, 1, 1], dtype=int64), 'edge_index_ori': array([[ 393, 1601, 2983, ..., 229, 1223, 572],
[2692, 1580, 1188, ..., 341, 644, 554]], dtype=int64), 'edge_type_ori': array([2, 2, 5, ..., 6, 2, 2], dtype=int64)}"
1,"{'edge_index': array([[4881, 5080, 2512, ..., 531, 876, 3547],
[4882, 30, 38, ..., 532, 574, 95]], dtype=int64), 'edge_type': array([10, 0, 1, ..., 1, 4, 0], dtype=int64), 'edge_index_ori': array([[8515, 2880, 4337, ..., 3921, 721, 1996],
[8391, 2695, 4333, ..., 234, 1556, 2442]], dtype=int64), 'edge_type_ori': array([0, 2, 4, ..., 4, 5, 2], dtype=int64)}","{'edge_index': array([[ 309, 1661, 2880, ..., 1861, 1831, 652],
[ 72, 2083, 2154, ..., 127, 2730, 3940]], dtype=int64), 'edge_type': array([0, 0, 1, ..., 0, 0, 0], dtype=int64), 'edge_index_ori': array([[ 402, 4842, 827, ..., 2229, 2742, 228],
[ 379, 5955, 826, ..., 2256, 890, 5910]], dtype=int64), 'edge_type_ori': array([2, 2, 4, ..., 2, 2, 2], dtype=int64)}","{'edge_index': array([[1821, 91, 1049, ..., 59, 749, 560],
[1800, 398, 2353, ..., 511, 34, 381]], dtype=int64), 'edge_type': array([0, 0, 0, ..., 0, 0, 0], dtype=int64), 'edge_index_ori': array([[4224, 62, 620, ..., 1059, 2398, 1696],
[5502, 3330, 4833, ..., 266, 13, 125]], dtype=int64), 'edge_type_ori': array([2, 2, 2, ..., 2, 2, 2], dtype=int64)}"
2,"{'edge_index': array([[1048, 5151, 2026, ..., 3552, 1835, 897],
[ 286, 33, 2180, ..., 4712, 1836, 56]], dtype=int64), 'edge_type': array([2, 2, 0, ..., 8, 6, 2], dtype=int64), 'edge_index_ori': array([[5172, 4261, 1779, ..., 6148, 222, 1803],
[6817, 6663, 1859, ..., 9069, 2987, 6810]], dtype=int64), 'edge_type_ori': array([ 0, 0, 2, ..., 8, 12, 0], dtype=int64)}","{'edge_index': array([[ 508, 5263, 1230, ..., 577, 1646, 439],
[ 649, 4329, 649, ..., 1496, 2298, 598]], dtype=int64), 'edge_type': array([0, 0, 0, ..., 0, 0, 0], dtype=int64), 'edge_index_ori': array([[ 630, 1847, 1297, ..., 266, 576, 876],
[1256, 8628, 1256, ..., 4247, 1734, 3295]], dtype=int64), 'edge_type_ori': array([2, 2, 2, ..., 2, 2, 2], dtype=int64)}","{'edge_index': array([[ 91, 1798, 1622, ..., 2358, 4665, 427],
[ 672, 2482, 82, ..., 506, 136, 1011]], dtype=int64), 'edge_type': array([0, 3, 0, ..., 0, 0, 0], dtype=int64), 'edge_index_ori': array([[ 527, 3702, 566, ..., 1209, 251, 39],
[ 224, 2336, 4523, ..., 666, 1172, 4816]], dtype=int64), 'edge_type_ori': array([2, 5, 2, ..., 2, 2, 2], dtype=int64)}"
数据分发
1.将隐私数据分发到客户机 (客户拥有),初始化服务器
2.统计客户机测试集、验证集的数量,以及权重数量
class FedR(object):
def __init__(self, args, all_data):
self.args = args
train_dataloader_list, valid_dataloader_list, test_dataloader_list, \
self.rel_freq_mat, ent_embed_list, nrelation = get_all_clients(all_data, args)
self.args.nrelation = nrelation # question
# client
self.num_clients = len(train_dataloader_list)
# Create client objects for each client
self.clients = []
for i in range(self.num_clients):
client = Client(args, i, all_data[i], train_dataloader_list[i], valid_dataloader_list[i],
test_dataloader_list[i], ent_embed_list[i])
self.clients.append(client)
# Create the server object
self.server = Server(args, nrelation)
# 统计客户机测试集、验证集的数量,以及权重数量
# Calculate total test data size and test evaluation weights
self.total_test_data_size = 0
for client in self.clients:
self.total_test_data_size += len(client.test_dataloader.dataset)
self.test_eval_weights = []
for client in self.clients:
weight = len(client.test_dataloader.dataset) / self.total_test_data_size
self.test_eval_weights.append(weight)
# Calculate total valid data size and valid evaluation weights
self.total_valid_data_size = 0
for client in self.clients:
self.total_valid_data_size += len(client.valid_dataloader.dataset)
self.valid_eval_weights = []
for client in self.clients:
weight = len(client.valid_dataloader.dataset) / self.total_valid_data_size
self.valid_eval_weights.append(weight)
对初始数据集进行分发
1.all_rel = np.union1d(all_rel, data['train']['edge_type_ori']).reshape(-1):在这里,通过 np.union1d 函数将当前客户端的训练数据中的关系类型与 all_rel 数组进行合并并去除重复项。最后通过 reshape(-1) 将结果变为一维数组,并更新 all_rel。
2.train_dataloader_list, valid_dataloader_list, test_dataloader_list, ent_embed_list, rel_freq_list 初始化:这里分别初始化了存储训练、验证和测试数据加载器、实体嵌入向量以及关系频率的列表。
3.for data in tqdm(all_data):这个循环遍历所有客户端的数据,并对每个客户端进行处理
4.nentity = len(np.unique(data['train']['edge_index'])): 这行代码计算当前客户端训练数据中的实体数量,通过获取边索引 'edge_index' 并使用 np.unique 函数获取独特的实体索引,然后通过 len 函数计算实体的数量。
5.构建训练、验证和测试数据集:这部分代码通过整理当前客户端的训练、验证和测试数据来创建相应的数据集。训练数据集使用了 TrainDataset 类,而验证和测试数据集则使用了 valid_dataset和TestDataset 类。
6.构建数据加载器:用于将数据分发给不同客户端
7.初始化实体嵌入向量 ent_embed:这部分代码根据模型的不同(args.model)初始化实体嵌入向量 ent_embed,并将其添加到 ent_embed_list 列表中。
8.计算关系频率:计算不同客户机中relation的出现频率,并将其保存在 rel_freq 中。这样可以用于在后续任务中根据关系频率进行权重调整等操作。
9.rel_freq_mat = torch.stack(rel_freq_list).to(args.gpu):将关系频率列表 rel_freq_list 转换为 PyTorch 张量,并将其放置在指定的 GPU 上。
10.返回结果:最后,函数返回所有客户端的数据加载器、关系频率矩阵、实体嵌入向量列表和总关系数量 nrelation
def get_all_clients(all_data, args):
all_rel = np.array([], dtype=int)
for data in all_data:
all_rel = np.union1d(all_rel, data['train']['edge_type_ori']).reshape(-1)
nrelation = len(all_rel) # all relations of training set in all clients
train_dataloader_list = []
test_dataloader_list = []
valid_dataloader_list = []
ent_embed_list = []
rel_freq_list = []
for data in tqdm(all_data): # in a client
nentity = len(np.unique(data['train']['edge_index'])) # entities of training in a client
train_triples = np.stack((data['train']['edge_index'][0],
data['train']['edge_type_ori'],
data['train']['edge_index'][1])).T
valid_triples = np.stack((data['valid']['edge_index'][0],
data['valid']['edge_type_ori'],
data['valid']['edge_index'][1])).T
test_triples = np.stack((data['test']['edge_index'][0],
data['test']['edge_type_ori'],
data['test']['edge_index'][1])).T
client_mask_rel = np.setdiff1d(np.arange(nrelation),
np.unique(data['train']['edge_type_ori'].reshape(-1)), assume_unique=True)
all_triples = np.concatenate([train_triples, valid_triples, test_triples]) # in a client
train_dataset = TrainDataset(train_triples, nentity, args.num_neg)
valid_dataset = TestDataset(valid_triples, all_triples, nentity, client_mask_rel)
test_dataset = TestDataset(test_triples, all_triples, nentity, client_mask_rel)
# dataloader,数据划分
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
collate_fn=TrainDataset.collate_fn
)
train_dataloader_list.append(train_dataloader)
valid_dataloader = DataLoader(
valid_dataset,
batch_size=args.test_batch_size,
collate_fn=TestDataset.collate_fn
)
valid_dataloader_list.append(valid_dataloader)
test_dataloader = DataLoader(
test_dataset,
batch_size=args.test_batch_size,
collate_fn=TestDataset.collate_fn
)
test_dataloader_list.append(test_dataloader)
embedding_range = torch.Tensor([(args.gamma + args.epsilon) / args.hidden_dim])
'''use n of entity in train or all (train, valid, test)?'''
if args.model in ['RotatE', 'ComplEx']:
ent_embed = torch.zeros(nentity, args.hidden_dim*2).to(args.gpu).requires_grad_()
else:
ent_embed = torch.zeros(nentity, args.hidden_dim).to(args.gpu).requires_grad_()
nn.init.uniform_(
tensor=ent_embed,
a=-embedding_range.item(),
b=embedding_range.item()
)
ent_embed_list.append(ent_embed)
rel_freq = torch.zeros(nrelation)
for r in data['train']['edge_type_ori'].reshape(-1):
rel_freq[r] += 1
rel_freq_list.append(rel_freq)
rel_freq_mat = torch.stack(rel_freq_list).to(args.gpu)
return train_dataloader_list, valid_dataloader_list, test_dataloader_list, \
rel_freq_mat, ent_embed_list, nrelation
客户端的数据分发
每个客户端都有数据,并且拥有自己的模型
class Client(object):
def __init__(self, args, client_id, data, train_dataloader,
valid_dataloader, test_dataloader, ent_embed):
self.args = args
self.data = data
self.train_dataloader = train_dataloader
self.valid_dataloader = valid_dataloader
self.test_dataloader = test_dataloader
self.ent_embed = ent_embed
self.client_id = client_id
self.score_local = []
self.score_global = []
self.kge_model = KGEModel(args, args.model)
self.rel_embed = None
class KGEModel(nn.Module):
def __init__(self, args, model_name):
super(KGEModel, self).__init__()
self.model_name = model_name
self.embedding_range = torch.Tensor([(args.gamma + args.epsilon) / args.hidden_dim])
self.gamma = nn.Parameter(
torch.Tensor([args.gamma]),
requires_grad=False
)
服务器的数据分发
1.embedding_range = torch.Tensor([(args.gamma + args.epsilon) / args.hidden_dim]):这行代码计算了关系嵌入向量初始化的范围 embedding_range。参数 args.gamma 和 args.epsilon 是模型的一些超参数,用于控制关系嵌入向量初始化范围的大小。args.hidden_dim 是模型中嵌入向量的维度。
2.self.rel_embed = torch.zeros(nrelation, args.hidden_dim*2).to(args.gpu).requires_grad_():如果模型类型是 'ComplEx',则创建一个形状为 (nrelation, args.hidden_dim2) 的全零张量 self.rel_embed,用于存储关系嵌入向量。nrelation 是关系的数量,args.hidden_dim2 是每个关系嵌入向量的维度。通过 .to(args.gpu) 将张量放置在指定的 GPU 上(如果使用了 GPU)。最后,通过 requires_grad_() 方法指定张量需要计算梯度,用于后续的模型训练和优化。
3.nn.init.uniform_(tensor=self.rel_embed, a=-embedding_range.item(), b=embedding_range.item()):这行代码使用均匀分布初始化关系嵌入向量 self.rel_embed。关系嵌入向量的值被随机采样自均匀分布,范围是从-embedding_range.item() 到 embedding_range.item()。
def __init__(self, args, nrelation):
self.args = args
embedding_range = torch.Tensor([(args.gamma + args.epsilon) / args.hidden_dim])
if args.model in ['ComplEx']:
self.rel_embed = torch.zeros(nrelation, args.hidden_dim*2).to(args.gpu).requires_grad_()
else:
self.rel_embed = torch.zeros(nrelation, args.hidden_dim).to(args.gpu).requires_grad_()
nn.init.uniform_(
tensor=self.rel_embed,
a=-embedding_range.item(),
b=embedding_range.item()
)
self.nrelation = nrelation
模型训练
1.best_epoch = 0, best_mrr = 0, bad_count = 0:初始化一些变量,best_epoch 用于存储在训练过程中获得最佳性能的轮次,best_mrr 用于存储最佳的 Mean Reciprocal Rank (MRR) 值,bad_count 用于记录模型性能没有提升的轮次数量。
2.mrr_plot_result = [], loss_plot_result = []:初始化空列表,用于保存每轮评估时的 MRR 值和每轮训练时的损失值。
3.for num_round in range(self.args.max_round):外层循环是训练的主循环,根据 self.args.max_round 指定的最大轮次进行训练。
4.n_sample = max(round(self.args.fraction * self.num_clients), 1):根据 self.args.fraction 和客户端的数量 self.num_clients 计算出本轮次选择的客户端数量 n_sample,保证选择的客户端数量不少于 1。
5.self.send_emb():将服务器relation向量传到客户机中
6.round_loss = 0: 初始化一个变量 round_loss 用于记录当前轮次的总损失值。
7.for k in iter(sample_set):这是内层循环,遍历本轮次选择的客户端。
8.self.server.aggregation(self.clients, self.rel_freq_mat):执行一个函数 aggregation(),该函数可能是用于在服务器端聚合从客户端接收到的更新,以更新全局模型参数。在分布式学习中,通常需要将不同客户端的更新进行聚合,以得到全局的模型。
9.if num_round % self.args.check_per_round == 0 and num_round != 0:检查是否到了评估模型的轮次,self.args.check_per_round 是指定的评估间隔。num_round != 0 确保从第一轮之后才进行评估。
10.eval_res = self.evaluate(): 执行一个函数 evaluate(),该函数用于评估当前轮次模型在验证集上的性能,并返回评估结果。
11.if eval_res['mrr'] > best_mrr:判断当前轮次的 MRR 是否优于最佳 MRR,如果是,则更新最佳 MRR 和最佳轮次 best_mrr 和 best_epoch。
12.bad_count += 1 和 logging.info('best model is at round {0}, mrr {1:.4f}, bad count {2}'.format(best_epoch, best_mrr, bad_count)):如果当前轮次 MRR 不如最佳 MRR,则增加 bad_count 记录性能没有提升的轮次数量,并打印当前最佳轮次的信息。
13.if bad_count >= self.args.early_stop_patience:检查是否达到早停止条件。self.args.early_stop_patience 是设定的早停止容忍度,如果连续 bad_count 轮模型性能没有提升,则触发早停止。
14.self.save_model(best_epoch):保存获得最佳性能的模型参数。
15.self.before_test_load() 和 self.evaluate(istest=True):在最后完成训练后,加载之前保存的最佳模型参数,并在测试集上进行最终的模型评估。
def train(self):
best_epoch = 0
best_mrr = 0
bad_count = 0
mrr_plot_result = []
loss_plot_result = []
for num_round in range(self.args.max_round):
n_sample = max(round(self.args.fraction * self.num_clients), 1)
sample_set = np.random.choice(self.num_clients, n_sample, replace=False)
self.send_emb()
round_loss = 0
for k in iter(sample_set):#不同客户机的损失值相加
client_loss = self.clients[k].client_update()
round_loss += client_loss
round_loss /= n_sample
self.server.aggregation(self.clients, self.rel_freq_mat)
logging.info('round: {} | loss: {:.4f}'.format(num_round, np.mean(round_loss)))
self.write_training_loss(np.mean(round_loss), num_round)
loss_plot_result.append(np.mean(round_loss))
if num_round % self.args.check_per_round == 0 and num_round != 0:
eval_res = self.evaluate()
self.write_evaluation_result(eval_res, num_round)
if eval_res['mrr'] > best_mrr:
best_mrr = eval_res['mrr']
best_epoch = num_round
logging.info('best model | mrr {:.4f}'.format(best_mrr))
self.save_checkpoint(num_round)
bad_count = 0
else:
bad_count += 1
logging.info('best model is at round {0}, mrr {1:.4f}, bad count {2}'.format(
best_epoch, best_mrr, bad_count))
mrr_plot_result.append(eval_res['mrr'])
if bad_count >= self.args.early_stop_patience:
logging.info('early stop at round {}'.format(num_round))
loss_file_name = 'loss/' + self.args.name + '_loss.pkl'
with open(loss_file_name, 'wb') as fp:
pickle.dump(loss_plot_result, fp)
mrr_file_name = 'loss/' + self.args.name + '_mrr.pkl'
with open(mrr_file_name, 'wb') as fp:
pickle.dump(mrr_plot_result, fp)
break
logging.info('finish training')
logging.info('save best model')
self.save_model(best_epoch)
self.before_test_load()
self.evaluate(istest=True)
客户机的训练
def client_update(self):
optimizer = optim.Adam([{'params': self.rel_embed},
{'params': self.ent_embed}], lr=self.args.lr)
losses = []
for i in range(self.args.local_epoch):
for batch in self.train_dataloader:
positive_sample, negative_sample, sample_idx = batch
positive_sample = positive_sample.to(self.args.gpu)
negative_sample = negative_sample.to(self.args.gpu)
#这里会调用kge_model的forward函数
negative_score = self.kge_model((positive_sample, negative_sample),
self.rel_embed, self.ent_embed)
negative_score = (F.softmax(negative_score * self.args.adversarial_temperature, dim=1).detach()
* F.logsigmoid(-negative_score)).sum(dim=1)
positive_score = self.kge_model(positive_sample,
self.rel_embed, self.ent_embed, neg=False)
positive_score = F.logsigmoid(positive_score).squeeze(dim=1)
positive_sample_loss = - positive_score.mean()
negative_sample_loss = - negative_score.mean()
loss = (positive_sample_loss + negative_sample_loss) / 2
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
return np.mean(losses)
计算分数f
分别为positive_sample和negative_sample计算分数、先执行forward函数,再执行TransE函数
def forward(self, sample, relation_embedding, entity_embedding, neg=True):
if not neg:
head = torch.index_select(
entity_embedding,
dim=0,
index=sample[:, 0]
).unsqueeze(1)
relation = torch.index_select(
relation_embedding,
dim=0,
index=sample[:, 1]
).unsqueeze(1)
tail = torch.index_select(
entity_embedding,
dim=0,
index=sample[:, 2]
).unsqueeze(1)
else:
head_part, tail_part = sample
batch_size = head_part.shape[0]
head = torch.index_select(
entity_embedding,
dim=0,
index=head_part[:, 0]
).unsqueeze(1)
relation = torch.index_select(
relation_embedding,
dim=0,
index=head_part[:, 1]
).unsqueeze(1)
if tail_part == None:
tail = entity_embedding.unsqueeze(0)
else:
negative_sample_size = tail_part.size(1)
tail = torch.index_select(
entity_embedding,
dim=0,
index=tail_part.view(-1)
).view(batch_size, negative_sample_size, -1)
model_func = {
'TransE': self.TransE,
'DistMult': self.DistMult,
'ComplEx': self.ComplEx,
'RotatE': self.RotatE,
}
score = model_func[self.model_name](head, relation, tail)
return score
def TransE(self, head, relation, tail):
score = (head + relation) - tail
score = self.gamma.item() - torch.norm(score, p=1, dim=2)
return score
服务器的聚合
def aggregation(self, clients, rel_update_weights):
agg_rel_mask = rel_update_weights #relation在三个客户机的权重
agg_rel_mask[rel_update_weights != 0] = 1 #非0元素标注为1
rel_w_sum = torch.sum(agg_rel_mask, dim=0) #对relation求和
rel_w = agg_rel_mask / rel_w_sum
rel_w[torch.isnan(rel_w)] = 0 #归一化
if self.args.model in ['ComplEx']:
update_rel_embed = torch.zeros(self.nrelation, self.args.hidden_dim * 2).to(self.args.gpu)
else:
update_rel_embed = torch.zeros(self.nrelation, self.args.hidden_dim).to(self.args.gpu) #初始化
for i, client in enumerate(clients):
local_rel_embed = client.rel_embed.clone().detach()
# 这行代码将当前客户端的局部关系嵌入向量 local_rel_embed 与其对应的关系权重 rel_w[i] 相乘,
# 并加到 update_rel_embed 中。
# 由于 local_rel_embed 的形状为 (self.nrelation, self.args.hidden_dim),
# 而 rel_w[i] 的形状为 (self.nrelation,),所以通过 rel_w[i].reshape(-1, 1) 将其转换为形状为 (self.nrelation, 1),这样两个张量可以进行逐元素乘法。
update_rel_embed += local_rel_embed * rel_w[i].reshape(-1, 1)
self.rel_embed = update_rel_embed.requires_grad_()