MakKEr代码的学习
入口函数
这个Python脚本定义了一个入口函数,用于运行一个知识图谱嵌入(Knowledge Graph Embedding)模型的元学习训练过程。以下是对这个入口函数的解读:
-
导入必要的库:
argparse
:用于解析命令行参数。init_dir
:一个自定义的函数,用于初始化目录。MetaTrainer
:一个自定义的元训练器类,用于执行元学习训练过程。- 其他必要的库和模块。
-
定义命令行参数:
使用argparse.ArgumentParser
来定义各种命令行参数,这些参数将用于配置训练过程。包括数据路径、状态目录、日志目录、任务名称、实验名称等。 -
检查模块入口:
if __name__ == '__main__':
这个条件语句确保脚本仅在直接运行时才执行以下代码块,而不会在模块导入时执行。 -
解析命令行参数:
使用parser.parse_args()
解析命令行参数,并将它们存储在args
对象中。 -
根据不同的嵌入模型(kge)设置维度:
根据所选择的嵌入模型(TransE、DistMult、ComplEx、RotatE),调整实体维度(ent_dim)和关系维度(rel_dim)。不同的嵌入模型可能需要不同的维度设置。 -
创建子图数据集:
根据参数中的数据路径,生成子图数据集。如果数据路径对应的子图数据集不存在,会调用gen_subgraph_datasets(args)
函数生成子图数据集。 -
初始化目录:
调用init_dir(args)
函数,用于初始化目录,确保状态目录、日志目录等存在。 -
进行多次实验:
使用一个循环来进行多次实验。循环变量run
的范围由args.num_exp
定义,表示要运行多少次实验。 -
配置实验参数:
在每次实验开始前,根据循环变量run
配置实验参数,包括设置exp_name
为任务名称加上当前运行次数,即实验名称。 -
创建元训练器并进行训练:
- 创建一个
MetaTrainer
实例,传入args
作为参数。 - 调用
trainer.train()
方法执行元学习训练过程。
- 创建一个
-
删除训练器实例:
在每次实验完成后,通过del trainer
语句删除训练器实例,释放资源。
总之,这个入口函数通过命令行参数配置了知识图谱嵌入模型的元学习训练过程,支持多次实验,并在每次实验中根据不同参数配置进行训练。不同的命令行参数会影响嵌入模型、数据集、训练配置等。
import argparse
from utils import init_dir
from meta_trainer import MetaTrainer
import os
from subgraph import gen_subgraph_datasets
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', default='./data/fb_ext.pkl')
parser.add_argument('--state_dir', default='./state')
parser.add_argument('--log_dir', default='./log')
parser.add_argument('--tb_log_dir', default='./tb_log')
parser.add_argument('--task_name', default='fb_ext')
parser.add_argument('--exp_name', default=None, type=str)
parser.add_argument('--num_exp', default=1, type=int)
parser.add_argument('--train_bs', default=64, type=int)
parser.add_argument('--eval_bs', default=16, type=int)
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--num_step', default=100000, type=int)
parser.add_argument('--log_per_step', default=10, type=int)
parser.add_argument('--check_per_step', default=30, type=int)
parser.add_argument('--early_stop_patience', default=20, type=int)
parser.add_argument('--num_sample_cand', default=5, type=int)
parser.add_argument('--dim', default=32, type=int)
parser.add_argument('--ent_dim', default=None, type=int)
parser.add_argument('--rel_dim', default=None, type=int)
parser.add_argument('--num_layers', default=2, type=int)
parser.add_argument('--num_rel_bases', default=4, type=int)
parser.add_argument('--kge', default='TransE', type=str, choices=['TransE', 'DistMult', 'ComplEx', 'RotatE'])
parser.add_argument('--metatrain_num_neg', default=32)
parser.add_argument('--adv_temp', default=1, type=float)
parser.add_argument('--gamma', default=10, type=float)
parser.add_argument('--cpu_num', default=10, type=float)
parser.add_argument('--gpu', default='cuda:0', type=str)
# subgraph
parser.add_argument('--db_path', default=None)
parser.add_argument('--num_train_subgraph', default=10000)
parser.add_argument('--num_sample_for_estimate_size', default=10)
parser.add_argument('--rw_0', default=10, type=int)
parser.add_argument('--rw_1', default=10, type=int)
parser.add_argument('--rw_2', default=5, type=int)
args = parser.parse_args()
if args.kge in ['TransE', 'DistMult']:
args.ent_dim = args.dim
args.rel_dim = args.dim
elif args.kge == 'RotatE':
args.ent_dim = args.dim * 2
args.rel_dim = args.dim
elif args.kge == 'ComplEx':
args.ent_dim = args.dim * 2
args.rel_dim = args.dim * 2
args.db_path = args.data_path[:-4] + '_subgraph'
if not os.path.exists(args.db_path):
gen_subgraph_datasets(args)
init_dir(args)
for run in range(args.num_exp):
args.run = run
args.exp_name = args.task_name + f'_run{args.run}'
trainer = MetaTrainer(args)
trainer.train()
del trainer
创建子图数据集
数据集的格式
data = {
'train': {
'triples': [[0, 1, 2], [3, 4, 5], ...] # a list of triples in (h, r, t), denoted by corresponding indexes
'ent2id': {'abc':0, 'def':1, ...} # map entity name from original dataset (e.g., FB15k-237) to the index of above triples
'rel2id': {'xyz':0, 'ijk':1, ...} # map relation name from original dataset (e.g., FB15k-237) to the index of above triples
}
'valid': {
'support': # support triples
'query': # query triples
'ent_map_list': [0, -1, 4, -1, -1, ...] # map entity indexes to train entities, -1 denotes an unseen entitie
'rel_map_list': [-1, 2, -1, -1, -1, ...] # map relation indexes to train relation, -1 denotes an unseen relation
'ent2id':
'rel2id':
}
'test': {
'support':
'query_uent': # query triples only containing unseen entities
'query_urel': # query triples only containing unseen relations
'query_uboth': # query triples containing unseen entities and relations
'ent_map_list':
'rel_map_list':
'ent2id':
'rel2id':
}}
解释ent_map_list和rel_map_list
当涉及到元学习或者子图生成任务时,往往需要将不同的实体和关系进行映射,以便在任务中使用。在上面提供的数据示例中,ent_map_list
和 rel_map_list
就是用来进行这种映射的列表。
举例来说,假设我们有一个训练数据集包含以下信息:
'train': {
'triples': [
[0, 1, 2], # (entity 0, relation 1, entity 2)
[3, 4, 5], # (entity 3, relation 4, entity 5)
...
],
'ent2id': {'abc': 0, 'def': 1, ...}, # entity name to index mapping
'rel2id': {'xyz': 0, 'ijk': 1, ...} # relation name to index mapping
}
在上述训练数据中,实体和关系都被映射到了数字索引。现在,假设我们有一个验证数据集,其包含了一些支持(support)和查询(query)三元组,同时也包含了一些实体和关系的映射关系。
'valid': {
'support': [...], # support triples
'query': [...], # query triples
'ent_map_list': [0, -1, 4, -1, -1, ...], # entity index mapping
'rel_map_list': [-1, 2, -1, -1, -1, ...], # relation index mapping
'ent2id': {...}, # entity name to index mapping (for the validation set)
'rel2id': {...} # relation name to index mapping (for the validation set)
}
在这个示例中,ent_map_list
是一个列表,其中的值对应着在验证数据集中的实体索引。例如,如果 ent_map_list[0]
的值为 0,那么表示验证数据集中的第一个实体(按顺序)与训练数据集中的第一个实体(索引为 0)相对应。如果 ent_map_list[1]
的值为 -1,那么表示验证数据集中的第二个实体是一个未见过的实体,没有对应的训练数据。同样,rel_map_list
的处理方式也类似,用于映射验证数据集中的关系索引到训练数据集中的关系索引。
这样的映射操作在元学习中很常见,因为在元学习任务中,模型需要在不同的子任务之间进行学习和推理,而不同的子任务可能涉及到不同的实体和关系。
-
print('----------generate tasks(sub-KGs) for meta-training----------')
:- 使用
print
函数在控制台打印消息。 - 消息是一个字符串,用于表示正在生成用于元学习的子图数据任务。
- 使用
-
data = pickle.load(open(args.data_path, 'rb'))
:- 使用
open
函数打开指定路径的文件,并以二进制模式('rb')读取。 - 被
pickle.load
函数加载的数据存储在变量data
中。 - 数据是使用
pickle
序列化的对象,包含了训练、验证和测试数据的相关信息。
- 使用
-
bg_train_g = get_g(data['train']['triples'])
:- 通过索引访问字典
data
中的'train'
键,获得训练数据。 - 从训练数据中获取
'triples'
键对应的值,即训练集中的三元组列表。 - 调用函数
get_g
,使用训练集三元组构建一个图。 - 将图对象存储在变量
bg_train_g
中。
- 通过索引访问字典
-
BYTES_PER_DATUM = get_average_subgraph_size(args, args.num_sample_for_estimate_size, bg_train_g) * 2
:- 调用函数
get_average_subgraph_size
,并传入参数args
、args.num_sample_for_estimate_size
以及bg_train_g
。 - 乘以 2,得到每个子图数据的估计字节数,并将结果存储在变量
BYTES_PER_DATUM
中。
- 调用函数
-
map_size = (args.num_train_subgraph) * BYTES_PER_DATUM
:- 计算用于存储所有子图数据的 LMDB 映射大小。
- 乘以
args.num_train_subgraph
(训练子图的数量),得到总的映射大小,并将结果存储在变量map_size
中。
-
env = lmdb.open(args.db_path, map_size=map_size, max_dbs=1)
:- 调用
lmdb.open
函数,打开一个 LMDB 环境。 - 使用指定的映射大小
map_size
和最大数据库数max_dbs
,创建一个环境对象,并将其存储在变量env
中。
- 调用
-
train_subgraphs_db = env.open_db("train_subgraphs".encode())
:- 在打开的 LMDB 环境
env
中,使用"train_subgraphs"
作为数据库名称,创建一个数据库对象,并将其存储在变量train_subgraphs_db
中。
- 在打开的 LMDB 环境
-
with mp.Pool(processes=10, initializer=intialize_worker, initargs=(args, bg_train_g)) as p:
使用多进程池进行并行子图生成,使用最多 10 个进程。- 使用
mp.Pool
创建一个多进程池对象p
,设置最大进程数为 10,初始化函数为intialize_worker
,传递的参数为args
和bg_train_g
。 - 使用
range
函数生成一个索引范围idx_
,用于迭代生成子图。 - 使用
p.imap
在多个进程中并行生成子图,使用函数sample_one_subgraph
进行子图生成,总共迭代args.num_train_subgraph
次。
- 使用
-
for (str_id, datum) in tqdm(p.imap(sample_one_subgraph, idx_), total=args.num_train_subgraph):
:- 使用
tqdm
函数创建一个进度条,显示子图生成的进度。 - 迭代并行生成的子图数据,对每个子图数据执行以下操作。
- 使用
-
with env.begin(write=True, db=train_subgraphs_db) as txn:
:- 使用
env.begin
创建一个 LMDB 事务,支持写操作。 - 打开创建的数据库
train_subgraphs_db
,将数据库连接存储在变量txn
中。
- 使用
-
txn.put(str_id, serialize(datum))
:- 在数据库事务
txn
中,使用put
函数将子图数据序列化后存储。 - 使用
str_id
作为键,datum
经过序列化后的数据作为值。
- 在数据库事务
当然,我会为您详细解释这部分代码块的每一部分。
with mp.Pool(processes=10, initializer=intialize_worker, initargs=(args, bg_train_g)) as p:
idx_ = range(args.num_train_subgraph)
for (str_id, datum) in tqdm(p.imap(sample_one_subgraph, idx_), total=args.num_train_subgraph):
with env.begin(write=True, db=train_subgraphs_db) as txn:
txn.put(str_id, serialize(datum))
-
with mp.Pool(processes=10, initializer=intialize_worker, initargs=(args, bg_train_g)) as p:
:- 使用
mp.Pool
创建一个具有 10 个进程的进程池p
。这允许我们并行地执行子图生成任务。 initializer=intialize_worker
:在每个进程启动之前,将调用名为intialize_worker
的初始化函数。它通常用于初始化每个工作进程需要的环境和资源。initargs=(args, bg_train_g)
:将参数args
和bg_train_g
传递给初始化函数。
- 使用
-
idx_ = range(args.num_train_subgraph)
:- 创建一个迭代器
idx_
,该迭代器包含了一个从 0 到args.num_train_subgraph - 1
的范围。
- 创建一个迭代器
-
for (str_id, datum) in tqdm(p.imap(sample_one_subgraph, idx_), total=args.num_train_subgraph):
:- 对于每个子图生成任务,使用
p.imap
并行地迭代idx_
中的索引。 sample_one_subgraph
函数将在进程池中的一个进程中被调用,并且传递当前索引作为参数。tqdm
是一个用于创建进度条的工具,它显示了子图生成的进度,总数为args.num_train_subgraph
。
- 对于每个子图生成任务,使用
-
with env.begin(write=True, db=train_subgraphs_db) as txn:
:- 创建一个 LMDB 事务
txn
,以进行写操作。这将在每次存储子图数据时使用。
- 创建一个 LMDB 事务
-
txn.put(str_id, serialize(datum))
:- 在 LMDB 事务
txn
中,使用put
方法将子图数据存储到数据库中。 str_id
是子图数据在数据库中的键,通常是一个字符串或字节串。serialize(datum)
将子图数据datum
序列化,以便存储在数据库中。
通过这段代码,我们创建了一个进程池,并在多个进程中并行地生成子图数据。每个生成的子图数据将序列化并存储在 LMDB 数据库中,以供后续的元学习训练使用。这种并行生成和存储的方式可以显著提高数据生成的效率。
- 在 LMDB 事务
def gen_subgraph_datasets(args):
print('----------generate tasks(sub-KGs) for meta-training----------')
data = pickle.load(open(args.data_path, 'rb'))
bg_train_g = get_g(data['train']['triples'])
BYTES_PER_DATUM = get_average_subgraph_size(args, args.num_sample_for_estimate_size, bg_train_g) * 2
map_size = (args.num_train_subgraph) * BYTES_PER_DATUM
env = lmdb.open(args.db_path, map_size=map_size, max_dbs=1)
train_subgraphs_db = env.open_db("train_subgraphs".encode())
with mp.Pool(processes=10, initializer=intialize_worker, initargs=(args, bg_train_g)) as p:
idx_ = range(args.num_train_subgraph)
for (str_id, datum) in tqdm(p.imap(sample_one_subgraph, idx_), total=args.num_train_subgraph):
with env.begin(write=True, db=train_subgraphs_db) as txn:
txn.put(str_id, serialize(datum))
代码中的函数 sample_one_subgraph
是用来生成子图的核心部分,它执行以下操作:
-
创建一个双向图(
bg_train_g_undir
):通过将原始图的边连接成双向边,用于随机游走采样。 -
随机游走采样:在双向图上进行多次随机游走,从中获取节点,形成一个子图。这个过程是为了获取一部分子图数据,确保子图的大小满足要求。
-
转换子图边为三元组:将子图中的边转换为三元组的形式(头实体、关系、尾实体),便于后续处理。
-
重新索引实体和关系:为了减小实体和关系的索引范围,将实体和关系进行重新索引,并统计各自的频率。
-
生成查询和支持三元组:从重新索引后的三元组中随机选择查询和支持三元组。查询三元组用于模型的元学习任务,支持三元组用于帮助生成查询三元组。
-
获取映射和模式三元组:根据支持和查询三元组,生成一些用于模型训练的映射(
hr2t
、rt2h
)和模式(pattern_tris
)三元组。 -
创建 LMDB 键值对:将生成的子图数据进行处理,将索引转换为字节串,并将数据存储在 LMDB 数据库中。
整个过程涉及随机采样、数据转换、频率统计和数据存储等步骤,以便为元学习模型提供任务数据。这些生成的子图数据将用于元学习模型的训练过程。
-
str_id
: 这是一个用于在 LMDB 数据库中存储数据的键,通常是一个格式化后的字符串编码成的 ASCII 字节串,用于唯一标识存储的数据。 -
sup_tris
: 这是一个列表,包含支持三元组的信息。支持三元组是用于帮助生成查询三元组的数据,通常在元学习中被使用。 -
pattern_tris
: 这也是一个列表,包含模式三元组的信息。模式三元组在元学习中用于训练模型,帮助模型学习关于关系和实体的模式。 -
que_tris
: 这同样是一个列表,包含查询三元组的信息。查询三元组是元学习任务的目标,模型的目标是根据查询三元组进行预测。 -
hr2t
: 这是一个映射,表示头实体到尾实体的关系,用于支持查询三元组的生成和预测。 -
rt2h
: 这也是一个映射,表示尾实体到头实体的关系,同样用于支持查询三元组的生成和预测。 -
ent_map_list
: 这是一个列表,包含了重新索引的实体的映射关系。它将原始数据集中的实体索引映射到训练数据集中的重新索引后的实体索引。 -
rel_map_list
: 类似于ent_map_list
,这也是一个列表,包含了重新索引的关系的映射关系。它将原始数据集中的关系索引映射到训练数据集中的重新索引后的关系索引。
这些变量在生成子图数据时承载了不同的功能,涉及支持、查询、映射和模式的信息,用于元学习任务的构建和训练。
def sample_one_subgraph(idx_):
args = args_
bg_train_g = bg_train_g_
# get graph with bi-direction
bg_train_g_undir = dgl.graph((torch.cat([bg_train_g.edges()[0], bg_train_g.edges()[1]]),
torch.cat([bg_train_g.edges()[1], bg_train_g.edges()[0]])))
# induce sub-graph by sampled nodes
while True:
while True:
sel_nodes = []
for i in range(args.rw_0):
if i == 0:
cand_nodes = np.arange(bg_train_g.num_nodes())
else:
cand_nodes = sel_nodes
try:
rw, _ = dgl.sampling.random_walk(bg_train_g_undir,
np.random.choice(cand_nodes, 1, replace=False).repeat(args.rw_1),
length=args.rw_2)
except ValueError:
print(cand_nodes)
sel_nodes.extend(np.unique(rw.reshape(-1)))
sel_nodes = list(np.unique(sel_nodes)) if -1 not in sel_nodes else list(np.unique(sel_nodes))[1:]
sub_g = dgl.node_subgraph(bg_train_g, sel_nodes)
if sub_g.num_nodes() >= 50:
break
sub_tri = torch.stack([sub_g.ndata[dgl.NID][sub_g.edges()[0]],
sub_g.edata['rel'],
sub_g.ndata[dgl.NID][sub_g.edges()[1]]])
sub_tri = sub_tri.T.tolist()
random.shuffle(sub_tri)
ent_freq = ddict(int)
rel_freq = ddict(int)
triples_reidx = []
rel_reidx = dict()
relidx = 0
ent_reidx = dict()
entidx = 0
for tri in sub_tri:
h, r, t = tri
if h not in ent_reidx.keys():
ent_reidx[h] = entidx
entidx += 1
if t not in ent_reidx.keys():
ent_reidx[t] = entidx
entidx += 1
if r not in rel_reidx.keys():
rel_reidx[r] = relidx
relidx += 1
ent_freq[ent_reidx[h]] += 1
ent_freq[ent_reidx[t]] += 1
rel_freq[rel_reidx[r]] += 1
triples_reidx.append([ent_reidx[h], rel_reidx[r], ent_reidx[t]])
ent_reidx_inv = {v: k for k, v in ent_reidx.items()}
rel_reidx_inv = {v: k for k, v in rel_reidx.items()}
ent_map_list = [ent_reidx_inv[i] for i in range(len(ent_reidx))]
rel_map_list = [rel_reidx_inv[i] for i in range(len(rel_reidx))]
# randomly get query triples
que_tris = []
sup_tris = []
for idx, tri in enumerate(triples_reidx):
h, r, t = tri
if ent_freq[h] > 2 and ent_freq[t] > 2 and rel_freq[r] > 2:
que_tris.append(tri)
ent_freq[h] -= 1
ent_freq[t] -= 1
rel_freq[r] -= 1
else:
sup_tris.append(tri)
if len(que_tris) >= int(len(triples_reidx)*0.1):
break
sup_tris.extend(triples_reidx[idx+1:])
if len(que_tris) >= int(len(triples_reidx)*0.1):
break
# hr2t, rt2h
hr2t, rt2h, rel_head, rel_tail = get_hr2t_rt2h_sup_que(sup_tris, que_tris)
pattern_tris = get_train_pattern_g(rel_head, rel_tail)
str_id = '{:08}'.format(idx_).encode('ascii')
return str_id, (sup_tris, pattern_tris, que_tris, hr2t, rt2h, ent_map_list, rel_map_list)
初始化trainer的参数
这两段代码定义了两个类,Trainer
和 MetaTrainer
,分别用于训练模型和元训练模型。下面对每个类进行解读:
Trainer
类:
-
构造函数
__init__(self, args)
:- 初始化函数接受一个
args
参数,该参数包含各种训练相关的配置选项。 - 初始化了一些实例变量,如
args
、name
、writer
、logger
、state_path
等。
- 初始化函数接受一个
-
日志和写入器初始化:
- 创建了一个命名为
name
的实验记录名称,用于在训练过程中创建 TensorBoard 日志和日志文件。 - 创建了一个
SummaryWriter
对象用于创建 TensorBoard 日志,将其保存在给定路径下。 - 创建了一个日志记录器(logger),用于记录训练中的信息,比如配置参数等。日志被写入到给定的日志目录中。
- 创建了一个命名为
-
数据加载和初始化:
- 构建了存储训练状态的目录路径
state_path
。 - 加载训练数据,使用
pickle.load
函数从指定路径读取数据文件。 - 从训练数据中获取实体和关系的数量,以便在模型中使用。
- 构建了存储训练状态的目录路径
-
构建验证和测试数据集:
- 使用训练数据和验证数据构建
ValidData
和TestData
数据集对象,这些对象将用于验证和测试模型。
- 使用训练数据和验证数据构建
-
初始化 KGE 模型和优化器:
- 创建了一个
KGEModel
对象(知识图嵌入模型),并将其放置在 GPU 上进行计算。 - 未初始化优化器,将在子类中初始化。
- 创建了一个
-
设置训练参数:
- 初始化了用于控制训练的参数,如
num_step
、log_per_step
、check_per_step
等。
- 初始化了用于控制训练的参数,如
MetaTrainer
类(继承自 Trainer
类):
-
构造函数
__init__(self, args)
:- 继承自
Trainer
类的初始化方法,并在其中进行特定的初始化操作。 - 创建一个用于迭代子图数据的迭代器
train_subgraph_iter
,并使用DataLoader
对其进行包装,以便进行批处理、洗牌等操作。
- 继承自
-
构建模型和优化器:
- 创建一个
Model
对象(可能是元学习相关的模型),并将其放置在 GPU 上进行计算。 - 初始化一个 Adam 优化器,将其绑定到模型的参数上,设置学习率。
- 创建一个
-
设置训练参数:
- 初始化了用于控制训练的参数,如
num_step
、log_per_step
、check_per_step
等,这些参数在Trainer
中已经解释过。
- 初始化了用于控制训练的参数,如
总之,这些类用于管理训练过程中的数据加载、模型构建、优化器配置等步骤,使得训练代码更加清晰和可扩展。MetaTrainer
类是 Trainer
类的一个扩展,用于特定的元学习场景。
class MetaTrainer(Trainer):
def __init__(self, args):
super(MetaTrainer, self).__init__(args)
# dataset
self.train_subgraph_iter = OneShotIterator(DataLoader(TrainSubgraphDataset(args),
batch_size=self.args.train_bs,
shuffle=True,
collate_fn=TrainSubgraphDataset.collate_fn))
# model
self.model = Model(args).to(args.gpu)
# optimizer
self.optimizer = optim.Adam(self.model.parameters(), lr=self.args.lr)
# args for controlling training
self.num_step = args.num_step
self.log_per_step = args.log_per_step
self.check_per_step = args.check_per_step
self.early_stop_patience = args.early_stop_patience
这段代码定义了两个类和一个实例化过程,涉及数据加载和处理的一些操作。
OneShotIterator
类:
-
构造函数
__init__(self, dataloader)
:- 接受一个 PyTorch 的数据加载器
dataloader
作为参数。 - 初始化一个成员变量
iterator
,通过调用静态方法one_shot_iterator(dataloader)
来创建一个迭代器。
- 接受一个 PyTorch 的数据加载器
-
静态方法
one_shot_iterator(dataloader)
:- 将传入的 PyTorch 数据加载器
dataloader
转换成一个 Python 迭代器。 - 使用无限循环
while True
遍历数据加载器的每个批次(即数据块)。 - 每次迭代返回一个批次的数据。
- 将传入的 PyTorch 数据加载器
TrainSubgraphDataset
类(继承自 Dataset
类):
-
构造函数
__init__(self, args)
:- 接受一个参数
args
,其中包含了数据加载所需的配置信息。 - 初始化类的实例,将
args
存储为成员变量。 - 打开一个以只读方式访问的 LMDB 数据库,并获取名为 "train_subgraphs" 的数据库。
- 接受一个参数
-
实例化
train_subgraph_iter
对象:- 使用
TrainSubgraphDataset
类的构造函数TrainSubgraphDataset(args)
创建一个数据集对象。 - 将数据集对象传递给
DataLoader
构造函数,同时指定批量大小(batch_size
)、是否进行洗牌(shuffle
)以及数据集的整理函数(collate_fn
)。 - 最终创建一个
OneShotIterator
对象train_subgraph_iter
,使用上述数据加载器作为参数。
- 使用
综合来看,train_subgraph_iter
是一个用于迭代训练子图数据的迭代器,它在 OneShotIterator
类中被定义,将 PyTorch 数据加载器转换为 Python 迭代器,用于在训练过程中一次一批地加载子图数据。
class OneShotIterator(object):
def __init__(self, dataloader):
self.iterator = self.one_shot_iterator(dataloader)
@staticmethod
def one_shot_iterator(dataloader):
'''
Transform a PyTorch Dataloader into python iterator
'''
while True:
for data in dataloader:
yield data
class TrainSubgraphDataset(Dataset):
def __init__(self, args):
self.args = args
self.env = lmdb.open(args.db_path, readonly=True, max_dbs=1, lock=False)
self.subgraphs_db = self.env.open_db("train_subgraphs".encode())
class Trainer(object):
def __init__(self, args):
self.args = args
# writer and logger
self.name = args.exp_name
self.writer = SummaryWriter(os.path.join(args.tb_log_dir, self.name))
self.logger = Log(args.log_dir, self.name).get_logger()
self.logger.info(json.dumps(vars(args)))
# state dir
self.state_path = os.path.join(args.state_dir, self.name)
if not os.path.exists(self.state_path):
os.makedirs(self.state_path)
# load data
self.data = pickle.load(open(args.data_path, 'rb'))
args.num_ent = len(self.data['train']['ent2id'])
args.num_rel = len(self.data['train']['rel2id'])
# dataset for validation and testing
self.valid_data = ValidData(args, self.data['valid'])
self.test_data = TestData(args, self.data['test'])
# kge models
self.kge_model = KGEModel(args).to(args.gpu)
# optimizer
self.optimizer = None
# args for controlling training
self.num_step = None
self.log_per_step = None
self.check_per_step = None
self.early_stop_patience = None
这段代码定义了两个类 ValidData
和 TestData
,它们继承自一个名为 Data
的基类。这两个类用于处理验证和测试数据,并对数据进行预处理和准备,以便在模型中使用。
ValidData
类:
-
构造函数
__init__(self, args, data)
:- 接受参数
args
和data
,其中args
包含数据处理的配置信息,而data
包含验证数据的详细信息。 - 调用基类
Data
的构造函数,传递args
和data
。
- 接受参数
-
初始化数据:
- 从
data
中获取支持三元组(sup_triples
)和查询三元组(que_triples
)。 - 从
data
中获取实体映射列表(ent_map_list
)和关系映射列表(rel_map_list
)。 - 根据支持和查询三元组获取
hr2t_all
和rt2h_all
映射。
- 从
-
构建图和模式图:
- 使用
get_train_g
方法根据支持三元组和实体映射列表构建训练图(g
)。 - 使用
get_pattern_tri
方法根据支持三元组获取模式三元组(pattern_tri
)。 - 使用
get_pattern_g
方法根据模式三元组和关系映射列表构建模式图(pattern_g
)。
- 使用
TestData
类:
-
构造函数
__init__(self, args, data)
:- 同样,接受参数
args
和data
。 - 调用基类
Data
的构造函数,传递args
和data
。
- 同样,接受参数
-
初始化数据:
- 获取支持三元组、不同类型的查询三元组(
que_uent
、que_urel
、que_uboth
)和实体、关系映射列表。
- 获取支持三元组、不同类型的查询三元组(
-
构建图和模式图:
- 类似于
ValidData
,使用支持三元组和实体映射列表构建训练图(g
)。 - 使用支持三元组获取模式三元组(
pattern_tri
)。 - 使用模式三元组和关系映射列表构建模式图(
pattern_g
)。
- 类似于
综合来看,这两个类的作用是根据输入的数据和配置信息,准备验证和测试数据的图形表示,以及支持和查询三元组的处理,以便在训练或测试模型时使用。
class ValidData(Data):
def __init__(self, args, data):
super(ValidData, self).__init__(args, data)
self.sup_triples = data['support']
self.que_triples = data['query']
self.ent_map_list = data['ent_map_list']
self.rel_map_list = data['rel_map_list']
self.hr2t_all, self.rt2h_all = self.get_hr2t_rt2h(self.sup_triples + self.que_triples)
# g and pattern g
self.g = self.get_train_g(self.sup_triples, ent_reidx_list=self.ent_map_list).to(args.gpu)
self.pattern_tri = self.get_pattern_tri(self.sup_triples)
self.pattern_g = self.get_pattern_g(self.pattern_tri, rel_reidx_list=self.rel_map_list).to(args.gpu)
class TestData(Data):
def __init__(self, args, data):
super(TestData, self).__init__(args, data)
self.sup_triples = data['support']
self.que_triples = data['query_uent'] + data['query_urel'] + data['query_uboth']
self.que_uent = data['query_uent']
self.que_urel = data['query_urel']
self.que_uboth = data['query_uboth']
self.ent_map_list = data['ent_map_list']
self.rel_map_list = data['rel_map_list']
self.hr2t_all, self.rt2h_all = self.get_hr2t_rt2h(self.sup_triples + self.que_triples)
# g and pattern g
self.g = self.get_train_g(self.sup_triples, ent_reidx_list=self.ent_map_list).to(args.gpu)
self.pattern_tri = self.get_pattern_tri(self.sup_triples)
self.pattern_g = self.get_pattern_g(self.pattern_tri, rel_reidx_list=self.rel_map_list).to(args.gpu)
这段代码定义了一个类 KGEModel
,该类继承自 PyTorch 的 nn.Module
,用于创建知识图嵌入(Knowledge Graph Embedding)模型。
-
构造函数
__init__(self, args)
:- 接受参数
args
,其中包含了模型的配置信息。 - 调用基类
nn.Module
的构造函数,初始化模型。
- 接受参数
-
初始化模型参数和超参数:
- 从参数
args
中获取模型的相关配置,如嵌入维度(emb_dim
)和损失函数中的超参数(gamma
)等。 - 创建一个模型的名称(
model_name
),该名称由参数args
中的kge
字段指定。 - 初始化模型中的一些超参数,如损失函数中的边界项
epsilon
,表示图嵌入向量在欧几里得空间中的分布范围。
- 从参数
-
初始化其他参数:
- 创建一个张量
gamma
,其中存储了参数args
中的gamma
值,用于损失函数的计算。 - 计算一个边界值,存储在张量
embedding_range
中,用于对嵌入向量进行范围约束,该值基于gamma
和epsilon
。 - 初始化一个常数
pi
,用于存储圆周率的值。
- 创建一个张量
综合来看,KGEModel
类的作用是初始化知识图嵌入模型,并设置模型的相关参数和超参数,以及为损失函数提供必要的值,以便在训练和评估过程中使用。
class KGEModel(nn.Module):
def __init__(self, args):
super(KGEModel, self).__init__()
self.args = args
self.model_name = args.kge
# self.nrelation = args.num_rel
self.emb_dim = args.dim
self.epsilon = 2.0
self.gamma = torch.Tensor([args.gamma])
self.embedding_range = torch.Tensor([(self.gamma.item() + self.epsilon) / args.dim])
self.pi = 3.14159265358979323846