音乐播放器
sola的小屋
 
文章 标签
20

Powered by Gridea | Theme: Fog
载入天数...
载入时分秒...
总访问量:  |   访问人数:

InferLLM大模型推理框架项目(08)——Model类的实现(src/core/model.h+.cpp)

Model 类代码结构与功能实现分析

Model 类是 InferLLM 框架的对外接口层,它封装了底层模型实现的细节,提供了一套简洁的 API 供用户使用。下面对 model.hmodel.cpp 的代码结构和功能实现进行详细分析。

1. API 宏定义

#if defined(_WIN32)
#define API __declspec(dllexport)
#else
#define API __attribute__((visibility("default")))
#endif

这个宏定义用于控制函数和类的导出属性,确保在不同平台上正确导出符号:

  • 在 Windows 平台上使用 __declspec(dllexport)
  • 在其他平台上使用 __attribute__((visibility("default")))

2. ModelConfig 结构体

struct ModelConfig {
    //! dtype include 'float32','float16','int8','int4'
    std::string compt_type = "float32";
    //! device_type include 'cpu','gpu'
    std::string device_type = "cpu";
    uint32_t nr_thread;
    uint32_t nr_ctx;
    int32_t device_id;
    bool enable_mmap;
};

ModelConfig 结构体定义了模型的配置参数:

  • compt_type:计算类型,支持 "float32"、"float16"、"int8"、"int4"
  • device_type:设备类型,支持 "cpu"、"gpu"
  • nr_thread:线程数量
  • nr_ctx:上下文长度
  • device_id:设备 ID(对于 GPU)
  • enable_mmap:是否启用内存映射

3. Model 类

Model 类是框架的主要接口类,它使用了桥接模式,将具体实现委托给 ModelImp 类:

class API Model {
public:
    //! 构造函数,创建指定类型的模型
    Model(const ModelConfig& config, const std::string& model_name);

    //! 加载模型
    void load(const std::string& model_path);

    //! 初始化模型参数
    void init(uint32_t top_k, float top_p, float temp, float repeat_penalty,
              int repeat_last_n, int32_t seed, int32_t end_token);

    //! 获取剩余 token 数量
    uint32_t get_remain_token();

    //! 重置 token
    void reset_token();

    //! 预填充模型
    void prefill(const std::string& promote);

    //! 解码(带用户输入)
    std::string decode(const std::string& user_input, int& token);

    //! 迭代解码
    std::string decode_iter(int& token);

    //! 获取解码摘要
    std::string decode_summary() const;

private:
    std::shared_ptr<ModelImp> m_model_imp;
};

3.1 构造函数

Model::Model(const ModelConfig& config, const std::string& model_name) {
    //! TODO: create the model implement by the model name
    m_model_imp = std::make_shared<ModelImp>(config, model_name);
}

构造函数接收模型配置和模型名称,创建对应的 ModelImp 实例。注释中的 TODO 表明这里未来可能会根据模型名称创建不同的实现类。

3.2 load 方法

void Model::load(const std::string& model_path) {
    m_model_imp->load(model_path);
}

load 方法从指定路径加载模型文件,它直接调用 ModelImpload 方法。

3.3 init 方法

void Model::init(
        uint32_t top_k, float top_p, float temp, float repeat_penalty,
        int repeat_last_n, int32_t seed, int32_t end_token) {
    m_model_imp->init(
            top_k, top_p, temp, repeat_penalty, repeat_last_n, seed, end_token);
}

init 方法初始化模型的生成参数:

  • top_k:保留概率最高的 k 个 token
  • top_p:保留累积概率达到 p 的 token
  • temp:温度参数,控制生成的随机性
  • repeat_penalty:重复惩罚,降低已生成 token 的概率
  • repeat_last_n:考虑最后 n 个 token 进行重复惩罚
  • seed:随机数种子
  • end_token:结束 token

3.4 token 管理方法

uint32_t Model::get_remain_token() {
    return m_model_imp->get_remain_token();
}

void Model::reset_token() {
    return m_model_imp->reset_token();
}

这两个方法用于管理 token:

  • get_remain_token:获取剩余可生成的 token 数量
  • reset_token:重置 token 计数,通常在开始新的生成任务时调用

3.5 生成方法

void Model::prefill(const std::string& promote) {
    return m_model_imp->prefill(promote);
}

std::string Model::decode(const std::string& user_input, int& token) {
    return m_model_imp->decode(user_input, token);
}

std::string Model::decode_iter(int& token) {
    return m_model_imp->decode_iter(token);
}

std::string Model::decode_summary() const {
    return m_model_imp->decode_summary();
}

这些方法用于文本生成:

  • prefill:预填充模型,处理初始提示文本
  • decode:解码用户输入,生成下一个 token
  • decode_iter:迭代解码,生成下一个 token(不需要用户输入)
  • decode_summary:获取生成摘要,通常包含生成统计信息

4. 桥接模式的应用

Model 类采用了桥接模式,将接口与实现分离:

  • Model 类提供稳定的对外接口
  • ModelImp 类负责具体实现
  • 这种设计使得可以在不改变接口的情况下,更换或升级底层实现

桥接模式的优点在这里得到了充分体现:

  1. 接口稳定:用户代码只需要依赖 Model 类的接口,不需要关心具体实现
  2. 实现可替换:可以根据不同的模型类型或硬件平台,提供不同的 ModelImp 实现
  3. 隐藏复杂性:复杂的模型加载、推理逻辑都封装在 ModelImp 中,对用户透明

5. 工作流程

使用 Model 类的典型工作流程如下:

  1. 创建模型

    ModelConfig config;
    config.compt_type = "float32";
    config.device_type = "cpu";
    config.nr_thread = 4;
    config.nr_ctx = 2048;
    config.enable_mmap = true;
    
    Model model(config, "llama");
    
  2. 加载模型

    model.load("path/to/model.bin");
    
  3. 初始化参数

    model.init(40, 0.9f, 0.8f, 1.1f, 64, 42, -1);
    
  4. 预填充模型

    model.prefill("你好,请介绍一下自己。");
    
  5. 迭代生成

    int token;
    std::string result;
    while (true) {
        std::string next = model.decode_iter(token);
        if (token == end_token) break;
        result += next;
    }
    

总结

Model 类是 InferLLM 框架的对外接口层,它通过桥接模式封装了底层模型实现的细节,提供了一套简洁、稳定的 API 供用户使用。它支持模型加载、参数配置和文本生成等核心功能,使用户可以方便地使用大型语言模型进行推理。

通过分析 model.hmodel.cpp,我们可以看到 InferLLM 框架采用了良好的设计模式和接口抽象,使得框架具有良好的可扩展性和可维护性。