FedR——攻击代码的学习
攻击客户机1
请注意,这是一个简化的示例,实际数据和嵌入可能更复杂。攻击的成功率取决于实体嵌入的质量和相似性度量的准确性。在实际应用中,可能需要更多的数据和更复杂的算法来实现更准确和有效的攻击。
这段代码是用于进行攻击的部分。它试图通过使用客户端0的信息(实体嵌入和关系嵌入)来破解客户端1的信息(部分实体和关系的嵌入)。攻击的过程包括以下步骤:
-
加载训练得到的模型参数:通过
torch.load()
函数加载之前训练得到的模型参数,其中ent_embed
和rel_embed
分别表示实体嵌入和关系嵌入。 -
创建客户端0的字典信息:从数据中提取客户端0的实体,并将其与对应的实体嵌入组成字典
c0_ent_embed_dict
。 -
对客户端0的实体进行映射:由于在客户端1的数据中,实体的索引可能与客户端0的数据中不同,因此需要建立映射关系
c0_mapping
来将客户端0的实体索引映射到客户端1的实体索引。 -
在客户端1上执行攻击:对客户端1进行攻击,通过在客户端1的实体池中选择一部分实体(由
p
参数控制选择比例),然后计算这些实体与客户端0的实体的嵌入之间的余弦距离,并选择距离最近的客户端0的实体作为对应的伪造实体,形成伪造实体列表syc_ent_list
。 -
计算攻击成功率:计算成功破解的实体的比例和关系的比例,即伪造实体列表中与客户端1的实体池中实体相同的实体数量与客户端1数据中所有三元组数量之间的比值。
请注意,这段代码是为了演示攻击方法,并且使用余弦距离来测量实体之间的相似性。在实际应用中,可能需要更复杂的攻击策略和更准确的相似性度量来实现更高效的攻击。
已知客户机0的所有信息,以及客户机1的向量信息,破解客户机1的实体信息
假设数据集中包含以下数据:
客户端0的数据:
- 训练集包含3个三元组:(A, R1, B), (B, R2, C), (C, R1, D)
- 实体集合:
- 关系集合:
客户端1的数据:
- 训练集包含4个三元组:(X, R2, Y), (Y, R3, Z), (Z, R1, X), (W, R2, Z)
- 实体集合:
- 关系集合:
攻击参数设置:
p = 0.5
客户端0的实体嵌入:
c0_ent_embed_dict = {
A: [0.1, 0.2],
B: [0.3, 0.4],
C: [0.5, 0.6],
D: [0.7, 0.8]
}
客户端0的实体索引映射:
c0_mapping = {
0: A,
1: B,
2: C,
3: D
}
客户端1的实体池和实体嵌入:
c1_ent_pool = [X, Y]
c1_ent_embed = {
X: [0.9, 1.0],
Y: [1.1, 1.2]
}
攻击过程:
- 计算客户端1中选中实体(X, Y)与客户端0实体的余弦距离,并选择最相似的实体:
# 以X为例计算与客户端0实体的余弦距离
cosine_distance_X_A = spatial.distance.cosine([0.9, 1.0], [0.1, 0.2]) # 假设计算结果为0.5
cosine_distance_X_B = spatial.distance.cosine([0.9, 1.0], [0.3, 0.4]) # 假设计算结果为0.8
cosine_distance_X_C = spatial.distance.cosine([0.9, 1.0], [0.5, 0.6]) # 假设计算结果为0.2
cosine_distance_X_D = spatial.distance.cosine([0.9, 1.0], [0.7, 0.8]) # 假设计算结果为0.9
# 选择余弦距离最小的实体作为对应的伪造实体
syn_ent_list = [C, X]
- 计算攻击成功率:
tru_ent_list = [X, Y] # 客户端1实体池中的实体列表
syn_ent_list = [C, X] # 由攻击生成的伪造实体列表
# 计算成功破解的实体的比例
success_rate = sum(first == second for (first, second) in zip(syn_ent_list, tru_ent_list)) / len(c1_ent_pool)
# 假设攻击成功破解的实体数量为1,计算结果为1/2 = 0.5
请注意,这是一个简化的示例,实际数据和嵌入可能更复杂。攻击的成功率取决于实体嵌入的质量和相似性度量的准确性。在实际应用中,可能需要更多的数据和更复杂的算法来实现更准确和有效的攻击。
import torch
import pickle
import numpy as np
import random
from scipy import spatial
emb = torch.load('./state/fb15k237_fed3_fed_TransE.best', map_location=torch.device('cpu'))
ent_embed = emb['ent_embed']
rel_embed = emb['rel_embed']
data = pickle.load(open("Fed_data/FB15K237-Fed3.pkl", "rb" ))
#生成第一个客户端的字典信息
c0_ent = np.unique(data[0]['train']['edge_index'])
c0_ent_embed_dict = {}
value = ent_embed[0]
for idx,ent in enumerate(c0_ent):
c0_ent_embed_dict[ent] = value[idx]
c0_mapping = dict(zip(data[0]['train']['edge_index'][0], data[0]['train']['edge_index_ori'][0]))
c0_mapping.update(dict(zip(data[0]['train']['edge_index'][1], data[0]['train']['edge_index_ori'][1])))
c0_ent_embed_dict_mapped = dict((c0_mapping[key], value) for (key, value) in c0_ent_embed_dict.items())
c0_ent_pool_mapped = [c0_mapping[i] for i in c0_ent]
# map local to global
c1_mapping = dict(zip(data[1]['train']['edge_index'][0], data[1]['train']['edge_index_ori'][0]))
c1_mapping.update(dict(zip(data[1]['train']['edge_index'][1], data[1]['train']['edge_index_ori'][1])))
c1_ent = np.unique(data[1]['train']['edge_index'])
random.seed(10)
np.random.seed(10)
p = 1
c1_ent_pool = np.random.choice(c1_ent, int(p * len(c1_ent)), replace = False)
c1_ent_embed = ent_embed[1][[c1_ent_pool]]
c1_ent_pool_mapped = [c1_mapping[i] for i in c1_ent_pool]
syn_ent_list = [] # synthetic entity label
for i in c1_ent_pool:
c1_ent_embed = ent_embed[1][i]
count = 0
loss_bound = 0
ent_idx = []
for j in c0_ent_embed_dict_mapped:
loss = spatial.distance.cosine(c1_ent_embed.detach().numpy(), c0_ent_embed_dict_mapped[j].detach().numpy())
if count == 0: # first round
loss_bound = loss
ent_idx.append(j)
count += 1
else:
if loss < loss_bound:
loss_bound = loss
ent_idx.append(j)
syn_ent_list.append(ent_idx[-1]) # global index of the entity
tru_ent_list = [c1_mapping[i] for i in c1_ent_pool]
# calculate the number of correct reconstruction
sum(first == second for (first, second) in zip(syn_ent_list, tru_ent_list)) / len(c1_ent)
c0_rel = np.unique(data[0]['train']['edge_type_ori'])
# creat relation pool based on selected entities (global relation index)
c1_triple_all = np.array([data[1]['train']['edge_index_ori'][0],
data[1]['train']['edge_type_ori'],
data[1]['train']['edge_index_ori'][1]])
tru_trr_list = []
# the adversary knows all relation embeddings and their corresponding index, so here we use ori directly
len_c1_triple = c1_triple_all[0].shape[0]
for i in range(len_c1_triple):
triple = c1_triple_all[:,i]
h, r, t= triple[0], triple[1], triple[2]
if (h in c1_ent_pool_mapped) and (t in c1_ent_pool_mapped):
if h not in tru_trr_list:
tru_trr_list.append(h)
if t not in tru_trr_list:
tru_trr_list.append(t)
syn_trr_list = []
for (first, second) in zip(syn_ent_list, tru_ent_list):
if first == second:
syn_trr_list.append(first)
# calculate the number of correct reconstruction
len(list(set(syn_trr_list).intersection(tru_trr_list))) / len_c1_triple