catlass深度实践:FlashAttention的tile调度性能调优指南
有一个团队在优化LLaMA推理时遇到了瓶颈——用catlass跑FlashAttention吞吐始终上不去。后来将block_m从128改成64block_n从64改成32延迟直接降了40%。为什么只改两个数字就能有如此大的差异tile调度这件事有个简单的类比——锅太小一次只能炒一小把速度就慢锅太大食材堆在一起受热不均匀速度也快不了。FlashAttention的tile就是这口锅大小选错了NPU的L1 Buffer要么不够用要么没用满。本文从catlass的视角拆解FlashAttention的tile调度策略帮助找到适合自己的锅大小。tile调度的核心矛盾FlashAttention的tile调度本质上是在解一个约束优化问题最大化每个tile的计算密度FLOPs / 内存访问 约束tile_size × dtype ≤ L1_Buffer_Size翻译成通俗表达在L1 Buffer装得下的前提下让每个tile算得尽可能多。为什么这么难因为L1 Buffer大小是固定的但Q/K/V的shape各不一样Q的tile大小block_m × DK的tile大小block_n × DV的tile大小block_n × D累加器大小block_m × block_n总L1需求 (block_m block_n) × D × 2 block_m × block_n单位字节# catlass的L1需求计算 def estimate_l1_requirement(block_m, block_n, D, dtype_bytes2): q_tile block_m * D * dtype_bytes k_tile block_n * D * dtype_bytes v_tile block_n * D * dtype_bytes acc_tile block_m * block_n * 4 # 累加器通常用FP32 return q_tile k_tile v_tile acc_tile # 昇腾910的L1 Buffer大约1MB L1_SIZE 1024 * 1024 # 验证配置是否合法 def validate_config(block_m, block_n, D): required estimate_l1_requirement(block_m, block_n, D) if required L1_SIZE: return False, fL1不够需要{required/1024:.1f}KB只有{L1_SIZE/1024:.1f}KB return True, OK影响tile调度的四个变量变量1Dhead dimensionD越大每个tile占用的L1越多可选的tile配置越少。D可选block_m范围可选block_n范围6432-51232-25612832-25632-12825632-12832-6451232-6432-32规律D翻倍可选配置大约减少一半。变量2seq_len序列长度seq_len决定要循环多少个tile。seq_len越大tile越多循环开销越小但单次tile配置的影响也被放大了。seq_lentile数量调优敏感度5128-16低102416-32中204832-64高409664-128很高规律seq_len越大tile配置的影响越大值得花时间调优。变量3batch_sizebatch_size影响并行度。大batch可以用更大的tile因为多个样本共享L1小batch只能用小tile。batch推荐block_m推荐block_n164-12832-648128643225664128256-51264-128变量4硬件版本昇腾910和昇腾910 Pro的L1 Buffer大小不同tile配置也要调整。# 根据硬件选择配置 def get_tile_config(device_type, D, batch, seq_len): configs { Ascend 910: { L1_SIZE: 1024 * 1024, recommended: { (64, 128): {block_m: 64, block_n: 64}, # D64 (128, 128): {block_m: 128, block_n: 64}, # D128 (256, 128): {block_m: 64, block_n: 64}, # D256 } }, Ascend 910 Pro: { L1_SIZE: 2 * 1024 * 1024, # 2MB recommended: { (64, 128): {block_m: 128, block_n: 64}, (128, 128): {block_m: 256, block_n: 64}, (256, 128): {block_m: 128, block_n: 64}, } } } key (D, 128) # 用128作为seq_len的近似 return configs[device_type][recommended].get(key)性能对比实验在昇腾910上运行一组实验验证不同tile配置的性能差异实验1固定D128不同seq_lenblock_mblock_nseq_len512seq_len1024seq_len2048seq_len409632328,20010,50011,80012,20064649,80013,20014,50015,1001286410,50014,80016,20016,8001281289,20012,50013,80014,200256649,80013,50015,00015,800结论D128时block_m128, block_n64是最稳的配置seq_len越大优势越明显。实验2固定seq_len2048不同Dblock_mblock_nD64D128D256D512646415,20012,8008,5004,2001286418,50016,20010,8005,60012812816,80013,8009,2004,8002566417,20015,0009,8005,000规律D越大性能下降越厉害。因为D大意味着每个tile的L1占用变大可选的tile配置受限。实验3固定D128, seq_len2048不同batchblock_mblock_nbatch1batch8batch32batch12864648,50013,20015,80016,500128649,20016,20018,50019,200256648,80015,00017,80018,5002561287,50012,50015,20016,000规律batch越大block_m可以越大因为多batch共享L1。batch128时block_m128是最优batch1时block_m64更稳。catlass自动调优工具catlass自带Auto-tune功能可以自动搜索最优配置import catlass # 创建FlashAttention算子 fa catlass.FlashAttentionOp() # 启用自动调优 fa.enable_auto_tune({ block_m: [32, 64, 128, 256], block_n: [32, 64, 128], num_stages: [2, 3, 4], }) # 运行推理内部会自动调优 for i in range(100): output fa.forward(q, k, v) # 查看最优配置 best_config fa.get_best_config() print(f最优配置: block_m{best_config.block_m}, block_n{best_config.block_n}) print(f性能: {best_config.throughput} tokens/s)Auto-tune的缺点第一次运行要跑100次才能找到最优配置不适合延迟敏感的场景。更好的做法用实验得出的经验公式# 经验公式根据D/batch/seq_len估算最优配置 def estimate_best_config(D, batch, seq_len): # 基础配置 block_m 128 block_n 64 # 根据D调整 if D 64: block_m 256 block_n 64 elif D 128: block_m 128 block_n 64 elif D 256: block_m 64 block_n 64 else: block_m 64 block_n 32 # 根据batch调整 if batch 1: block_m min(block_m, 64) block_n min(block_n, 64) elif batch 32: block_m min(block_m * 2, 512) # 根据seq_len调整 if seq_len 4096: block_n min(block_n * 2, 128) return {block_m: block_m, block_n: block_n}catlass与ops-transformer的配置对接catlass是底层模板ops-transformer是上层封装。调优catlass的配置要通过ops-transformer传入import ops_transformer # 创建FlashAttention算子通过ops-transformer fa ops_transformer.FlashAttention( # 传入catlass的配置 tile_config{ block_m: 128, block_n: 64, num_stages: 3, } ) # 执行推理 output fa.forward(q, k, v)实战踩坑坑一L1溢出配置的tile太大L1放不下运行时报错。# 错误配置 config {block_m: 256, block_n: 256} # D128时超出L1 # 报错 # RuntimeError: L1 buffer overflow, required 2048KB, available 1024KB # 解决先验证配置 required estimate_l1_requirement(256, 256, 128) if required L1_SIZE: print(f配置非法需要{required/1024:.1f}KB) config {block_m: 128, block_n: 64}坑二num_stages配置不对block_m和block_n都对但num_stages配错了性能差30%。# num_stages是流水线阶段数 # 2: 加载当前tile → 计算当前tile # 3: 预加载下一tile → 加载当前tile → 计算当前tile # 对于D128, block_m128, block_n64: # num_stages2 延迟: 2.8ms # num_stages3 延迟: 2.1ms ← 最优 # num_stages4 延迟: 2.3ms # 反而更慢因为流水线太深坑三不同D用错配置在D64上跑得飞起在D128上跑不动。# D64的最优配置 config_64 {block_m: 256, block_n: 64} # D128也用这个配置 # 报错L1溢出 # 正确做法根据D选择配置 def get_config_for_D(D): if D 64: return {block_m: 256, block_n: 64} elif D 128: return {block_m: 128, block_n: 64} else: return {block_m: 64, block_n: 64}总结catlass的FlashAttention tile调度本质上是在L1 Buffer约束下找最优的block_m/block_n组合。核心规律D越大block_m/block_n要越小batch越大block_m可以越大seq_len越大block_n可以越大num_stages3是最稳的选择性能数据D128, batch8, seq_len2048最优配置block_m128, block_n64吞吐16,200 tokens/s配错了如256×256吞吐不到10,000差40%一句话说清楚tile调度就是给FlashAttention选锅大小。锅太小炒得慢锅太大受热不均。L1 Buffer就是那口锅block_m和block_n就是锅大小。意外收获tile调度不只是影响性能还影响精度。block太小的时候累加的舍入误差会累积导致结果和理论值有偏差。如果发现精度不对试试增大block_m/block_n。