音乐播放器
sola的小屋
 
文章 标签
20

Powered by Gridea | Theme: Fog
载入天数...
载入时分秒...
总访问量:  |   访问人数:

InferLLM大模型推理框架项目(07)——KvStorage类的实现(src/core/kvstorage.h+.cpp)

KvStorage 类代码结构与功能实现分析

KvStorage 类是 InferLLM 框架中用于存储注意力机制中的 Key 和 Value 缓存的特殊张量类。它继承自 Tensor 类,并提供了动态扩展内存的能力,以适应长序列生成过程中不断增长的 KV 缓存需求。

1. KvStorageConfig 类

首先,框架定义了一个单例类 KvStorageConfig,用于管理 KV 缓存的全局配置:

class KvStorageConfig {
public:
    constexpr static uint32_t START_KV_INDEX = 100;  // 初始 KV 缓存大小
    constexpr static uint32_t KV_STEP = 100;         // KV 缓存扩展步长
    static std::shared_ptr<KvStorageConfig> instance;
    
    // 单例模式获取实例
    static std::shared_ptr<KvStorageConfig> get_instance() {
        if (instance == nullptr) {
            instance = std::make_shared<KvStorageConfig>();
        }
        return instance;
    }
    
    // 增加 KV 缓存计数
    uint32_t increase_count() {
        m_kv_count++;
        return m_kv_count;
    }
    
    // 获取起始索引
    uint32_t get_start_index() { return START_KV_INDEX + m_kv_count * 2; }
    
    // 获取当前计数
    uint32_t get_count() { return m_kv_count; }
    
private:
    uint32_t m_kv_count = 0;  // KV 缓存计数
};

这个类使用单例模式,确保全局只有一个实例,用于跟踪和管理所有 KV 缓存的分配。它定义了两个关键常量:

  • START_KV_INDEX:初始 KV 缓存大小,设为 100
  • KV_STEP:KV 缓存扩展步长,也设为 100

2. KvStorage 类

KvStorage 类继承自 Tensor 类,并添加了管理 KV 缓存的特殊功能:

class KvStorage : public Tensor {
public:
    // 构造函数
    KvStorage(std::vector<size_t> shape, DType dtype, Device* device);
    
    // 析构函数
    ~KvStorage() {
        auto data = ptr();
        device()->aligned_free(data);
    }
    
    // 获取当前数据指针
    void* get_current_data();
    
    // 设置共享内存
    void set_shared_memory(void* data, size_t length = 0) override;
    
    // 准备指定长度的数据
    TensorState prepare_data_with_length(uint32_t len);
    
    // 增加索引
    size_t add_id(uint32_t id);
    
    // 获取当前索引
    size_t current_index() const { return m_store_id; }
    
    // 重置索引
    void reset_id();
    
private:
    size_t m_store_id;    // 当前存储索引
    size_t m_total_id;    // 总索引数
    uint32_t m_curr_id;   // 当前分配的索引数
    void* m_curr_data;    // 当前数据指针
    uint32_t m_kv_id;     // KV 缓存 ID
};

2.1 构造函数

KvStorage::KvStorage(std::vector<size_t> shape, DType dtype, Device* device)
        : Tensor(device, "kvstorage") {
    m_store_id = 0;
    m_total_id = shape[0];
    m_kv_id = KvStorageConfig::get_instance()->increase_count();
    m_curr_id = KvStorageConfig::get_instance()->get_start_index();
    //! only allocate the memory of length m_curr_id * embd
    shape[0] = m_curr_id;
    set_shape(shape, dtype);
    size_t len = length_in_byte();
    //! no need use memory pool
    auto data = device->aligned_alloc(len);
    set_shared_memory(data, len);
}

构造函数初始化了 KV 缓存的关键参数:

  1. 设置当前存储索引 m_store_id 为 0
  2. 记录总索引数 m_total_id
  3. KvStorageConfig 获取一个新的 KV 缓存 ID
  4. 获取初始分配的索引数 m_curr_id
  5. 调整形状,只分配初始需要的内存
  6. 使用 aligned_alloc 分配对齐内存
  7. 设置为共享内存

注意,这里没有使用设备的内存池,而是直接分配对齐内存,这是因为 KV 缓存需要长期存在,不适合使用内存池管理。

2.2 set_shared_memory 方法

void KvStorage::set_shared_memory(void* data, size_t size) {
    Tensor::set_shared_memory(data, size);
    m_curr_data =
            static_cast<char*>(ptr()) +
            static_cast<size_t>((stride()[0] * m_store_id * dtype_in_byte(dtype())));
}

