Skip to content

Commit bd92aad

Browse files
committed
WIP: tokenizer.FromFile
1 parent 425d38d commit bd92aad

File tree

6 files changed

+501
-18
lines changed

6 files changed

+501
-18
lines changed

config.go

+38-18
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
package tokenizer
22

33
import (
4-
// "encoding/json"
4+
"encoding/json"
5+
"fmt"
6+
"os"
57
)
68

79
// Config construct configuration for creating Tokenizer.
810
type Config struct {
9-
Version string `json:"version"`
10-
Truncation interface{} `json:"truncation"`
11-
Padding interface{} `json:"padding"`
12-
AddedTokens []TokenConfig `json:"added_tokens"`
13-
Normalizer NormalizerConfig `json:"normalizer"`
14-
PreTokenizer PreTokenizerConfig `json:"pre_tokenizer"`
15-
PostProcessor PostProcessorConfig `json:"post_processor"`
16-
Decoder DecoderConfig `json:"decoder"`
17-
Model ModelConfig `json:"model"`
11+
Version string `json:"version"`
12+
Truncation map[string]interface{} `json:"truncation"`
13+
Padding map[string]interface{} `json:"padding"`
14+
AddedTokens []TokenConfig `json:"added_tokens"`
15+
Normalizer NormalizerConfig `json:"normalizer"`
16+
PreTokenizer PreTokenizerConfig `json:"pre_tokenizer"`
17+
PostProcessor PostProcessorConfig `json:"post_processor"`
18+
Decoder DecoderConfig `json:"decoder"`
19+
Model ModelConfig `json:"model"`
1820
}
1921

