音乐播放器
sola的小屋
 
文章 标签
20

Powered by Gridea | Theme: Fog
载入天数...
载入时分秒...
总访问量:  |   访问人数:

InferLLM大模型推理框架项目(06)——Graph类的实现(src/core/graph.h+.cpp)

Graph 类代码结构与功能实现分析

Graph 类是 InferLLM 框架中的核心组件,负责构建和执行模型的计算图。下面对 graph.hgraph.cpp 的代码结构和功能实现进行详细分析。

1. 核心数据结构

1.1 LlmParams 结构体

struct LlmParams {
    bool is_multi_query = false;
    int32_t multi_query_group_num = 1;
    int32_t n_vocab;        // 词汇表大小
    int32_t n_embd;         // 嵌入维度
    int32_t n_mult;         // 中间层倍数
    int32_t n_head;         // 注意力头数
    int32_t n_layer;        // 层数
    int32_t n_rot;          // 旋转位置编码维度
    int32_t ftype;          // 权重类型
    int32_t n_ctx;          // 上下文长度
};

这个结构体存储了 LLM 模型的核心参数,包括模型结构信息和计算配置。

1.2 UserConfig 结构体

struct UserConfig {
    DType compt_type;       // 计算类型
};

这个结构体存储了用户配置的计算类型。

2. OprModuleBase 类

OprModuleBase 是所有操作模块的基类,封装了一组相关的算子:

class OprModuleBase {
public:
    // 构造函数
    OprModuleBase(std::shared_ptr<Tensor> input, Device* device, const std::string& name);
    OprModuleBase(std::vector<std::shared_ptr<Tensor>> inputs, Device* device, const std::string& name);
    
    // 核心方法
    size_t get_workspace_in_byte();
    void deduce_output_shape();
    virtual void execute(WorkSpace* workspace, uint32_t nr_past, bool is_prefill = false);
    
    // 添加算子
    template <typename Op, typename... Args>
    std::vector<std::shared_ptr<Tensor>> add_opr(Args&&... args);
    
    // 获取权重
    std::vector<std::shared_ptr<Tensor>> get_all_weights();
    
    // 输入输出管理
    std::vector<std::shared_ptr<Tensor>> inputs() const;
    std::shared_ptr<Tensor> input(int id = 0) const;
    std::shared_ptr<Tensor> output() const;
    
    // 重置上下文
    virtual void reset_ctx() {}
};

OprModuleBase 的主要功能:

  1. 管理一组相关的算子
  2. 提供统一的执行接口
  3. 管理输入输出张量
  4. 计算工作空间需求

2.1 execute 方法实现

void OprModuleBase::execute(WorkSpace* workspace, uint32_t nr_past, bool) {
    for (auto opr : m_oprs) {
        opr->pre_execute();
#ifdef INFER_PROFILE
        struct timeval start, end;
        gettimeofday(&start, NULL);
#endif
        opr->execute(workspace, nr_past);

#ifdef INFER_PROFILE
        gettimeofday(&end, NULL);
        long seconds = end.tv_sec - start.tv_sec;
        float micros = (seconds * 1000) + (float)(end.tv_usec - start.tv_usec) / 1000;
        printf("Op %s spent time %f ms\n", opr->name().c_str(), micros);
#endif
        opr->end_execute();
    }
}

这个方法按顺序执行模块中的所有算子,并支持性能分析。

3. 特定模块实现

框架定义了多种特定的模块,用于构建不同类型的模型:

3.1 AttentionModule

template <typename Attention>
class AttentionModule : public OprModuleBase {
    // 构造函数中构建注意力计算图
    AttentionModule(...) {
        // 创建注意力算子
        m_attention_op = std::make_shared<Attention>(...);
        oprs().push_back(m_attention_op);
        
        // 创建投影算子
        auto proj_out = add_opr<MatMul>(...)[0];
        set_output(proj_out);
    }
    
    // 重置上下文
    void reset_ctx() override { m_attention_op->reset_ctx(); }
};

