Skip to content

Commit ae178ab

Browse files
authored
llama : make tensor_split ptr instead of array (#2272)
1 parent 54e3bc7 commit ae178ab

File tree

4 files changed

+8
-4
lines changed

4 files changed

+8
-4
lines changed

examples/common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
586586
lparams.n_batch = params.n_batch;
587587
lparams.n_gpu_layers = params.n_gpu_layers;
588588
lparams.main_gpu = params.main_gpu;
589-
memcpy(lparams.tensor_split, params.tensor_split, LLAMA_MAX_DEVICES*sizeof(float));
589+
lparams.tensor_split = params.tensor_split;
590590
lparams.low_vram = params.low_vram;
591591
lparams.seed = params.seed;
592592
lparams.f16_kv = params.memory_f16;

ggml-cuda.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2512,6 +2512,9 @@ void ggml_init_cublas() {
25122512
}
25132513

25142514
void ggml_cuda_set_tensor_split(const float * tensor_split) {
2515+
if (tensor_split == nullptr) {
2516+
return;
2517+
}
25152518
bool all_zero = true;
25162519
for (int i = 0; i < g_device_count; ++i) {
25172520
if (tensor_split[i] != 0.0f) {

llama.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,7 @@ struct llama_context_params llama_context_default_params() {
849849
/*.n_batch =*/ 512,
850850
/*.gpu_layers =*/ 0,
851851
/*.main_gpu =*/ 0,
852-
/*.tensor_split =*/ {0},
852+
/*.tensor_split =*/ nullptr,
853853
/*.rope_freq_base =*/ 10000.0f,
854854
/*.rope_freq_scale =*/ 1.0f,
855855
/*.progress_callback =*/ nullptr,
@@ -1289,7 +1289,7 @@ static bool llama_model_load(
12891289
int n_batch,
12901290
int n_gpu_layers,
12911291
int main_gpu,
1292-
float * tensor_split,
1292+
const float * tensor_split,
12931293
float rope_freq_base,
12941294
float rope_freq_scale,
12951295
bool low_vram,

llama.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ extern "C" {
8888
int32_t n_batch; // prompt processing batch size
8989
int32_t n_gpu_layers; // number of layers to store in VRAM
9090
int32_t main_gpu; // the GPU that is used for scratch and small tensors
91-
float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs
91+
92+
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
9293

9394
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
9495
float rope_freq_base; // RoPE base frequency

0 commit comments

Comments
 (0)