-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathbinding.cpp
More file actions
128 lines (110 loc) · 3.8 KB
/
binding.cpp
File metadata and controls
128 lines (110 loc) · 3.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#include "chatglm.h"
#include "binding.h"
#include <vector>
#include <algorithm>
/* ===== structure/class definitions. ===== */
struct GenerationConfig : public chatglm::GenerationConfig {
GenerationConfig(
int max_length,
int max_context_length,
bool do_sample,
int top_k,
float top_p,
float temperature,
float repetition_penalty,
int num_threads
): chatglm::GenerationConfig(max_length, max_context_length, do_sample, top_k, top_p,
temperature, repetition_penalty, num_threads) {}
~GenerationConfig() {}
};
struct Pipeline : public chatglm::Pipeline {
Pipeline(char* path): chatglm::Pipeline(path) {}
~Pipeline() {}
};
// CallbackStreamer is much like chatglm::TextStreamer except that it sends the
// generated text to a callback function, which is implemented in Go.
class CallbackStreamer : public chatglm::BaseStreamer {
public:
CallbackStreamer(Pipeline* pipeline, chatglm::BaseTokenizer *tokenizer)
: pipeline_(pipeline), tokenizer_(tokenizer), is_prompt_(true), print_len_(0) {}
void put(const std::vector<int> &output_ids) override;
void end() override;
private:
Pipeline* pipeline_;
chatglm::BaseTokenizer *tokenizer_;
bool is_prompt_;
std::vector<int> token_cache_;
int print_len_;
};
/* ===== function implementations. ===== */
GenerationConfig* NewGenerationConfig(
int max_length,
int max_context_length,
bool do_sample,
int top_k,
float top_p,
float temperature,
float repetition_penalty,
int num_threads
) {
return new GenerationConfig(max_length,
max_context_length,
do_sample,
top_k,
top_p,
temperature,
repetition_penalty,
num_threads);
}
void DeleteGenerationConfig(GenerationConfig* p) {
delete p;
}
Pipeline* NewPipeline(char* path) {
return new Pipeline(path);
}
void DeletePipeline(Pipeline* p) {
delete p;
}
void Pipeline_Generate(Pipeline* p, char* prompt, GenerationConfig* gen_config, char* output) {
const GenerationConfig & config = *gen_config;
auto streamer = std::make_shared<CallbackStreamer>(p, p->tokenizer.get());
std::string result = p->generate(prompt, config, streamer.get());
if (output != NULL) {
std::strcpy(output, result.c_str());
}
}
void CallbackStreamer::put(const std::vector<int> &output_ids) {
if (is_prompt_) {
// skip prompt
is_prompt_ = false;
return;
}
static const std::vector<char> puncts{',', '!', ':', ';', '?'};
token_cache_.insert(token_cache_.end(), output_ids.begin(), output_ids.end());
std::string text = tokenizer_->decode(token_cache_);
if (text.empty()) {
return;
}
std::string printable_text;
if (text.back() == '\n') {
// flush the cache after newline
printable_text = text.substr(print_len_);
token_cache_.clear();
print_len_ = 0;
} else if (std::find(puncts.begin(), puncts.end(), text.back()) != puncts.end()) {
// last symbol is a punctuation, hold on
} else if (text.size() >= 3 && text.compare(text.size() - 3, 3, "�") == 0) {
// ends with an incomplete token, hold on
} else {
printable_text = text.substr(print_len_);
print_len_ = text.size();
}
streamCallback(pipeline_, const_cast<char*>(printable_text.c_str()), 0);
}
void CallbackStreamer::end() {
std::string text = tokenizer_->decode(token_cache_);
streamCallback(pipeline_, const_cast<char*>(text.substr(print_len_).c_str()), 1);
is_prompt_ = true;
token_cache_.clear();
print_len_ = 0;
}