


[InferLLM大模型推理框架项目](25)kern中ARM优化模块整体分析(src/kern/optimized/arm)
InferLLM 框架中 ARM 优化模块分析
InferLLM 框架中的 ARM 优化模块是针对 ARM 架构处理器的优化实现,主要利用 ARM NEON 指令集进行向量化计算,提高大语言模型推理的性能。
1. 目录结构
ARM 优化模块包含以下文件:
optimized/arm/
├── kernel.cpp # 实现各种计算函数
├── kernel.h # 声明计算函数和注册内核
├── optimized.h # 实现基础向量运算
└── quantize.h # 实现量化和反量化操作
2. 核心功能模块
2.1 基础向量运算 (optimized.h)
optimized.h
文件实现了一系列基础向量运算函数:
2.1.1 元素级操作
// 向量加法
inline void elemwise_vector_add(
const int n, const float* __restrict x, const float* __restrict y,
float* __restrict z) {
for (int i = 0; i < n; i++) {
z[i] = x[i] + y[i];
}
}
// 向量乘法
inline void elemwise_vector_mul(
const int n, const float* __restrict x, const float* __restrict y,
float* __restrict z) {
for (int i = 0; i < n; i++) {
z[i] = x[i] * y[i];
}
}
2.1.2 激活函数
// SiLU 激活函数 (x * sigmoid(x))
inline void elemwise_vector_silu(
const int n, const float* __restrict x, float* __restrict z) {
for (int i = 0; i < n; i++) {
z[i] = x[i] / (1 + exp(-x[i]));
}
}
// GELU 激活函数
inline void elemwise_vector_gelu(
const int n, const float* __restrict x, float* __restrict z) {
for (int i = 0; i < n; i++) {
float src = x[i];
z[i] = 0.5 * src * (1 + tanh(sqrt(2.0 / PI) * (src + PGELU * src * src * src)));
}
}
2.1.3 归约操作
// 求最大值
inline float reduce_max(const int n, const float* __restrict x) {
float max = -INFINITY;
for (int i = 0; i < n; i++) {
max = std::max(max, x[i]);
}
return max;
}
// 求平方和
inline float reduce_square_sum(const int n, const float* __restrict x) {
float sum = 0.0f;
for (int i = 0; i < n; i++) {
sum += x[i] * x[i];
}
return sum;
}
2.1.4 矩阵运算
// 带偏移的矩阵乘法(用于注意力计算)
inline void compute_src_offset_embd_matmul(
const float* __restrict srcq_head, int offsetq,
const float* __restrict srck_head, int offsetk, float* dst_head, int seqlen,
int length, int sub_embd) {
for (uint32_t row = 0; row < seqlen; row++) {
auto p_srcq = srcq_head + row * offsetq;
uint32_t len = 0;
for (; len + 3 < length; len += 4) {
// 每次处理4列,提高计算效率
auto p_dst = dst_head + row * length + len;
auto p_srck0 = srck_head + len * offsetk;
auto p_srck1 = srck_head + (len + 1) * offsetk;
auto p_srck2 = srck_head + (len + 2) * offsetk;
auto p_srck3 = srck_head + (len + 3) * offsetk;
float sum0 = 0, sum1 = 0, sum2 = 0, sum3 = 0;
for (uint32_t k = 0; k < sub_embd; k++) {
sum0 += p_srck0[k] * p_srcq[k];
sum1 += p_srck1[k] * p_srcq[k];
sum2 += p_srck2[k] * p_srcq[k];
sum3 += p_srck3[k] * p_srcq[k];
}
p_dst[0] = sum0;
p_dst[1] = sum1;
p_dst[2] = sum2;
p_dst[3] = sum3;
}
// 处理剩余列
for (; len < length; len++) {
auto p_dst = dst_head + row * length + len;
auto p_srck = srck_head + len * offsetk;
float sum = 0;
for (uint32_t k = 0; k < sub_embd; k++) {
sum += p_srck[k] * p_srcq[k];
}
*p_dst = sum;
}
}
}
2.2 量化计算 (quantize.h)
quantize.h
文件实现了量化和反量化操作:
2.2.1 4位整数量化
inline void quantize_row_q4_0(const float* __restrict x, void* __restrict vy, int k) {
const int nb = k / QK40;
BlockQ40* __restrict y = static_cast<BlockQ40*>(vy);
for (int i = 0; i < nb; i++) {
float32x4_t srcv[8];
float32x4_t asrcv[8];
float32x4_t amaxv[8];
// 加载数据
for (int l = 0; l < 8; l++)
srcv[l] = vld1q_f32(x + i * 32 + 4 * l);
// 计算绝对值
for (int l = 0; l < 8; l++)
asrcv[l] = vabsq_f32(srcv[l]);
// 计算最大值
for (int l = 0; l < 4; l++)
amaxv[2 * l] = vmaxq_f32(asrcv[2 * l], asrcv[2 * l + 1]);
for (int l = 0; l < 2; l++)
amaxv[4 * l] = vmaxq_f32(amaxv[4 * l], amaxv[4 * l + 2]);
for (int l = 0; l < 1; l++)
amaxv[8 * l] = vmaxq_f32(amaxv[8 * l], amaxv[8 * l + 4]);
const float amax = vmaxvq_f32(amaxv[0]);
// 计算量化参数
const float d = amax / ((1 << 3) - 1);
const float id = d ? 1.0f / d : 0.0f;
y[i].d = d;
// 量化数据
for (int l = 0; l < 8; l++) {
const float32x4_t v = vmulq_n_f32(srcv[l], id);
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
const int32x4_t vi = vcvtq_s32_f32(vf);
// 将4个int32压缩为2个uint8(每个uint8存储2个4位整数)
y[i].qs[2 * l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
y[i].qs[2 * l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
}
}
}
2.2.2 4位整数反量化
inline void dequantize_row_q4_0(const void* __restrict vx, float* __restrict y, int k) {
assert(k % QK40 == 0);
const int nb = k / QK40;
const BlockQ40* __restrict x = static_cast<const BlockQ40*>(vx);
for (int i = 0; i < nb; i++) {
const float32x4_t vd = vdupq_n_f32(x[i].d);
const uint8_t* __restrict pp = x[i].qs;
for (int l = 0; l < QK40; l += 16) {
// 加载8个uint8(每个uint8存储2个4位整数)
const uint8x8_t v8 = vld1_u8(pp + l / 2);
// 提取4位整数
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); // 低4位
const uint8x8_t v1 = vshr_n_u8(v8, 4); // 高4位
// 转换为有符号整数
const int8x8_t vs_0 = vreinterpret_s8_u8(v0);
const int8x8_t vs_1 = vreinterpret_s8_u8(v1);
// 减去偏移值8
const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8));
const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8));
// 交错排列
const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);
const int8x16_t vq = vcombine_s8(vx_0, vx_1);
// 转换为int16
const int16x8_t vi_0 = vmovl_s8(vget_low_s8(vq));
const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));
// 转换为float32
const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vi_0)));
const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vi_1)));
const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));
// 乘以量化参数
const float32x4_t r0 = vmulq_f32(vf_0, vd);
const float32x4_t r1 = vmulq_f32(vf_1, vd);
const float32x4_t r2 = vmulq_f32(vf_2, vd);
const float32x4_t r3 = vmulq_f32(vf_3, vd);
// 存储结果
vst1q_f32(y + i * QK40 + l + 0, r0);
vst1q_f32(y + i * QK40 + l + 4, r1);
vst1q_f32(y + i * QK40 + l + 8, r2);
vst1q_f32(y + i * QK40 + l + 12, r3);
}
}
}
2.2.3 量化点积计算
量化点积计算是 ARM 优化模块中的关键优化技术之一,主要用于加速矩阵乘法计算。在 quantize.h
文件中,vec_vec_dot_q40_with_q80
函数实现了 4 位整数量化与 8 位整数量化的点积计算:
inline float vec_vec_dot_q40_with_q80(
const int n, const void* __restrict vx, const void* __restrict vy) {
const int nb = n / QK80;
assert(n % QK80 == 0);
assert(nb % 2 == 0);
const BlockQ40* __restrict x = (BlockQ40*)vx;
const BlockQ80* __restrict y = (BlockQ80*)vy;
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
for (int i = 0; i < nb; i += 2) {
// 加载数据
const BlockQ40* __restrict x0 = &x[i + 0];
const BlockQ40* __restrict x1 = &x[i + 1];
const BlockQ80* __restrict y0 = &y[i + 0];
const BlockQ80* __restrict y1 = &y[i + 1];
// 处理4位整数量化数据
const uint8x16_t m4b = vdupq_n_u8(0x0F);
const int8x16_t s8b = vdupq_n_s8(0x8);
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
// 提取4位整数
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
// 减去偏移值8
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
// 加载8位整数量化数据
const int8x16_t v1_0 = vld1q_s8(y0->qs);
const int8x16_t v1_1 = vld1q_s8(y1->qs);
// 计算点积
// 使用 SDOT 指令(如果支持)或模拟实现
#if defined(__ARM_FEATURE_DOTPROD)
// 使用 SDOT 指令计算点积
int32x4_t p0_0 = vdupq_n_s32(0);
int32x4_t p0_1 = vdupq_n_s32(0);
int32x4_t p0_2 = vdupq_n_s32(0);
int32x4_t p0_3 = vdupq_n_s32(0);
p0_0 = vdotq_s32(p0_0, v0_0ls, v1_0);
p0_1 = vdotq_s32(p0_1, v0_0hs, v1_0);
p0_2 = vdotq_s32(p0_2, v0_1ls, v1_1);
p0_3 = vdotq_s32(p0_3, v0_1hs, v1_1);
#else
// 模拟 SDOT 指令
// 将 int8x16_t 转换为 int16x8_t
int16x8_t pl0_0 = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0));
int16x8_t ph0_0 = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0));
int16x8_t pl0_1 = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0));
int16x8_t ph0_1 = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0));
int16x8_t pl0_2 = vmull_s8(vget_low_s8(v0_1ls), vget_low_s8(v1_1));
int16x8_t ph0_2 = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1));
int16x8_t pl0_3 = vmull_s8(vget_low_s8(v0_1hs), vget_low_s8(v1_1));
int16x8_t ph0_3 = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1));
// 水平求和
int32x4_t p0_0 = vpaddlq_s16(pl0_0);
p0_0 = vpadalq_s16(p0_0, ph0_0);
int32x4_t p0_1 = vpaddlq_s16(pl0_1);
p0_1 = vpadalq_s16(p0_1, ph0_1);
int32x4_t p0_2 = vpaddlq_s16(pl0_2);
p0_2 = vpadalq_s16(p0_2, ph0_2);
int32x4_t p0_3 = vpaddlq_s16(pl0_3);
p0_3 = vpadalq_s16(p0_3, ph0_3);
#endif
// 转换为浮点数并乘以量化参数
const float d0_0 = x0->d * y0->d;
const float d0_1 = x0->d * y0->d;
const float d0_2 = x1->d * y1->d;
const float d0_3 = x1->d * y1->d;
const float32x4_t d0_0v = vdupq_n_f32(d0_0);
const float32x4_t d0_1v = vdupq_n_f32(d0_1);
const float32x4_t d0_2v = vdupq_n_f32(d0_2);
const float32x4_t d0_3v = vdupq_n_f32(d0_3);
const float32x4_t p0_0f = vcvtq_f32_s32(p0_0);
const float32x4_t p0_1f = vcvtq_f32_s32(p0_1);
const float32x4_t p0_2f = vcvtq_f32_s32(p0_2);
const float32x4_t p0_3f = vcvtq_f32_s32(p0_3);
// 累加结果
sumv0 = vmlaq_f32(sumv0, p0_0f, d0_0v);
sumv0 = vmlaq_f32(sumv0, p0_1f, d0_1v);
sumv1 = vmlaq_f32(sumv1, p0_2f, d0_2v);
sumv1 = vmlaq_f32(sumv1, p0_3f, d0_3v);
}
// 水平求和
sumv0 = vaddq_f32(sumv0, sumv1);
float32x2_t sum = vadd_f32(vget_low_f32(sumv0), vget_high_f32(sumv0));
sum = vpadd_f32(sum, sum);
return vget_lane_f32(sum, 0);
}
这个函数的主要优化点包括:
- NEON 指令集优化:使用 NEON 指令集进行向量化计算,每次处理多个数据元素
- 条件编译:根据 ARM 处理器是否支持 SDOT 指令(ARMv8.2-A 及以上),使用不同的实现
- 分块处理:每次处理两个块,减少循环开销
- 并行计算:使用多个向量寄存器并行计算,提高指令级并行性
- 内存访问优化:使用连续的内存访问模式,提高缓存命中率
这个函数是矩阵乘法中的核心计算部分,通过优化点积计算,可以显著提高矩阵乘法的性能。在 llm_matmul_compute_int4_float
函数中,使用这个函数计算矩阵乘法:
dst[m * N + n] = vec_vec_dot_q40_with_q80(K, q_weight0, src) + b0;
dst[m * N + n + 1] = vec_vec_dot_q40_with_q80(K, q_weight1, src) + b1;
dst[m * N + n + 2] = vec_vec_dot_q40_with_q80(K, q_weight2, src) + b2;
dst[m * N + n + 3] = vec_vec_dot_q40_with_q80(K, q_weight3, src) + b3;
通过使用量化点积计算,可以减少内存占用和内存带宽需求,提高计算效率,特别是对于大型矩阵乘法,这种优化可以显著提高性能。
与其他平台的量化点积计算比较
与 x86 平台相比,ARM 平台的量化点积计算有以下特点:
- 指令集差异:ARM 使用 NEON 和 SDOT 指令,x86 使用 SSE/AVX/AVX2 和 VNNI 指令
- 向量宽度:ARM NEON 的向量宽度为 128 位,x86 AVX/AVX2 的向量宽度为 256 位
- 专用指令:ARMv8.2-A 引入了 SDOT 指令,专门用于加速点积计算;x86 引入了 VNNI 指令,也是专门用于加速点积计算
总体而言,ARM 平台的量化点积计算实现充分利用了 NEON 指令集的特性,在支持 SDOT 指令的处理器上可以获得更好的性能。
3. 计算函数实现 (kernel.cpp)
kernel.cpp
文件实现了各种计算函数,这些函数使用 TaskSet 实现多线程并行计算。
3.1 嵌入层计算
TaskSet llm_embedding_get_int4_float(
const void* weights, const uint32_t* index, float* dst, uint32_t len_seq,
uint32_t embd) {
auto task = [=](const TaskId& id) {
for (uint32_t i = id.start; i < id.end; ++i) {
const int row = index[i];
const int weight_stride =
embd * dtype_in_byte(DType::Int4) / dtype_block_size(DType::Int4);
dequantize_row_q4_0(
(static_cast<const char*>(weights) + row * weight_stride),
dst + i * embd, embd);
}
};
return TaskSet{{task, len_seq}};
}
这个函数实现了嵌入层的计算,将4位整数量化的权重反量化为浮点数。每个任务处理一个或多个序列位置,通过 dequantize_row_q4_0
函数将量化权重反量化为浮点数。
3.2 元素级计算
TaskSet llm_elemwise_compute_float(
InData<float> srcs, float* dst, size_t length, ElemMode mode) {
MultiThreadingTask task;
switch (mode) {
case ElemMode::Add: {
task = [=](const TaskId& id) {
uint32_t offset = id.start;
uint32_t len = id.end - id.start;
elemwise_vector_add(
len, srcs[0] + offset, srcs[1] + offset, dst + offset);
};
break;
}
case ElemMode::Mul: {
task = [=](const TaskId& id) {
uint32_t offset = id.start;
uint32_t len = id.end - id.start;
elemwise_vector_mul(
len, srcs[0] + offset, srcs[1] + offset, dst + offset);
};
break;
}
case ElemMode::Silu: {
task = [=](const TaskId& id) {
uint32_t offset = id.start;
uint32_t len = id.end - id.start;
return elemwise_vector_silu(len, srcs[0] + offset, dst + offset);
};
break;
}
case ElemMode::Gelu: {
task = [=](const TaskId& id) {
uint32_t offset = id.start;
uint32_t len = id.end - id.start;
return elemwise_vector_gelu(len, srcs[0] + offset, dst + offset);
};
break;
}
default:
INFER_ASSERT(0, "Not supported.");
}
return TaskSet{{task, length}};
}
这个函数实现了元素级计算,支持加法、乘法、SiLU 激活函数和 GELU 激活函数。每个任务处理一段连续的数据,通过 elemwise_vector_add
、elemwise_vector_mul
、elemwise_vector_silu
和 elemwise_vector_gelu
函数实现具体的计算。
3.3 广播计算
TaskSet llm_elemwise_broadcast_dim0_src1_compute_float(
const float* src0, const float* src1, float* dst, uint32_t len0, uint32_t len1,
ElemMode mode) {
MultiThreadingTask task;
switch (mode) {
case ElemMode::Add: {
task = [=](const TaskId& id) {
for (size_t i = id.start; i < id.end; i++) {
const float* p_src = src0 + i * len1;
float* p_dst = dst + i * len1;
elemwise_vector_add(len1, p_src, src1, p_dst);
}
};
break;
}
case ElemMode::Mul: {
task = [=](const TaskId& id) {
for (size_t i = id.start; i < id.end; i++) {
auto p_src = src0 + i * len1;
auto p_dst = dst + i * len1;
elemwise_vector_mul(len1, p_src, src1, p_dst);
}
};
break;
}
default:
INFER_ASSERT(0, "Not supported.");
}
return TaskSet{{task, len0}};
}
这个函数实现了广播计算,将 src1
广播到 src0
的第0维,然后进行元素级计算。每个任务处理 src0
的一个或多个行,通过 elemwise_vector_add
和 elemwise_vector_mul
函数实现具体的计算。
3.4 RMS 归一化
TaskSet llm_rms_norm_compute_float(
const float* src, float* dst, uint32_t seq_len, uint32_t embd, float eps) {
auto task = [=](const TaskId& id) {
for (uint32_t i = id.start; i < id.end; i++) {
const float* row = src + i * embd;
float* out = dst + i * embd;
float mean = reduce_square_sum(embd, row) / embd;
const float scale = 1.0 / sqrt(mean + eps);
elemwise_vec_scale(embd, row, scale, out);
}
};
return TaskSet{{task, seq_len}};
}
这个函数实现了 RMS 归一化,每个任务处理一个或多个序列位置。首先计算平方和的均值,然后计算缩放因子,最后将输入向量乘以缩放因子得到归一化结果。
3.5 Softmax 计算
TaskSet llm_softmax_compute_float(
const float* src, float* dst, uint32_t len_row, uint32_t col) {
auto task = [=](const TaskId& id) {
for (uint32_t row = id.start; row < id.end; row++) {
const float* psrc = src + row * col;
float* pdst = dst + row * col;
float max = reduce_max(col, psrc);
float sum = select_sub_max_and_reduce_sum(col, psrc, pdst, max);
sum = 1.0 / sum;
elemwise_vec_scale(col, pdst, sum, pdst);
}
};
return TaskSet{{task, len_row}};
}
这个函数实现了 Softmax 计算,每个任务处理一个或多个行。首先找到最大值,然后减去最大值并计算指数和,最后将每个元素除以指数和得到 Softmax 结果。
3.6 量化矩阵乘法
TaskSet llm_matmul_compute_int4_float(
float* dst, const void* src0, const float* bias, const float* src1, uint32_t M,
uint32_t N, uint32_t K, void* workspace, uint32_t size) {
// ... 参数检查和计算步长 ...
// 第一阶段:量化输入
auto task1 = [=](const TaskId& id) {
for (uint32_t m = id.start; m < id.end; m++) {
BlockQ80* q_src1 = (BlockQ80*)(static_cast<uint8_t*>(workspace) +
m * weight_q80_stride);
quantize_row_q8_0(src1 + m * K, q_src1, K);
}
};
// 第二阶段:计算矩阵乘法
int8_t* q_src = static_cast<int8_t*>(workspace);
auto task2 = [=](const TaskId& id) {
// 每次处理4列,提高计算效率
uint32_t N_len = id.end - id.start;
uint32_t n_block_4 = N_len / 4;
uint32_t n_block_4_left = N_len - n_block_4 * 4;
// 处理4列一组的部分
for (uint32_t block4 = 0; block4 < n_block_4; block4++) {
uint32_t n = block4 * 4 + id.start;
// ... 加载偏置和权重 ...
// 计算矩阵乘法
for (uint32_t m = 0; m < M; m++) {
int8_t* src = q_src + m * weight_q80_stride;
dst[m * N + n] = vec_vec_dot_q40_with_q80(K, q_weight0, src) + b0;
dst[m * N + n + 1] = vec_vec_dot_q40_with_q80(K, q_weight1, src) + b1;
dst[m * N + n + 2] = vec_vec_dot_q40_with_q80(K, q_weight2, src) + b2;
dst[m * N + n + 3] = vec_vec_dot_q40_with_q80(K, q_weight3, src) + b3;
}
}
// 处理剩余列
for (uint32_t left = 0; left < n_block_4_left; left++) {
uint32_t n = n_block_4 * 4 + left + id.start;
// ... 加载偏置和权重 ...
// 计算矩阵乘法
for (uint32_t m = 0; m < M; m++) {
int8_t* src = q_src + m * weight_q80_stride;
dst[m * N + n] = vec_vec_dot_q40_with_q80(K, q_weight, src) + b0;
}
}
};
return TaskSet{{task1, M}, {task2, N}};
}
这个函数实现了4位整数权重与浮点数激活值的矩阵乘法。它分为两个阶段:第一阶段将输入量化为8位整数,第二阶段计算矩阵乘法。每个阶段都使用多线程并行计算,第一阶段按行分解,第二阶段按列分解。
3.7 多头注意力计算
TaskSet llm_matmul_compute_with_head_stride_float(
float* dst, const float* srck, const float* srcq, uint32_t seqlen,
uint32_t embd, uint32_t head, uint32_t nr_past) {
uint32_t sub_embd = embd / head;
uint32_t length = nr_past + seqlen;
uint32_t line_stride = embd;
auto task = [=](const TaskId& id) {
for (uint32_t h = id.start; h < id.end; h++) {
auto dst_head = dst + h * seqlen * (nr_past + seqlen);
auto srck_head = srck + h * sub_embd;
auto srcq_head = srcq + h * sub_embd;
compute_src_offset_embd_matmul(
srcq_head, embd, srck_head, embd, dst_head, seqlen, length,
sub_embd);
}
};
return TaskSet{{task, head}};
}
TaskSet llm_head_batched_matmul_compute_float(
float* dst, const float* v, const float* qk, uint32_t seqlen, uint32_t embd,
uint32_t head, uint32_t nr_past) {
uint32_t sub_embd = embd / head;
uint32_t length = nr_past + seqlen;
uint32_t line_stride = embd;
auto task = [=](const TaskId& id) {
for (uint32_t h = id.start; h < id.end; h++) {
float* dst_head = dst + h * sub_embd;
const float* v_head = v + h * sub_embd;
const float* qk_head = qk + h * seqlen * length;
comput_matmul_with_dst_uncontinue(
dst_head, embd, v_head, embd, qk_head, seqlen, length, sub_embd);
}
};
return TaskSet{{task, head}};
}
这两个函数实现了多头注意力计算中的矩阵乘法。llm_matmul_compute_with_head_stride_float
计算 Q 和 K 的矩阵乘法,llm_head_batched_matmul_compute_float
计算 QK 和 V 的矩阵乘法。每个任务处理一个注意力头,通过 compute_src_offset_embd_matmul
和 comput_matmul_with_dst_uncontinue
函数实现具体的计算。
4. 优化策略分析
4.1 NEON 指令集优化
ARM 优化模块使用 NEON 指令集进行向量化计算,提高计算效率:
// 使用 NEON 指令集优化的向量加法
inline void elemwise_vector_add(
const int n, const float* __restrict x, const float* __restrict y,
float* __restrict z) {
int i = 0;
#if defined(__ARM_NEON)
for (; i + 15 < n; i += 16) {
float32x4_t vx0 = vld1q_f32(x + i);
float32x4_t vy0 = vld1q_f32(y + i);
float32x4_t vx1 = vld1q_f32(x + i + 4);
float32x4_t vy1 = vld1q_f32(y + i + 4);
float32x4_t vx2 = vld1q_f32(x + i + 8);
float32x4_t vy2 = vld1q_f32(y + i + 8);
float32x4_t vx3 = vld1q_f32(x + i + 12);
float32x4_t vy3 = vld1q_f32(y + i + 12);
float32x4_t vz0 = vaddq_f32(vx0, vy0);
float32x4_t vz1 = vaddq_f32(vx1, vy1);
float32x4_t vz2 = vaddq_f32(vx2, vy2);
float32x4_t vz3 = vaddq_f32(vx3, vy3);
vst1q_f32(z + i, vz0);
vst1q_f32(z + i + 4, vz1);
vst1q_f32(z + i + 8, vz2);
vst1q_f32(z + i + 12, vz3);
}
#endif
// 标量回退实现
for (; i < n; i++) {
z[i] = x[i] + y[i];
}
}
NEON 指令集允许同时处理多个数据元素,大大提高了计算效率。上面的代码每次处理 16 个浮点数,使用 4 个 float32x4_t 寄存器,每个寄存器可以同时处理 4 个浮点数。
4.2 分块处理策略
ARM 优化模块使用分块处理策略,将数据分成多个块进行处理:
// 矩阵乘法中的分块处理
auto task2 = [=](const TaskId& id) {
uint32_t N_len = id.end - id.start;
uint32_t n_block_4 = N_len / 4;
uint32_t n_block_4_left = N_len - n_block_4 * 4;
// 处理4列一组的部分
for (uint32_t block4 = 0; block4 < n_block_4; block4++) {
// ... 处理4列 ...
}
// 处理剩余列
for (uint32_t left = 0; left < n_block_4_left; left++) {
// ... 处理1列 ...
}
};
这种策略可以最大化利用 NEON 指令集的并行性,同时处理所有数据。
4.3 多线程并行策略
ARM 优化模块使用 TaskSet 实现多线程并行,每个计算函数都返回一个 TaskSet,包含一个或多个任务及其子任务数量:
TaskSet llm_rms_norm_compute_float(
const float* src, float* dst, uint32_t seq_len, uint32_t embd, float eps) {
auto task = [=](const TaskId& id) {
for (uint32_t i = id.start; i < id.end; i++) {
// ... 计算逻辑 ...
}
};
return TaskSet{{task, seq_len}};
}
不同的计算函数使用不同的任务分解策略:
- 按序列长度分解:如
llm_rms_norm_compute_float
,每个任务处理一个或多个序列位置 - 按头数分解:如
llm_matmul_compute_with_head_stride_float
,每个任务处理一个或多个注意力头 - 按矩阵行列分解:如
llm_matmul_compute_int4_float
,使用两个任务集,一个按行分解,一个按列分解
4.4 量化计算优化
ARM 优化模块使用量化计算减少内存占用和计算量:
TaskSet llm_matmul_compute_int4_float(
float* dst, const void* src0, const float* bias, const float* src1, uint32_t M,
uint32_t N, uint32_t K, void* workspace, uint32_t size) {
// 第一阶段:量化输入
auto task1 = [=](const TaskId& id) {
for (uint32_t m = id.start; m < id.end; m++) {
BlockQ80* q_src1 = (BlockQ80*)(static_cast<uint8_t*>(workspace) +
m * weight_q80_stride);
quantize_row_q8_0(src1 + m * K, q_src1, K);
}
};
// 第二阶段:使用量化数据计算
auto task2 = [=](const TaskId& id) {
// ... 使用 vec_vec_dot_q40_with_q80 计算点积 ...
};
return TaskSet{{task1, M}, {task2, N}};
}
这种设计减少了内存占用和内存带宽需求,提高了计算效率。特别是对于大型矩阵乘法,量化计算可以显著提高性能。
5. 内核注册机制
ARM 优化模块使用内核注册机制,将函数与内核 ID 关联起来:
// kernel.h
#define ImplementKernel(kernel_id, fun) \
template <> \
struct Comp<KernelID::kernel_id, KernelPlatform::Optimized> { \
template <typename... Args> \
static TaskSet exec(Args... args) { \
return fun(std::forward<Args>(args)...); \
} \
};
#define ImplementSpace(kernel_id, fun) \
template <> \
struct Space<KernelID::kernel_id, KernelPlatform::Optimized> { \
template <typename... Args> \
static size_t get(Args... args) { \
return fun(std::forward<Args>(args)...); \
} \
};
// 注册内核
ImplementKernel(ElemwiseFloat, llm_elemwise_compute_float);
ImplementKernel(ElemwiseBroadcastDim0Src1Float, llm_elemwise_broadcast_dim0_src1_compute_float);
ImplementKernel(NormFloat, llm_norm_compute_float);
ImplementKernel(RmsNormFloat, llm_rms_norm_compute_float);
ImplementKernel(EmbeddingGetInt4Float, llm_embedding_get_int4_float);
ImplementKernel(SoftmaxFloat, llm_softmax_compute_float);
ImplementKernel(MatmulInt4Float, llm_matmul_compute_int4_float);
ImplementKernel(MatmulWithHeadStrideFloat, llm_matmul_compute_with_head_stride_float);
ImplementKernel(HeadBatchedMatmulFloat, llm_head_batched_matmul_compute_float);
// 注册工作空间计算函数
ImplementSpace(MatmulInt4Float, llm_matmul_get_workspace_float);
这些宏和注册语句将内核 ID 与具体的实现函数关联起来,使得框架可以在运行时根据内核 ID 选择合适的实现。
6. 与 naive 实现的比较
ARM 优化模块与 naive 模块的主要区别:
- 向量化实现:ARM 优化模块使用 NEON 指令集进行向量化计算,naive 模块使用标量实现
- 分块处理:ARM 优化模块使用分块处理策略,naive 模块使用简单的循环
- 量化计算:ARM 优化模块使用量化计算减少内存占用和计算量,naive 模块使用浮点计算
- 多线程并行:ARM 优化模块使用 TaskSet 实现多线程并行,naive 模块也使用 TaskSet,但任务分解策略不同
在实际应用中,ARM 优化模块的性能明显优于 naive 模块,特别是在支持 NEON 指令集的 ARM 处理器上。
7. 与 x86 和 RVV 优化的比较
ARM 优化模块与 x86 和 RVV 优化模块的主要区别:
- 指令集:ARM 优化模块使用 NEON 指令集,x86 优化模块使用 SSE/AVX/AVX2 指令集,RVV 优化模块使用 RISC-V 向量扩展指令集
- 向量宽度:NEON 指令集的向量宽度为 128 位,可以同时处理 4 个单精度浮点数;AVX/AVX2 指令集的向量宽度为 256 位,可以同时处理 8 个单精度浮点数;RVV 指令集的向量宽度可变,取决于硬件实现
- 优化程度:x86 优化模块的优化程度最高,实现了更多的优化技术;ARM 优化模块次之;RVV 优化模块的优化程度最低,主要依赖 RVV 指令集的基本向量操作
8. 未来优化方向
基于当前实现,可以考虑以下优化方向:
8.1 更多 ARM 指令集支持
- 支持 ARMv8.2-A 的 FP16 指令,使用半精度浮点数进行计算
- 支持 ARMv8.2-A 的 DotProd 指令,加速点积计算
- 支持 ARMv8.6-A 的 BFloat16 指令,使用 BFloat16 进行计算
8.2 更高效的算法
- 使用 Winograd 算法优化矩阵乘法
- 使用 Flash Attention 算法优化注意力计算
- 使用混合精度计算提高性能
8.3 更多量化方法
- 支持 3 位、2 位甚至 1 位量化
- 支持非对称量化
- 支持组量化
8.4 更高级的并行策略
- 使用流水线并行减少内存占用
- 使用张量并行和模型并行处理大型模型
- 使用异步计算提高计算效率
总结
InferLLM 框架中的 ARM 优化模块提供了针对 ARM 架构处理器的优化实现,主要利用 NEON 指令集进行向量化计算,提高大语言模型推理的性能。通过使用 NEON 指令集、分块处理策略、多线程并行和量化计算等技术,ARM 优化模块实现了高效的计算。
与 naive 模块相比,ARM 优化模块的性能明显更高,特别是在支持 NEON 指令集的 ARM 处理器上。与 x86 和 RVV 优化模块相比,ARM 优化模块使用了不同的指令集和优化技术,适用于不同的硬件平台。
未来可以考虑支持更多 ARM 指令集、使用更高效的算法、支持更多量化方法和使用更高级的并行策略等方向进行优化,进一步提高大语言模型在 ARM 平台上的推理性能。
