|
| 1 | +package tokenizer |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + "io" |
| 6 | + "log" |
| 7 | + "net/http" |
| 8 | + "os" |
| 9 | + "path" |
| 10 | + "strconv" |
| 11 | + "strings" |
| 12 | +) |
| 13 | + |
| 14 | +// This file provides functions to work with local dataset cache, ... |
| 15 | +// Copied from `transformer` package |
| 16 | + |
| 17 | +const ( |
| 18 | + WeightName = "pytorch_model.gt" |
| 19 | + ConfigName = "config.json" |
| 20 | + TokenizerName = "tokenizer.json" |
| 21 | + |
| 22 | + // NOTE. URL form := `$HFpath/ModelName/resolve/main/WeightName` |
| 23 | + HFpath = "https://huggingface.co" |
| 24 | +) |
| 25 | + |
| 26 | +var ( |
| 27 | + DUMMY_INPUT [][]int64 = [][]int64{ |
| 28 | + {7, 6, 0, 0, 1}, |
| 29 | + {1, 2, 3, 0, 0}, |
| 30 | + {0, 0, 0, 4, 5}, |
| 31 | + } |
| 32 | +) |
| 33 | + |
| 34 | +// CachedPath resolves and caches data based on input string, then returns fullpath to the cached data. |
| 35 | +// |
| 36 | +// Parameters: |
| 37 | +// - `modelNameOrPath`: model name e.g., "bert-base-uncase" or path to directory contains model/config files. |
| 38 | +// - `fileName`: model or config file name. E.g., "pytorch_model.py", "config.json" |
| 39 | +// |
| 40 | +// CachedPath does several things consequently: |
| 41 | +// 1. Resolves input string to a fullpath cached filename candidate. |
| 42 | +// 2. Check it at `CachedPath`, if exists, then return the candidate. If not |
| 43 | +// 3. Retrieves and Caches data to `CachedPath` and returns path to cached data |
| 44 | +// |
| 45 | +// NOTE. default `CachedDir` is at "{$HOME}/.cache/transformer" |
| 46 | +// Custom `CachedDir` can be changed by setting with environment `GO_TRANSFORMER` |
| 47 | +func CachedPath(modelNameOrPath, fileName string) (resolvedPath string, err error) { |
| 48 | + |
| 49 | + // Resolves to "candidate" filename at `CacheDir` |
| 50 | + cachedFileCandidate := fmt.Sprintf("%s/%s/%s", CachedDir, modelNameOrPath, fileName) |
| 51 | + |
| 52 | + // 1. Cached candidate file exists |
| 53 | + if _, err := os.Stat(cachedFileCandidate); err == nil { |
| 54 | + return cachedFileCandidate, nil |
| 55 | + } |
| 56 | + |
| 57 | + // 2. If valid fullpath to local file, caches it and return cached filename |
| 58 | + filepath := fmt.Sprintf("%s/%s", modelNameOrPath, fileName) |
| 59 | + if _, err := os.Stat(filepath); err == nil { |
| 60 | + err := copyFile(filepath, cachedFileCandidate) |
| 61 | + if err != nil { |
| 62 | + err := fmt.Errorf("CachedPath() failed at copying file: %w", err) |
| 63 | + return "", err |
| 64 | + } |
| 65 | + return cachedFileCandidate, nil |
| 66 | + } |
| 67 | + |
| 68 | + // 3. Cached candidate file NOT exist. Try to download it and save to `CachedDir` |
| 69 | + url := fmt.Sprintf("%s/%s/resolve/main/%s", HFpath, modelNameOrPath, fileName) |
| 70 | + // url := fmt.Sprintf("%s/%s/raw/main/%s", HFpath, modelNameOrPath, fileName) |
| 71 | + if isValidURL(url) { |
| 72 | + if _, err := http.Get(url); err == nil { |
| 73 | + err := downloadFile(url, cachedFileCandidate) |
| 74 | + if err != nil { |
| 75 | + err = fmt.Errorf("CachedPath() failed at trying to download file: %w", err) |
| 76 | + return "", err |
| 77 | + } |
| 78 | + |
| 79 | + return cachedFileCandidate, nil |
| 80 | + } else { |
| 81 | + err = fmt.Errorf("CachedPath() failed: Unable to parse '%v' as a URL or as a local path.\n", url) |
| 82 | + return "", err |
| 83 | + } |
| 84 | + } |
| 85 | + |
| 86 | + // Not resolves |
| 87 | + err = fmt.Errorf("CachedPath() failed: Unable to parse '%v' as a URL or as a local path.\n", url) |
| 88 | + return "", err |
| 89 | +} |
| 90 | + |
| 91 | +func isValidURL(url string) bool { |
| 92 | + |
| 93 | + // TODO: implement |
| 94 | + return true |
| 95 | +} |
| 96 | + |
| 97 | +// downloadFile downloads file from URL and stores it in local filepath. |
| 98 | +// It writes to the destination file as it downloads it, without loading |
| 99 | +// the entire file into memory. An `io.TeeReader` is passed into Copy() |
| 100 | +// to report progress on the download. |
| 101 | +func downloadFile(url string, filepath string) error { |
| 102 | + // Create path if not existing |
| 103 | + dir := path.Dir(filepath) |
| 104 | + filename := path.Base(filepath) |
| 105 | + if _, err := os.Stat(dir); os.IsNotExist(err) { |
| 106 | + if err := os.MkdirAll(dir, 0755); err != nil { |
| 107 | + log.Fatal(err) |
| 108 | + } |
| 109 | + } |
| 110 | + |
| 111 | + // Create the file with .tmp extension, so that we won't overwrite a |
| 112 | + // file until it's downloaded fully |
| 113 | + out, err := os.Create(filepath + ".tmp") |
| 114 | + if err != nil { |
| 115 | + return err |
| 116 | + } |
| 117 | + defer out.Close() |
| 118 | + |
| 119 | + // Get the data |
| 120 | + resp, err := http.Get(url) |
| 121 | + if err != nil { |
| 122 | + return err |
| 123 | + } |
| 124 | + defer resp.Body.Close() |
| 125 | + |
| 126 | + // Check server response |
| 127 | + if resp.StatusCode != http.StatusOK { |
| 128 | + err := fmt.Errorf("bad status: %s(%v)", resp.Status, resp.StatusCode) |
| 129 | + |
| 130 | + if resp.StatusCode == 404 { |
| 131 | + // if filename == "rust_model.ot" { |
| 132 | + // msg := fmt.Sprintf("model weight file not found. That means a compatible pretrained model weight file for Go is not available.\n") |
| 133 | + // msg = msg + fmt.Sprintf("You might need to manually convert a 'pytorch_model.bin' for Go. ") |
| 134 | + // msg = msg + fmt.Sprintf("See tutorial at: 'example/convert'") |
| 135 | + // err = fmt.Errorf(msg) |
| 136 | + // } else { |
| 137 | + // err = fmt.Errorf("download file not found: %q for downloading", url) |
| 138 | + // } |
| 139 | + err = fmt.Errorf("download file not found: %q for downloading", url) |
| 140 | + } else { |
| 141 | + err = fmt.Errorf("download file failed: %q", url) |
| 142 | + } |
| 143 | + return err |
| 144 | + } |
| 145 | + |
| 146 | + // the total file size to download |
| 147 | + size, _ := strconv.Atoi(resp.Header.Get("Content-Length")) |
| 148 | + downloadSize := uint64(size) |
| 149 | + |
| 150 | + // Create our bytes counter and pass it to be used alongside our writer |
| 151 | + counter := &writeCounter{FileSize: downloadSize} |
| 152 | + _, err = io.Copy(out, io.TeeReader(resp.Body, counter)) |
| 153 | + if err != nil { |
| 154 | + return err |
| 155 | + } |
| 156 | + |
| 157 | + fmt.Printf("\r%s... %s/%s completed", filename, byteCountIEC(counter.Total), byteCountIEC(counter.FileSize)) |
| 158 | + // The progress use the same line so print a new line once it's finished downloading |
| 159 | + fmt.Println() |
| 160 | + |
| 161 | + // Rename the tmp file back to the original file |
| 162 | + err = os.Rename(filepath+".tmp", filepath) |
| 163 | + if err != nil { |
| 164 | + return err |
| 165 | + } |
| 166 | + |
| 167 | + return nil |
| 168 | +} |
| 169 | + |
| 170 | +// writeCounter counts the number of bytes written to it. By implementing the Write method, |
| 171 | +// it is of the io.Writer interface and we can pass this into io.TeeReader() |
| 172 | +// Every write to this writer, will print the progress of the file write. |
| 173 | +type writeCounter struct { |
| 174 | + Total uint64 |
| 175 | + FileSize uint64 |
| 176 | +} |
| 177 | + |
| 178 | +func (wc *writeCounter) Write(p []byte) (int, error) { |
| 179 | + n := len(p) |
| 180 | + wc.Total += uint64(n) |
| 181 | + wc.printProgress() |
| 182 | + return n, nil |
| 183 | +} |
| 184 | + |
| 185 | +// PrintProgress prints the progress of a file write |
| 186 | +func (wc writeCounter) printProgress() { |
| 187 | + // Clear the line by using a character return to go back to the start and remove |
| 188 | + // the remaining characters by filling it with spaces |
| 189 | + fmt.Printf("\r%s", strings.Repeat(" ", 50)) |
| 190 | + |
| 191 | + // Return again and print current status of download |
| 192 | + fmt.Printf("\rDownloading... %s/%s", byteCountIEC(wc.Total), byteCountIEC(wc.FileSize)) |
| 193 | +} |
| 194 | + |
| 195 | +// byteCountIEC converts bytes to human-readable string in binary (IEC) format. |
| 196 | +func byteCountIEC(b uint64) string { |
| 197 | + const unit = 1024 |
| 198 | + if b < unit { |
| 199 | + return fmt.Sprintf("%d B", b) |
| 200 | + } |
| 201 | + div, exp := uint64(unit), 0 |
| 202 | + for n := b / unit; n >= unit; n /= unit { |
| 203 | + div *= unit |
| 204 | + exp++ |
| 205 | + } |
| 206 | + return fmt.Sprintf("%.1f %ciB", |
| 207 | + float64(b)/float64(div), "KMGTPE"[exp]) |
| 208 | +} |
| 209 | + |
| 210 | +func copyFile(src, dst string) error { |
| 211 | + sourceFileStat, err := os.Stat(src) |
| 212 | + if err != nil { |
| 213 | + return err |
| 214 | + } |
| 215 | + |
| 216 | + if !sourceFileStat.Mode().IsRegular() { |
| 217 | + return fmt.Errorf("%s is not a regular file", src) |
| 218 | + } |
| 219 | + |
| 220 | + source, err := os.Open(src) |
| 221 | + if err != nil { |
| 222 | + return err |
| 223 | + } |
| 224 | + defer source.Close() |
| 225 | + |
| 226 | + destination, err := os.Create(dst) |
| 227 | + if err != nil { |
| 228 | + return err |
| 229 | + } |
| 230 | + defer destination.Close() |
| 231 | + _, err = io.Copy(destination, source) |
| 232 | + return err |
| 233 | +} |
| 234 | + |
| 235 | +// CleanCache removes all files cached in transformer cache directory `CachedDir`. |
| 236 | +// |
| 237 | +// NOTE. custom `CachedDir` can be changed by setting environment `GO_TRANSFORMER` |
| 238 | +func CleanCache() error { |
| 239 | + err := os.RemoveAll(CachedDir) |
| 240 | + if err != nil { |
| 241 | + err = fmt.Errorf("CleanCache() failed: %w", err) |
| 242 | + return err |
| 243 | + } |
| 244 | + |
| 245 | + return nil |
| 246 | +} |
0 commit comments