GET代码学习

csjywu01 / 2024-10-08 / 原文

这段代码展示了如何将数据从 DataFrame 形式转化为 “blocks” 的过程。具体来说,它通过处理 DataFrame 的每一行(每一行代表一个原子),将它们按残基(residue)分组,每个组被称为一个 "block"。以下是对代码的详细分析,解释如何将数据转化为 "block"。

主要步骤:

  1. 循环遍历每个原子(DataFrame 的每一行)

    • df.itertuples() 用于按行遍历 DataFrame,每一行都代表一个原子。
    • 从每一行中提取原子所属的残基 ID、原子名称、元素类型及其空间坐标等信息。
  2. 按残基(residue)分组

    • res_id 是一个字符串,它将残基 ID 和插入代码(insertion code)拼接起来形成独特的标识符。如果没有插入代码,只使用残基 ID。
    • 如果 res_id 发生变化(即遇到新的残基),意味着当前残基的原子分组已经结束,形成了一个完整的 "block"。
    • 新的 "block" 使用 Block 类来表示,包含残基的符号(res_symbol)和对应的原子列表(units)。
  3. 跳过氢原子

    • 如果元素为氢(H),则跳过这一行(continue),因为氢原子在许多情况下被忽略。
  4. 创建 Atom 实例

    • 每个原子被创建为一个 Atom 对象,包含原子的名称、三维坐标(x, y, z),以及元素类型。
    • 这些 Atom 对象会被添加到当前残基的 "block" 中。
  5. 生成 Block 对象

    • 当遍历到新的残基时,前一个残基的 "block" 完成,创建并存储到 blocks 列表中。
    • 遍历结束后,最后一个残基的 "block" 被加入到 blocks 中。

代码分析:

1. df_to_blocks 函数:

这是将 DataFrame 中的原子数据分块(block)的核心函数。

def df_to_blocks(df, key_residue='residue', key_insertion_code='insertion_code', key_resname='resname',
                 key_atom_name='atom_name', key_element='element', key_x='x', key_y='y', key_z='z') -> List[Block]:
    last_res_id, last_res_symbol = None, None  # 保存上一个残基的ID和符号
    blocks, units = [], []  # blocks存放每个block,units存放每个block中的原子
    for row in df.itertuples():  # 遍历每一行(原子)
        residue = getattr(row, key_residue)  # 提取残基ID
        if key_insertion_code is None:
            res_id = str(residue)
        else:
            insert_code = getattr(row, key_insertion_code)  # 提取插入代码
            res_id = f'{residue}{insert_code}'.rstrip()  # 拼接残基ID和插入代码
        
        # 如果残基ID发生变化,意味着当前block结束
        if res_id != last_res_id:
            block = Block(last_res_symbol, units)  # 创建一个新的Block对象
            blocks.append(block)  # 将Block添加到blocks列表
            units = []  # 清空units以开始新的Block
            last_res_id = res_id  # 更新为新的残基ID
            last_res_symbol = VOCAB.abrv_to_symbol(getattr(row, key_resname))  # 通过残基名称获取符号
        
        # 处理当前原子
        atom = getattr(row, key_atom_name)  # 获取原子名称
        element = getattr(row, key_element)  # 获取元素类型
        if element == 'H':  # 跳过氢原子
            continue
        units.append(Atom(atom, [getattr(row, axis) for axis in [key_x, key_y, key_z]], element))  # 添加原子
    blocks = blocks[1:]  # 忽略第一个空block
    blocks.append(Block(last_res_symbol, units))  # 添加最后一个Block
    return blocks

详细步骤解释:

  1. 初始化

    • last_res_idlast_res_symbol 用于追踪前一个残基的信息,blocks 用于存放最终的 Block 列表,units 用于暂存当前 Block 的原子。
  2. 按行遍历 DataFrame

    • df.itertuples() 会将 DataFrame 中每一行作为一个元组返回,通过 getattr() 可以访问每个属性(例如 residueatom_name 等)。
  3. 检查残基 ID 变化

    • res_id 发生变化时,说明已经处理完当前残基,需要创建一个新的 Block。
    • 使用 Block(last_res_symbol, units) 创建一个 Block 对象,units 则是该 Block 中的所有原子。
  4. 处理原子信息

    • 每一行代表一个原子,创建 Atom 对象时,会提取该原子的名称、三维坐标和元素信息。
  5. 返回结果

    • 最后返回 blocks 列表,包含所有的 Block 对象。