这是一个模板类,可以使用不同的注意力实现(如 LlamaAttention、GlmAttention 等)。

3.2 FFN 模块

框架实现了多种 FFN(前馈网络)模块,适用于不同的模型架构:

LlamaFFNModule

LlamaFFNModule::LlamaFFNModule(...) {
    size_t nff = ((2 * (4 * embd) / 3 + mult - 1) / mult) * mult;
    // 创建第一个矩阵乘法
    auto matmul_out0 = add_opr<MatMul>(...)[0];
    // 创建第二个矩阵乘法
    auto matmul_out1 = add_opr<MatMul>(...)[0];
    // SiLU 激活函数
    auto silu_out = add_opr<Elemwise>(..., ElemMode::Silu)[0];
    // 元素乘法
    auto mul_out = add_opr<Elemwise>(..., ElemMode::Mul)[0];
    // 最后的矩阵乘法
    auto matmul_out2 = add_opr<MatMul>(...)[0];
    set_output(matmul_out2);
}

GlmFFNModule

GlmFFNModule::GlmFFNModule(...) {
    // 矩阵乘法
    auto matmul_out1 = add_opr<MatMul>(..., true)[0];
    // GELU 激活函数
    auto gelu_out = add_opr<Elemwise>(..., ElemMode::Gelu)[0];
    // 矩阵乘法
    auto matmul_out2 = add_opr<MatMul>(..., true)[0];
    set_output(matmul_out2);
}

Glm2FFNModule

Glm2FFNModule::Glm2FFNModule(...) {
    // 矩阵乘法
    auto matmul_out1 = add_opr<MatMul>(..., false)[0];
    // SpliteHalfActiveMul 操作
    auto gelu_out = add_opr<SpliteHalfActiveMul>(..., ElemMode::Silu)[0];
    // 矩阵乘法
    auto matmul_out2 = add_opr<MatMul>(..., false)[0];
    set_output(matmul_out2);
}

3.3 HeadModule

HeadModule::HeadModule(...) {
    // 层归一化
    auto norm_out = add_opr<LayerNorm>(...)[0];
    // 矩阵乘法(输出层)
    auto matmul_out = add_opr<MatMulLast>(...)[0];
    set_output(matmul_out);
}

void HeadModule::execute(WorkSpace* workspace, uint32_t nr_past, bool is_prefill) {
    // 预填充模式下不执行
    if (!is_prefill) {
        for (auto opr : oprs()) {
            opr->pre_execute();
            opr->execute(workspace, nr_past);
            opr->end_execute();
        }
    }
}

HeadModule 实现了模型的输出头,将隐藏状态映射到词汇表空间。它有一个特殊的 execute 方法,在预填充模式下不执行。

3.4 EmbdModule

EmbdModule::EmbdModule(...) {
    // 嵌入层
    auto embd_out = add_opr<Embedding>(...)[0];
    set_output(embd_out);
}

EmbdModule 实现了模型的嵌入层,将输入 token 映射到嵌入空间。

3.5 OneOpModule

template <class Op>
class OneOpModule : public OprModuleBase {
    // 添加单个算子
    template <typename... Args>
    std::shared_ptr<Tensor> add_opr(Args&&... args) {
        auto opr = std::make_shared<Op>(
                device(), name(), inputs(), std::forward<Args>(args)...);
        oprs().push_back(opr);
        set_output(opr->outputs()[0]);
        return opr->outputs()[0];
    }
};

OneOpModule 是一个辅助类,用于创建只包含一个算子的模块。

4. Graph 类

Graph 类是整个计算图的管理者,负责构建模型结构、加载权重和执行推理:

class Graph : public std::enable_shared_from_this<Graph> {
public:
    // 构造函数
    Graph(UserConfig model_config, Device* device, const std::string& name);
    
    // 静态工厂方法
    static std::shared_ptr<Graph> make_graph(UserConfig model_config, Device* device, const std::string& name);
    
    // 执行方法
    void execute(std::vector<int32_t> in_token, std::vector<float>& logist, uint32_t nr_past, bool prefill = false);
    
