欢迎阅读「 [InferLLM大模型推理框架项目](25)kern中ARM优化模块整体分析(src/kern/optimized/arm) 」
音乐播放器
sola的小屋
 故事
文章 标签
53 20

Powered by Gridea | Theme: Fog
本站已安全运行 1796 天
17 小时 14 分 55 秒
总访问量:2194  |   访问人数:1736

[InferLLM大模型推理框架项目](25)kern中ARM优化模块整体分析(src/kern/optimized/arm)

InferLLM 框架中 ARM 优化模块分析

InferLLM 框架中的 ARM 优化模块是针对 ARM 架构处理器的优化实现,主要利用 ARM NEON 指令集进行向量化计算,提高大语言模型推理的性能。

1. 目录结构

ARM 优化模块包含以下文件:

optimized/arm/
├── kernel.cpp   # 实现各种计算函数
├── kernel.h     # 声明计算函数和注册内核
├── optimized.h  # 实现基础向量运算
└── quantize.h   # 实现量化和反量化操作

2. 核心功能模块

2.1 基础向量运算 (optimized.h)

optimized.h 文件实现了一系列基础向量运算函数:

2.1.1 元素级操作

// 向量加法
inline void elemwise_vector_add(
        const int n, const float* __restrict x, const float* __restrict y,
        float* __restrict z) {
    for (int i = 0; i < n; i++) {
        z[i] = x[i] + y[i];
    }
}

// 向量乘法
inline void elemwise_vector_mul(
        const int n, const float* __restrict x, const float* __restrict y,
        float* __restrict z) {
    for (int i = 0; i < n; i++) {
        z[i] = x[i] * y[i];
    }
}

2.1.2 激活函数

// SiLU 激活函数 (x * sigmoid(x))
inline void elemwise_vector_silu(
        const int n, const float* __restrict x, float* __restrict z) {
    for (int i = 0; i < n; i++) {
        z[i] = x[i] / (1 + exp(-x[i]));
    }
}

// GELU 激活函数
inline void elemwise_vector_gelu(
        const int n, const float* __restrict x, float* __restrict z) {
    for (int i = 0; i < n; i++) {
        float src = x[i];
        z[i] = 0.5 * src * (1 + tanh(sqrt(2.0 / PI) * (src + PGELU * src * src * src)));
    }
}

2.1.3 归约操作

// 求最大值
inline float reduce_max(const int n, const float* __restrict x) {
    float max = -INFINITY;
    for (int i = 0; i < n; i++) {
        max = std::max(max, x[i]);
    }
    return max;
}

// 求平方和
inline float reduce_square_sum(const int n, const float* __restrict x) {
    float sum = 0.0f;
    for (int i = 0; i < n; i++) {
        sum += x[i] * x[i];
    }
    return sum;
}

2.1.4 矩阵运算

// 带偏移的矩阵乘法(用于注意力计算)
inline void compute_src_offset_embd_matmul(
        const float* __restrict srcq_head, int offsetq,
        const float* __restrict srck_head, int offsetk, float* dst_head, int seqlen,
        int length, int sub_embd) {
    for (uint32_t row = 0; row < seqlen; row++) {
        auto p_srcq = srcq_head + row * offsetq;
        uint32_t len = 0;
        for (; len + 3 < length; len += 4) {
            // 每次处理4列,提高计算效率
            auto p_dst = dst_head + row * length + len;
            auto p_srck0 = srck_head + len * offsetk;
            auto p_srck1 = srck_head + (len + 1) * offsetk;
            auto p_srck2 = srck_head + (len + 2) * offsetk;
            auto p_srck3 = srck_head + (len + 3) * offsetk;
            float sum0 = 0, sum1 = 0, sum2 = 0, sum3 = 0;
            
            for (uint32_t k = 0; k < sub_embd; k++) {
                sum0 += p_srck0[k] * p_srcq[k];
                sum1 += p_srck1[k] * p_srcq[k];
                sum2 += p_srck2[k] * p_srcq[k];
                sum3 += p_srck3[k] * p_srcq[k];
            }
            
            p_dst[0] = sum0;
            p_dst[1] = sum1;
            p_dst[2] = sum2;
            p_dst[3] = sum3;
        }
        
        // 处理剩余列
        for (; len < length; len++) {
            auto p_dst = dst_head + row * length + len;
            auto p_srck = srck_head + len * offsetk;
            float sum = 0;
            for (uint32_t k = 0; k < sub_embd; k++) {
                sum += p_srck[k] * p_srcq[k];
            }
            *p_dst = sum;
        }
    }
}

