Skip to content

feat: flash attention support for hexagon-npu #45

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 91 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
0bb7a5f
add flash attn op
chraac May 27, 2025
85156a4
expend src tensor size
chraac May 28, 2025
8744400
add flash attn sources
chraac May 28, 2025
60c79cf
add quantize row functions
chraac May 28, 2025
b908b9e
make a separated file for vec_dot
chraac May 28, 2025
2beda85
wip
chraac May 29, 2025
6a00c93
wip
chraac May 30, 2025
9846b78
refactor: rename quants.hpp includes and add vec_dot to type traits
chraac May 30, 2025
c357a15
add flash_attn impl
chraac May 30, 2025
f63b99e
split vec_scale_f32
chraac May 30, 2025
5cecdd7
move vec_reduction_qf32 to vec_ops
chraac May 31, 2025
0e4ea1e
add vec_scale_f16
chraac May 31, 2025
aae0d33
opt
chraac May 31, 2025
baa1b6b
add vec_mad
chraac May 31, 2025
62bf246
implement vec_mad_f16
chraac May 31, 2025
6f8dd34
opt
chraac May 31, 2025
f2ec1e8
add op template
chraac May 31, 2025
39a8c42
opt
chraac May 31, 2025
5b339d4
add align version
chraac May 31, 2025
c527f2a
enable flash attn
chraac May 31, 2025
0d40cfa
wip
chraac May 31, 2025
3a1c0b0
log print improve
chraac Jun 1, 2025
5e01705
add profiler log
chraac Jun 1, 2025
9175ec6
wip
chraac Jun 1, 2025
8bba769
wip
chraac Jun 1, 2025
300afef
add multi sub proc perf tracker
chraac Jun 1, 2025
3e7662b
increase log buffer
chraac Jun 2, 2025
9fc46be
remove sub prov pcycle
chraac Jun 2, 2025
56bcc76
wip
chraac Jun 2, 2025
cdeb534
wip
chraac Jun 2, 2025
0967eea
add prefetch for vec_dot
chraac Jun 2, 2025
f2582dc
wip
chraac Jun 2, 2025
24c6e86
wip
chraac Jun 2, 2025
ee5c19d
opt f16 vec dot
chraac Jun 2, 2025
93e721f
opt f16 vecdot
chraac Jun 3, 2025
8d4dad3
reuse vec_dot_product_impl in vec dot f32
chraac Jun 3, 2025
e7c9e2a
small opt to unblock pipeline
chraac Jun 3, 2025
27be1eb
opt on aligned address
chraac Jun 3, 2025
f075a4c
Revert "opt on aligned address"
chraac Jun 3, 2025
ea582fe
add profiler log at thread_pool
chraac Jun 3, 2025
594644d
wip
chraac Jun 3, 2025
3ccedc3
invalidate all...
chraac Jun 3, 2025
62d6790
Reapply "opt on aligned address"
chraac Jun 3, 2025
c1583b4
add is_constant for tensor config
chraac Jun 3, 2025
2ba39f2
disable align tensor opt in mul_mat
chraac Jun 3, 2025
b803a63
wip
chraac Jun 4, 2025
b8c6462
wip
chraac Jun 4, 2025
830e67a
vec_scale_impl: unrolling the loop
chraac Jun 4, 2025
18c975b
wip
chraac Jun 4, 2025
559e5b8
wip
chraac Jun 4, 2025
c9ee6c8
replace reinterpret_cast with direct pointer access for write/read bu…
chraac Jun 4, 2025
56d9f3a
add fetch
chraac Jun 4, 2025
9dece44
wip
chraac Jun 4, 2025
30dcf58
wip
chraac Jun 4, 2025
5701374
wip
chraac Jun 4, 2025
9ad1165
add log
chraac Jun 5, 2025
627b504
check tensor shape at flash_attn
chraac Jun 6, 2025
ec4ec34
wip
chraac Jun 6, 2025
c5ec077
wip
chraac Jun 6, 2025
bc6db4f
fix: update tensor type handling in flash_attn_impl
chraac Jun 6, 2025
a3ac36a
wip
chraac Jun 6, 2025
2e68f1c
fix: align cache size
chraac Jun 7, 2025
52d6ee1
fix: qf16->hf
chraac Jun 7, 2025
4231c5e
fix: swap order of elements in vector combine for correct scaling
chraac Jun 8, 2025
0734627
fix: opt f16 scale and mad
chraac Jun 8, 2025
ddac620
Merge branch 'dev-refactoring' into dev-flash-attn
chraac Jun 9, 2025
b676497
fix leftover fetch
chraac Jun 9, 2025
5182ba1
wip
chraac Jun 9, 2025
8320d7c
load into vector pair
chraac Jun 9, 2025
efd5254
opt cache size calculation in flash_attn_impl
chraac Jun 9, 2025
7db0816
refactoring: hold vtcm at thread local object
chraac Jun 10, 2025
8639ea4
wip
chraac Jun 10, 2025
7d49c49
add profiler log
chraac Jun 10, 2025
49a6c27
mark tensors as modified
chraac Jun 10, 2025
0a8ff2b
restrict tensor invalidation to the first thread in compute_impl
chraac Jun 10, 2025
8da9e8e
Revert "restrict tensor invalidation to the first thread in compute_i…
chraac Jun 10, 2025
e2ba224
invalidate last tensor in compute_impl
chraac Jun 10, 2025
faa47bd
invalidate last tensor in compute function
chraac Jun 10, 2025
0347ee7
wip
chraac Jun 10, 2025
8c6e298
refactor dequantize_row_q4_0 to simplify vector alignment
chraac Jun 10, 2025
54b3c2a
wip
chraac Jun 11, 2025
0809df6
refactoring: move VTCM quota calculation to thread pool
chraac Jun 11, 2025
e7a92ba
wip
chraac Jun 11, 2025
095c811
fix: correct condition check for HEXAGON_SDK_ROOT existence
chraac Jun 13, 2025
8793f0c
wip
chraac Jun 13, 2025
28ec32e
wip
chraac Jun 13, 2025
62b77a8
wip
chraac Jun 13, 2025
a651dcf
wip
chraac Jun 13, 2025
2265fe4
fix: update condition checks match the naming
chraac Jun 13, 2025
6384928
fix: improve tensor handling checks and logging in graph and operatio…
chraac Jun 14, 2025
37d97e5
wip
chraac Jun 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ggml/src/ggml-qnn/npu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ cmake_policy(SET CMP0115 OLD)

