
InferLLM大模型推理框架项目(10)——ThreadPool类的实现(src/core/thread_pool.h+.cpp)
ThreadPool 类代码结构与功能实现分析
ThreadPool
类是 InferLLM 框架中的线程池实现,用于并行执行计算任务,提高模型推理的性能。下面对 thread_pool.h
和 thread_pool.cpp
的代码结构和功能实现进行详细分析。
1. INFER_PAUSE 宏定义
首先,框架定义了一个平台相关的 INFER_PAUSE
宏,用于实现 CPU 级别的暂停操作:
#ifndef INFER_PAUSE
# if defined __GNUC__ && (defined __i386__ || defined __x86_64__)
# if !defined(__SSE2__)
static inline void non_sse_mm_pause() { __asm__ __volatile__ ("rep; nop"); }
# define _mm_pause non_sse_mm_pause
# else
# include <immintrin.h>
# endif
# define INFER_PAUSE(v) do { for (int __delay = (v); __delay > 0; --__delay) { _mm_pause(); } } while (0)
# elif defined __GNUC__ && defined __aarch64__
# define INFER_PAUSE(v) do { for (int __delay = (v); __delay > 0; --__delay) { asm volatile("yield" ::: "memory"); } } while (0)
# elif defined __GNUC__ && defined __arm__
# define INFER_PAUSE(v) do { for (int __delay = (v); __delay > 0; --__delay) { asm volatile("" ::: "memory"); } } while (0)
# elif defined __GNUC__ && defined __riscv
# define INFER_PAUSE(v) do { for (int __delay = (v); __delay > 0; --__delay) { asm volatile("nop"); } } while (0)
# else
# warning "Can't detect 'pause' (CPU-yield) instruction on the target platform. Specify INFER_PAUSE() definition via compiler flags."
# define INFER_PAUSE(...) do { /* no-op: works, but not effective */ } while (0)
# endif
#endif // MTDA_PAUSE
这个宏根据不同的平台和架构,使用不同的指令实现 CPU 暂停操作,主要用于自旋锁中减少 CPU 资源消耗。
2. Worker 结构体
Worker
结构体表示一个工作线程及其状态:
struct Worker {
public:
Worker(std::function<void()>&& run) : thread{run} {}
~Worker() { thread.join(); }
//! Worker thread
std::thread thread;
//! Indicate whether the Worker thread need run
std::atomic<bool> work_flag{false};
};
每个 Worker
包含一个线程和一个原子标志 work_flag
,用于指示线程是否需要执行任务。
3. ThreadPool 类
ThreadPool
类是线程池的主要实现:
class ThreadPool {
public:
//! Create thread-pool nr_threads thread_pool
ThreadPool(uint32_t nr_threads);
//! The main thread set the task, parallelism and worker flag to
//! notify other thread.
void add_task(const MultiThreadingTask& task, uint32_t nr_task);
inline void sync();
//! wake up all the threads from cv.wait(), when the thread pool is not
//! active, all the threads will go to sleep.
inline void active();
//! all the threads go to sleep which will reduce CPU occupation
void deactive();
~ThreadPool();
uint32_t nr_threads() const { return m_nr_threads; }
//! The number of iterations < main thread yeild resource>
static constexpr int MAIN_THREAD_ACTIVE_WAIT = 10000;
//! The number of iterations < worker thread yeild resource>
static constexpr int WORKER_ACTIVE_WAIT = 2000;
//! The number of iterations <pause>
static constexpr int ACTIVE_WAIT_PAUSE_LIMIT = 16;
private:
uint32_t m_nr_threads = 1;
//! All the sub task number
uint32_t m_nr_task = 0;
uint32_t m_task_per_thread = 0;
std::atomic_bool m_stop{false};
std::atomic_bool m_active{false};
//! The executable funcition pointer
MultiThreadingTask m_task;
std::vector<Worker*> m_workers;
//! The cv and mutex for threading activity
std::condition_variable m_cv;
std::mutex m_mutex;
};
3.1 构造函数
ThreadPool::ThreadPool(uint32_t threads_num)
: m_nr_threads(threads_num), m_stop{false}, m_active{false} {
if (threads_num < 1) {
m_nr_threads = 1;
}
if (m_nr_threads > 1) {
auto system_cpu_count = std::thread::hardware_concurrency();
if (m_nr_threads > system_cpu_count) {
INFER_LOG(
"The number of threads is bigger than number of "
"physical cpu cores, got: %d core_number: %d",
system_cpu_count, nr_threads());
}
for (uint32_t i = 0; i < m_nr_threads - 1; i++) {
m_workers.push_back(new Worker([this, i]() {
while (!m_stop) {
while (m_active) {
//! if the thread should work
if (m_workers[i]->work_flag.load(std::memory_order_acquire)) {
m_task(TaskId{
i * m_task_per_thread,
std::min((i + 1) * m_task_per_thread, m_nr_task),
i});
//! Flag worker is finished
m_workers[i]->work_flag.store(
false, std::memory_order_release);
}
//! Wait next task coming
for (int it = 0; it < WORKER_ACTIVE_WAIT; it++) {
if (m_workers[i]->work_flag.load(
std::memory_order_acquire)) {
break;
}
if (it < ACTIVE_WAIT_PAUSE_LIMIT || (it & 1)) {
INFER_PAUSE(16); // Spin lock's CPU-level yield
} else {
// Spin lock's OS-level yield
std::this_thread::yield();
}
}
}
{
std::unique_lock<std::mutex> lock(m_mutex);
if (!m_stop && !m_active) {
m_cv.wait(lock, [this] { return m_stop || m_active; });
}
}
}
}));
}
}
}
构造函数创建指定数量的工作线程:
- 确保线程数至少为 1
- 检查线程数是否超过系统 CPU 核心数,如果超过则发出警告
- 创建
m_nr_threads - 1
个工作线程(主线程也会参与计算) - 每个工作线程的主循环:
- 外层循环:当
m_stop
为 false 时持续运行 - 中层循环:当
m_active
为 true 时处于活动状态 - 内层逻辑:
- 如果
work_flag
为 true,执行任务并将work_flag
设为 false - 使用自旋等待检查是否有新任务
- 如果超过等待限制,则使用
std::this_thread::yield()
让出 CPU
- 如果
- 当
m_active
为 false 时,使用条件变量等待唤醒
- 外层循环:当
3.2 add_task 方法
void ThreadPool::add_task(const MultiThreadingTask& task, uint32_t nr_task) {
//! If only one thread or one task, execute directly
if (m_nr_threads == 1 || nr_task == 1) {
task({0, nr_task, m_nr_threads - 1});
return;
} else {
active();
INFER_ASSERT(m_active, "thread pool is not actived.");
m_nr_task = nr_task;
//! Set the task number, task iter and task
m_task_per_thread = (nr_task + m_nr_threads - 1) / m_nr_threads;
m_task = std::move(task);
for (uint32_t i = 0; i < m_nr_threads - 1; i++) {
m_workers[i]->work_flag.store(true, std::memory_order_release);
}
//! Main thread working
uint32_t start = (m_nr_threads - 1) * m_task_per_thread;
// printf("main threads start\n");
m_task({start, nr_task, m_nr_threads - 1});
//! make sure all threads done
sync();
}
}
add_task
方法向线程池添加任务:
- 如果只有一个线程或一个任务,直接在当前线程执行
- 否则:
- 激活线程池
- 设置任务数量和每个线程的任务数
- 设置任务函数
- 将所有工作线程的
work_flag
设为 true,通知它们开始工作 - 主线程也参与计算,处理剩余的任务
- 调用
sync()
等待所有线程完成任务
3.3 sync 方法
inline void ThreadPool::sync() {
bool no_finished = false;
uint32_t no_finished_id = 0;
do {
no_finished = false;
for (uint32_t i = no_finished_id; i < m_nr_threads - 1; ++i) {
if (m_workers[i]->work_flag.load(std::memory_order_acquire)) {
no_finished = true;
no_finished_id = i;
break;
}
}
if (no_finished) {
for (int it = 0; it < MAIN_THREAD_ACTIVE_WAIT; it++) {
if (!m_workers[no_finished_id]->work_flag.load(
std::memory_order_acquire)) {
break;
}
if ((it < ACTIVE_WAIT_PAUSE_LIMIT || (it & 1))) {
INFER_PAUSE(16);
} else {
std::this_thread::yield();
}
}
}
} while (no_finished);
}
sync
方法等待所有工作线程完成任务:
- 检查是否有未完成的线程(
work_flag
为 true) - 如果有,使用自旋等待该线程完成
- 如果超过等待限制,则使用
std::this_thread::yield()
让出 CPU - 重复上述过程,直到所有线程完成任务
3.4 active 和 deactive 方法
inline void ThreadPool::active() {
if (!m_active) {
std::unique_lock<std::mutex> lock(m_mutex);
m_active = true;
m_cv.notify_all();
}
}
void ThreadPool::deactive() {
std::unique_lock<std::mutex> lock(m_mutex);
m_active = false;
}
active
方法激活线程池,唤醒所有等待的工作线程deactive
方法停用线程池,使工作线程进入休眠状态,减少 CPU 占用
3.5 析构函数
ThreadPool::~ThreadPool() {
{
std::unique_lock<std::mutex> lock(m_mutex);
m_stop = true;
m_active = false;
m_cv.notify_all();
}
for (auto& worker : m_workers) {
delete worker;
}
}
析构函数清理线程池资源:
- 设置
m_stop
为 true,通知所有工作线程退出 - 唤醒所有等待的工作线程
- 删除所有工作线程对象
4. 任务分配策略
ThreadPool 使用简单的任务分配策略:
- 将总任务数
nr_task
平均分配给所有线程 - 每个线程处理
m_task_per_thread = (nr_task + m_nr_threads - 1) / m_nr_threads
个任务 - 第 i 个线程处理的任务范围是
[i * m_task_per_thread, min((i + 1) * m_task_per_thread, m_nr_task)]
- 主线程(最后一个线程)处理剩余的任务
5. 同步机制
ThreadPool 使用多种同步机制:
-
原子变量:
m_stop
:指示线程池是否停止m_active
:指示线程池是否处于活动状态work_flag
:指示工作线程是否需要执行任务
-
条件变量:
m_cv
:用于在线程池激活时唤醒工作线程
-
自旋锁:
- 使用
INFER_PAUSE
和std::this_thread::yield()
实现自旋等待,减少线程切换开销
- 使用
6. 优化策略
ThreadPool 实现了几种优化策略:
-
主线程参与计算:
主线程也参与任务计算,减少线程创建和切换开销。 -
自旋等待:
使用自旋等待检查任务状态,减少线程切换开销。自旋等待分为两个阶段:- 前
ACTIVE_WAIT_PAUSE_LIMIT
次迭代使用 CPU 级别的暂停操作 - 之后交替使用 CPU 级别的暂停和 OS 级别的让出
- 前
-
线程休眠:
当线程池不活动时,工作线程进入休眠状态,减少 CPU 占用。 -
任务合并:
当只有一个线程或一个任务时,直接在当前线程执行,避免线程调度开销。 -
动态休眠:
线程池可以通过deactive()
方法进入休眠状态,在不需要并行计算时减少 CPU 占用。 -
快速检测:
在sync()
方法中,记录上一个未完成的线程 ID,优先检查该线程,减少不必要的检查。
7. 使用示例
ThreadPool 的典型使用方式如下:
// 创建线程池
ThreadPool pool(4); // 4个线程
// 定义任务函数
auto task = [&](const TaskId& task_id) {
// task_id.start: 任务起始索引
// task_id.end: 任务结束索引
// task_id.thread_id: 线程ID
for (uint32_t i = task_id.start; i < task_id.end; ++i) {
// 处理第i个任务
}
};
// 添加任务
pool.add_task(task, 100); // 100个任务
// 停用线程池(不需要并行计算时)
pool.deactive();
// 激活线程池(需要并行计算时)
pool.active();
8. TaskId 结构体
struct TaskId {
uint32_t start; // 任务起始索引
uint32_t end; // 任务结束索引
uint32_t thread_id; // 线程ID
};
TaskId
结构体用于描述任务的范围和执行线程,它包含三个字段:
start
:任务的起始索引end
:任务的结束索引thread_id
:执行任务的线程ID
9. MultiThreadingTask 类型
using MultiThreadingTask = std::function<void(const TaskId&)>;
MultiThreadingTask
是一个函数类型,表示可以并行执行的任务。它接受一个 TaskId
参数,用于指定任务的范围和执行线程。
10. 线程池的生命周期管理
ThreadPool 的生命周期管理非常重要,它涉及到线程的创建、激活、停用和销毁:
-
创建:
在构造函数中创建工作线程,但线程初始处于非活动状态。 -
激活:
调用active()
方法激活线程池,唤醒所有工作线程。 -
停用:
调用deactive()
方法停用线程池,使工作线程进入休眠状态。 -
销毁:
在析构函数中设置m_stop
为 true,通知所有工作线程退出,然后等待线程结束并释放资源。
11. 与其他组件的交互
ThreadPool 主要与 CPUDevice 类交互,用于并行执行 CPU 上的计算任务:
// 在 CPUDevice 中创建线程池
m_thread_pool = std::make_unique<ThreadPool>(nr_thread);
// 使用线程池执行并行任务
m_thread_pool->add_task([&](const TaskId& task_id) {
// 执行计算任务
}, total_tasks);
总结
ThreadPool 类是 InferLLM 框架中的线程池实现,用于并行执行计算任务,提高模型推理的性能。它采用了多种优化策略,如主线程参与计算、自旋等待、线程休眠和任务合并等,以减少线程调度开销并提高并行效率。它还提供了灵活的接口,支持动态激活和停用线程池,以适应不同的计算需求。
通过分析 thread_pool.h
和 thread_pool.cpp
,我们可以看到 InferLLM 框架在并行计算方面的设计思想和实现细节,这对于理解框架的性能优化策略非常有帮助。
