Skip to content

Commit c51c64a

Browse files
data-angelggerganov
authored andcommitted
main : make reverse prompt option act as a stop token in non-interactive mode (ggml-org#1032)
* Make reverse prompt option act as a stop token in non-interactive scenarios * Making requested review changes * Update gpt_params_parse and fix a merge error * Revert "Update gpt_params_parse and fix a merge error" This reverts commit 2bb2ff1. * Update gpt_params_parse and fix a merge error take 2
1 parent 75c017f commit c51c64a

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

examples/common.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
351351
}
352352
if (params.prompt_cache_all &&
353353
(params.interactive || params.interactive_first ||
354-
params.instruct || params.antiprompt.size())) {
354+
params.instruct)) {
355355
fprintf(stderr, "error: --prompt-cache-all not supported in interactive mode yet\n");
356356
gpt_print_usage(argc, argv, default_params);
357357
exit(1);
@@ -373,8 +373,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
373373
fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
374374
fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n");
375375
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
376-
fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n");
377-
fprintf(stderr, " specified more than once for multiple prompts).\n");
376+
fprintf(stderr, " halt generation at PROMPT, return control in interactive mode\n");
377+
fprintf(stderr, " (can be specified more than once for multiple prompts).\n");
378378
fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n");
379379
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
380380
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);

examples/main/main.cpp

+18-8
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,8 @@ int main(int argc, char ** argv) {
208208
params.antiprompt.push_back("### Instruction:\n\n");
209209
}
210210

211-
// enable interactive mode if reverse prompt or interactive start is specified
212-
if (params.antiprompt.size() != 0 || params.interactive_first) {
211+
// enable interactive mode if interactive start is specified
212+
if (params.interactive_first) {
213213
params.interactive = true;
214214
}
215215

@@ -305,7 +305,7 @@ int main(int argc, char ** argv) {
305305

306306
std::vector<llama_token> embd;
307307

308-
while (n_remain != 0 || params.interactive) {
308+
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
309309
// predict
310310
if (embd.size() > 0) {
311311
// infinite text generation via context swapping
@@ -503,9 +503,8 @@ int main(int argc, char ** argv) {
503503
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
504504
}
505505

506-
// in interactive mode, and not currently processing queued inputs;
507-
// check if we should prompt the user for more
508-
if (params.interactive && (int) embd_inp.size() <= n_consumed) {
506+
// if not currently processing queued inputs;
507+
if ((int) embd_inp.size() <= n_consumed) {
509508

510509
// check for reverse prompt
511510
if (params.antiprompt.size()) {
@@ -516,10 +515,21 @@ int main(int argc, char ** argv) {
516515

517516
is_antiprompt = false;
518517
// Check if each of the reverse prompts appears at the end of the output.
518+
// If we're not running interactively, the reverse prompt might be tokenized with some following characters
519+
// so we'll compensate for that by widening the search window a bit.
519520
for (std::string & antiprompt : params.antiprompt) {
520-
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
521-
is_interacting = true;
521+
size_t extra_padding = params.interactive ? 0 : 2;
522+
size_t search_start_pos = last_output.length() > static_cast<size_t>(antiprompt.length() + extra_padding)
523+
? last_output.length() - static_cast<size_t>(antiprompt.length() + extra_padding)
524+
: 0;
525+
526+
if (last_output.find(antiprompt.c_str(), search_start_pos) != std::string::npos) {
527+
if (params.interactive) {
528+
is_interacting = true;
529+
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
530+
}
522531
is_antiprompt = true;
532+
fflush(stdout);
523533
break;
524534
}
525535
}

0 commit comments

Comments
 (0)