2.2 量化计算 (quantize.h)

quantize.h 文件实现了量化和反量化操作:

2.2.1 4位整数量化

inline void quantize_row_q4_0(const float* __restrict x, void* __restrict vy, int k) {
    const int nb = k / QK40;

    BlockQ40* __restrict y = static_cast<BlockQ40*>(vy);
    for (int i = 0; i < nb; i++) {
        float32x4_t srcv[8];
        float32x4_t asrcv[8];
        float32x4_t amaxv[8];

        // 加载数据
        for (int l = 0; l < 8; l++)
            srcv[l] = vld1q_f32(x + i * 32 + 4 * l);
        
        // 计算绝对值
        for (int l = 0; l < 8; l++)
            asrcv[l] = vabsq_f32(srcv[l]);
        
        // 计算最大值
        for (int l = 0; l < 4; l++)
            amaxv[2 * l] = vmaxq_f32(asrcv[2 * l], asrcv[2 * l + 1]);
        for (int l = 0; l < 2; l++)
            amaxv[4 * l] = vmaxq_f32(amaxv[4 * l], amaxv[4 * l + 2]);
        for (int l = 0; l < 1; l++)
            amaxv[8 * l] = vmaxq_f32(amaxv[8 * l], amaxv[8 * l + 4]);

        const float amax = vmaxvq_f32(amaxv[0]);

        // 计算量化参数
        const float d = amax / ((1 << 3) - 1);
        const float id = d ? 1.0f / d : 0.0f;

        y[i].d = d;

        // 量化数据
        for (int l = 0; l < 8; l++) {
            const float32x4_t v = vmulq_n_f32(srcv[l], id);
            const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
            const int32x4_t vi = vcvtq_s32_f32(vf);

            // 将4个int32压缩为2个uint8(每个uint8存储2个4位整数)
            y[i].qs[2 * l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
            y[i].qs[2 * l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
        }
    }
}

2.2.2 4位整数反量化

inline void dequantize_row_q4_0(const void* __restrict vx, float* __restrict y, int k) {
    assert(k % QK40 == 0);
    const int nb = k / QK40;

    const BlockQ40* __restrict x = static_cast<const BlockQ40*>(vx);
    for (int i = 0; i < nb; i++) {
        const float32x4_t vd = vdupq_n_f32(x[i].d);
        const uint8_t* __restrict pp = x[i].qs;

        for (int l = 0; l < QK40; l += 16) {
            // 加载8个uint8(每个uint8存储2个4位整数)
            const uint8x8_t v8 = vld1_u8(pp + l / 2);

            // 提取4位整数
            const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));  // 低4位
            const uint8x8_t v1 = vshr_n_u8(v8, 4);              // 高4位

            // 转换为有符号整数
            const int8x8_t vs_0 = vreinterpret_s8_u8(v0);
            const int8x8_t vs_1 = vreinterpret_s8_u8(v1);

            // 减去偏移值8
            const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8));
            const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8));

            // 交错排列
            const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
            const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);

            const int8x16_t vq = vcombine_s8(vx_0, vx_1);

            // 转换为int16
            const int16x8_t vi_0 = vmovl_s8(vget_low_s8(vq));
            const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));

            // 转换为float32
            const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vi_0)));
            const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
            const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vi_1)));
            const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));

            // 乘以量化参数
            const float32x4_t r0 = vmulq_f32(vf_0, vd);
            const float32x4_t r1 = vmulq_f32(vf_1, vd);
            const float32x4_t r2 = vmulq_f32(vf_2, vd);
            const float32x4_t r3 = vmulq_f32(vf_3, vd);

            // 存储结果
            vst1q_f32(y + i * QK40 + l + 0, r0);
            vst1q_f32(y + i * QK40 + l + 4, r1);
            vst1q_f32(y + i * QK40 + l + 8, r2);
            vst1q_f32(y + i * QK40 + l + 12, r3);
        }
    }
}

2.2.3 量化点积计算

量化点积计算是 ARM 优化模块中的关键优化技术之一,主要用于加速矩阵乘法计算。在 quantize.h 文件中,vec_vec_dot_q40_with_q80 函数实现了 4 位整数量化与 8 位整数量化的点积计算:

inline float vec_vec_dot_q40_with_q80(
        const int n, const void* __restrict vx, const void* __restrict vy) {
    const int nb = n / QK80;

    assert(n % QK80 == 0);
    assert(nb % 2 == 0);

    const BlockQ40* __restrict x = (BlockQ40*)vx;
    const BlockQ80* __restrict y = (BlockQ80*)vy;

    float32x4_t sumv0 = vdupq_n_f32(0.0f);
    float32x4_t sumv1 = vdupq_n_f32(0.0f);

    for (int i = 0; i < nb; i += 2) {
        // 加载数据
        const BlockQ40* __restrict x0 = &x[i + 0];
        const BlockQ40* __restrict x1 = &x[i + 1];
        const BlockQ80* __restrict y0 = &y[i + 0];
        const BlockQ80* __restrict y1 = &y[i + 1];

        // 处理4位整数量化数据
        const uint8x16_t m4b = vdupq_n_u8(0x0F);
        const int8x16_t s8b = vdupq_n_s8(0x8);

        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
        const uint8x16_t v0_1 = vld1q_u8(x1->qs);

        // 提取4位整数
        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));

        // 减去偏移值8
        const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
        const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
        const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
        const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);

        // 加载8位整数量化数据
        const int8x16_t v1_0 = vld1q_s8(y0->qs);
        const int8x16_t v1_1 = vld1q_s8(y1->qs);

        // 计算点积
        // 使用 SDOT 指令(如果支持)或模拟实现
#if defined(__ARM_FEATURE_DOTPROD)
        // 使用 SDOT 指令计算点积
        int32x4_t p0_0 = vdupq_n_s32(0);
        int32x4_t p0_1 = vdupq_n_s32(0);
        int32x4_t p0_2 = vdupq_n_s32(0);
        int32x4_t p0_3 = vdupq_n_s32(0);

        p0_0 = vdotq_s32(p0_0, v0_0ls, v1_0);
        p0_1 = vdotq_s32(p0_1, v0_0hs, v1_0);
        p0_2 = vdotq_s32(p0_2, v0_1ls, v1_1);
        p0_3 = vdotq_s32(p0_3, v0_1hs, v1_1);
#else
        // 模拟 SDOT 指令
        // 将 int8x16_t 转换为 int16x8_t
        int16x8_t pl0_0 = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0));
        int16x8_t ph0_0 = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0));
        int16x8_t pl0_1 = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0));
        int16x8_t ph0_1 = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0));
        int16x8_t pl0_2 = vmull_s8(vget_low_s8(v0_1ls), vget_low_s8(v1_1));
        int16x8_t ph0_2 = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1));
        int16x8_t pl0_3 = vmull_s8(vget_low_s8(v0_1hs), vget_low_s8(v1_1));
        int16x8_t ph0_3 = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1));

        // 水平求和
        int32x4_t p0_0 = vpaddlq_s16(pl0_0);
        p0_0 = vpadalq_s16(p0_0, ph0_0);
        int32x4_t p0_1 = vpaddlq_s16(pl0_1);
        p0_1 = vpadalq_s16(p0_1, ph0_1);
        int32x4_t p0_2 = vpaddlq_s16(pl0_2);
        p0_2 = vpadalq_s16(p0_2, ph0_2);
        int32x4_t p0_3 = vpaddlq_s16(pl0_3);
        p0_3 = vpadalq_s16(p0_3, ph0_3);
#endif

        // 转换为浮点数并乘以量化参数
        const float d0_0 = x0->d * y0->d;
        const float d0_1 = x0->d * y0->d;
        const float d0_2 = x1->d * y1->d;
        const float d0_3 = x1->d * y1->d;

        const float32x4_t d0_0v = vdupq_n_f32(d0_0);
        const float32x4_t d0_1v = vdupq_n_f32(d0_1);
        const float32x4_t d0_2v = vdupq_n_f32(d0_2);
        const float32x4_t d0_3v = vdupq_n_f32(d0_3);

        const float32x4_t p0_0f = vcvtq_f32_s32(p0_0);
        const float32x4_t p0_1f = vcvtq_f32_s32(p0_1);
        const float32x4_t p0_2f = vcvtq_f32_s32(p0_2);
        const float32x4_t p0_3f = vcvtq_f32_s32(p0_3);

        // 累加结果
        sumv0 = vmlaq_f32(sumv0, p0_0f, d0_0v);
        sumv0 = vmlaq_f32(sumv0, p0_1f, d0_1v);
        sumv1 = vmlaq_f32(sumv1, p0_2f, d0_2v);
        sumv1 = vmlaq_f32(sumv1, p0_3f, d0_3v);
    }

    // 水平求和
    sumv0 = vaddq_f32(sumv0, sumv1);
    float32x2_t sum = vadd_f32(vget_low_f32(sumv0), vget_high_f32(sumv0));
    sum = vpadd_f32(sum, sum);

    return vget_lane_f32(sum, 0);
}

