Llama3学习记录

AiHorizon / 2024-09-20 / 原文

Llama3学习记录

Llama3是一个稠密的transformer网络模型,应用于预测文本序列的下一个token。相较于先前版本的Llama模型,其性能提升主要来自于数据质量的提升以及多样性,并且也受益于模型参数的增加

1. 网络架构

  • 由上图可知,Llama3是一个decoder only的网络模型

  • Llama3模型具体架构层如上图,可以看到,Llama模型使用了前置的RMSNorm层,并且在注意力机制中采用了GQA架构,并且在Q、K上使用了RoPE旋转位置编码
  • 由于模型是预测下一个token,因此Llama在训练时,会mask掉卫位于当前token之后的token

2. 核心概念

2.1 RMSNorm:

RMSNorm是LayerNorm的变体,通过激活值的均方根来实现归一化

优点:

  1. 不计算均值,相比于LayerNorm,减少了计算开销
  2. 避免了过度归一化,使得训练更加稳定:没有对均值进行归一化,只归一化方差,因此可以保留均值的信息,减少对信息的破坏

2.1.1 计算步骤:

  1. 计算均方根值RMS(x):对输入的特征x计算其均方根值,x是输入特征向量,n是特征维度

  1. 进行归一化,利用计算的均方根进行归一化:

  1. 可以通过可学习的参数进行缩放和平移,g是缩放参数,b是平移参数

2.1.2 代码实现:

# 计算归一化结果并进行缩放(self.weight为缩放参数)
# torch.rsqrt:计算每个元素平方根的倒数

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

2.2 Rope:旋转位置编码

  • 对于输入的向量进行分解,看作多个二维向量的组合。Rope对每个二位向量进行旋转变换,对于位置p,旋转角度公式为:image,其中image是一个伴随维度变化的常数,用来控制旋转速度,每个二维向量的旋转公式为:

  • 其计算的矩阵形式为:

  • Rope会随着相对位置的增加,逐渐减小
  • 代码实现:
# 计算频率矩阵,返回余弦和正弦的频率矩阵
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device) 
    freqs = torch.outer(t, freqs).float() 
    freqs_cos = torch.cos(freqs) 
    freqs_sin = torch.sin(freqs) 
    return freqs_cos, freqs_sin

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cos: torch.Tensor,
    freqs_sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:

    # 重塑 xq 和 xk,使其与复数表示相匹配
    xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
    xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)

    # 重塑形为了广播
    freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
    freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)

    # 应用旋转嵌入
    xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
    xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
    xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos

    # 将最后两维度拉平。
    xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
    xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)

    return xq_out.type_as(xq), xk_out.type_as(xk)

2.3 Attn实现

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        
        self.group = args.n_group
        self.heads = args.n_heads
        self.kv_heads = args.n_heads // args.n_group
        assert args.n_heads % self.kv_heads == 0
        self.head_dim = args.dim // args.n_heads
        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
        self.attn_dropout = nn.Dropout(args.dropout)
        self.resid_dropout = nn.Dropout(args.dropout)
        self.dropout = args.dropout
        mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
        mask = torch.triu(mask, diagonal=1)
        self.register_buffer("mask", mask)

    def forward(
        self,
        x: torch.Tensor,
        freqs_cos: torch.Tensor,
        freqs_sin: torch.Tensor,
    ):
        bsz, seqlen, _ = x.shape

        # QKV
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.kv_heads, self.head_dim)

        # RoPE relative positional embeddings
        xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)

        # grouped multiquery attention: expand out keys and values
        xk = repeat_kv(xk, self.group)  # (bs, seqlen, n_local_heads, head_dim)
        xv = repeat_kv(xv, self.group)  # (bs, seqlen, n_local_heads, head_dim)

        # make heads into a batch dimension
        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        xk = xk.transpose(1, 2)
        xv = xv.transpose(1, 2)


        # 先不使用flash attn,从零走一遍流程!
        scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
        assert hasattr(self, 'mask')
        scores = scores + self.mask[:, :, :seqlen, :seqlen]   # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        scores = self.attn_dropout(scores)
        output = torch.matmul(scores, xv)  # (bs, n_local_heads, seqlen, head_dim)

        # restore time as batch dimension and concat heads
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

        # 最终送入output层并正则,得到最终结果。
        output = self.wo(output)
        output = self.resid_dropout(output)
        return output

3. Llama3本地部署

按照HF库中文档的简单部署与使用

import transformers
import torch

model_id = "meta-llama/Meta-Llama-3-8B"

pipeline = transformers.pipeline(
    "text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto"
)
pipeline("Hey how are you doing today?")
  • 本地部署以及webui使用:
    • https://gitcode.csdn.net/6630586675c93e11c804127d.html
    • https://ducafecat.com/blog/llama3-model-api-local