总结:

  • 代码中的 df_to_blocks 函数从 DataFrame 中提取原子,并将它们按残基分组,创建了一个个 Block。每个 Block 包含一个残基的信息(符号和原子列表)。
  • 如果 DataFrame 中的数据已经按照这种结构组织好,那么这段代码可以将数据高效地转换为模型可以处理的块状数据结构。

这个分组和转换过程的核心是识别残基 ID 的变化,然后将每个残基的原子聚合成一个 Block。

按残基分组是什么意思。也就是说一个block里面会有多个相同的残基么?

“按残基分组”是指将同一个残基(residue)中所有的原子归类到一个 block 中。在蛋白质结构中,残基通常指氨基酸,而每个氨基酸残基包含若干原子,例如碳、氢、氧、氮等。

一个 block 代表的是一个残基及其包含的所有原子。因此,一个 block 里面不会有多个相同的残基,而是每个残基单独作为一个 block,包含该残基的所有原子。每个残基只对应一个 block,多个相同的残基(例如多个“甘氨酸”残基)会分别出现在不同的 block 中。

详细解释:

  • 在蛋白质结构中,残基通常指的是氨基酸。每个氨基酸残基包含多个原子(例如,氮原子、碳原子、氧原子等),这些原子组成了氨基酸的骨架和侧链。
  • 当你“按残基分组”时,代码会遍历每个原子,并根据原子所属的残基(通过残基 ID res_id)将同一残基的所有原子放在一起。
  • 一个 block 就是代表了一个残基的原子集群。举例来说,一个“甘氨酸”残基,它包含 C、N、O 等原子,这些原子会被归为同一个 block。

举例:

假设我们有一个蛋白质片段,包含两个氨基酸残基:甘氨酸(Gly)和丙氨酸(Ala)。每个氨基酸残基包含若干原子:

  1. 甘氨酸残基(Gly)

    • 原子:N、C、O、H
    • 这些原子被分组到同一个 block 中,表示为 Block 1。
  2. 丙氨酸残基(Ala)

    • 原子:N、C、O、H、Cβ(侧链上的原子)
    • 这些原子被分组到另一个 block 中,表示为 Block 2。

代码的具体过程:

  • 当代码遍历 DataFrame 时,它会检查每一行原子所属的残基 ID(res_id)。
  • 当它遇到一个新的残基 ID 时,意味着前一个残基的所有原子已经处理完毕,生成一个 block。
  • 一个 block 包含当前残基的所有原子。然后,开始处理下一个残基,创建新的 block。

总结:

  • 一个 block 只包含一个残基的所有原子
  • 不同的残基会分别生成不同的 block
  • 多个相同的残基(例如多个“甘氨酸”残基)会出现在不同的 block 中,而不是合并在一起。

希望这能帮助你理解按残基分组和 block 之间的关系!如果还有其他问题,请告诉我。

if isinstance(items, list):
            data = process_map(self._preprocess, items, max_workers=n_cpu, chunksize=10)  所以一个items包含了两个蛋白质,将蛋白质1的所有残基转换为block1,而将蛋白质2的所有残基转换为了block2,是这样么?

是的,理解是正确的。在这段代码中,items 代表的是输入数据,其中包含了两个蛋白质的原子信息。当 items 是一个列表时,process_map(self._preprocess, items) 会将列表中的每一项传递给 self._preprocess 函数进行处理。

具体流程:

  1. items 包含了两个蛋白质的结构数据:

    • items['atoms_interface1']:包含了蛋白质 1 的所有原子信息。
    • items['atoms_interface2']:包含了蛋白质 2 的所有原子信息。
  2. _preprocess 函数中:

    • blocks1 = df_to_blocks(item['atoms_interface1']) 会将 蛋白质 1 的所有残基转换为 blocks1,即一个包含每个残基的“block”的列表。
    • blocks2 = df_to_blocks(item['atoms_interface2']) 会将 蛋白质 2 的所有残基转换为 blocks2,同样也是一个包含每个残基的“block”的列表。
  3. 最后,blocks1blocks2 分别表示两个蛋白质的残基结构,这两个列表会被传递到 blocks_to_data(blocks1, blocks2) 中进行进一步处理,可能是用于对接或相互作用分析。