if(DEFINED ENV{HEXAGON_SDK_ROOT})
set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT})
message("HEXAGON_SDK_ROOT (from environment): ${HEXAGON_SDK_ROOT}")
elseif(DEFINED HEXAGON_SDK_ROOT)
message("HEXAGON_SDK_ROOT: ${HEXAGON_SDK_ROOT}")
else()
message(FATAL_ERROR "HEXAGON_SDK_ROOT not defined")
Expand Down
28 changes: 14 additions & 14 deletions ggml/src/ggml-qnn/npu/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
#include "graph.hpp"
#include "hexagon_npu.h"
#include "op_impl.hpp"
#include "quants.hpp"
#include "remote.h"
#include "tensor.hpp"
#include "thread_pool.hpp"
#include "type_traits.hpp"
#include "util.hpp"

namespace {
Expand Down Expand Up @@ -124,21 +124,20 @@ int npu_device_close(remote_handle64 h) {

AEEResult npu_device_device_get_alignment(remote_handle64 _h, uint32_t * alignment) {
NPU_UNUSED(_h);
*alignment = sizeof(HVX_Vector);
*alignment = sizeof(HVX_VectorPair);
return AEE_SUCCESS;
}

AEEResult npu_device_device_support_op(remote_handle64 _h, const npu_device_tensor_spec * src0,
const npu_device_tensor_spec * src1, const npu_device_tensor_spec * dst,
npu_device_tensor_op op, boolean * is_supported) {
AEEResult npu_device_device_support_op(remote_handle64 _h, npu_device_tensor_op op, const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs, int srcsLen, boolean * is_supported) {
NPU_UNUSED(_h);

if (!src0 || !src1 || !dst || !is_supported) {
if (!srcs || srcsLen <= 0 || !dst || !is_supported) {
DEVICE_LOG_ERROR("npu_device_device_support_op: Invalid arguments");
return AEE_EINVARGS;
}

*is_supported = hexagon::support_op(*src0, *src1, *dst, op);
*is_supported = hexagon::support_op(op, dst, srcs, srcsLen);
return AEE_SUCCESS;
}

Expand Down Expand Up @@ -208,19 +207,20 @@ AEEResult npu_device_graph_set_tensor_with_param(remote_handle64 _h, npu_device_
int tensor_paramsLen) {
NPU_UNUSED(_h);
auto * graph = graph_from_handle(graph_handle);
if (!graph || !tensor_handles || tensor_handlesLen <= 0 || !tensor_params ||
tensor_handlesLen != tensor_paramsLen) {
if (!graph || tensor_handlesLen != tensor_paramsLen || tensor_handlesLen < 0) {
return AEE_EINVHANDLE;
}

graph->set_tensor(tensor_handles, tensor_handlesLen);
for (int i = 0; i < tensor_handlesLen; ++i) {
auto * tensor = tensor_from_handle(tensor_handles[i]);
if (tensor) {
tensor->update_config(tensor_params[i]);
if (tensor_params && tensor_handles) {
for (int i = 0; i < tensor_handlesLen; ++i) {
auto * tensor = tensor_from_handle(tensor_handles[i]);
if (tensor) {
tensor->update_config(tensor_params[i]);
}
}
}

graph->set_tensor(tensor_handles, tensor_handlesLen);
return AEE_SUCCESS;
}

Expand Down
29 changes: 17 additions & 12 deletions ggml/src/ggml-qnn/npu/device/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
namespace hexagon {

graph::graph() noexcept {
_vtcm_quota_size = hexagon::vtcm_mem::get_avail_block_size(); // TODO: move to device init?
DEVICE_LOG_DEBUG("graph(%p) created: vtcm quota size: %zu\n", (void *) this, _vtcm_quota_size);
DEVICE_LOG_DEBUG("graph(%p) created\n", (void *) this);
}

graph::~graph() noexcept {
Expand All @@ -20,9 +19,10 @@ graph::~graph() noexcept {
}

void graph::set_tensor(const npu_device_tensor_handle_t * tensors, int tensor_count) {
if (tensor_count <= 0) {
if (tensor_count <= 0 || !tensors) {
_tensors.reset();
_tensor_count = 0;
DEVICE_LOG_DEBUG("graph(%p) set_tensor: no tensors to set\n", (void *) this);
return;
}

Expand Down Expand Up @@ -50,21 +50,27 @@ bool graph::compute(default_thread_pool * thread_pool, const float * f16_to_f32_
DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]compute", (void *) this);
_f16_to_f32_table = f16_to_f32_table;
if (thread_pool) {
thread_pool->sync_execute(reinterpret_cast<default_thread_pool::task_type>(&graph::thread_pool_task), this);
thread_pool->sync_execute(&graph::thread_pool_task, this);
} else {
compute_impl(nullptr, 0, 1);
default_thread_pool::thread_params param = {
0, 1, nullptr, hexagon::vtcm_mem::get_avail_block_size()
}; // TODO: should have a better way to initialize thread_params

compute_impl(nullptr, &param);
}

_tensors[_tensor_count - 1]->invalidate();
_f16_to_f32_table = nullptr;
return true;
}

void graph::thread_pool_task(default_thread_pool * pool, size_t thread_idx, size_t thread_count, graph * graph) {
graph->compute_impl(pool, thread_idx, thread_count);
void graph::thread_pool_task(default_thread_pool * pool, default_thread_pool::thread_params * thread_params,
void * graph) {
reinterpret_cast<hexagon::graph *>(graph)->compute_impl(pool, thread_params);
}

void graph::compute_impl(default_thread_pool * pool, size_t thread_idx, size_t thread_count) {
hexagon::compute_params params = { thread_idx, thread_count, _vtcm_quota_size / thread_count, _f16_to_f32_table };
void graph::compute_impl(default_thread_pool * pool, default_thread_pool::thread_params * thread_params) {
hexagon::compute_params params = { thread_params, _f16_to_f32_table };

for (size_t i = 0; i < _tensor_count; ++i) {
auto * dst = _tensors[i];
Expand All @@ -78,13 +84,12 @@ void graph::compute_impl(default_thread_pool * pool, size_t thread_idx, size_t t
DEVICE_LOG_ERROR("graph(%p) tensor[%zu] op %d compute failed\n", (void *) this, i, op);
}

DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu", (void *) this, thread_idx);

const bool should_sync = requires_thread_barrier(op);
if (pool && should_sync && i < _tensor_count - 1) {
DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu, tensor[%zu/%zu]", (void *) this,
params.get_thread_index(), i, _tensor_count);
pool->sync_thread();
}
dst->invalidate();
}
}

Expand Down
6 changes: 3 additions & 3 deletions ggml/src/ggml-qnn/npu/device/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ class graph {
bool compute(default_thread_pool * thread_pool, const float * f16_to_f32_table);

private:
static void thread_pool_task(default_thread_pool * pool, size_t thread_idx, size_t thread_count, graph * graph);
void compute_impl(default_thread_pool * pool, size_t thread_idx, size_t thread_count);
static void thread_pool_task(default_thread_pool * pool, default_thread_pool::thread_params * thread_params,
void * graph);
void compute_impl(default_thread_pool * pool, default_thread_pool::thread_params * thread_params);

std::unique_ptr<tensor *[]> _tensors;
size_t _tensor_count = 0;
size_t _vtcm_quota_size = 0;
const float * _f16_to_f32_table = nullptr;

DISABLE_COPY_AND_MOVE(graph);
Expand Down
Loading