这个方法重写了 Tensor 类的 set_shared_memory 方法,除了设置共享内存外,还更新了当前数据指针 m_curr_data

2.3 prepare_data_with_length 方法

TensorState KvStorage::prepare_data_with_length(uint32_t len) {
    Tensor::prepare_data();
    //! if memory is not enough, allocate a new memory and copy the data to the new
    if (m_store_id + len >= m_curr_id) {
        auto shape = this->shape();
        shape[0] = m_curr_id + KvStorageConfig::KV_STEP;
        size_t old_len = length_in_byte();
        void* old_ptr = ptr();

        set_shape(shape, dtype());
        size_t len = length_in_byte();
        auto data = device()->aligned_alloc(len);
        device()->device2device_copy(data, old_ptr, old_len);

        device()->aligned_free(old_ptr);

        set_shared_memory(data, len);
        m_curr_id += KvStorageConfig::KV_STEP;
    }
    m_curr_data =
            static_cast<char*>(ptr()) +
            static_cast<size_t>((stride()[0] * m_store_id * dtype_in_byte(dtype())));
    return TensorState::Own;
}

这是 KvStorage 类的核心方法,它实现了动态扩展内存的功能:

  1. 首先调用基类的 prepare_data 方法确保数据已准备好
  2. 检查当前内存是否足够存储新的数据
  3. 如果内存不足,则:
    • 计算新的形状,增加 KV_STEP 个索引
    • 记录旧内存的长度和指针
    • 更新形状
    • 分配新的内存
    • 将旧内存的数据复制到新内存
    • 释放旧内存
    • 设置新的共享内存
    • 更新当前分配的索引数
  4. 更新当前数据指针
  5. 返回 TensorState::Own 表示数据已准备好

这个方法确保了 KV 缓存可以在需要时动态扩展,而不会丢失已有的数据。

2.4 其他方法

void* KvStorage::get_current_data() {
    INFER_ASSERT(
            is_own(),
            "The Kvstorage is not ready, please call prepare_data ahead.");
    m_curr_data = static_cast<char*>(ptr()) +
                  static_cast<size_t>(
                          (stride()[0] * m_store_id * dtype_in_byte(dtype())));
    return m_curr_data;
}

get_current_data 方法返回当前数据指针,它首先检查数据是否已准备好,然后计算当前数据的位置并返回。

size_t KvStorage::add_id(uint32_t id) {
    INFER_ASSERT(id + m_store_id < m_total_id, "KvStorage add id error!");
    m_store_id += id;
    m_curr_data = static_cast<char*>(ptr()) +
                  static_cast<size_t>(
                          (stride()[0] * m_store_id * dtype_in_byte(dtype())));
    return m_store_id;
}

add_id 方法增加当前存储索引,并更新当前数据指针。它首先检查增加后的索引是否超过总索引数,然后更新索引和数据指针。

void KvStorage::reset_id() {
    m_store_id = 0;
    m_curr_data = ptr();
}

reset_id 方法重置当前存储索引和数据指针,通常在处理新的序列时调用。

3. KvStorage 的工作流程

KvStorage 的典型工作流程如下:

  1. 创建 KvStorage

    auto kv_cache = std::make_shared<KvStorage>(shape, dtype, device);
    

    这会分配初始大小的内存。

  2. 准备数据

    kv_cache->prepare_data_with_length(len);
    

    这会确保有足够的内存存储新的数据,必要时会扩展内存。

  3. 获取当前数据指针

    auto data = kv_cache->get_current_data();
    

    这会返回当前数据的指针,用于读写数据。

  4. 增加索引

    kv_cache->add_id(1);
    

    这会增加当前存储索引,通常在处理完一个 token 后调用。

  5. 重置索引

    kv_cache->reset_id();
    

    这会重置当前存储索引,通常在处理新的序列时调用。

4. KvStorage 的优化策略

KvStorage 实现了几种优化策略:

  1. 延迟分配:初始只分配一部分内存,随着需求增加再扩展
  2. 内存对齐:使用对齐内存分配,提高内存访问效率
  3. 内存复用:在扩展内存时,复制已有数据,避免重新计算
  4. 索引管理:通过索引管理数据位置,避免不必要的内存移动

总结

KvStorage 类是 InferLLM 框架中用于管理注意力机制 KV 缓存的特殊张量类。它通过动态扩展内存的方式,支持长序列生成过程中不断增长的 KV 缓存需求。它的设计考虑了内存效率和访问效率,是框架中实现高效推理的关键组件之一。