这个函数的主要优化点包括:

  1. NEON 指令集优化:使用 NEON 指令集进行向量化计算,每次处理多个数据元素
  2. 条件编译:根据 ARM 处理器是否支持 SDOT 指令(ARMv8.2-A 及以上),使用不同的实现
  3. 分块处理:每次处理两个块,减少循环开销
  4. 并行计算:使用多个向量寄存器并行计算,提高指令级并行性
  5. 内存访问优化:使用连续的内存访问模式,提高缓存命中率

这个函数是矩阵乘法中的核心计算部分,通过优化点积计算,可以显著提高矩阵乘法的性能。在 llm_matmul_compute_int4_float 函数中,使用这个函数计算矩阵乘法:

dst[m * N + n] = vec_vec_dot_q40_with_q80(K, q_weight0, src) + b0;
dst[m * N + n + 1] = vec_vec_dot_q40_with_q80(K, q_weight1, src) + b1;
dst[m * N + n + 2] = vec_vec_dot_q40_with_q80(K, q_weight2, src) + b2;
dst[m * N + n + 3] = vec_vec_dot_q40_with_q80(K, q_weight3, src) + b3;

通过使用量化点积计算,可以减少内存占用和内存带宽需求,提高计算效率,特别是对于大型矩阵乘法,这种优化可以显著提高性能。

与其他平台的量化点积计算比较

与 x86 平台相比,ARM 平台的量化点积计算有以下特点:

  1. 指令集差异:ARM 使用 NEON 和 SDOT 指令,x86 使用 SSE/AVX/AVX2 和 VNNI 指令
  2. 向量宽度:ARM NEON 的向量宽度为 128 位,x86 AVX/AVX2 的向量宽度为 256 位
  3. 专用指令:ARMv8.2-A 引入了 SDOT 指令,专门用于加速点积计算;x86 引入了 VNNI 指令,也是专门用于加速点积计算

总体而言,ARM 平台的量化点积计算实现充分利用了 NEON 指令集的特性,在支持 SDOT 指令的处理器上可以获得更好的性能。

3. 计算函数实现 (kernel.cpp)

kernel.cpp 文件实现了各种计算函数,这些函数使用 TaskSet 实现多线程并行计算。

3.1 嵌入层计算

TaskSet llm_embedding_get_int4_float(
        const void* weights, const uint32_t* index, float* dst, uint32_t len_seq,
        uint32_t embd) {
    auto task = [=](const TaskId& id) {
        for (uint32_t i = id.start; i < id.end; ++i) {
            const int row = index[i];
            const int weight_stride =
                    embd * dtype_in_byte(DType::Int4) / dtype_block_size(DType::Int4);
            dequantize_row_q4_0(
                    (static_cast<const char*>(weights) + row * weight_stride),
                    dst + i * embd, embd);
        }
    };
    return TaskSet{{task, len_seq}};
}

这个函数实现了嵌入层的计算,将4位整数量化的权重反量化为浮点数。每个任务处理一个或多个序列位置,通过 dequantize_row_q4_0 函数将量化权重反量化为浮点数。

3.2 元素级计算

TaskSet llm_elemwise_compute_float(
        InData<float> srcs, float* dst, size_t length, ElemMode mode) {
    MultiThreadingTask task;
    switch (mode) {
        case ElemMode::Add: {
            task = [=](const TaskId& id) {
                uint32_t offset = id.start;
                uint32_t len = id.end - id.start;
                elemwise_vector_add(
                        len, srcs[0] + offset, srcs[1] + offset, dst + offset);
            };
            break;
        }
        case ElemMode::Mul: {
            task = [=](const TaskId& id) {
                uint32_t offset = id.start;
                uint32_t len = id.end - id.start;
                elemwise_vector_mul(
                        len, srcs[0] + offset, srcs[1] + offset, dst + offset);
            };
            break;
        }
        case ElemMode::Silu: {
            task = [=](const TaskId& id) {
                uint32_t offset = id.start;
                uint32_t len = id.end - id.start;
                return elemwise_vector_silu(len, srcs[0] + offset, dst + offset);
            };
            break;
        }
        case ElemMode::Gelu: {
            task = [=](const TaskId& id) {
                uint32_t offset = id.start;
                uint32_t len = id.end - id.start;
                return elemwise_vector_gelu(len, srcs[0] + offset, dst + offset);
            };
            break;
        }
        default:
            INFER_ASSERT(0, "Not supported.");
    }
    return TaskSet{{task, length}};
}

