@@ -86,7 +86,7 @@ struct llama_model {
86
86
};
87
87
88
88
// load the model's weights from a file
89
- bool llama_model_load (const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx) {
89
+ bool llama_model_load (const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, ggml_type memory_type = GGML_TYPE_F32 ) {
90
90
fprintf (stderr, " %s: loading model from '%s' - please wait ...\n " , __func__, fname.c_str ());
91
91
92
92
std::vector<char > f_buf (1024 *1024 );
@@ -207,8 +207,8 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
207
207
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef (wtype)); // w2
208
208
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef (wtype)); // w3
209
209
210
- ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef (GGML_TYPE_F16 ); // memory_k
211
- ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef (GGML_TYPE_F16 ); // memory_v
210
+ ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef (memory_type ); // memory_k
211
+ ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef (memory_type ); // memory_v
212
212
213
213
ctx_size += (5 + 10 *n_layer)*256 ; // object overhead
214
214
@@ -293,8 +293,8 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
293
293
const int n_mem = n_layer*n_ctx;
294
294
const int n_elements = n_embd*n_mem;
295
295
296
- model.memory_k = ggml_new_tensor_1d (ctx, GGML_TYPE_F16 , n_elements);
297
- model.memory_v = ggml_new_tensor_1d (ctx, GGML_TYPE_F16 , n_elements);
296
+ model.memory_k = ggml_new_tensor_1d (ctx, memory_type , n_elements);
297
+ model.memory_v = ggml_new_tensor_1d (ctx, memory_type , n_elements);
298
298
299
299
const size_t memory_size = ggml_nbytes (model.memory_k ) + ggml_nbytes (model.memory_v );
300
300
@@ -814,8 +814,9 @@ int main(int argc, char ** argv) {
814
814
815
815
// load the model
816
816
{
817
+ const ggml_type memory_type = params.memory_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
817
818
const int64_t t_start_us = ggml_time_us ();
818
- if (!llama_model_load (params.model , model, vocab, params.n_ctx )) {
819
+ if (!llama_model_load (params.model , model, vocab, params.n_ctx , memory_type )) {
819
820
fprintf (stderr, " %s: failed to load model from '%s'\n " , __func__, params.model .c_str ());
820
821
return 1 ;
821
822
}
0 commit comments