结论:

是的,_preprocess 函数的作用就是将蛋白质 1 的残基转换为 blocks1,将蛋白质 2 的残基转换为 blocks2,每个 block 都代表一个残基及其包含的所有原子信息。

这里的blocks1和blocks2是指什么

在这段代码中,blocks1blocks2 是从 item['atoms_interface1']item['atoms_interface2'] 中提取出来的数据。它们代表了两个不同的分子界面(interface1interface2)中的残基和原子的分组信息,经过 df_to_blocks() 函数处理后被表示为“block”的形式。

blocks1blocks2 具体指什么?

  • blocks1:是从 item['atoms_interface1'] 中提取出来的“block”,表示分子界面 1(interface1)中按残基分组的原子列表。这个分子界面包含了某个蛋白质或分子片段的一部分(通常是参与对接或相互作用的表面区域)。

  • blocks2:是从 item['atoms_interface2'] 中提取出来的“block”,表示分子界面 2(interface2)中按残基分组的原子列表。这个界面通常是与 blocks1 进行相互作用的另一部分(例如另一个蛋白质或分子)。

每个 block 是由一个残基的所有原子组成的对象,可能包含残基的符号和它的原子坐标、原子类型等信息。

df_to_blocks() 函数的作用:

  • df_to_blocks() 函数将 DataFrame 中的原子信息按残基分组,生成多个 Block 对象,每个 Block 包含一个残基的所有原子信息。
  • blocks1blocks2 代表两个不同分子界面中的所有残基,每个界面都被分成若干 Block

示例场景:

假设 item['atoms_interface1']item['atoms_interface2'] 分别包含两个不同蛋白质(或分子)的界面原子信息,那么:

  • blocks1 会包含 interface1 中所有残基及其原子的分组;
  • blocks2 会包含 interface2 中所有残基及其原子的分组。

这些分子界面信息可能被用于计算两个蛋白质之间的对接相互作用,或者分析它们的几何结构。

blocks_to_data(blocks1, blocks2) 的作用:

  • 这个函数(虽然代码中没有定义)看起来是将 blocks1blocks2 结合在一起,可能用于计算两个界面的相互作用或距离矩阵,或者其他结构相关的指标。它会将两个界面的“block”数据转化为模型需要的输入形式。

data['label'] = item['affinity']['neglog_aff']

  • 最后一步将 item['affinity']['neglog_aff'] 赋值给 data['label'],表示这个样本的目标值(label),通常是一个关于亲和力的数值(例如结合自由能的负对数),用作监督学习中的标签。

总结:

  • blocks1blocks2 是分别从 item['atoms_interface1']item['atoms_interface2'] 中提取出来的“block”对象,代表了按残基分组的原子信息。
  • 它们代表了两个分子界面的结构,这些结构可能用于进一步的相互作用分析或模型输入。
blocks_to_data的函数

这段代码实现了 blocks_to_data 函数,它的作用是将多个 Block 对象的列表转化为统一的数据格式,通常用于神经网络模型的输入。具体来说,它将 Block 中的各类数据(如原子类型、原子坐标、残基类型等)整合成模型需要的输入格式。以下是对每个步骤的详细分析:

函数功能:

  • 输入:函数接受若干个 blocks_list 参数,每个参数是一个 Block 对象的列表(即多个残基的集合)。
  • 输出:函数返回一个字典,字典中包含了处理后的所有 Block 数据,组织成统一格式,方便后续处理。

