Skip to content

Commit af620a1

Browse files
authored
feat: flash attention support for hexagon-npu (#45)
* add flash attn op * expend src tensor size * add flash attn sources * add quantize row functions * make a separated file for vec_dot * wip * wip * refactor: rename quants.hpp includes and add vec_dot to type traits * add flash_attn impl * split vec_scale_f32 * move vec_reduction_qf32 to vec_ops * add vec_scale_f16 * opt * add vec_mad * implement vec_mad_f16 * opt * add op template * opt * add align version * enable flash attn * wip * log print improve * add profiler log * wip * wip * add multi sub proc perf tracker * increase log buffer * remove sub prov pcycle * wip * wip * add prefetch for vec_dot * wip * wip * opt f16 vec dot * opt f16 vecdot * reuse vec_dot_product_impl in vec dot f32 * small opt to unblock pipeline * opt on aligned address wip * Revert "opt on aligned address" This reverts commit 27be1eb. * add profiler log at thread_pool * wip * invalidate all... * Reapply "opt on aligned address" This reverts commit f075a4c. * add is_constant for tensor config * disable align tensor opt in mul_mat * wip * wip * vec_scale_impl: unrolling the loop * wip * wip * replace reinterpret_cast with direct pointer access for write/read buffers * add fetch * wip * wip * wip * add log * check tensor shape at flash_attn * wip * wip * fix: update tensor type handling in flash_attn_impl * wip * fix: align cache size * fix: qf16->hf * fix: swap order of elements in vector combine for correct scaling * fix: opt f16 scale and mad * fix leftover fetch * wip * load into vector pair * opt cache size calculation in flash_attn_impl * refactoring: hold vtcm at thread local object * wip * add profiler log * mark tensors as modified * restrict tensor invalidation to the first thread in compute_impl * Revert "restrict tensor invalidation to the first thread in compute_impl" This reverts commit 0a8ff2b. * invalidate last tensor in compute_impl * invalidate last tensor in compute function * wip * refactor dequantize_row_q4_0 to simplify vector alignment * wip * refactoring: move VTCM quota calculation to thread pool * wip * fix: correct condition check for HEXAGON_SDK_ROOT existence * wip * wip * wip * wip * fix: update condition checks match the naming * fix: improve tensor handling checks and logging in graph and operation implementations * wip
1 parent da5dc57 commit af620a1

27 files changed

+1860
-801
lines changed

ggml/src/ggml-qnn/npu/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ cmake_policy(SET CMP0115 OLD)
33

44
if(DEFINED ENV{HEXAGON_SDK_ROOT})
55
set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT})
6+
message("HEXAGON_SDK_ROOT (from environment): ${HEXAGON_SDK_ROOT}")
7+
elseif(DEFINED HEXAGON_SDK_ROOT)
68
message("HEXAGON_SDK_ROOT: ${HEXAGON_SDK_ROOT}")
79
else()
810
message(FATAL_ERROR "HEXAGON_SDK_ROOT not defined")

ggml/src/ggml-qnn/npu/device/device.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
#include "graph.hpp"
1010
#include "hexagon_npu.h"
1111
#include "op_impl.hpp"
12-
#include "quants.hpp"
1312
#include "remote.h"
1413
#include "tensor.hpp"
1514
#include "thread_pool.hpp"
15+
#include "type_traits.hpp"
1616
#include "util.hpp"
1717

