
InferLLM大模型推理框架项目(07)——KvStorage类的实现(src/core/kvstorage.h+.cpp)
KvStorage 类代码结构与功能实现分析
KvStorage
类是 InferLLM 框架中用于存储注意力机制中的 Key 和 Value 缓存的特殊张量类。它继承自 Tensor
类,并提供了动态扩展内存的能力,以适应长序列生成过程中不断增长的 KV 缓存需求。
1. KvStorageConfig 类
首先,框架定义了一个单例类 KvStorageConfig
,用于管理 KV 缓存的全局配置:
class KvStorageConfig {
public:
constexpr static uint32_t START_KV_INDEX = 100; // 初始 KV 缓存大小
constexpr static uint32_t KV_STEP = 100; // KV 缓存扩展步长
static std::shared_ptr<KvStorageConfig> instance;
// 单例模式获取实例
static std::shared_ptr<KvStorageConfig> get_instance() {
if (instance == nullptr) {
instance = std::make_shared<KvStorageConfig>();
}
return instance;
}
// 增加 KV 缓存计数
uint32_t increase_count() {
m_kv_count++;
return m_kv_count;
}
// 获取起始索引
uint32_t get_start_index() { return START_KV_INDEX + m_kv_count * 2; }
// 获取当前计数
uint32_t get_count() { return m_kv_count; }
private:
uint32_t m_kv_count = 0; // KV 缓存计数
};
这个类使用单例模式,确保全局只有一个实例,用于跟踪和管理所有 KV 缓存的分配。它定义了两个关键常量:
START_KV_INDEX
:初始 KV 缓存大小,设为 100KV_STEP
:KV 缓存扩展步长,也设为 100
2. KvStorage 类
KvStorage
类继承自 Tensor
类,并添加了管理 KV 缓存的特殊功能:
class KvStorage : public Tensor {
public:
// 构造函数
KvStorage(std::vector<size_t> shape, DType dtype, Device* device);
// 析构函数
~KvStorage() {
auto data = ptr();
device()->aligned_free(data);
}
// 获取当前数据指针
void* get_current_data();
// 设置共享内存
void set_shared_memory(void* data, size_t length = 0) override;
// 准备指定长度的数据
TensorState prepare_data_with_length(uint32_t len);
// 增加索引
size_t add_id(uint32_t id);
// 获取当前索引
size_t current_index() const { return m_store_id; }
// 重置索引
void reset_id();
private:
size_t m_store_id; // 当前存储索引
size_t m_total_id; // 总索引数
uint32_t m_curr_id; // 当前分配的索引数
void* m_curr_data; // 当前数据指针
uint32_t m_kv_id; // KV 缓存 ID
};
2.1 构造函数
KvStorage::KvStorage(std::vector<size_t> shape, DType dtype, Device* device)
: Tensor(device, "kvstorage") {
m_store_id = 0;
m_total_id = shape[0];
m_kv_id = KvStorageConfig::get_instance()->increase_count();
m_curr_id = KvStorageConfig::get_instance()->get_start_index();
//! only allocate the memory of length m_curr_id * embd
shape[0] = m_curr_id;
set_shape(shape, dtype);
size_t len = length_in_byte();
//! no need use memory pool
auto data = device->aligned_alloc(len);
set_shared_memory(data, len);
}
构造函数初始化了 KV 缓存的关键参数:
- 设置当前存储索引
m_store_id
为 0 - 记录总索引数
m_total_id
- 从
KvStorageConfig
获取一个新的 KV 缓存 ID - 获取初始分配的索引数
m_curr_id
- 调整形状,只分配初始需要的内存
- 使用
aligned_alloc
分配对齐内存 - 设置为共享内存
注意,这里没有使用设备的内存池,而是直接分配对齐内存,这是因为 KV 缓存需要长期存在,不适合使用内存池管理。
2.2 set_shared_memory 方法
void KvStorage::set_shared_memory(void* data, size_t size) {
Tensor::set_shared_memory(data, size);
m_curr_data =
static_cast<char*>(ptr()) +
static_cast<size_t>((stride()[0] * m_store_id * dtype_in_byte(dtype())));
}
这个方法重写了 Tensor
类的 set_shared_memory
方法,除了设置共享内存外,还更新了当前数据指针 m_curr_data
。
2.3 prepare_data_with_length 方法
TensorState KvStorage::prepare_data_with_length(uint32_t len) {
Tensor::prepare_data();
//! if memory is not enough, allocate a new memory and copy the data to the new
if (m_store_id + len >= m_curr_id) {
auto shape = this->shape();
shape[0] = m_curr_id + KvStorageConfig::KV_STEP;
size_t old_len = length_in_byte();
void* old_ptr = ptr();
set_shape(shape, dtype());
size_t len = length_in_byte();
auto data = device()->aligned_alloc(len);
device()->device2device_copy(data, old_ptr, old_len);
device()->aligned_free(old_ptr);
set_shared_memory(data, len);
m_curr_id += KvStorageConfig::KV_STEP;
}
m_curr_data =
static_cast<char*>(ptr()) +
static_cast<size_t>((stride()[0] * m_store_id * dtype_in_byte(dtype())));
return TensorState::Own;
}
这是 KvStorage
类的核心方法,它实现了动态扩展内存的功能:
- 首先调用基类的
prepare_data
方法确保数据已准备好 - 检查当前内存是否足够存储新的数据
- 如果内存不足,则:
- 计算新的形状,增加
KV_STEP
个索引 - 记录旧内存的长度和指针
- 更新形状
- 分配新的内存
- 将旧内存的数据复制到新内存
- 释放旧内存
- 设置新的共享内存
- 更新当前分配的索引数
- 计算新的形状,增加
- 更新当前数据指针
- 返回
TensorState::Own
表示数据已准备好
这个方法确保了 KV 缓存可以在需要时动态扩展,而不会丢失已有的数据。
2.4 其他方法
void* KvStorage::get_current_data() {
INFER_ASSERT(
is_own(),
"The Kvstorage is not ready, please call prepare_data ahead.");
m_curr_data = static_cast<char*>(ptr()) +
static_cast<size_t>(
(stride()[0] * m_store_id * dtype_in_byte(dtype())));
return m_curr_data;
}
get_current_data
方法返回当前数据指针,它首先检查数据是否已准备好,然后计算当前数据的位置并返回。
size_t KvStorage::add_id(uint32_t id) {
INFER_ASSERT(id + m_store_id < m_total_id, "KvStorage add id error!");
m_store_id += id;
m_curr_data = static_cast<char*>(ptr()) +
static_cast<size_t>(
(stride()[0] * m_store_id * dtype_in_byte(dtype())));
return m_store_id;
}
add_id
方法增加当前存储索引,并更新当前数据指针。它首先检查增加后的索引是否超过总索引数,然后更新索引和数据指针。
void KvStorage::reset_id() {
m_store_id = 0;
m_curr_data = ptr();
}
reset_id
方法重置当前存储索引和数据指针,通常在处理新的序列时调用。
3. KvStorage 的工作流程
KvStorage 的典型工作流程如下:
-
创建 KvStorage:
auto kv_cache = std::make_shared<KvStorage>(shape, dtype, device);
这会分配初始大小的内存。
-
准备数据:
kv_cache->prepare_data_with_length(len);
这会确保有足够的内存存储新的数据,必要时会扩展内存。
-
获取当前数据指针:
auto data = kv_cache->get_current_data();
这会返回当前数据的指针,用于读写数据。
-
增加索引:
kv_cache->add_id(1);
这会增加当前存储索引,通常在处理完一个 token 后调用。
-
重置索引:
kv_cache->reset_id();
这会重置当前存储索引,通常在处理新的序列时调用。
4. KvStorage 的优化策略
KvStorage 实现了几种优化策略:
- 延迟分配:初始只分配一部分内存,随着需求增加再扩展
- 内存对齐:使用对齐内存分配,提高内存访问效率
- 内存复用:在扩展内存时,复制已有数据,避免重新计算
- 索引管理:通过索引管理数据位置,避免不必要的内存移动
总结
KvStorage 类是 InferLLM 框架中用于管理注意力机制 KV 缓存的特殊张量类。它通过动态扩展内存的方式,支持长序列生成过程中不断增长的 KV 缓存需求。它的设计考虑了内存效率和访问效率,是框架中实现高效推理的关键组件之一。