主要步骤:

  1. 初始化变量

    • B: 存储每个 block(残基)的类型,来自 VOCAB.symbol_to_idx
    • A: 存储原子的类型,来自 VOCAB.get_atom_global_idx()block.to_data()
    • X: 存储原子的坐标信息。
    • atom_positions: 存储原子位置的索引,用于进一步计算原子的几何信息。
    • block_lengths: 记录每个 block(残基)包含的原子数量。
    • segment_ids: 记录这些 block 属于哪个分子或片段(即哪一个 blocks_list)。
  2. 处理每个 blocks_list

    • 外层循环 for i, blocks in enumerate(blocks_list) 遍历输入的多个 blocks_list(每个界面或分子),并逐一处理其中的 blocks(代表每个残基)。
  3. 处理每个 block

    • 对于每个 blocks_list,首先创建一个 "global node"(全局节点),即代表整个分子的虚拟中心点,用于保存整个分子的全局信息。
    • 每个 Block 对象调用 block.to_data() 方法,该方法返回 b(残基类型)、a(原子类型)、x(原子坐标)、positions(原子位置索引)和 block_len(该残基包含的原子数)。
  4. 计算全局节点的中心点

    • cur_X[0] = np.mean(cur_X[1:], axis=0):该操作将当前 Block 的所有原子坐标(去掉全局节点本身)取平均值,并将结果存储为全局节点的坐标。也就是说,"全局节点" 的坐标是所有原子坐标的质心。
  5. 生成段 ID

    • cur_segment_ids = [i for _ in cur_B]:创建一个 segment_ids 列表,用于标记这些 block 来自于哪个 blocks_list(可以理解为它们属于哪个分子或片段)。
  6. 将当前 block 的信息存入全局列表

    • B.extend(cur_B)A.extend(cur_A) 等操作将处理好的 block 数据添加到全局变量中,最终将所有 block 的信息合并在一起。
  7. 返回数据

    • 最终返回的数据是一个字典 data,其中包含了处理好的所有 block 和原子信息。这些信息包括:
      • X: 原子的坐标数组。
      • B: 残基类型的索引。
      • A: 原子类型的索引。
      • atom_positions: 原子的几何位置信息。
      • block_lengths: 每个残基包含的原子数量。
      • segment_ids: 每个残基所属的分子片段。

详细字段解释:

  1. X:

    • 形状为 [Natom, 2, 3],表示每个原子的三维坐标。X[0] 是全局节点的坐标(残基的中心),其余的是残基中每个原子的坐标。
  2. B:

    • B 表示每个残基的类型。它是一个列表,每个元素对应一个残基,表示残基的种类(由 VOCAB.symbol_to_idx() 转换)。
  3. A:

    • A 是所有原子的类型列表,每个原子通过 VOCAB.get_atom_global_idx()block.to_data() 得到其类型。
  4. atom_positions:

    • atom_positions 记录每个原子的位置信息,它可能是用于更复杂的几何计算或进一步分析。
  5. block_lengths:

    • block_lengths 记录每个残基包含的原子数量。每个残基有不同数量的原子,这个列表提供了每个残基的长度信息。
  6. segment_ids:

    • segment_ids 用于标记这些 block 属于哪个片段(如不同的分子或不同的界面),使得不同 blocks_list 可以区分开。

总结:

blocks_to_data 函数的核心目的是将 Block 对象的列表转化为一个统一的数据格式,包括残基类型、原子类型、坐标等信息。这些数据会被用于后续的神经网络模型或其他分析任务。

形状为 [Natom, 2, 3] 这里的维度2是什么意思

X 的形状 [Natom, 2, 3] 中,维度 2 对应的是每个原子的两种坐标表示。具体解释如下:

  1. Natom: 表示所有原子的数量。这个维度的大小取决于输入数据中总共包含的原子数(包括所有 block 中的原子)。

  2. 2: 这个维度表示两种不同的坐标表示方式:

    • 第一种表示X[:, 0, :])通常是指全局节点的坐标,即代表整个残基或整个分子片段的质心坐标(block 中所有原子坐标的平均值)。
    • 第二种表示X[:, 1, :])是原子的实际坐标,即每个原子相对于残基的三维空间位置。
  3. 3: 表示每个原子的三维空间坐标(即 x, y, z 三个轴的坐标)。这个维度的大小始终为 3,表示三维坐标系中的位置。

详细解释:

  • X[:, 0, :]: 对应的是每个 block 的全局坐标,即质心(或残基的中心位置)。这种坐标表示是通过 np.mean(cur_X[1:], axis=0) 计算得出的,表示该 block 中所有原子的平均坐标。

  • X[:, 1, :]: 对应的是原子的实际坐标,表示该原子在三维空间中的 x, y, z 位置。

举个例子:

假设有 2 个原子,X 的形状为 [2, 2, 3]