1818
namespace {
@@ -124,21 +124,20 @@ int npu_device_close(remote_handle64 h) {
124124

125125
AEEResult npu_device_device_get_alignment(remote_handle64 _h, uint32_t * alignment) {
126126
NPU_UNUSED(_h);
127-
*alignment = sizeof(HVX_Vector);
127+
*alignment = sizeof(HVX_VectorPair);
128128
return AEE_SUCCESS;
129129
}
130130

131-
AEEResult npu_device_device_support_op(remote_handle64 _h, const npu_device_tensor_spec * src0,
132-
const npu_device_tensor_spec * src1, const npu_device_tensor_spec * dst,
133-
npu_device_tensor_op op, boolean * is_supported) {
131+
AEEResult npu_device_device_support_op(remote_handle64 _h, npu_device_tensor_op op, const npu_device_tensor_spec * dst,
132+
const npu_device_tensor_spec * srcs, int srcsLen, boolean * is_supported) {
134133
NPU_UNUSED(_h);
135134

136-
if (!src0 || !src1 || !dst || !is_supported) {
135+
if (!srcs || srcsLen <= 0 || !dst || !is_supported) {
137136
DEVICE_LOG_ERROR("npu_device_device_support_op: Invalid arguments");
138137
return AEE_EINVARGS;
139138
}
140139

141-
*is_supported = hexagon::support_op(*src0, *src1, *dst, op);
140+
*is_supported = hexagon::support_op(op, dst, srcs, srcsLen);
142141
return AEE_SUCCESS;
143142
}
144143

@@ -208,19 +207,20 @@ AEEResult npu_device_graph_set_tensor_with_param(remote_handle64 _h, npu_device_
208207
int tensor_paramsLen) {
209208
NPU_UNUSED(_h);
210209
auto * graph = graph_from_handle(graph_handle);
211-
if (!graph || !tensor_handles || tensor_handlesLen <= 0 || !tensor_params ||
212-
tensor_handlesLen != tensor_paramsLen) {
210+
if (!graph || tensor_handlesLen != tensor_paramsLen || tensor_handlesLen < 0) {
213211
return AEE_EINVHANDLE;
214212
}
215213

216-
graph->set_tensor(tensor_handles, tensor_handlesLen);
217-
for (int i = 0; i < tensor_handlesLen; ++i) {
218-
auto * tensor = tensor_from_handle(tensor_handles[i]);
219-
if (tensor) {
220-
tensor->update_config(tensor_params[i]);
214+
if (tensor_params && tensor_handles) {
215+
for (int i = 0; i < tensor_handlesLen; ++i) {
216+
auto * tensor = tensor_from_handle(tensor_handles[i]);
217+
if (tensor) {
218+
tensor->update_config(tensor_params[i]);
219+
}
221220
}
222221
}
223222

223+
graph->set_tensor(tensor_handles, tensor_handlesLen);
224224
return AEE_SUCCESS;
225225
}
226226

ggml/src/ggml-qnn/npu/device/graph.cpp

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
namespace hexagon {
1111

1212
graph::graph() noexcept {
13-
_vtcm_quota_size = hexagon::vtcm_mem::get_avail_block_size(); // TODO: move to device init?
14-
DEVICE_LOG_DEBUG("graph(%p) created: vtcm quota size: %zu\n", (void *) this, _vtcm_quota_size);
13+
DEVICE_LOG_DEBUG("graph(%p) created\n", (void *) this);
1514
}
1615

1716
graph::~graph() noexcept {
@@ -20,9 +19,10 @@ graph::~graph() noexcept {
2019
}
2120

2221
void graph::set_tensor(const npu_device_tensor_handle_t * tensors, int tensor_count) {
23-
if (tensor_count <= 0) {
22+
if (tensor_count <= 0 || !tensors) {
2423
_tensors.reset();
2524
_tensor_count = 0;
25+
DEVICE_LOG_DEBUG("graph(%p) set_tensor: no tensors to set\n", (void *) this);
2626
return;
2727
}
2828

@@ -50,21 +50,27 @@ bool graph::compute(default_thread_pool * thread_pool, const float * f16_to_f32_
5050
DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]compute", (void *) this);
5151
_f16_to_f32_table = f16_to_f32_table;
5252
if (thread_pool) {
53-
thread_pool->sync_execute(reinterpret_cast<default_thread_pool::task_type>(&graph::thread_pool_task), this);
53+
thread_pool->sync_execute(&graph::thread_pool_task, this);
5454
} else {
55-
compute_impl(nullptr, 0, 1);
55+
default_thread_pool::thread_params param = {
56+
0, 1, nullptr, hexagon::vtcm_mem::get_avail_block_size()
57+
}; // TODO: should have a better way to initialize thread_params
58+
59+
compute_impl(nullptr, &param);
5660
}
5761

62+
_tensors[_tensor_count - 1]->invalidate();
5863
_f16_to_f32_table = nullptr;
5964
return true;
6065
}
6166

62-
void graph::thread_pool_task(default_thread_pool * pool, size_t thread_idx, size_t thread_count, graph * graph) {
63-
graph->compute_impl(pool, thread_idx, thread_count);
67+
void graph::thread_pool_task(default_thread_pool * pool, default_thread_pool::thread_params * thread_params,
68+
void * graph) {
69+
reinterpret_cast<hexagon::graph *>(graph)->compute_impl(pool, thread_params);
6470
}
6571

66-
void graph::compute_impl(default_thread_pool * pool, size_t thread_idx, size_t thread_count) {
67-
hexagon::compute_params params = { thread_idx, thread_count, _vtcm_quota_size / thread_count, _f16_to_f32_table };
72+
void graph::compute_impl(default_thread_pool * pool, default_thread_pool::thread_params * thread_params) {
73+
hexagon::compute_params params = { thread_params, _f16_to_f32_table };
6874

6975
for (size_t i = 0; i < _tensor_count; ++i) {
7076
auto * dst = _tensors[i];
@@ -78,13 +84,12 @@ void graph::compute_impl(default_thread_pool * pool, size_t thread_idx, size_t t
7884
DEVICE_LOG_ERROR("graph(%p) tensor[%zu] op %d compute failed\n", (void *) this, i, op);
7985
}
8086

81-
DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu", (void *) this, thread_idx);
82-
8387
const bool should_sync = requires_thread_barrier(op);
8488
if (pool && should_sync && i < _tensor_count - 1) {
89+
DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu, tensor[%zu/%zu]", (void *) this,
90+
params.get_thread_index(), i, _tensor_count);
8591
pool->sync_thread();
8692
}
87-
dst->invalidate();
8893
}
8994
}
9095

ggml/src/ggml-qnn/npu/device/graph.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ class graph {
2020
bool compute(default_thread_pool * thread_pool, const float * f16_to_f32_table);
2121

2222
private:
23-
static void thread_pool_task(default_thread_pool * pool, size_t thread_idx, size_t thread_count, graph * graph);
24-
void compute_impl(default_thread_pool * pool, size_t thread_idx, size_t thread_count);
23+
static void thread_pool_task(default_thread_pool * pool, default_thread_pool::thread_params * thread_params,
24+
void * graph);
25+
void compute_impl(default_thread_pool * pool, default_thread_pool::thread_params * thread_params);
2526

2627
std::unique_ptr<tensor *[]> _tensors;
2728
size_t _tensor_count = 0;
28-
size_t _vtcm_quota_size = 0;
2929
const float * _f16_to_f32_table = nullptr;
3030

3131
DISABLE_COPY_AND_MOVE(graph);

0 commit comments

Comments
 (0)