    // 添加模块
    template <typename OpModule, typename... Args>
    std::shared_ptr<Tensor> add_module(Args&&... args);
    
    // 重置上下文
    void reset_ctx();
    
    // 收集权重
    void collect_weights();
    
    // 加载模型
    virtual void load(std::shared_ptr<InputFile> fin, LlmParams& param, std::shared_ptr<Vocab> vocab);
    
    // 构建模型结构(由派生类实现)
    virtual void construct_llm() = 0;
};

4.1 execute 方法实现

void Graph::execute(std::vector<int32_t> in_token, std::vector<float>& logist, uint32_t nr_past, bool prefill) {
    // 检查输入形状,必要时重新分配工作空间
    if (m_input->dims() == 0 || !same_input_shape(in_token)) {
        m_input->set_shape({in_token.size()}, DType::Int32);
        size_t len = get_workspace_in_byte();
        if(m_workspace->ptr() == nullptr) {
            auto data = m_device->allocate(len);
            m_workspace->set_memory(data, len);
        } else if (m_workspace->ptr() && len > m_workspace->length()) {
            m_device->free_device(m_workspace->ptr());
            auto data = m_device->allocate(len);
            m_workspace->set_memory(data, len);
        }
    }
    
    // 准备输入数据
    m_input->resume_user_count();
    m_input->prepare_data();
    m_device->host2device_copy(m_input->ptr(), in_token.data(), in_token.size() * sizeof(int32_t), true);
    
    // 执行所有模块
    for (size_t i = 0; i < m_modules.size(); i++) {
        m_modules[i]->execute(m_workspace.get(), nr_past, prefill);
    }
    
    // 如果不是预填充模式,复制输出数据
    if (!prefill) {
        m_device->device2host_copy(logist.data(), m_output->ptr(), logist.size() * sizeof(float), true);
    }
    
    // 同步设备并释放输出内存
    m_device->sync();
    m_output->recall_data();
}

这个方法是模型推理的核心,它执行以下步骤:

  1. 检查输入形状,必要时重新分配工作空间
  2. 准备输入数据
  3. 按顺序执行所有模块
  4. 复制输出数据(非预填充模式)
  5. 同步设备并释放输出内存

4.2 load 方法实现