2022
type TokenConfig struct {
@@ -28,20 +30,20 @@ type TokenConfig struct {
2830
}
2931

3032
type NormalizerConfig struct {
31-
Type string `json:"type"`
32-
Normalizers map[string]interface{} `json:"normalizers"`
33+
Type string `json:"type"`
34+
Normalizers []map[string]interface{} `json:"normalizers"`
3335
}
3436
type PreTokenizerConfig struct{}
3537
type PostProcessorConfig struct {
36-
Type string `json:"type"`
37-
Single map[string]interface{} `json:"single"`
38-
Pair map[string]interface{} `json:"pair"`
39-
SpecialTokens map[string]interface{} `json:"speical_tokens"`
38+
Type string `json:"type"`
39+
Single []map[string]interface{} `json:"single"`
40+
Pair []map[string]interface{} `json:"pair"`
41+
SpecialTokens map[string]interface{} `json:"speical_tokens"`
4042
}
4143

4244
type DecoderConfig struct {
43-
Type string `json:"type"`
44-
Decoders map[string]interface{} `json:"decoders"`
45+
Type string `json:"type"`
46+
Decoders []map[string]interface{} `json:"decoders"`
4547
}
4648

4749
type ModelConfig struct {
@@ -55,3 +57,21 @@ type ModelConfig struct {
5557
Vocab map[string]int `json:"vocab"`
5658
Merges []string `json:"merges"`
5759
}
60+
61+
// ConfigFromFile loads config from file.
62+
func ConfigFromFile(file string) (*Config, error) {
63+
f, err := os.Open(file)
64+
if err != nil {
65+
return nil, err
66+
}
67+
68+
dec := json.NewDecoder(f)
69+
70+
var config *Config
71+
err = dec.Decode(&config)
72+
if err != nil {
73+
return nil, err
74+
}
75+
76+
return config, nil
77+
}

config_test.go

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package tokenizer
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"os"
7+
)
8+
9+
func ExampleConfig() {
10+
tokFile, err := CachedPath("hf-internal-testing/llama-tokenizer", "tokenizer.json")
11+
if err != nil {
12+
panic(err)
13+
}
14+
15+
f, err := os.Open(tokFile)
16+
if err != nil {
17+
panic(err)
18+
}
19+
20+
dec := json.NewDecoder(f)
21+
22+
var config *Config
23+
24+
err = dec.Decode(&config)
25+
if err != nil {
26+
panic(err)
27+
}
28+
29+
modelType := config.Model.Type
30+
fmt.Println(modelType)
31+
32+
// Output:
33+
// BPE
34+
}

file-util.go

+246
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
package tokenizer
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"log"
7+
"net/http"
8+
"os"
9+
"path"
10+
"strconv"
11+
"strings"
12+
)
13+
14+
// This file provides functions to work with local dataset cache, ...
15+
// Copied from `transformer` package
16+
17+
const (
18+
WeightName = "pytorch_model.gt"
19+
ConfigName = "config.json"
20+
TokenizerName = "tokenizer.json"
21+
22+
// NOTE. URL form := `$HFpath/ModelName/resolve/main/WeightName`
23+
HFpath = "https://huggingface.co"
24+
)
25+
26+
var (
27+
DUMMY_INPUT [][]int64 = [][]int64{
28+
{7, 6, 0, 0, 1},
29+
{1, 2, 3, 0, 0},
30+
{0, 0, 0, 4, 5},
31+
}
32+
)
33+
34+
// CachedPath resolves and caches data based on input string, then returns fullpath to the cached data.
35+
//
36+
// Parameters:
37+
// - `modelNameOrPath`: model name e.g., "bert-base-uncase" or path to directory contains model/config files.
38+
// - `fileName`: model or config file name. E.g., "pytorch_model.py", "config.json"
39+
//
40+
// CachedPath does several things consequently:
41+
// 1. Resolves input string to a fullpath cached filename candidate.
42+
// 2. Check it at `CachedPath`, if exists, then return the candidate. If not
43+
// 3. Retrieves and Caches data to `CachedPath` and returns path to cached data
44+
//
45+
// NOTE. default `CachedDir` is at "{$HOME}/.cache/transformer"
46+
// Custom `CachedDir` can be changed by setting with environment `GO_TRANSFORMER`
47+
func CachedPath(modelNameOrPath, fileName string) (resolvedPath string, err error) {
48+
49+
// Resolves to "candidate" filename at `CacheDir`
50+
cachedFileCandidate := fmt.Sprintf("%s/%s/%s", CachedDir, modelNameOrPath, fileName)
51+
52+
// 1. Cached candidate file exists
53+
if _, err := os.Stat(cachedFileCandidate); err == nil {
54+
return cachedFileCandidate, nil
55+
}
56+
57+
// 2. If valid fullpath to local file, caches it and return cached filename
58+
filepath := fmt.Sprintf("%s/%s", modelNameOrPath, fileName)
59+
if _, err := os.Stat(filepath); err == nil {
60+
err := copyFile(filepath, cachedFileCandidate)
61+
if err != nil {
62+
err := fmt.Errorf("CachedPath() failed at copying file: %w", err)
63+
return "", err
64+
}
65+
return cachedFileCandidate, nil
66+
}
67+
68+
// 3. Cached candidate file NOT exist. Try to download it and save to `CachedDir`
69+
url := fmt.Sprintf("%s/%s/resolve/main/%s", HFpath, modelNameOrPath, fileName)
70+
// url := fmt.Sprintf("%s/%s/raw/main/%s", HFpath, modelNameOrPath, fileName)
71+
if isValidURL(url) {
72+
if _, err := http.Get(url); err == nil {
73+
err := downloadFile(url, cachedFileCandidate)
74+
if err != nil {
75+
err = fmt.Errorf("CachedPath() failed at trying to download file: %w", err)
76+
return "", err
77+
}
78+
79+
return cachedFileCandidate, nil
80+
} else {
81+
err = fmt.Errorf("CachedPath() failed: Unable to parse '%v' as a URL or as a local path.\n", url)
82+
return "", err
83+
}
84+
}
85+
86+
// Not resolves
87+
err = fmt.Errorf("CachedPath() failed: Unable to parse '%v' as a URL or as a local path.\n", url)
88+
return "", err
89+
}
90+
91+
func isValidURL(url string) bool {
92+
93+
// TODO: implement
94+
return true
95+
}
96+
97+
// downloadFile downloads file from URL and stores it in local filepath.
98+
// It writes to the destination file as it downloads it, without loading
99+
// the entire file into memory. An `io.TeeReader` is passed into Copy()
100+
// to report progress on the download.
101+
func downloadFile(url string, filepath string) error {
102+
// Create path if not existing
103+
dir := path.Dir(filepath)
104+
filename := path.Base(filepath)
105+
if _, err := os.Stat(dir); os.IsNotExist(err) {
106+
if err := os.MkdirAll(dir, 0755); err != nil {
107+
log.Fatal(err)
108+
}
109+
}
110+
111+
// Create the file with .tmp extension, so that we won't overwrite a
112+
// file until it's downloaded fully
113+
out, err := os.Create(filepath + ".tmp")
114+
if err != nil {
115+
return err
116+
}
117+
defer out.Close()
118+
119+
// Get the data
120+
resp, err := http.Get(url)
121+
if err != nil {
122+
return err
123+
}
124+
defer resp.Body.Close()
125+
126+
// Check server response
127+
if resp.StatusCode != http.StatusOK {
128+
err := fmt.Errorf("bad status: %s(%v)", resp.Status, resp.StatusCode)
129+
130+
if resp.StatusCode == 404 {
131+
// if filename == "rust_model.ot" {
132+
// msg := fmt.Sprintf("model weight file not found. That means a compatible pretrained model weight file for Go is not available.\n")
133+
// msg = msg + fmt.Sprintf("You might need to manually convert a 'pytorch_model.bin' for Go. ")
134+
// msg = msg + fmt.Sprintf("See tutorial at: 'example/convert'")
135+
// err = fmt.Errorf(msg)
136+
// } else {
137+
// err = fmt.Errorf("download file not found: %q for downloading", url)
138+
// }
139+
err = fmt.Errorf("download file not found: %q for downloading", url)
140+
} else {
141+
err = fmt.Errorf("download file failed: %q", url)
142+
}
143+
return err
144+
}
145+
146+
// the total file size to download
147+
size, _ := strconv.Atoi(resp.Header.Get("Content-Length"))
148+
downloadSize := uint64(size)
149+
150+
// Create our bytes counter and pass it to be used alongside our writer
151+
counter := &writeCounter{FileSize: downloadSize}
152+
_, err = io.Copy(out, io.TeeReader(resp.Body, counter))
153+
if err != nil {
154+
return err
155+
}
156+
157+
fmt.Printf("\r%s... %s/%s completed", filename, byteCountIEC(counter.Total), byteCountIEC(counter.FileSize))
158+
// The progress use the same line so print a new line once it's finished downloading
159+
fmt.Println()
160+
161+
// Rename the tmp file back to the original file
162+
err = os.Rename(filepath+".tmp", filepath)
163+
if err != nil {
164+
return err
165+
}
166+
167+
return nil
168+
}
169+
170+
// writeCounter counts the number of bytes written to it. By implementing the Write method,
171+
// it is of the io.Writer interface and we can pass this into io.TeeReader()
172+
// Every write to this writer, will print the progress of the file write.
173+
type writeCounter struct {
174+
Total uint64
175+
FileSize uint64
176+
}
177+
178+
func (wc *writeCounter) Write(p []byte) (int, error) {
179+
n := len(p)
180+
wc.Total += uint64(n)
181+
wc.printProgress()
182+
return n, nil
183+
}
184+
185+
// PrintProgress prints the progress of a file write
186+
func (wc writeCounter) printProgress() {
187+
// Clear the line by using a character return to go back to the start and remove
188+
// the remaining characters by filling it with spaces
189+
fmt.Printf("\r%s", strings.Repeat(" ", 50))
190+
191+
// Return again and print current status of download
192+
fmt.Printf("\rDownloading... %s/%s", byteCountIEC(wc.Total), byteCountIEC(wc.FileSize))
193+
}
194+
195+
// byteCountIEC converts bytes to human-readable string in binary (IEC) format.
196+
func byteCountIEC(b uint64) string {
197+
const unit = 1024
198+
if b < unit {
199+
return fmt.Sprintf("%d B", b)
200+
}
201+
div, exp := uint64(unit), 0
202+
for n := b / unit; n >= unit; n /= unit {
203+
div *= unit
204+
exp++
205+
}
206+
return fmt.Sprintf("%.1f %ciB",
207+
float64(b)/float64(div), "KMGTPE"[exp])
208+
}
209+
210+
func copyFile(src, dst string) error {
211+
sourceFileStat, err := os.Stat(src)
212+
if err != nil {
213+
return err
214+
}
215+
216+
if !sourceFileStat.Mode().IsRegular() {
217+
return fmt.Errorf("%s is not a regular file", src)
218+
}
219+
220+
source, err := os.Open(src)
221+
if err != nil {
222+
return err
223+
}
224+
defer source.Close()
225+
226+
destination, err := os.Create(dst)
227+
if err != nil {
228+
return err
229+
}
230+
defer destination.Close()
231+
_, err = io.Copy(destination, source)
232+
return err
233+
}
234+
235+
// CleanCache removes all files cached in transformer cache directory `CachedDir`.
236+
//
237+
// NOTE. custom `CachedDir` can be changed by setting environment `GO_TRANSFORMER`
238+
func CleanCache() error {
239+
err := os.RemoveAll(CachedDir)
240+
if err != nil {
241+
err = fmt.Errorf("CleanCache() failed: %w", err)
242+
return err
243+
}
244+
245+
return nil
246+
}

0 commit comments

Comments
 (0)