音乐播放器
sola的小屋
 
文章 标签
20

Powered by Gridea | Theme: Fog
载入天数...
载入时分秒...
总访问量:  |   访问人数:

InferLLM大模型推理框架项目(10)——ThreadPool类的实现(src/core/thread_pool.h+.cpp)

ThreadPool 类代码结构与功能实现分析

ThreadPool 类是 InferLLM 框架中的线程池实现,用于并行执行计算任务,提高模型推理的性能。下面对 thread_pool.hthread_pool.cpp 的代码结构和功能实现进行详细分析。

1. INFER_PAUSE 宏定义

首先,框架定义了一个平台相关的 INFER_PAUSE 宏,用于实现 CPU 级别的暂停操作:

#ifndef INFER_PAUSE
# if defined __GNUC__ && (defined __i386__ || defined __x86_64__)
#   if !defined(__SSE2__)
      static inline void non_sse_mm_pause() { __asm__ __volatile__ ("rep; nop"); }
#     define _mm_pause non_sse_mm_pause
#   else
#       include <immintrin.h>
#   endif
#   define INFER_PAUSE(v) do { for (int __delay = (v); __delay > 0; --__delay) { _mm_pause(); } } while (0)
# elif defined __GNUC__ && defined __aarch64__
#   define INFER_PAUSE(v) do { for (int __delay = (v); __delay > 0; --__delay) { asm volatile("yield" ::: "memory"); } } while (0)
# elif defined __GNUC__ && defined __arm__
#   define INFER_PAUSE(v) do { for (int __delay = (v); __delay > 0; --__delay) { asm volatile("" ::: "memory"); } } while (0)
# elif defined __GNUC__ && defined __riscv
#   define INFER_PAUSE(v) do { for (int __delay = (v); __delay > 0; --__delay) { asm volatile("nop"); } } while (0)
# else
#   warning "Can't detect 'pause' (CPU-yield) instruction on the target platform. Specify INFER_PAUSE() definition via compiler flags."
#   define INFER_PAUSE(...) do { /* no-op: works, but not effective */ } while (0)
# endif
#endif // MTDA_PAUSE

这个宏根据不同的平台和架构,使用不同的指令实现 CPU 暂停操作,主要用于自旋锁中减少 CPU 资源消耗。

2. Worker 结构体

Worker 结构体表示一个工作线程及其状态:

struct Worker {
public:
    Worker(std::function<void()>&& run) : thread{run} {}
    ~Worker() { thread.join(); }
    //! Worker thread
    std::thread thread;
    //! Indicate whether the Worker thread need run
    std::atomic<bool> work_flag{false};
};

每个 Worker 包含一个线程和一个原子标志 work_flag,用于指示线程是否需要执行任务。

3. ThreadPool 类

ThreadPool 类是线程池的主要实现:

class ThreadPool {
public:
    //! Create thread-pool nr_threads thread_pool
    ThreadPool(uint32_t nr_threads);
    //! The main thread set the task, parallelism and worker flag to
    //! notify other thread.
    void add_task(const MultiThreadingTask& task, uint32_t nr_task);

    inline void sync();
    //! wake up all the threads from cv.wait(), when the thread pool is not
    //! active, all the threads will go to sleep.
    inline void active();
    //! all the threads go to sleep which will reduce CPU occupation
    void deactive();
    ~ThreadPool();

    uint32_t nr_threads() const { return m_nr_threads; }

    //! The number of iterations < main thread yeild resource>
    static constexpr int MAIN_THREAD_ACTIVE_WAIT = 10000;
    //! The number of iterations < worker thread yeild resource>
    static constexpr int WORKER_ACTIVE_WAIT = 2000;
    //! The number of iterations <pause>
    static constexpr int ACTIVE_WAIT_PAUSE_LIMIT = 16;

private:
    uint32_t m_nr_threads = 1;
    //! All the sub task number
    uint32_t m_nr_task = 0;
    uint32_t m_task_per_thread = 0;
    std::atomic_bool m_stop{false};
    std::atomic_bool m_active{false};
    //! The executable funcition pointer
    MultiThreadingTask m_task;

    std::vector<Worker*> m_workers;
    //! The cv and mutex for threading activity
    std::condition_variable m_cv;
    std::mutex m_mutex;
};

3.1 构造函数

