Refactor db code

The versioning now happens using the user_version pragma, provided by
sqlite. From the docs:

> The user-version is an integer that is available to applications to use
> however they want. SQLite makes no use of the user-version itself.

Also the migrations happen in their own loop. Each migratin should
get tested(?)
pull/8/head
isthisnagee 3 years ago
parent 06891059ac
commit 42a61dcddb

@ -1,3 +1,5 @@
/// This package is in charge of connecting to the DB and migrations
package db package db
import ( import (
@ -10,61 +12,95 @@ import (
var __version = 1 var __version = 1
type LogCtx struct { type DbCtx struct {
db *sql.DB db *sql.DB
version int version int
} }
func Init(db_location string) (*LogCtx, error) { func initVersion(tx *sql.Tx) (int, error) {
db, err := sql.Open("sqlite3", db_location) var version int
if err != nil { // Check the version
return nil, fmt.Errorf("Could not connect to DB. %w", err) if err := tx.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
} tx.Rollback()
if err := db.Ping(); err != nil { return 0, fmt.Errorf("Could not select user_version. %w", err)
return nil, fmt.Errorf("Could not ping DB. %w", err)
} }
ctx := context.Background() if version == 0 {
tx, err := db.BeginTx(ctx, nil) _, err := tx.Exec(fmt.Sprintf("PRAGMA user_version=%d", __version))
if err != nil { if err != nil {
return nil, err tx.Rollback()
return 0, fmt.Errorf("Could not update pragma version. %w", err)
}
version = __version
} else if version != __version {
tx.Rollback()
return 0, errors.New(fmt.Sprintf("Wrong version. Expected %d got %d", __version, version))
} }
return version, nil
}
func migration0(tx *sql.Tx) error {
// Optionally add the version table // Optionally add the version table
if _, err := tx.Exec(` if _, err := tx.Exec(`
create table if not exists eng_log_version (id integer not null); create table if not exists eng_log_version (id integer not null);
`); err != nil { `); err != nil {
tx.Rollback() tx.Rollback()
return nil, fmt.Errorf("Could not create eng_log_version. %w", err) return fmt.Errorf("Could not create eng_log_version. %w", err)
} }
// Optionally add the log table // Optionally add the log table
if _, err := tx.Exec(` if _, err := tx.Exec(`
create table if not exists log (id integer not null primary key, title string) create table if not exists log (id integer not null primary key, title string)
`); err != nil { `); err != nil {
tx.Rollback() tx.Rollback()
return nil, fmt.Errorf("Could not create log. %w", err) return fmt.Errorf("Could not create log. %w", err)
}
var version int
// Check the version
if err := tx.QueryRow("SELECT IFNULL((SELECT id FROM eng_log_version LIMIT 1), 0)").Scan(&version); err != nil {
tx.Rollback()
return nil, fmt.Errorf("Could not query for eng_log_version id. %w", err)
} }
return nil
}
if version == 0 { func initMigrations(tx *sql.Tx, start_from_versiono int) error {
_, err := tx.Exec("INSERT INTO eng_log_version (id) VALUES (?)", __version) var migrations = []func(*sql.Tx) error{migration0}
for migration_idx, migration := range migrations {
// Version is 1 indexed, while the migration_idx is 0 indexed
var migration_num = migration_idx + 1
if migration_num < start_from_versiono {
continue
}
err := migration(tx)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return nil, fmt.Errorf("Could not insert log-version. %w", err) return fmt.Errorf("Failed migration %d. %w", migration_num, err)
} }
version = __version }
} else if version != __version { return nil
tx.Rollback() }
return nil, errors.New(fmt.Sprintf("Wrong version. Expected %d got %d", __version, version))
func Init(db_location string) (*DbCtx, error) {
db, err := sql.Open("sqlite3", db_location)
if err != nil {
return nil, fmt.Errorf("Could not connect to DB. %w", err)
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("Could not ping DB. %w", err)
}
ctx := context.Background()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
version, err := initVersion(tx)
if err != nil {
return nil, err
}
err = initMigrations(tx, version)
if err != nil {
return nil, err
} }
tx.Commit() tx.Commit()
return &LogCtx{db, version}, nil return &DbCtx{db, version}, nil
} }

@ -29,11 +29,6 @@ func TestInitBasic(t *testing.T) {
// check that the tables exist // check that the tables exist
var table_name string var table_name string
ctx.db.QueryRow(
"SELECT name FROM sqlite_master WHERE type='table' AND name='eng_log_version';",
).Scan(&table_name)
assert_string(t, "eng_log_version", table_name)
ctx.db.QueryRow( ctx.db.QueryRow(
"SELECT name FROM sqlite_master WHERE type='table' AND name='log';", "SELECT name FROM sqlite_master WHERE type='table' AND name='log';",
).Scan(&table_name) ).Scan(&table_name)
@ -41,26 +36,12 @@ func TestInitBasic(t *testing.T) {
// Check that the version stored is correct // Check that the version stored is correct
var version int var version int
ctx.db.QueryRow("SELECT id FROM eng_log_version").Scan(&version) ctx.db.QueryRow("PRAGMA user_version").Scan(&version)
assert_int(t, __version, ctx.version) assert_int(t, __version, ctx.version)
assert_int(t, __version, version) assert_int(t, __version, version)
} }
func TestInitLogVersionTableExists(t *testing.T) { func TestInitWrongVersion(t *testing.T) {
var db_location = ":memory:"
db, err := sql.Open("sqlite3", db_location)
db.Exec(`
CREATE TABLE eng_log_version (id INTEGER NOT NULL);
INSERT INTO eng_log_version (id) VALUES (1);
`)
_, err = Init(db_location)
if err != nil {
t.Fatalf(err.Error())
}
}
func TestInitWrongLogVersionExists(t *testing.T) {
file, err := ioutil.TempFile("", "log_test_init_wrong.*.db") file, err := ioutil.TempFile("", "log_test_init_wrong.*.db")
defer os.Remove(file.Name()) // clean up defer os.Remove(file.Name()) // clean up
@ -69,15 +50,10 @@ func TestInitWrongLogVersionExists(t *testing.T) {
log.Fatal(err.Error()) log.Fatal(err.Error())
} }
db.Exec(` db.Exec(fmt.Sprintf(`PRAGMA user_version=%d`, __version+1))
CREATE TABLE eng_log_version (id INTEGER NOT NULL);
`)
db.Exec(`
INSERT INTO eng_log_version (id) VALUES (?);
`, __version+1)
var version int var version int
db.QueryRow("SELECT id FROM eng_log_version LIMIT 1").Scan(&version) db.QueryRow("PRAGMA user_version").Scan(&version)
_, err = Init(file.Name()) _, err = Init(file.Name())
if err == nil { if err == nil {

Loading…
Cancel
Save