X = [
    [[x_global1, y_global1, z_global1], [x_atom1, y_atom1, z_atom1]],
    [[x_global2, y_global2, z_global2], [x_atom2, y_atom2, z_atom2]]
]
  • 对于第一个原子:

    • X[0, 0, :] 是全局节点的坐标(残基的中心或质心坐标)。
    • X[0, 1, :] 是该原子的具体坐标。
  • 对于第二个原子:

    • X[1, 0, :] 是另一个 block 的全局坐标。
    • X[1, 1, :] 是第二个原子的具体坐标。

总结:

  • 维度 2 中的第一个位置存储全局节点的坐标(质心),而第二个位置存储实际的原子坐标。通过这种方式,每个原子都有两种坐标表示,既有它本身的坐标,也有全局坐标。

index里面这些变量是什么

从你提供的截图来看,self.indexes 是一个包含字典的列表,每个字典对应一个蛋白质-配体对的结合数据。这些变量存储的是蛋白质和配体的亲和力及其相关的物理化学数据。下面是这些变量的具体含义:

每个字典中的键值解释:

  1. affinity:

    • 包含结合亲和力的相关信息,通常用于评估蛋白质与配体之间的结合强度。
    • Kd:解离常数(dissociation constant),表示蛋白质和配体在平衡时的结合强度。Kd 值越小,表示结合强度越高。
    • dG:自由能变化(Gibbs free energy change),通常是由解离常数计算得出,代表了结合反应的能量变化。通常,dG 值越低,结合越稳定。
    • neglog_aff:亲和力的负对数值,可能用于处理 Kd 值较小时的数值稳定性。
  2. id:

    • 这是蛋白质-配体对的唯一标识符(通常是 PDB ID),每个 PDB ID 对应一种特定的蛋白质结构或蛋白质-配体复合体。

具体示例:

  • affinity

    • 'Kd': 5.59e-07:表示蛋白质-配体的解离常数为 5.59 × 10^-7 M,表明蛋白质和配体结合较紧密。
    • 'dG': -8.53:自由能变化为 -8.53 kcal/mol,结合的稳定性较高。
    • 'neglog_aff': 6.25:可能是亲和力的负对数值,通常用于评估结合强度。
  • id

    • 例如,'id': '3bpl':这是蛋白质-配体对的唯一标识符,在 PDB 数据库中可能查找到这个复合物的结构。

总结:

self.indexes 是一个包含多个蛋白质-配体对的数据列表,每个字典包含了:

  • 亲和力相关的数值(Kd, dG, neglog_aff)。
  • 蛋白质-配体对的唯一标识符(id)。
    这些信息可能用于分析蛋白质与配体的相互作用,以及模型的输入或标签(label)。

是的,data['label'] = item['affinity']['neglog_aff'] 这行代码的作用是将结合强度(以负对数亲和力 neglog_aff 的形式)提取出来,并将其作为 data['label'],用于模型的标签(label)。

具体解释:

  • item['affinity']['neglog_aff']:这里的 neglog_aff 是亲和力的负对数(通常是 -log(Kd),其中 Kd 是解离常数),用于衡量蛋白质与配体的结合强度。

    • Kd (解离常数):Kd 值越小,表明蛋白质和配体结合越紧密。为了便于数值计算和处理,Kd 常常以负对数形式表示,因此 neglog_aff 越大,表示结合强度越强。
  • data['label']:将提取的 neglog_aff 值赋值给 data['label'],这个 label 将作为模型的目标值,用于监督学习或预测蛋白质-配体之间的结合强度。

为什么使用 neglog_aff

使用亲和力的负对数(neglog_aff)而不是 Kd 值的原因是:

  1. 数值稳定性:Kd 值的范围可能非常广,从 10^-9 到 10^-3 甚至更大。直接使用 Kd 值进行计算可能导致数值不稳定。通过对 Kd 取负对数,可以缩小数值范围,使其更适合用于机器学习模型。
  2. 线性关系:在某些情况下,-log(Kd) 与结合强度的物理化学过程有更好的线性关系,因此更适合用于回归模型。

总结:

这行代码是提取蛋白质-配体结合强度的 neglog_aff 值,并将其作为标签(label),用于模型训练或评估结合强度的任务。

以batch为单位的数据