ThreadPool::ThreadPool(uint32_t threads_num)
        : m_nr_threads(threads_num), m_stop{false}, m_active{false} {
    if (threads_num < 1) {
        m_nr_threads = 1;
    }
    if (m_nr_threads > 1) {
        auto system_cpu_count = std::thread::hardware_concurrency();
        if (m_nr_threads > system_cpu_count) {
            INFER_LOG(
                    "The number of threads is bigger than number of "
                    "physical cpu cores, got: %d core_number: %d",
                    system_cpu_count, nr_threads());
        }
        for (uint32_t i = 0; i < m_nr_threads - 1; i++) {
            m_workers.push_back(new Worker([this, i]() {
                while (!m_stop) {
                    while (m_active) {
                        //! if the thread should work
                        if (m_workers[i]->work_flag.load(std::memory_order_acquire)) {
                            m_task(TaskId{
                                    i * m_task_per_thread,
                                    std::min((i + 1) * m_task_per_thread, m_nr_task),
                                    i});
                            //! Flag worker is finished
                            m_workers[i]->work_flag.store(
                                    false, std::memory_order_release);
                        }
                        //! Wait next task coming
                        for (int it = 0; it < WORKER_ACTIVE_WAIT; it++) {
                            if (m_workers[i]->work_flag.load(
                                        std::memory_order_acquire)) {
                                break;
                            }
                            if (it < ACTIVE_WAIT_PAUSE_LIMIT || (it & 1)) {
                                INFER_PAUSE(16);  // Spin lock's CPU-level yield
                            } else {
                                // Spin lock's OS-level yield
                                std::this_thread::yield();
                            }
                        }
                    }
                    {
                        std::unique_lock<std::mutex> lock(m_mutex);
                        if (!m_stop && !m_active) {
                            m_cv.wait(lock, [this] { return m_stop || m_active; });
                        }
                    }
                }
            }));
        }
    }
}

构造函数创建指定数量的工作线程:

  1. 确保线程数至少为 1
  2. 检查线程数是否超过系统 CPU 核心数,如果超过则发出警告
  3. 创建 m_nr_threads - 1 个工作线程(主线程也会参与计算)
  4. 每个工作线程的主循环:
    • 外层循环:当 m_stop 为 false 时持续运行
    • 中层循环:当 m_active 为 true 时处于活动状态
    • 内层逻辑:
      • 如果 work_flag 为 true,执行任务并将 work_flag 设为 false
      • 使用自旋等待检查是否有新任务
      • 如果超过等待限制,则使用 std::this_thread::yield() 让出 CPU
    • m_active 为 false 时,使用条件变量等待唤醒

3.2 add_task 方法

void ThreadPool::add_task(const MultiThreadingTask& task, uint32_t nr_task) {
    //! If only one thread or one task, execute directly
    if (m_nr_threads == 1 || nr_task == 1) {
        task({0, nr_task, m_nr_threads - 1});
        return;
    } else {
        active();
        INFER_ASSERT(m_active, "thread pool is not actived.");
        m_nr_task = nr_task;
        //! Set the task number, task iter and task
        m_task_per_thread = (nr_task + m_nr_threads - 1) / m_nr_threads;
        m_task = std::move(task);
        for (uint32_t i = 0; i < m_nr_threads - 1; i++) {
            m_workers[i]->work_flag.store(true, std::memory_order_release);
        }
        //! Main thread working
        uint32_t start = (m_nr_threads - 1) * m_task_per_thread;
        // printf("main threads start\n");
        m_task({start, nr_task, m_nr_threads - 1});
        //! make sure all threads done
        sync();
    }
}

add_task 方法向线程池添加任务:

  1. 如果只有一个线程或一个任务,直接在当前线程执行
  2. 否则:
    • 激活线程池
    • 设置任务数量和每个线程的任务数
    • 设置任务函数
    • 将所有工作线程的 work_flag 设为 true,通知它们开始工作
    • 主线程也参与计算,处理剩余的任务
    • 调用 sync() 等待所有线程完成任务

3.3 sync 方法

inline void ThreadPool::sync() {
    bool no_finished = false;
    uint32_t no_finished_id = 0;
    do {
        no_finished = false;
        for (uint32_t i = no_finished_id; i < m_nr_threads - 1; ++i) {
            if (m_workers[i]->work_flag.load(std::memory_order_acquire)) {
                no_finished = true;
                no_finished_id = i;
                break;
            }
        }
        if (no_finished) {
            for (int it = 0; it < MAIN_THREAD_ACTIVE_WAIT; it++) {
                if (!m_workers[no_finished_id]->work_flag.load(
                            std::memory_order_acquire)) {
                    break;
                }
                if ((it < ACTIVE_WAIT_PAUSE_LIMIT || (it & 1))) {
                    INFER_PAUSE(16);
                } else {
                    std::this_thread::yield();
                }
            }
        }
    } while (no_finished);
}

sync 方法等待所有工作线程完成任务:

  1. 检查是否有未完成的线程(work_flag 为 true)
  2. 如果有,使用自旋等待该线程完成
  3. 如果超过等待限制,则使用 std::this_thread::yield() 让出 CPU
  4. 重复上述过程,直到所有线程完成任务

3.4 active 和 deactive 方法

inline void ThreadPool::active() {
    if (!m_active) {
        std::unique_lock<std::mutex> lock(m_mutex);
        m_active = true;
        m_cv.notify_all();
    }
}

void ThreadPool::deactive() {
    std::unique_lock<std::mutex> lock(m_mutex);
    m_active = false;
}
  • active 方法激活线程池,唤醒所有等待的工作线程
  • deactive 方法停用线程池,使工作线程进入休眠状态,减少 CPU 占用

3.5 析构函数