这个函数实现了元素级计算,支持加法、乘法、SiLU 激活函数和 GELU 激活函数。每个任务处理一段连续的数据,通过 elemwise_vector_addelemwise_vector_mulelemwise_vector_siluelemwise_vector_gelu 函数实现具体的计算。

3.3 广播计算

TaskSet llm_elemwise_broadcast_dim0_src1_compute_float(
        const float* src0, const float* src1, float* dst, uint32_t len0, uint32_t len1,
        ElemMode mode) {
    MultiThreadingTask task;
    switch (mode) {
        case ElemMode::Add: {
            task = [=](const TaskId& id) {
                for (size_t i = id.start; i < id.end; i++) {
                    const float* p_src = src0 + i * len1;
                    float* p_dst = dst + i * len1;
                    elemwise_vector_add(len1, p_src, src1, p_dst);
                }
            };
            break;
        }
        case ElemMode::Mul: {
            task = [=](const TaskId& id) {
                for (size_t i = id.start; i < id.end; i++) {
                    auto p_src = src0 + i * len1;
                    auto p_dst = dst + i * len1;
                    elemwise_vector_mul(len1, p_src, src1, p_dst);
                }
            };
            break;
        }
        default:
            INFER_ASSERT(0, "Not supported.");
    }
    return TaskSet{{task, len0}};
}

这个函数实现了广播计算,将 src1 广播到 src0 的第0维,然后进行元素级计算。每个任务处理 src0 的一个或多个行,通过 elemwise_vector_addelemwise_vector_mul 函数实现具体的计算。

3.4 RMS 归一化

TaskSet llm_rms_norm_compute_float(
        const float* src, float* dst, uint32_t seq_len, uint32_t embd, float eps) {
    auto task = [=](const TaskId& id) {
        for (uint32_t i = id.start; i < id.end; i++) {
            const float* row = src + i * embd;
            float* out = dst + i * embd;
            float mean = reduce_square_sum(embd, row) / embd;
            const float scale = 1.0 / sqrt(mean + eps);
            elemwise_vec_scale(embd, row, scale, out);
        }
    };
    return TaskSet{{task, seq_len}};
}

这个函数实现了 RMS 归一化,每个任务处理一个或多个序列位置。首先计算平方和的均值,然后计算缩放因子,最后将输入向量乘以缩放因子得到归一化结果。

3.5 Softmax 计算

TaskSet llm_softmax_compute_float(
        const float* src, float* dst, uint32_t len_row, uint32_t col) {
    auto task = [=](const TaskId& id) {
        for (uint32_t row = id.start; row < id.end; row++) {
            const float* psrc = src + row * col;
            float* pdst = dst + row * col;

            float max = reduce_max(col, psrc);
            float sum = select_sub_max_and_reduce_sum(col, psrc, pdst, max);
            sum = 1.0 / sum;
            elemwise_vec_scale(col, pdst, sum, pdst);
        }
    };
    return TaskSet{{task, len_row}};
}

这个函数实现了 Softmax 计算,每个任务处理一个或多个行。首先找到最大值,然后减去最大值并计算指数和,最后将每个元素除以指数和得到 Softmax 结果。

3.6 量化矩阵乘法