根据你提供的截图内容,下面是对各个变量的解释以及它们可能对应的数据含义:

  1. A:

    • 内容: tensor([2, 9, 8, ..., 8, 10, 8], device='cuda:0')
    • 解释: 这是一个表示原子索引的 Tensor,通常对应的是某种类型的原子属性或索引。它可能是每个原子所对应的类型(例如不同元素的类别)。
  2. B:

    • 内容: tensor([3, 2, 7, ..., 12, 14, 5], device='cuda:0')
    • 解释: 这个变量通常表示块(block)或残基(residue)的索引或类型。它可能是表示不同残基或块的分类信息。
  3. Z:

    • 内容: tensor([[[0.2207, 17.2229, -14.8194], [-11.4810, 22.9620, -30.7930]]], device='cuda:0')
    • 解释: Z 表示原子的三维坐标信息。这个 Tensor 的形状是 (8548, 1, 3),表示 8548 个原子的三维坐标。这通常用于计算分子之间的距离或作为模型输入。
  4. atom_positions:

    • 内容: tensor([2, 3, 4, ..., 3, 3, 5], device='cuda:0')
    • 解释: 这是每个原子的位置信息,可能是原子在结构中的顺序或与三维坐标有关的索引。结合 Z 使用时,它可以帮助定位每个原子的具体位置。
  5. block_lengths:

    • 内容: tensor([1, 8, 8, ..., 8, 8, 5], device='cuda:0')
    • 解释: 这是一个记录每个块(block)或残基(residue)包含多少原子的 Tensor。它可以帮助定义每个块的大小和结构。
  6. label:

    • 内容: tensor([10.3841, 5.7959, 5.0969, ..., 9.7696], device='cuda:0')
    • 解释: 这是目标变量的标签信息,通常用于表示结合亲和力(affinity)或其他分子间属性。这些标签是模型训练时的目标值。
  7. lengths:

    • 内容: tensor([92, 111, 82, ..., 47, 175, 39], device='cuda:0')
    • 解释: 这是每个样本(分子或块)的长度信息,表示每个分子或残基块的长度,用于动态批处理或其他模型输入处理。
  8. return_loss:

    • 内容: False
    • 解释: 这是一个布尔值,表示是否返回损失值(loss)。它通常用于控制在训练过程中是否需要计算并返回损失值。
  9. return_noise:

    • 内容: False
    • 解释: 这是一个布尔值,表示是否返回噪声信息。它可能用于控制在一些模型中是否引入噪声以进行数据增强或模拟不同情况。
  10. segment_ids:

    • 内容: tensor([0, 0, 0, ..., 1, 1, 1], device='cuda:0')
    • 解释: 这是一个表示分段信息的 Tensor,用于标记数据中不同片段的边界。它可能用于区分来自不同分子或残基的部分。
  11. self (AffinityPredictor):

    • 解释: 这是模型的实例对象(AffinityPredictor),它可能是你正在使用的模型类,专门用于预测分子亲和力或其他相关任务。这个模型可能包含了一个块嵌入层(BlockEmbedding),用于将分子块嵌入为向量表示。

总结:

这些变量是模型输入数据的不同部分。大致可以归类为以下几类:

  • 几何和结构信息:如 Z(三维坐标)、atom_positionsblock_lengthslengths
  • 标签和目标值:如 label(亲和力标签)。
  • 模型的控制和处理标志:如 return_lossreturn_noise
  • 其他辅助信息:如 segment_idsAB(分别表示原子和块的索引或类型)。

这些信息结合起来构成了用于分子建模和预测的输入数据,以及用于控制训练过程的标志。

边序号的构建

在这段代码中,construct_edges 函数的主要作用是根据输入的块(block)、批次(batch)、段(segment)等信息,利用 edge_constructor 来构建 边集(edges),这些边表示分子结构或其他复杂结构中各节点(如原子、块)之间的连接关系。

