音乐播放器
sola的小屋
 
文章 标签
20

Powered by Gridea | Theme: Fog
载入天数...
载入时分秒...
总访问量:  |   访问人数:

[InferLLM大模型推理框架项目](19)kernel类的定义(src/kern/kernel.h)

kernel.h 代码结构与功能实现分析

kernel.h 是 InferLLM 框架中的核心头文件,定义了 Kernel 类,该类作为不同平台内核实现的统一接口,负责根据硬件平台选择合适的计算内核实现,并管理多线程任务调度。

1. 文件结构概览

文件结构可以分为以下几个部分:

  1. 头文件引入和条件编译
  2. Kernel 类定义
  3. 构造函数和初始化
  4. 线程和优化相关方法
  5. 计算操作接口
  6. 成员变量

2. 头文件引入和条件编译

#pragma once

#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include "core/thread_pool.h"
#include "kern/kernel_define.h"
#include "utils.h"

#if INFER_X86
#include "kern/optimized/x86/kernel.h"
#elif INFER_ARM
#include "kern/optimized/arm/kernel.h"
#elif INFER_RVV
#include "kern/optimized/rvv/kernel.h"
#else
#include "kern/naive/naive.h"
#endif

#ifdef ENABLE_GPU
#include "kern/gpu/kernel_gpu.h"
#endif

这部分代码通过条件编译引入不同平台的优化实现:

  • 如果定义了 INFER_X86,则引入 x86 平台的优化实现
  • 如果定义了 INFER_ARM,则引入 ARM 平台的优化实现
  • 如果定义了 INFER_RVV,则引入 RISC-V 向量扩展的优化实现
  • 否则,引入朴素实现

此外,如果定义了 ENABLE_GPU,则引入 GPU 平台的优化实现。

这种设计使得编译时可以根据目标平台选择合适的优化实现,而不需要修改源代码。

3. Kernel 类定义

namespace inferllm {

class Kernel {
public:
    // 构造函数和方法
    // ...
};

}  // namespace inferllm

Kernel 类是整个内核系统的核心,它提供了统一的接口,根据硬件平台选择合适的计算内核实现,并管理多线程任务调度。

4. 构造函数和初始化

Kernel(KernelType kernel_type) : m_kernel_type(kernel_type) {}
Kernel(KernelType kernel_type, ThreadPool* thread_pool)
        : m_kernel_type(kernel_type), m_thread_pool(thread_pool) {
#ifdef INFER_RVV
    opt::init();
#endif
}

Kernel 类提供了两个构造函数:

  • 第一个构造函数只接受内核类型参数,不使用线程池
  • 第二个构造函数接受内核类型和线程池指针,可以进行多线程计算

如果定义了 INFER_RVV,则在构造函数中调用 opt::init() 进行 RISC-V 向量扩展的初始化。

5. 线程和优化相关方法

uint32_t nr_thread() const {
    if (m_thread_pool == nullptr)
        return 1;
    return m_thread_pool->nr_threads();
}

bool supported_optimization(KernelOptMethod method) {
    if (m_kernel_type == KernelType::Arm || m_kernel_type == KernelType::Naive) {
        if (method == KernelOptMethod::MatmulInt4Reorder) {
#if defined(__ARM_FEATURE_DOTPROD)
            return true;
#else
            return false;
#endif
        }
        return false;
    }
    return false;
}

这部分代码提供了两个方法:

  • nr_thread():返回线程池中的线程数量,如果没有线程池则返回 1
  • supported_optimization():检查是否支持特定的优化方法,目前只检查 ARM 平台是否支持 MatmulInt4Reorder 优化

supported_optimization() 方法使用条件编译检查 ARM 平台是否支持点积指令(__ARM_FEATURE_DOTPROD),如果支持则返回 true,否则返回 false

6. 计算操作接口

//! compute
template <KernelID Id, typename... Args>
void operator()(Args... args) {
    //! parallel to execute tasks
    if (m_kernel_type == KernelType::GPU) {
#if ENABLE_GPU
        gpu::Comp<Id, Args...>::exec(std::forward<Args>(args)..., m_handle);
#endif

    } else {
        TaskSet task_set =
                opt::Comp<Id, Args...>::get_all_task(std::forward<Args>(args)...);
        for (auto& task : task_set) {
            m_thread_pool->add_task(task.first, task.second);
        }
    }
}
template <KernelID Id, typename... Args>
size_t get_workspace(Args... args) {
    return opt::Space<Id, Args...>::get(std::forward<Args>(args)...);
}

这部分代码提供了两个模板方法:

  • operator():执行计算操作,根据内核类型选择不同的实现
  • get_workspace():获取计算操作所需的工作空间大小

operator() 方法是 Kernel 类的核心,它根据内核类型选择不同的实现:

  • 如果内核类型是 GPU,则调用 GPU 实现的 exec 方法
  • 否则,调用优化实现的 get_all_task 方法获取任务集,然后将任务添加到线程池中执行

这种设计使得不同平台的优化实现可以共享相同的接口,而具体实现由不同平台的优化代码提供。

7. 成员变量

ThreadPool* m_thread_pool = nullptr;
KernelType m_kernel_type;
#if ENABLE_GPU
void set_handle(cudaHandle* handle) { m_handle = handle; }
cudaHandle* m_handle;
#endif

Kernel 类有两个主要成员变量:

  • m_thread_pool:线程池指针,用于多线程计算
  • m_kernel_type:内核类型,用于选择合适的计算内核实现

如果定义了 ENABLE_GPU,则还有一个 m_handle 成员变量,用于 GPU 计算。

8. 功能分析

通过分析 kernel.h 文件,可以看出它实现了以下功能:

8.1 统一接口

通过 Kernel 类提供统一的接口,使得不同平台的优化实现可以共享相同的接口,而具体实现由不同平台的优化代码提供。

8.2 平台选择

通过条件编译和 KernelType 枚举,可以在编译时和运行时选择合适的优化实现,从而在不同硬件平台上获得最佳性能。

8.3 多线程计算

通过线程池和任务集,可以将计算任务分解为多个子任务,并分配给不同的线程处理,从而提高计算效率。

8.4 GPU 加速

通过条件编译和 KernelType::GPU,可以在支持 GPU 的平台上使用 GPU 加速计算。

9. 设计模式分析

kernel.h 文件采用了多种设计模式:

9.1 策略模式

通过 KernelType 枚举和不同平台的优化实现,可以在运行时选择不同的计算策略。

9.2 模板方法模式

通过模板方法 operator()get_workspace(),提供了统一的接口,而具体实现由不同平台的优化代码提供。

9.3 工厂模式

通过 Kernel 类的构造函数和 KernelType 枚举,可以创建不同类型的内核实例。

总结

kernel.h 是 InferLLM 框架中的核心头文件,定义了 Kernel 类,该类作为不同平台内核实现的统一接口,负责根据硬件平台选择合适的计算内核实现,并管理多线程任务调度。通过条件编译、模板方法和策略模式等技术,实现了在不同硬件平台上的高效计算,同时保持了代码的可维护性和扩展性。这种设计使得 InferLLM 框架能够在不同硬件平台上高效运行,而不需要修改核心代码。