TaskSet llm_matmul_compute_int4_float(
        float* dst, const void* src0, const float* bias, const float* src1, uint32_t M,
        uint32_t N, uint32_t K, void* workspace, uint32_t size) {
    // ... 参数检查和计算步长 ...
    
    // 第一阶段:量化输入
    auto task1 = [=](const TaskId& id) {
        for (uint32_t m = id.start; m < id.end; m++) {
            BlockQ80* q_src1 = (BlockQ80*)(static_cast<uint8_t*>(workspace) +
                                           m * weight_q80_stride);
            quantize_row_q8_0(src1 + m * K, q_src1, K);
        }
    };
    
    // 第二阶段:计算矩阵乘法
    int8_t* q_src = static_cast<int8_t*>(workspace);
    auto task2 = [=](const TaskId& id) {
        // 每次处理4列,提高计算效率
        uint32_t N_len = id.end - id.start;
        uint32_t n_block_4 = N_len / 4;
        uint32_t n_block_4_left = N_len - n_block_4 * 4;
        
        // 处理4列一组的部分
        for (uint32_t block4 = 0; block4 < n_block_4; block4++) {
            uint32_t n = block4 * 4 + id.start;
            // ... 加载偏置和权重 ...
            
            // 计算矩阵乘法
            for (uint32_t m = 0; m < M; m++) {
                int8_t* src = q_src + m * weight_q80_stride;
                dst[m * N + n] = vec_vec_dot_q40_with_q80(K, q_weight0, src) + b0;
                dst[m * N + n + 1] = vec_vec_dot_q40_with_q80(K, q_weight1, src) + b1;
                dst[m * N + n + 2] = vec_vec_dot_q40_with_q80(K, q_weight2, src) + b2;
                dst[m * N + n + 3] = vec_vec_dot_q40_with_q80(K, q_weight3, src) + b3;
            }
        }
        
        // 处理剩余列
        for (uint32_t left = 0; left < n_block_4_left; left++) {
            uint32_t n = n_block_4 * 4 + left + id.start;
            // ... 加载偏置和权重 ...
            
            // 计算矩阵乘法
            for (uint32_t m = 0; m < M; m++) {
                int8_t* src = q_src + m * weight_q80_stride;
                dst[m * N + n] = vec_vec_dot_q40_with_q80(K, q_weight, src) + b0;
            }
        }
    };
    
    return TaskSet{{task1, M}, {task2, N}};
}

这个函数实现了4位整数权重与浮点数激活值的矩阵乘法。它分为两个阶段:第一阶段将输入量化为8位整数,第二阶段计算矩阵乘法。每个阶段都使用多线程并行计算,第一阶段按行分解,第二阶段按列分解。

3.7 多头注意力计算

TaskSet llm_matmul_compute_with_head_stride_float(
        float* dst, const float* srck, const float* srcq, uint32_t seqlen,
        uint32_t embd, uint32_t head, uint32_t nr_past) {
    uint32_t sub_embd = embd / head;
    uint32_t length = nr_past + seqlen;
    uint32_t line_stride = embd;

    auto task = [=](const TaskId& id) {
        for (uint32_t h = id.start; h < id.end; h++) {
            auto dst_head = dst + h * seqlen * (nr_past + seqlen);
            auto srck_head = srck + h * sub_embd;
            auto srcq_head = srcq + h * sub_embd;
            compute_src_offset_embd_matmul(
                    srcq_head, embd, srck_head, embd, dst_head, seqlen, length,
                    sub_embd);
        }
    };
    return TaskSet{{task, head}};
}

TaskSet llm_head_batched_matmul_compute_float(
        float* dst, const float* v, const float* qk, uint32_t seqlen, uint32_t embd,
        uint32_t head, uint32_t nr_past) {
    uint32_t sub_embd = embd / head;
    uint32_t length = nr_past + seqlen;
    uint32_t line_stride = embd;

    auto task = [=](const TaskId& id) {
        for (uint32_t h = id.start; h < id.end; h++) {
            float* dst_head = dst + h * sub_embd;
            const float* v_head = v + h * sub_embd;
            const float* qk_head = qk + h * seqlen * length;
            comput_matmul_with_dst_uncontinue(
                    dst_head, embd, v_head, embd, qk_head, seqlen, length, sub_embd);
        }
    };
    return TaskSet{{task, head}};
}

这两个函数实现了多头注意力计算中的矩阵乘法。llm_matmul_compute_with_head_stride_float 计算 Q 和 K 的矩阵乘法,llm_head_batched_matmul_compute_float 计算 QK 和 V 的矩阵乘法。每个任务处理一个注意力头,通过 compute_src_offset_embd_matmulcomput_matmul_with_dst_uncontinue 函数实现具体的计算。

4. 优化策略分析

4.1 NEON 指令集优化

ARM 优化模块使用 NEON 指令集进行向量化计算,提高计算效率:

// 使用 NEON 指令集优化的向量加法
inline void elemwise_vector_add(
        const int n, const float* __restrict x, const float* __restrict y,
        float* __restrict z) {
    int i = 0;
#if defined(__ARM_NEON)
    for (; i + 15 < n; i += 16) {
        float32x4_t vx0 = vld1q_f32(x + i);
        float32x4_t vy0 = vld1q_f32(y + i);
        float32x4_t vx1 = vld1q_f32(x + i + 4);
        float32x4_t vy1 = vld1q_f32(y + i + 4);
        float32x4_t vx2 = vld1q_f32(x + i + 8);
        float32x4_t vy2 = vld1q_f32(y + i + 8);
        float32x4_t vx3 = vld1q_f32(x + i + 12);
        float32x4_t vy3 = vld1q_f32(y + i + 12);
        float32x4_t vz0 = vaddq_f32(vx0, vy0);
        float32x4_t vz1 = vaddq_f32(vx1, vy1);
        float32x4_t vz2 = vaddq_f32(vx2, vy2);
        float32x4_t vz3 = vaddq_f32(vx3, vy3);
        vst1q_f32(z + i, vz0);
        vst1q_f32(z + i + 4, vz1);
        vst1q_f32(z + i + 8, vz2);
        vst1q_f32(z + i + 12, vz3);
    }
#endif
    // 标量回退实现
    for (; i < n; i++) {
        z[i] = x[i] + y[i];
    }
}

NEON 指令集允许同时处理多个数据元素,大大提高了计算效率。上面的代码每次处理 16 个浮点数,使用 4 个 float32x4_t 寄存器,每个寄存器可以同时处理 4 个浮点数。

4.2 分块处理策略

ARM 优化模块使用分块处理策略,将数据分成多个块进行处理:

// 矩阵乘法中的分块处理
auto task2 = [=](const TaskId& id) {
    uint32_t N_len = id.end - id.start;
    uint32_t n_block_4 = N_len / 4;
    uint32_t n_block_4_left = N_len - n_block_4 * 4;
    
    // 处理4列一组的部分
    for (uint32_t block4 = 0; block4 < n_block_4; block4++) {
        // ... 处理4列 ...
    }
    
    // 处理剩余列
    for (uint32_t left = 0; left < n_block_4_left; left++) {
        // ... 处理1列 ...
    }
};

这种策略可以最大化利用 NEON 指令集的并行性,同时处理所有数据。

4.3 多线程并行策略

ARM 优化模块使用 TaskSet 实现多线程并行,每个计算函数都返回一个 TaskSet,包含一个或多个任务及其子任务数量:

TaskSet llm_rms_norm_compute_float(
        const float* src, float* dst, uint32_t seq_len, uint32_t embd, float eps) {
    auto task = [=](const TaskId& id) {
        for (uint32_t i = id.start; i < id.end; i++) {
            // ... 计算逻辑 ...
        }
    };
    return TaskSet{{task, seq_len}};
}

不同的计算函数使用不同的任务分解策略:

  1. 按序列长度分解:如 llm_rms_norm_compute_float,每个任务处理一个或多个序列位置
  2. 按头数分解:如 llm_matmul_compute_with_head_stride_float,每个任务处理一个或多个注意力头
  3. 按矩阵行列分解:如 llm_matmul_compute_int4_float,使用两个任务集,一个按行分解,一个按列分解

4.4 量化计算优化

ARM 优化模块使用量化计算减少内存占用和计算量:

TaskSet llm_matmul_compute_int4_float(
        float* dst, const void* src0, const float* bias, const float* src1, uint32_t M,
        uint32_t N, uint32_t K, void* workspace, uint32_t size) {
    // 第一阶段:量化输入
    auto task1 = [=](const TaskId& id) {
        for (uint32_t m = id.start; m < id.end; m++) {
            BlockQ80* q_src1 = (BlockQ80*)(static_cast<uint8_t*>(workspace) +
                                           m * weight_q80_stride);
            quantize_row_q8_0(src1 + m * K, q_src1, K);
        }
    };
    
    // 第二阶段:使用量化数据计算
    auto task2 = [=](const TaskId& id) {
        // ... 使用 vec_vec_dot_q40_with_q80 计算点积 ...
    };
    
    return TaskSet{{task1, M}, {task2, N}};
}

这种设计减少了内存占用和内存带宽需求,提高了计算效率。特别是对于大型矩阵乘法,量化计算可以显著提高性能。

5. 内核注册机制

ARM 优化模块使用内核注册机制,将函数与内核 ID 关联起来:

// kernel.h
#define ImplementKernel(kernel_id, fun)                                \
    template <>                                                         \
    struct Comp<KernelID::kernel_id, KernelPlatform::Optimized> {      \
        template <typename... Args>                                     \
        static TaskSet exec(Args... args) {                             \
            return fun(std::forward<Args>(args)...);                    \
        }                                                               \
    };

#define ImplementSpace(kernel_id, fun)                                 \
    template <>                                                         \
    struct Space<KernelID::kernel_id, KernelPlatform::Optimized> {      \
        template <typename... Args>                                     \
        static size_t get(Args... args) {                               \
            return fun(std::forward<Args>(args)...);                    \
        }                                                               \
    };