边的构建步骤

  1. 不进行切片的边构建:
    complexity == -1 时,不进行任何切片操作,直接调用 edge_constructor 来构建所有边:

    intra_edges, inter_edges, global_global_edges, global_normal_edges, _ = edge_constructor(B, batch_id, segment_ids, X=X, block_id=block_id)
    return intra_edges, inter_edges, global_global_edges, global_normal_edges
    

    在这里,edge_constructor 会直接生成以下几种边:

    • intra_edges: 在块(block)内部的边。
    • inter_edges: 块与块之间的边。
    • global_global_edges: 全局节点之间的边。
    • global_normal_edges: 全局节点与普通节点之间的边。
  2. 进行切片的边构建:
    complexity != -1 时,会根据指定的复杂度 complexity 对数据进行切片。每次处理数据的一个小批次(mini-batch),并为每个小批次构建边,最后再合并这些边。

    具体过程如下:

    • 初始化部分变量:

      offset, bs_id_start, bs_id_end = 0, 0, 0
      mini_intra_edges, mini_inter_edges, mini_global_global_edges, mini_global_normal_edges = [], [], [], []
      batch_size = batch_id.max() + 1
      unit_batch_id = batch_id[block_id]
      lengths = scatter_sum(torch.ones_like(batch_id), batch_id, dim=0)
      

      这里 batch_size 是指总的批次数量,unit_batch_id 表示每个原子的批次ID,lengths 计算每个批次中的块的数量。

    • 逐个批次处理:

      while bs_id_end < batch_size:
          bs_id_start = bs_id_end
          bs_id_end += 1
          while bs_id_end + 1 <= batch_size and \
                (lengths[bs_id_start:bs_id_end + 1] * lengths[bs_id_start:bs_id_end + 1].max()).sum() < complexity:
              bs_id_end += 1
      

      这部分代码用于迭代处理每个批次的数据,同时根据 complexity 控制每个小批次的复杂度。lengths[bs_id_start:bs_id_end + 1] 计算当前批次中块的数量,用于决定是否要继续向下一个批次扩展。

    • 选择当前批次中的块和原子:

      block_is_in = (batch_id >= bs_id_start) & (batch_id < bs_id_end)
      unit_is_in = (unit_batch_id >= bs_id_start) & (unit_batch_id < bs_id_end)
      B_mini, batch_id_mini, segment_ids_mini = B[block_is_in], batch_id[block_is_in], segment_ids[block_is_in]
      X_mini, block_id_mini = X[unit_is_in], block_id[unit_is_in]
      

      这里 block_is_inunit_is_in 用于选择当前批次中的块和原子。B_minibatch_id_mini 等变量存储的是当前小批次中的块、批次ID、段ID、位置等信息。

    • 为当前批次构建边:

      intra_edges, inter_edges, global_global_edges, global_normal_edges, _ = edge_constructor(
          B_mini, batch_id_mini - bs_id_start, segment_ids_mini, X=X_mini, block_id=block_id_mini - offset)
      

      对于当前小批次,通过 edge_constructor 构建边集,包括块内边、块间边、全局边等。

    • 累加偏移量:

      offset += B_mini.shape[0]
      

      由于在每个小批次之间的索引会重新计算,因此需要通过 offset 来保持原始的全局索引。

    • 将小批次的边合并:
      通过检查 edge_constructor 是否有提供 given_intra_edgesgiven_inter_edges,如果没有的话,将小批次中的边添加到合并后的列表中,最终将这些小批次的边合并成全局的边:

      if hasattr(edge_constructor, 'given_intra_edges'):
          intra_edges = edge_constructor.given_intra_edges
      else:
          intra_edges = torch.cat(mini_intra_edges, dim=1)
      if hasattr(edge_constructor, 'given_inter_edges'):
          inter_edges = edge_constructor.given_inter_edges
      else:
          inter_edges = torch.cat(mini_inter_edges, dim=1)
      if global_global_edges is not None:
          global_global_edges = torch.cat(mini_global_global_edges, dim=1)
      if global_normal_edges is not None:
          global_normal_edges = torch.cat(mini_global_normal_edges, dim=1)
      

边的类型总结

最终返回的边包括以下几种类型:

  • intra_edges: 块内部的边,即块中原子之间的边。
  • inter_edges: 块之间的边,即不同块中的原子之间的边。
  • global_global_edges: 全局节点之间的边,用于捕捉全局信息。
  • global_normal_edges: 全局节点与普通节点之间的边,用于连接全局和局部信息。

这些边将被用于进一步的图神经网络中,用来建模分子结构的几何和拓扑信息。

如果你有关于这个过程的具体问题,欢迎进一步提问!