CANN EasyAsc DSL a2 Cube-Vec-Cube-Vec模式
a2 Cube-to-Vec-to-Cube-to-Vec Pattern (Triple Bridge, Normalized Online Softmax)
【免费下载链接】cannbot-skillsCANNBot 是面向 CANN 开发的用于提升开发效率的系列智能体,本仓库为其提供可复用的 Skills 模块。项目地址: https://gitcode.com/cann/cannbot-skills
Read this file when writing an a2 (easyasc.a2, deviceb3) kernel with:
- one cube stage that produces a score tile
- vec logic that updates running row max and running row sum
- a later cube stage that consumes the delayed probability tile
- a final vec stage that accumulates the delayed cube output
- one final vec-only divide by the accumulated row sum
Typical target formula:
score_j = q.float() @ k_j.float().t() * scalecurr_m = maximum(prev_m, rowmax(score_j))expdiff_j = exp(prev_m - curr_m)p_j = exp(score_j - curr_m)row_sum = row_sum * expdiff_j + p_j.sum(-1)pv_j = p_j.half().float() @ v_j.float()out = out * expdiff_j + pv_jout = out / row_sum
This is the normalized counterpart toa2-cube-vec-cube-vec.md. Use that older pattern only when the kernel stops at the unnormalized numerator.
One-page route for the common case
If this file matches your contract, donotpreload all of:
agent/references/constraints/reduction.mdagent/references/constraints/vec-reduction-a2.mdagent/references/constraints/vec-stride.mdagent/references/constraints/online-softmax-tail.md
This page now owns the common normalized-online-softmax authoring rules. Open the smaller constraint pages only when a specific failure mode still remains unclear after this file.
Why this needs its own a2 pattern
The a2 hardware constraints are the same as the unnormalized case:
- cube -> vec cannot use
l0c_to_ub - vec -> cube cannot use
ub_to_l1_* - delayed cube output must come back to vec for final accumulation
But normalized online softmax adds two stability-sensitive requirements:
- running
row_summust be updated from the floatexp(...)tile before any cast to half - the final divide must happen only once, after all delayed numerator tiles have been accumulated
So the stable a2 flow is:
GM(q,k,v) -> L1 -> L0 -> L0C(score) -> GM(score_ws) -> UB(score)-> vec(max, expdiff, exp, row_sum, cast p) -> GM(p_ws) -> L1 -> L0 -> L0C(pv)-> GM(pv_ws) -> UB(pv) -> UB(accum) -> final UB divide by row_sum -> GM(out)
Workspaces and ownership edges
Use the same three GM workspaces as the unnormalized pattern:
score_ws- dtype:
float - shape:
[GetCubeNum(), 2, TILE_M, TILE_N] - purpose:
L0C(score)->UB(score)
- dtype:
p_ws- dtype:
half - shape:
[GetCubeNum(), 2, TILE_M, TILE_N] - purpose:
UB(p_j.half())->L1(p_j)
- dtype:
pv_ws- dtype:
float - shape:
[GetCubeNum(), 2, TILE_M, D] - purpose:
L0C(pv_j)->UB(pv_j)
- dtype:
Ownership edges:
- stage 1 cube -> vec:
CvMutex(0, src_end_pipe=Pipe.FIX, dst_end_pipe=Pipe.MTE2) - stage 1 vec -> stage 2 cube:
VcMutex(1, src_end_pipe=Pipe.MTE3, dst_end_pipe=Pipe.FIX) - stage 2 cube -> stage 3 vec:
CvMutex(2, src_end_pipe=Pipe.FIX, dst_end_pipe=Pipe.MTE2)
Stable schedule
Use the same one-tile lookahead loop as the unnormalized pattern:
for ni in range(0, tiles_n + 1): if ni < tiles_n: # stage 1: produce tile j = ni if ni > 0: # stage 2 + stage 3: consume tile j = ni - 1That gives:
- warmup: first iteration only produces
- steady state: produce
jwhile consumingj - 1 - drain: final iteration only consumes the last delayed tile
SharedL0Crule
Reuse one physicalL0Cfamily across the two cube stages.
This is the same capacity-driven choice as the unnormalized pattern:
- stage 1 needs float
[TILE_M, TILE_N] - stage 2 needs float
[TILE_M, D]with validatedD == 128 - a2 still has only
128 KBL0C
Keep one sharedl0c_cnt, but do not merge unrelated counters just becauseL0Cis shared.
Counter layout
Keep these lifetimes separate:
l1qk_cnt: stage-1q/kloadsl1pv_cnt: stage-2p/vloadsl0c_cnt: shared physicalL0Cfamily across the two cube stagesstage1_cnt: delayed slot rhythm forscore_ws,p_ws, andexpdiffstage2_cnt: delayed slot rhythm forp_wsconsumption andpv_ws
Runningrow_sumdoes not need its own delayed counter. It stays vec-resident for the whole inner loop and updates immediately in stage 1.
Vec-resident persistent state
Keep these values in per-subblock UB across the whole inner loop:
- running row max:
[HALF_M, 1] - running row sum:
[HALF_M, 1] - delayed
expdiffslots:DBuff(DT.float, [HALF_M, 1], Position.UB) - final numerator accumulation:
[HALF_M, D]
UseGetSubBlockIdx()so each vec lane owns only its ownHALF_Mrows.
Stable stage-1 update order
The normalized online update order matters:
- compute
rowmax(score_j)in[HALF_M, 1] - snapshot
prev_minto the delayedexpdiffslot withadd(..., zero) - update
running_max = maximum(running_max, tile_max) - turn the delayed slot into
exp(prev_m - curr_m) - broadcast
running_maxand subtract from the score tile - compute the float probability tile
p_j = exp(score_j - curr_m) - reduce
sum_jfrom that float tile withadd+cadd - update
running_sum = running_sum * expdiff_j + sum_jin[HALF_M, 1] - cast
p_jtohalfonly now, because stage 2 wants the exactp_j.half().float()contract
Do not move the row-sum update after the cast. That would silently change the reference contract.
Vec rules you usually need without extra docs
For the commonTILE_N = 128,D = 128path, the usual extra questions are already answered here:
- keep
running_max,running_sum, and delayedexpdiffin scalar format[HALF_M, 1] - snapshot scalar state with
add(dst, src, zero), notub_to_ub cmax/caddoutput dense scalars, so broadcast them with:brcb(dst, src, dst_blk_stride=1, dst_rep_stride=8)
- when a wide
[HALF_M, 128]buffer is paired with a narrow[HALF_M, 8]broadcast row, operate on:buf[:, 0:64]buf[:, 64:128]rather than on the full 128-column view in one vec call
- update
running_sumfrom the floatp_jtile before any cast tohalforhif8 - for non-aligned
S2, invalidate score columns beforecmaxwith a sufficiently negative finite sentinel;valid_non the GM load alone is not enough
These six rules cover the usual reasons people would otherwise open the separate reduction, vec-reduction, vec-stride, and tail files.
Critical scalar-state rule on a2
Donotcopy[HALF_M, 1]scalar-format state withub_to_ub.
That applies to both:
prev_m- any temporary scalar snapshot you might be tempted to use for
row_sum
Useadd(dst, src, zero)for scalar-format copies, and keep bothrunning_maxandrunning_sumin[M,1]format until you explicitly need a broadcast.
Final vec accumulation and divide
Stage 3 still matches the unnormalized pattern:
- load delayed
pv_jback into UB brcbthe delayedexpdiffslot to[HALF_M, 8]- scale the two 64-column halves of
accum add(accum, accum, pv_j)
After the inner loop finishes:
brcbthe finalrunning_sumto[HALF_M, 8]div(accum[:, 0:64], accum[:, 0:64], row_sum_broadcast)div(accum[:, 64:128], accum[:, 64:128], row_sum_broadcast)- write the normalized result to GM
Why the divide happens at the end:
accummust finish all delayedpv_jcontributions firstrow_sumis the denominator for the whole streamed softmax, not one tile
Extending the pattern to non-alignedS2
The initial validated contract for this pattern keptS2 % 128 == 0so the first implementation could ignore score-tail masking.
WhenS2is not aligned, donotstop at GM-boundaryvalid_nslicing. For normalized online softmax, padded score columns can still corrupt:
rowmax(score_j)curr_m- delayed
expdiff row_sum
Stable rule:
- load
k/vthroughvalid_n - keep local score buffers full-sized
- before
cmax, force invalid score columns to behave like-inf - when materializing that mask, use a sufficiently large finite negative fill value instead of literal
-inf - after
exp, those same columns naturally behave like0
For the currentTILE_N = 128layout, the simplest a2 implementation is:
- split the score tile into two
[HALF_M, 64]halves - use vec mask + finite-negative
dup(...)on the affected half - recompute
prev_valid_nfor the delayedvload in stage 2
Read next for the exact rule and mask-construction trick:
agent/references/constraints/online-softmax-tail.md
Validation target
Keep the first validated contract narrow:
D == 128S1 % 128 == 0S2 % 128 == 0- input
q/k/varefloat16 - output is
float32
Suggested cases:
(1, 3, 256, 256, 128)for the smallest two-tile online update(1, 1, 256, 512, 128)(1, 3, 256, 512, 128)(1, 3, 2048, 4096, 128)
For non-alignedS2extensions, add at least:
- one aligned baseline:
S2 % 128 == 0 - one left-half tail:
S2 % 128 == 10 - one cross-boundary case:
S2 % 128 == 65 - one mid-right-half case:
S2 % 128 == 96 - one last-column case:
S2 % 128 == 127
Files to study / deeper fallbacks
agent/example/kernels/a2/flash_attn_full.pyagent/example/kernels/a2/flash_attn_unnorm.pyagent/example/kernels/a2/flash_attn_score_pv.pyagent/references/patterns/a2-cube-vec-cube-vec.mdagent/references/constraints/reduction.md— fallback only when the online update order is still unclearagent/references/constraints/vec-reduction-a2.md— fallback only when thecmax/cadd -> brcbdetail is still unclearagent/references/constraints/vec-stride.md— fallback only when a sliced wide/narrow vec op is still unclearagent/references/constraints/online-softmax-tail.md— fallback only when the non-alignedS2mask construction itself is the question
【免费下载链接】cannbot-skillsCANNBot 是面向 CANN 开发的用于提升开发效率的系列智能体,本仓库为其提供可复用的 Skills 模块。项目地址: https://gitcode.com/cann/cannbot-skills
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考
