音乐播放器
sola的小屋
 
文章 标签
20

Powered by Gridea | Theme: Fog
载入天数...
载入时分秒...
总访问量:  |   访问人数:

InferLLM大模型推理框架项目(04)——Tensor类的实现(src/core/tensor.h and .cpp)

Tensor 类代码结构与功能实现分析

Tensor 类是 InferLLM 框架中的核心数据结构,用于表示和管理多维数组数据。下面对其代码结构和功能实现进行详细分析。

1. 数据类型定义

首先,框架定义了多种数据类型,以支持不同精度的计算:

enum class DType {
    Float32 = 0,
    Float16 = 1,
    Float8 = 2,
    Int32 = 3,
    Int16 = 4,
    Int8 = 5,
    Uint8 = 6,
    Int4 = 7,
    Uint4 = 8,
    Int2 = 9,
};

同时定义了两个辅助函数:

  • dtype_in_byte:获取数据类型的字节大小
  • dtype_block_size:获取数据类型的块大小(用于量化类型)

2. Tensor 类的内存管理

Tensor 类支持三种内存来源方式:

  1. 自己分配的内存
  2. 外部共享的内存
  3. 从文件映射的内存(主要用于权重)

这通过 TensorState 枚举进行管理:

enum class TensorState {
    Own = 0,    // 拥有内存
    OutSide = 1, // 内存在外部
};

3. Tensor 类的主要成员变量

private:
    bool m_shared = false;           // 是否是共享内存
    int32_t m_usr_count = 0;         // 用户计数(引用计数)
    int32_t m_cur_count = 0;         // 当前用户计数
    
    Device* m_device;                // 设备指针
    OpBase* m_owner_op;              // 拥有该张量的算子
    
    TensorState m_state;             // 张量状态
    std::shared_ptr<InputFile> m_file; // 文件指针(用于从文件读取数据)
    size_t m_file_offset = 0;        // 文件偏移
    
    uint32_t m_dims = 0;             // 维度数
    size_t m_length = 0;             // 元素总数
    DType m_dtype;                   // 数据类型
    std::vector<size_t> m_shape;     // 形状
    std::vector<size_t> m_stride;    // 步长
    void* m_data = nullptr;          // 数据指针
    std::string m_name;              // 名称

4. 主要功能实现

4.1 构造函数

Tensor 类提供了两种构造方式:

// 通过设备和名称构造
Tensor(Device* device, std::string name);

// 通过形状、数据类型和设备构造
Tensor(std::vector<size_t> shape, DType dtype, Device* device);

初始状态都是 TensorState::OutSide,表示数据不在设备上。

4.2 形状和数据类型管理

// 设置形状
void set_shape(std::vector<size_t> shape);

// 计算步长和总长度
m_stride.resize(m_dims);
m_stride[m_dims - 1] = 1;
for (uint32_t i = 1; i < m_dims; i++) {
    m_stride[m_dims - 1 - i] = m_stride[m_dims - i] * m_shape[m_dims - i];
}
m_length = m_shape[0] * m_stride[0];

这里的步长计算采用了行优先(row-major)的存储方式,与 C/C++ 的多维数组存储一致。

4.3 内存管理核心方法

prepare_data()

TensorState Tensor::prepare_data() {
    size_t length = length_in_byte();
    if (!m_data && m_state == TensorState::OutSide) {
        if (m_file) {
            read_data_from_file();
        } else {
            m_data = m_device->allocate(length);
        }
    }
    m_state = TensorState::Own;
    return m_state;
}

该方法确保张量数据在设备上可用:

  • 如果数据不存在且状态为 OutSide
    • 如果有关联文件,从文件读取
    • 否则,从设备分配内存
  • 将状态设置为 Own

recall_data()

TensorState Tensor::recall_data() {
    if (m_shared) {
        return m_state;
    }
    if (!m_file && m_data != nullptr && m_state == TensorState::Own) {
        m_device->free_device(m_data);
        m_data = nullptr;
    }
    m_state = TensorState::OutSide;
    return m_state;
}

该方法释放张量占用的设备内存:

  • 如果是共享内存,不做任何操作
  • 如果不是从文件读取的,且数据存在且状态为 Own,释放设备内存
  • 将状态设置为 OutSide

4.4 文件数据读取

size_t Tensor::read_data_from_file() {
    size_t length = length_in_byte();
    if (m_file->enable_mmap()) {
        // 使用内存映射
        if (!m_device->unified_memory()) {
            // 非统一内存设备(如GPU)
            auto temp_ptr = m_file->get_mmap_data(length, m_file_offset);
            m_data = m_device->allocate(length);
            m_device->host2device_copy(m_data, temp_ptr, length);
        } else {
            // 统一内存设备(如CPU)
            m_data = m_file->get_mmap_data(length, m_file_offset);
        }
    } else {
        // 不使用内存映射,直接读取
        // ...(处理统一内存和非统一内存的情况)
    }
    return length;
}

该方法从文件读取数据到张量:

  • 支持内存映射和直接读取两种方式
  • 针对统一内存设备(CPU)和非统一内存设备(GPU)有不同处理
  • 支持权重预处理

4.5 引用计数管理

int32_t add_user() {
    m_usr_count++;
    return m_usr_count;
}

int32_t decrease_curr_user_count() {
    if (!m_shared) {
        INFER_ASSERT(m_cur_count > 0, "The user count is less than 0.");
        m_cur_count--;
        if (m_cur_count == 0) {
            recall_data();
        }
    }
    return m_cur_count;
}

int32_t resume_user_count() {
    m_cur_count = m_usr_count;
    return m_cur_count;
}

这组方法实现了引用计数机制:

  • add_user():增加用户计数
  • decrease_curr_user_count():减少当前用户计数,当计数为0时释放内存
  • resume_user_count():恢复用户计数(用于重用张量)

4.6 数据访问

void* ptr() {
    INFER_ASSERT(is_own(), "Tensor is OutSide the device, can't get the memory.");
    return m_data;
}

template <typename T>
T* ptr() {
    INFER_ASSERT(is_own(), "Tensor is OutSide the device, can't get the memory.");
    return static_cast<T*>(m_data);
}

这些方法提供了对张量数据的访问:

  • 确保数据在设备上可用
  • 支持类型转换

5. WorkSpace 类

WorkSpace 是一个简单的内存容器类,用于临时存储计算过程中的数据:

class WorkSpace {
public:
    void* ptr() { return m_data; };
    
    template <typename T>
    T* ptr() {
        return static_cast<T*>(m_data);
    }
    
    size_t length() { return m_length; }
    
    void set_memory(void* data, size_t length) {
        m_data = data;
        m_length = length;
    }
    
private:
    void* m_data = nullptr;
    size_t m_length = 0;
};

它不负责内存的分配和释放,只是提供一个访问接口。

6. 内存优化策略

Tensor 类实现了几种内存优化策略:

  1. 延迟分配:只有在实际需要时才分配内存
  2. 引用计数:通过引用计数管理内存生命周期
  3. 内存映射:支持从文件直接映射内存,减少内存拷贝
  4. 内存共享:支持共享外部内存,避免不必要的拷贝

总结

Tensor类是 InferLLM 框架的核心数据结构,提供了灵活的内存管理和数据访问机制。它支持多种数据类型、多种内存来源方式,并通过引用计数优化内存使用。同时,它还提供了从文件读取数据和数据预处理的功能,为模型权重的加载和处理提供了支持。