10
10
namespace hexagon {
11
11
12
12
graph::graph () noexcept {
13
- _vtcm_quota_size = hexagon::vtcm_mem::get_avail_block_size (); // TODO: move to device init?
14
- DEVICE_LOG_DEBUG (" graph(%p) created: vtcm quota size: %zu\n " , (void *) this , _vtcm_quota_size);
13
+ DEVICE_LOG_DEBUG (" graph(%p) created\n " , (void *) this );
15
14
}
16
15
17
16
graph::~graph () noexcept {
@@ -20,9 +19,10 @@ graph::~graph() noexcept {
20
19
}
21
20
22
21
void graph::set_tensor (const npu_device_tensor_handle_t * tensors, int tensor_count) {
23
- if (tensor_count <= 0 ) {
22
+ if (tensor_count <= 0 || !tensors ) {
24
23
_tensors.reset ();
25
24
_tensor_count = 0 ;
25
+ DEVICE_LOG_DEBUG (" graph(%p) set_tensor: no tensors to set\n " , (void *) this );
26
26
return ;
27
27
}
28
28
@@ -50,21 +50,27 @@ bool graph::compute(default_thread_pool * thread_pool, const float * f16_to_f32_
50
50
DEVICE_SCOPED_PERFORMANCE_TRACKER (" [%p]compute" , (void *) this );
51
51
_f16_to_f32_table = f16_to_f32_table;
52
52
if (thread_pool) {
53
- thread_pool->sync_execute (reinterpret_cast <default_thread_pool::task_type>( &graph::thread_pool_task) , this );
53
+ thread_pool->sync_execute (&graph::thread_pool_task, this );
54
54
} else {
55
- compute_impl (nullptr , 0 , 1 );
55
+ default_thread_pool::thread_params param = {
56
+ 0 , 1 , nullptr , hexagon::vtcm_mem::get_avail_block_size ()
57
+ }; // TODO: should have a better way to initialize thread_params
58
+
59
+ compute_impl (nullptr , ¶m);
56
60
}
57
61
62
+ _tensors[_tensor_count - 1 ]->invalidate ();
58
63
_f16_to_f32_table = nullptr ;
59
64
return true ;
60
65
}
61
66
62
- void graph::thread_pool_task (default_thread_pool * pool, size_t thread_idx, size_t thread_count, graph * graph) {
63
- graph->compute_impl (pool, thread_idx, thread_count);
67
+ void graph::thread_pool_task (default_thread_pool * pool, default_thread_pool::thread_params * thread_params,
68
+ void * graph) {
69
+ reinterpret_cast <hexagon::graph *>(graph)->compute_impl (pool, thread_params);
64
70
}
65
71
66
- void graph::compute_impl (default_thread_pool * pool, size_t thread_idx, size_t thread_count ) {
67
- hexagon::compute_params params = { thread_idx, thread_count, _vtcm_quota_size / thread_count , _f16_to_f32_table };
72
+ void graph::compute_impl (default_thread_pool * pool, default_thread_pool::thread_params * thread_params ) {
73
+ hexagon::compute_params params = { thread_params , _f16_to_f32_table };
68
74
69
75
for (size_t i = 0 ; i < _tensor_count; ++i) {
70
76
auto * dst = _tensors[i];
@@ -78,13 +84,12 @@ void graph::compute_impl(default_thread_pool * pool, size_t thread_idx, size_t t
78
84
DEVICE_LOG_ERROR (" graph(%p) tensor[%zu] op %d compute failed\n " , (void *) this , i, op);
79
85
}
80
86
81
- DEVICE_SCOPED_PERFORMANCE_TRACKER (" [%p]sync_thread, tidx: %zu" , (void *) this , thread_idx);
82
-
83
87
const bool should_sync = requires_thread_barrier (op);
84
88
if (pool && should_sync && i < _tensor_count - 1 ) {
89
+ DEVICE_SCOPED_PERFORMANCE_TRACKER (" [%p]sync_thread, tidx: %zu, tensor[%zu/%zu]" , (void *) this ,
90
+ params.get_thread_index (), i, _tensor_count);
85
91
pool->sync_thread ();
86
92
}
87
- dst->invalidate ();
88
93
}
89
94
}
90
95
0 commit comments