void Graph::load(std::shared_ptr<InputFile> fin, LlmParams& param, std::shared_ptr<Vocab> vocab) {
    // 验证魔数
    uint32_t magic;
    fin->read_raw((char*)&magic, sizeof(magic));
    INFER_ASSERT(magic == 0x123456, "model magic is not create!!!!");
    
    // 加载参数
    load_param(fin, param, vocab);
    
    // 构建模型结构
    construct_llm();
    collect_weights();
    set_weights_alias();
    
    // 加载权重
    while (true) {
        // 读取权重信息
        int32_t n_dims, length, ftype;
        if (fin->eof()) break;
        
        fin->read_raw(reinterpret_cast<char*>(&n_dims), sizeof(n_dims));
        fin->read_raw(reinterpret_cast<char*>(&length), sizeof(length));
        fin->read_raw(reinterpret_cast<char*>(&ftype), sizeof(ftype));
        
        if (fin->eof()) break;
        
        // 读取形状
        size_t nr_number = 1;
        int32_t shape[2] = {1, 1};
        for (int i = 0; i < n_dims; ++i) {
            fin->read_raw(reinterpret_cast<char*>(&shape[i]), sizeof(shape[i]));
            nr_number *= shape[i];
        }
        
        // 读取权重名称
        std::string name(length, 0);
        fin->read_raw(&name[0], length);
        auto alias_name = get_weight_alias(name);
        
        // 如果权重不在映射表中,跳过
        if (m_weights_map.count(alias_name) == 0) {
            auto dtype = convert_dtype(ftype);
            size_t length = nr_number * dtype_in_byte(dtype) / dtype_block_size(dtype);
            fin->skip(length);
            continue;
        }
        
        // 设置权重的文件信息
        auto weight = m_weights_map[alias_name];
        weight->set_file(fin, fin->tell());
        weight->set_shape({shape[0], shape[1]}, convert_dtype(ftype));

        // 跳过权重数据(实际读取在需要时进行)
        size_t length = weight->length_in_byte();
        fin->skip(length);

这里采用了延迟加载策略,只记录权重在文件中的位置,实际读取在需要时进行,这样可以减少内存使用。

4.3 collect_weights 方法

void Graph::collect_weights() {
    for (auto module : m_modules) {
        auto weights = module->get_all_weights();
        for (auto weight : weights) {
            auto name = weight->name();
            m_weights_map[name] = weight;
        }
    }
}

这个方法收集所有模块的权重,并将它们存储在权重映射表中,以便后续加载。

4.4 reset_ctx 方法

void Graph::reset_ctx() {
    for (auto module : m_modules) {
        module->reset_ctx();
    }
}

这个方法重置所有模块的上下文,通常在处理新的序列时调用。

5. 模型构建流程

InferLLM 框架中的模型构建流程如下:

  1. 创建 Graph 对象

    auto graph = Graph::make_graph(model_config, device, name);
    
  2. 加载模型参数

    graph->load(fin, param, vocab);
    

    这会调用 load_param 加载模型参数,然后调用 construct_llm 构建模型结构。

  3. 构建模型结构
    construct_llm 方法由派生类实现,用于构建特定模型的结构。例如,LLaMA 模型的构建过程会包含以下步骤:

    • 创建输入张量
    • 创建嵌入模块
    • 创建多个 Transformer 层,每层包含:
      • 层归一化
      • 注意力模块
      • 残差连接
      • 层归一化
      • FFN 模块
      • 残差连接
    • 创建输出头模块
  4. 收集权重

    graph->collect_weights();
    

    这会收集所有模块的权重,并将它们存储在权重映射表中。

  5. 设置权重别名

    graph->set_weights_alias();
    

    这会设置权重的别名,用于处理不同模型格式的权重名称差异。

6. 模型执行流程

InferLLM 框架中的模型执行流程如下:

  1. 准备输入

    graph->execute(tokens, logits, nr_past, prefill);
    

    这会将输入 token 复制到设备内存,并准备工作空间。

  2. 执行模块

    for (size_t i = 0; i < m_modules.size(); i++) {
        m_modules[i]->execute(m_workspace.get(), nr_past, prefill);
    }
    

    这会按顺序执行所有模块。

  3. 获取输出

    m_device->device2host_copy(logist.data(), m_output->ptr(), logist.size() * sizeof(float), true);
    

    这会将输出 logits 复制到主机内存。

  4. 同步和清理

    m_device->sync();
    m_output->recall_data();
    

    这会等待所有操作完成,并释放输出内存。

7. 优化策略

Graph 类实现了几种优化策略:

  1. 延迟加载:权重数据只在需要时才从文件读取,减少内存使用
  2. 工作空间复用:使用同一块内存作为所有算子的工作空间,减少内存分配
  3. 预填充模式:支持预填充模式,可以跳过输出头的计算,提高性能
  4. 上下文重置:支持重置上下文,便于处理新的序列

8. Tensor 类与 Graph 类的交互

Tensor 类是 Graph 类的基础,它提供了数据存储和管理功能。Graph 类通过以下方式与 Tensor 类交互:

  1. 创建张量:Graph 类创建输入、输出和中间张量
  2. 管理权重:Graph 类管理模型权重,这些权重是特殊的张量
  3. 数据传输:Graph 类负责在主机和设备之间传输数据
  4. 内存管理:Graph 类通过 Tensor 类的 prepare_datarecall_data 方法管理内存

总结

Graph 类是 InferLLM 框架的核心组件,它通过模块化的设计,实现了灵活、高效的模型构建和执行。它支持多种模型架构(如 LLaMA、ChatGLM 等),并提供了统一的接口进行模型加载和推理。通过延迟加载、工作空间复用等优化策略,它可以在有限的内存资源下高效运行大型语言模型。