ThreadPool::~ThreadPool() {
    {
        std::unique_lock<std::mutex> lock(m_mutex);
        m_stop = true;
        m_active = false;
        m_cv.notify_all();
    }
    for (auto& worker : m_workers) {
        delete worker;
    }
}

析构函数清理线程池资源:

  1. 设置 m_stop 为 true,通知所有工作线程退出
  2. 唤醒所有等待的工作线程
  3. 删除所有工作线程对象

4. 任务分配策略

ThreadPool 使用简单的任务分配策略:

  1. 将总任务数 nr_task 平均分配给所有线程
  2. 每个线程处理 m_task_per_thread = (nr_task + m_nr_threads - 1) / m_nr_threads 个任务
  3. 第 i 个线程处理的任务范围是 [i * m_task_per_thread, min((i + 1) * m_task_per_thread, m_nr_task)]
  4. 主线程(最后一个线程)处理剩余的任务

5. 同步机制

ThreadPool 使用多种同步机制:

  1. 原子变量

    • m_stop:指示线程池是否停止
    • m_active:指示线程池是否处于活动状态
    • work_flag:指示工作线程是否需要执行任务
  2. 条件变量

    • m_cv:用于在线程池激活时唤醒工作线程
  3. 自旋锁

    • 使用 INFER_PAUSEstd::this_thread::yield() 实现自旋等待,减少线程切换开销

6. 优化策略

ThreadPool 实现了几种优化策略:

  1. 主线程参与计算
    主线程也参与任务计算,减少线程创建和切换开销。

  2. 自旋等待
    使用自旋等待检查任务状态,减少线程切换开销。自旋等待分为两个阶段:

    • ACTIVE_WAIT_PAUSE_LIMIT 次迭代使用 CPU 级别的暂停操作
    • 之后交替使用 CPU 级别的暂停和 OS 级别的让出
  3. 线程休眠
    当线程池不活动时,工作线程进入休眠状态,减少 CPU 占用。

  4. 任务合并
    当只有一个线程或一个任务时,直接在当前线程执行,避免线程调度开销。

  5. 动态休眠
    线程池可以通过 deactive() 方法进入休眠状态,在不需要并行计算时减少 CPU 占用。

  6. 快速检测
    sync() 方法中,记录上一个未完成的线程 ID,优先检查该线程,减少不必要的检查。

7. 使用示例

ThreadPool 的典型使用方式如下:

// 创建线程池
ThreadPool pool(4);  // 4个线程

// 定义任务函数
auto task = [&](const TaskId& task_id) {
    // task_id.start: 任务起始索引
    // task_id.end: 任务结束索引
    // task_id.thread_id: 线程ID
    for (uint32_t i = task_id.start; i < task_id.end; ++i) {
        // 处理第i个任务
    }
};

// 添加任务
pool.add_task(task, 100);  // 100个任务

// 停用线程池(不需要并行计算时)
pool.deactive();

// 激活线程池(需要并行计算时)
pool.active();

8. TaskId 结构体

struct TaskId {
    uint32_t start;     // 任务起始索引
    uint32_t end;       // 任务结束索引
    uint32_t thread_id; // 线程ID
};

TaskId 结构体用于描述任务的范围和执行线程,它包含三个字段:

  • start:任务的起始索引
  • end:任务的结束索引
  • thread_id:执行任务的线程ID

9. MultiThreadingTask 类型

using MultiThreadingTask = std::function<void(const TaskId&)>;

MultiThreadingTask 是一个函数类型,表示可以并行执行的任务。它接受一个 TaskId 参数,用于指定任务的范围和执行线程。

10. 线程池的生命周期管理

ThreadPool 的生命周期管理非常重要,它涉及到线程的创建、激活、停用和销毁:

  1. 创建
    在构造函数中创建工作线程,但线程初始处于非活动状态。

  2. 激活
    调用 active() 方法激活线程池,唤醒所有工作线程。

  3. 停用
    调用 deactive() 方法停用线程池,使工作线程进入休眠状态。

  4. 销毁
    在析构函数中设置 m_stop 为 true,通知所有工作线程退出,然后等待线程结束并释放资源。

11. 与其他组件的交互

ThreadPool 主要与 CPUDevice 类交互,用于并行执行 CPU 上的计算任务:

// 在 CPUDevice 中创建线程池
m_thread_pool = std::make_unique<ThreadPool>(nr_thread);

// 使用线程池执行并行任务
m_thread_pool->add_task([&](const TaskId& task_id) {
    // 执行计算任务
}, total_tasks);

总结

ThreadPool 类是 InferLLM 框架中的线程池实现,用于并行执行计算任务,提高模型推理的性能。它采用了多种优化策略,如主线程参与计算、自旋等待、线程休眠和任务合并等,以减少线程调度开销并提高并行效率。它还提供了灵活的接口,支持动态激活和停用线程池,以适应不同的计算需求。

通过分析 thread_pool.hthread_pool.cpp,我们可以看到 InferLLM 框架在并行计算方面的设计思想和实现细节,这对于理解框架的性能优化策略非常有帮助。