vLLM 中,LLM 推理的 prefill 階段 attention 計算使用第三方庫 xformers 的優(yōu)化實現(xiàn),decoding 階段 attention 計算則使用項目編譯 CUDA 代碼實現(xiàn)。具體代碼在 vllm 的 csrc/attention/attention_kernels.cu 文件里,開發(fā)者洋洋灑灑寫了八百多行 CUDA 代碼。
Attention 計算時使用頁式(paged)管理 KVCache 用于增加服務吞吐率,但對延遲有負面影響,因此高效的 PA 實現(xiàn)方法,利用頁式內(nèi)存管理同時盡量降低其負面影響,對框架的綜合性能表現(xiàn)至關重要。
本文章將描述 PA CUDA Kernel 的實現(xiàn)細節(jié),這些細節(jié)是公開的論文和博客所不涉及的,但卻對框架的速度至關重要。另外,PA 實現(xiàn)改編自 FasterTransformers 某個版本的 MHA 實現(xiàn),NV 原始版本對 GPU 特性的運用也是相當老道的,值得大家借鑒。
vLLM 中有兩個版本 PA,使用一個簡單的啟發(fā)式方法來決定是使用 V1 還是 V2 版本。V1 是本文介紹的版本,改編自 FasterTransformers 的 MHA 實現(xiàn)。V2 是參考 FlashDecoding 方式進行實現(xiàn),對 sequence 維度進行切分以增加并行粒度,關于 FlashDecoding 可以參考本人知乎文章。V1 適合長度小于 8192 或者 num_seqs * num_heads>512 的情況。
參數(shù)定義和數(shù)據(jù)結構
num_seq:本次推理請求 sequence 數(shù)目。
num_head:Query 的 head 數(shù)目。
num_kv_heads:Key、Value 的 head 數(shù)目,對于 MHA 和 num_head 相同,如果是 GQA、MQA 則 num_kv_heads 小于 num_head。
head_size hidden dimension,特征的維度。
PA 使用 tensor 的維度信息:
out [num_seqs, num_heads, head_size]
Q [num_seqs, num_heads, head_size]
KCache [num_blocks, num_kv_heads, head_size/x, block_size, x]:x 表示一個向量化的大小,如 float16 -> 16 / sizeof(float16) = 8。
VCache [num_blocks, num_kv_heads, head_size, block_size]
Paged 內(nèi)存管理相關的輔助數(shù)據(jù)結構:
blk_size:也就是 block_size,是 KVCache page 的最高維,KVCache 是若干個 page 的集合,每個 page 存(blk_size, num_head,head_size)個 K、V 的元素。
head_mapping [num_heads] 用于 MQA, GQA,確定用的 KV_head
block_tables [num_seqs, max_num_blocks_per_seq] block_tables 映射表,表示每個 sequence 映射到哪幾個 block 上
context_lens [num_seqs] 用于變長
課前問題
如果你能回答以下兩個問題,那么說明你已經(jīng)非常熟練地掌握了 PA 實現(xiàn),并可以用批判性的眼光審閱本文,找出其中可能存在的錯誤。如果你暫時無法回答這些問題,請不要擔憂,閱讀完本文后會給你答案。
Q1:為什么 K Cache 的 layout 和 V Cache layout 不一樣?
Q2:PA 實現(xiàn)和 FlashAttention 有什么區(qū)別?
PagedAttention算子計算流程
首先,按照 CUDA 編程模型對任務進行并行劃分,grid 大小(num_heads, num_seqs),grid 中每個 CUDA thread block 大小(NUM_THREADS),NUM_THREADS 是常量默認為 128,也就說每個 thread block 包含 128 個線程,負責完成 output 矩陣一行(包含 head_size 個元素)結果的 attention 計算任務。thread block 中的線程進一步劃分若干個WARP。
眾所周知,WARP 是 GPU 一個基本的執(zhí)行單元,由 32 個線程組成,這些線程以 SMIT 方式在硬件上同時執(zhí)行相同的指令,在不同的數(shù)據(jù)上進行操作。在 PA 中比較特殊的是,warp 內(nèi) 32 個線程進一步劃分為 blk_size 個 thread group,這和 paged KVCache 設計 x 息息相關的,馬上會細講。
Attention 計算 softmax(QK^T)V,一圖勝前言,后面流程介紹將圍繞下面這幅圖展開。其中 thread block, warp, thread group, thread 別用不同顏色表示。
▲ 圖1:PagedAttention CUDA計算流程
在上圖的左側(cè)部分,我們看到了 Q 矩陣,這部分描述了從顯存讀取 Q 數(shù)據(jù)到共享內(nèi)存的過程。在這個過程中,一個 CUDA 線程塊會讀取圖中 Q 矩陣的一行(包含 head_size個元素)并將其存入共享內(nèi)存。
這個過程是通過一個循環(huán)來實現(xiàn)的,在每次迭代中,每個 thread group 會讀取 16 字節(jié)的 Q 數(shù)據(jù)(例如,如果使用 float16,那么就是 8 個元素)。每個 warp 會讀取 16*blk_size 字節(jié)的 Q 數(shù)據(jù),這些數(shù)據(jù)對應于一個 sequence 的一個 head,由 CUDA grid 索引指定。當循環(huán)訪問結束后,共享內(nèi)存存儲 Q 行的一部分。如下圖所示,綠色部分表示存儲在一個線程讀入共享內(nèi)存中的數(shù)據(jù)。
圖 1 中上面部分 K 矩陣部分描述了從顯存讀取 K Cache 到寄存器的過程。每個序列的 K Cache 包含 cxt_length * num_kv_heads * head_size 個元素,但由于采用了頁式內(nèi)存管理,這些元素在內(nèi)存中的存儲并不連續(xù)。每個 thread block 只負責計算一個 sequence 一個 head 的 QK^T,因此只需要 ctx_length * head_size 個 K Cache 元素。
然而,由于 ctx_length 維度的存儲是不連續(xù)的,并且以 blk_size 個 token 為粒度分布在不同的內(nèi)存地址,我們需要根據(jù)query的head_idx和 seq_idx 訪問 block_table 以找到 K Cache的physical_block_num。為了方便后續(xù)的描述,我們可以將 K Cache 視為(:, head_size)的形狀,其中 head_size 個元素組成一行。
K Cache 的布局為 [num_blocks, num_kv_heads, head_size/x, block_size, x],這是為了優(yōu)化寫入 shared memory 的操作。在 Q 和 K 矩陣的同一行元素被讀入寄存器并進行點乘運算后,結果需要被存入 shared memory。
如果一個 warp 中所有線程都計算 Q、K 同一行數(shù)據(jù),會導致寫入 shared memory 的同一個位置,這將造成 warp 內(nèi)不同線程順序地寫入。因此,為了優(yōu)化,warp的線程最好計算 Q 和 K 的不同行數(shù)據(jù)。因此,在設計 K Cache 布局時,我們將 block_size 放在比 head_size 更低的維度。
由于 warp size 大于 block_size,我們需要將 head_size 拆分為 head_size/x 和 x 兩個維度,借 x 到最低維度,以確保每個線程讀入的數(shù)據(jù)量和計算量都足夠大。最后,每個線程組派一個線程去寫入 shared memory,這樣一個 warp 有 blk_size 個線程并行寫入 shared memory,從而增加了 shared memory 的訪問帶寬。這種設計策略是為了實現(xiàn)高效的并行計算和內(nèi)存訪問,以提高整體的計算性能。
在代碼實現(xiàn)中,訪問 K 矩陣需要一個循環(huán),該循環(huán)使得 CUDA 線程塊中的所有 warp 依次訪問 num_block 個頁面。在每次循環(huán)迭代中,每個 warp 負責訪問連續(xù)的 blk_size個K Cache 行,這涉及到的數(shù)據(jù)量為 blk_size * head_size 個元素。同時,每個 thread group 負責訪問 K Cache 的一行,將 head_size 個元素加載到自己的寄存器中。
接著,寄存器中的 Q 和 K 數(shù)據(jù)元素立即進行點乘運算,運算結果被寫入 shared memory 中。因此,線程塊的 shared memory 存儲了一行 QK^T 的結果,包含 ctx_length 個元素。這種實現(xiàn)方式充分利用了 CUDA 的并行計算能力,以提高數(shù)據(jù)處理的效率。
然后,thread block 對 shared memory 中元素進行 max,sum 方式 reduction,然后計算得到 softmax 結果。
圖 1 右邊 V 矩陣部分描述從顯存讀 V Cache 到寄存器過程。和 K Cache 一樣,CUDA thread block 依次訪問 num_blk 個物理塊到寄存器,每個 warp 負責 blk_size 個 token 的 page 內(nèi)存,page 的真實物理地址同樣需要進行索引。
不過這里不需要以 thread group 為單位訪問 16 字節(jié),而是每個 thread 訪問 16 字節(jié)的元素。訪問完就可以與 shared memory 的 softmax(QK^T) 中間結果對應位置 16 字節(jié)的數(shù)據(jù)進行點乘,得到一個 float 結果,寫到 output 對應位置中。
為什么V Cache的layout是 [num_blocks, num_kv_heads, head_size, block_size],和 K Cache layout 不一樣?這是因為 V 要去做點乘的對象在shared memory,只需要讀,不涉及并行寫的問題。
和 FlashAttention(FA)有什么不同?結合我的圖和中間 FAv2 的流程圖對比就一目了然了。FA 用了兩層循環(huán),每次寫一個 Tile 的 output tensor,而 PA 一直只有一層循環(huán),每次寫一行 output tensor。因為每次都有整行的 QK^T 中間結果,不需要 online softmax 這種花哨技巧。
PAv1的問題
以我粗淺的理解指出幾點 vLLM PAv1 的問題。一、和 MHA 相比,MQA 和 GAQ 沒有減少對 KV Cache 的讀寫次數(shù)。讀 K、V Cache 時候只是做了一個 head_idx 的轉(zhuǎn)換,會重復從顯存讀相同的 head。二、對于 seq length 很長情況沒法適應,因為沒有沿著 ctx_length 或者 batch 維度做切分。這點 FlashAttention 和 FlashDecoding 就做了,因此 PAv2 借鑒了 FA 的切分思想。
總結
vLLM 的 paged attention v1 實現(xiàn)繼承自 FasterTransformers MHA 實現(xiàn),它和 FlashAttention 的并行任務劃分方式不同。其中對 KVCache layout 的設計比較巧妙,充分利用了 shared memory 寫帶寬,是一種常用 CUDA 編程技巧。
審核編輯:劉清
-
寄存器
+關注
關注
31文章
5394瀏覽量
122430 -
Cache
+關注
關注
0文章
129瀏覽量
28752 -
內(nèi)存管理
+關注
關注
0文章
168瀏覽量
14385 -
MQA
+關注
關注
0文章
3瀏覽量
6074
原文標題:vLLM皇冠上的明珠:深入淺出理解PagedAttention CUDA實現(xiàn)
文章出處:【微信號:zenRRan,微信公眾號:深度學習自然語言處理】歡迎添加關注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
相關推薦
評論