// 注册内核
ImplementKernel(ElemwiseFloat, llm_elemwise_compute_float);
ImplementKernel(ElemwiseBroadcastDim0Src1Float, llm_elemwise_broadcast_dim0_src1_compute_float);
ImplementKernel(NormFloat, llm_norm_compute_float);
ImplementKernel(RmsNormFloat, llm_rms_norm_compute_float);
ImplementKernel(EmbeddingGetInt4Float, llm_embedding_get_int4_float);
ImplementKernel(SoftmaxFloat, llm_softmax_compute_float);
ImplementKernel(MatmulInt4Float, llm_matmul_compute_int4_float);
ImplementKernel(MatmulWithHeadStrideFloat, llm_matmul_compute_with_head_stride_float);
ImplementKernel(HeadBatchedMatmulFloat, llm_head_batched_matmul_compute_float);

// 注册工作空间计算函数
ImplementSpace(MatmulInt4Float, llm_matmul_get_workspace_float);

这些宏和注册语句将内核 ID 与具体的实现函数关联起来,使得框架可以在运行时根据内核 ID 选择合适的实现。

6. 与 naive 实现的比较

ARM 优化模块与 naive 模块的主要区别:

  1. 向量化实现:ARM 优化模块使用 NEON 指令集进行向量化计算,naive 模块使用标量实现
  2. 分块处理:ARM 优化模块使用分块处理策略,naive 模块使用简单的循环
  3. 量化计算:ARM 优化模块使用量化计算减少内存占用和计算量,naive 模块使用浮点计算
  4. 多线程并行:ARM 优化模块使用 TaskSet 实现多线程并行,naive 模块也使用 TaskSet,但任务分解策略不同

在实际应用中,ARM 优化模块的性能明显优于 naive 模块,特别是在支持 NEON 指令集的 ARM 处理器上。

7. 与 x86 和 RVV 优化的比较

ARM 优化模块与 x86 和 RVV 优化模块的主要区别:

  1. 指令集:ARM 优化模块使用 NEON 指令集,x86 优化模块使用 SSE/AVX/AVX2 指令集,RVV 优化模块使用 RISC-V 向量扩展指令集
  2. 向量宽度:NEON 指令集的向量宽度为 128 位,可以同时处理 4 个单精度浮点数;AVX/AVX2 指令集的向量宽度为 256 位,可以同时处理 8 个单精度浮点数;RVV 指令集的向量宽度可变,取决于硬件实现
  3. 优化程度:x86 优化模块的优化程度最高,实现了更多的优化技术;ARM 优化模块次之;RVV 优化模块的优化程度最低,主要依赖 RVV 指令集的基本向量操作

8. 未来优化方向

基于当前实现,可以考虑以下优化方向:

8.1 更多 ARM 指令集支持

  • 支持 ARMv8.2-A 的 FP16 指令,使用半精度浮点数进行计算
  • 支持 ARMv8.2-A 的 DotProd 指令,加速点积计算
  • 支持 ARMv8.6-A 的 BFloat16 指令,使用 BFloat16 进行计算

8.2 更高效的算法

  • 使用 Winograd 算法优化矩阵乘法
  • 使用 Flash Attention 算法优化注意力计算
  • 使用混合精度计算提高性能

8.3 更多量化方法

  • 支持 3 位、2 位甚至 1 位量化
  • 支持非对称量化
  • 支持组量化

8.4 更高级的并行策略

  • 使用流水线并行减少内存占用
  • 使用张量并行和模型并行处理大型模型
  • 使用异步计算提高计算效率

总结

InferLLM 框架中的 ARM 优化模块提供了针对 ARM 架构处理器的优化实现,主要利用 NEON 指令集进行向量化计算,提高大语言模型推理的性能。通过使用 NEON 指令集、分块处理策略、多线程并行和量化计算等技术,ARM 优化模块实现了高效的计算。

与 naive 模块相比,ARM 优化模块的性能明显更高,特别是在支持 NEON 指令集的 ARM 处理器上。与 x86 和 RVV 优化模块相比,ARM 优化模块使用了不同的指令集和优化技术,适用于不同的硬件平台。

未来可以考虑支持更多 ARM 指令集、使用更高效的算法、支持更多量化方法和使用更高级的并行策略等方向进行优化,进一步提高大语言模型在 ARM 平台上的推理性能。