-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmigration.go
225 lines (204 loc) · 5.58 KB
/
migration.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
// Package migration contains functions for generating and running DB migrations.
package migration // import "code.soquee.net/migration"
import (
"context"
"database/sql"
"fmt"
"io"
"io/fs"
"os"
"path"
"strings"
"time"
)
// Generator returns a function that creates migration files at the given base
// path.
func Generator(basePath string) func(name string) error {
replacer := strings.NewReplacer(
" ", "_",
"\t", "_",
"'", "",
"\"", "",
)
return func(name string) error {
name = time.Now().Format("2006-01-02-150405_") + replacer.Replace(strings.TrimSpace(name))
relPath := path.Join(basePath, name)
// TODO: perform file creation operations in a temporary directory and then
// move everything to the final location.
err := os.MkdirAll(relPath, 0750)
if err != nil {
return err
}
var upfile *os.File
upfile, err = os.Create(path.Join(relPath, "up.sql"))
if err != nil {
return err
}
defer func() {
e := upfile.Close()
if e != nil && err == nil {
err = fmt.Errorf("error closing new up.sql: %q", err)
}
}()
_, err = fmt.Fprintf(upfile, "-- Your SQL goes here")
if err != nil {
return err
}
var downfile *os.File
downfile, err = os.Create(path.Join(relPath, "down.sql"))
if err != nil {
return err
}
defer func() {
e := downfile.Close()
if e != nil && err == nil {
err = fmt.Errorf("error closing new down.sql: %q", err)
}
}()
_, err = fmt.Fprintf(downfile, "-- This file should undo anything in `up.sql'")
return err
}
}
// RunStatus is a type that indicates if a migration has been run, not run, or
// if we can't determine the status.
type RunStatus int
// Valid RunStatus values. For more information see RunStatus.
const (
StatusUnknown RunStatus = iota
StatusNotRun
StatusRun
)
func contains(sl []string, s string) int {
for i, ss := range sl {
if ss == s {
return i
}
}
return -1
}
func getRunMigrations(ctx context.Context, tx *sql.Tx, migrationsTable string) ([]string, error) {
var ran []string
rows, err := tx.QueryContext(ctx,
fmt.Sprintf(`SELECT version FROM %s ORDER BY version ASC`, sanitize(migrationsTable)),
)
if err != nil {
return nil, err
}
/* #nosec */
defer rows.Close()
for rows.Next() {
var r string
err = rows.Scan(&r)
if err != nil {
return nil, err
}
ran = append(ran, r)
}
return ran, err
}
// LastRun returns the last migration directory by lexical order that exists in
// the database and on disk.
func LastRun(ctx context.Context, migrationsTable string, vfs fs.FS, tx *sql.Tx) (ident, name string, err error) {
var version string
if tx != nil {
err = tx.QueryRowContext(ctx,
fmt.Sprintf(`SELECT version FROM %s ORDER BY version DESC LIMIT 1`, sanitize(migrationsTable)),
).Scan(&version)
if err != nil {
return version, "", err
}
}
var fpath string
walker, err := NewWalker(ctx, migrationsTable, tx)
if err != nil {
return version, fpath, err
}
err = walker(vfs, func(name string, info fs.DirEntry, status RunStatus) error {
if tx != nil && name != version {
return nil
}
fpath = info.Name()
if tx != nil {
return io.EOF
}
return nil
})
if err != nil && err != io.EOF {
return version, fpath, err
}
return version, fpath, nil
}
// WalkFunc is the type of the function called for each file or directory
// visited by a Walker.
type WalkFunc func(name string, info fs.DirEntry, status RunStatus) error
// Walker is a function that can be used to walk a filesystem and calls WalkFunc
// for each migration.
type Walker func(vfs fs.FS, f WalkFunc) error
// NewWalker queries the database for migration status information and returns a
// function that walks the migrations it finds on the filesystem in lexical
// order (mostly, keep reading) and calls a function for each discovered
// migration, passing in its name, status, and file information.
//
// If a migration exists in the database but not on the filesystem, info will be
// nil and f will be called for it after the migrations that exist on the
// filesystem.
// No particular order is guaranteed for calls to f for migrations that do not
// exist on the filesystem.
//
// If NewWalker returns an error and a non-nil function, the function may still
// be used to walk the migrations on the filesystem but the status information
// may be wrong since the DB may not have been queried successfully.
func NewWalker(ctx context.Context, migrationsTable string, tx *sql.Tx) (Walker, error) {
var err error
var ran []string
if tx != nil {
ran, err = getRunMigrations(ctx, tx, migrationsTable)
if err != nil {
err = fmt.Errorf("error querying existing migrations: %q", err)
tx = nil
}
}
return func(vfs fs.FS, f WalkFunc) error {
err := fs.WalkDir(vfs, ".", func(p string, info fs.DirEntry, err error) error {
if p == "." {
return nil
}
if err != nil {
return err
}
if !info.IsDir() {
return nil
}
name := info.Name()
idx := strings.Index(name, "_")
if idx == -1 {
return nil
}
name = strings.Replace(name[:idx], "-", "", -1)
var status RunStatus
if tx != nil {
if n := contains(ran, name); n != -1 {
// The migration exists on the filesystem and in the database.
// Since we found it, remove it from the list of previously run
// migrations.
ran = append(ran[:n], ran[n+1:]...)
status = StatusRun
} else {
// The migration only exists on the filesystem.
status = StatusNotRun
}
}
return f(name, info, status)
})
if err != nil {
return err
}
for _, missing := range ran {
err = f(missing, nil, StatusRun)
if err != nil {
return err
}
}
return nil
}, err
}