diff --git a/db/db.go b/db/db.go index 1c3ad5b..710a683 100644 --- a/db/db.go +++ b/db/db.go @@ -1,3 +1,5 @@ +/// This package is in charge of connecting to the DB and migrations + package db import ( @@ -10,61 +12,95 @@ import ( var __version = 1 -type LogCtx struct { +type DbCtx struct { db *sql.DB version int } -func Init(db_location string) (*LogCtx, error) { - db, err := sql.Open("sqlite3", db_location) - if err != nil { - return nil, fmt.Errorf("Could not connect to DB. %w", err) - } - if err := db.Ping(); err != nil { - return nil, fmt.Errorf("Could not ping DB. %w", err) +func initVersion(tx *sql.Tx) (int, error) { + var version int + // Check the version + if err := tx.QueryRow("PRAGMA user_version").Scan(&version); err != nil { + tx.Rollback() + return 0, fmt.Errorf("Could not select user_version. %w", err) } - ctx := context.Background() - tx, err := db.BeginTx(ctx, nil) - if err != nil { - return nil, err + if version == 0 { + _, err := tx.Exec(fmt.Sprintf("PRAGMA user_version=%d", __version)) + if err != nil { + tx.Rollback() + return 0, fmt.Errorf("Could not update pragma version. %w", err) + } + version = __version + } else if version != __version { + tx.Rollback() + return 0, errors.New(fmt.Sprintf("Wrong version. Expected %d got %d", __version, version)) } + return version, nil +} +func migration0(tx *sql.Tx) error { // Optionally add the version table if _, err := tx.Exec(` create table if not exists eng_log_version (id integer not null); `); err != nil { tx.Rollback() - return nil, fmt.Errorf("Could not create eng_log_version. %w", err) + return fmt.Errorf("Could not create eng_log_version. %w", err) } // Optionally add the log table if _, err := tx.Exec(` create table if not exists log (id integer not null primary key, title string) `); err != nil { tx.Rollback() - return nil, fmt.Errorf("Could not create log. %w", err) - } - - var version int - // Check the version - if err := tx.QueryRow("SELECT IFNULL((SELECT id FROM eng_log_version LIMIT 1), 0)").Scan(&version); err != nil { - tx.Rollback() - return nil, fmt.Errorf("Could not query for eng_log_version id. %w", err) + return fmt.Errorf("Could not create log. %w", err) } + return nil +} - if version == 0 { - _, err := tx.Exec("INSERT INTO eng_log_version (id) VALUES (?)", __version) +func initMigrations(tx *sql.Tx, start_from_versiono int) error { + var migrations = []func(*sql.Tx) error{migration0} + for migration_idx, migration := range migrations { + // Version is 1 indexed, while the migration_idx is 0 indexed + var migration_num = migration_idx + 1 + if migration_num < start_from_versiono { + continue + } + err := migration(tx) if err != nil { tx.Rollback() - return nil, fmt.Errorf("Could not insert log-version. %w", err) + return fmt.Errorf("Failed migration %d. %w", migration_num, err) } - version = __version - } else if version != __version { - tx.Rollback() - return nil, errors.New(fmt.Sprintf("Wrong version. Expected %d got %d", __version, version)) + } + return nil +} + +func Init(db_location string) (*DbCtx, error) { + db, err := sql.Open("sqlite3", db_location) + if err != nil { + return nil, fmt.Errorf("Could not connect to DB. %w", err) + } + if err := db.Ping(); err != nil { + return nil, fmt.Errorf("Could not ping DB. %w", err) + } + + ctx := context.Background() + tx, err := db.BeginTx(ctx, nil) + + if err != nil { + return nil, err + } + + version, err := initVersion(tx) + if err != nil { + return nil, err + } + + err = initMigrations(tx, version) + if err != nil { + return nil, err } tx.Commit() - return &LogCtx{db, version}, nil + return &DbCtx{db, version}, nil } diff --git a/db/db_test.go b/db/db_test.go index b5dc306..643391e 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -29,11 +29,6 @@ func TestInitBasic(t *testing.T) { // check that the tables exist var table_name string - ctx.db.QueryRow( - "SELECT name FROM sqlite_master WHERE type='table' AND name='eng_log_version';", - ).Scan(&table_name) - assert_string(t, "eng_log_version", table_name) - ctx.db.QueryRow( "SELECT name FROM sqlite_master WHERE type='table' AND name='log';", ).Scan(&table_name) @@ -41,26 +36,12 @@ func TestInitBasic(t *testing.T) { // Check that the version stored is correct var version int - ctx.db.QueryRow("SELECT id FROM eng_log_version").Scan(&version) + ctx.db.QueryRow("PRAGMA user_version").Scan(&version) assert_int(t, __version, ctx.version) assert_int(t, __version, version) } -func TestInitLogVersionTableExists(t *testing.T) { - var db_location = ":memory:" - db, err := sql.Open("sqlite3", db_location) - - db.Exec(` - CREATE TABLE eng_log_version (id INTEGER NOT NULL); - INSERT INTO eng_log_version (id) VALUES (1); - `) - _, err = Init(db_location) - if err != nil { - t.Fatalf(err.Error()) - } -} - -func TestInitWrongLogVersionExists(t *testing.T) { +func TestInitWrongVersion(t *testing.T) { file, err := ioutil.TempFile("", "log_test_init_wrong.*.db") defer os.Remove(file.Name()) // clean up @@ -69,15 +50,10 @@ func TestInitWrongLogVersionExists(t *testing.T) { log.Fatal(err.Error()) } - db.Exec(` - CREATE TABLE eng_log_version (id INTEGER NOT NULL); - `) - db.Exec(` - INSERT INTO eng_log_version (id) VALUES (?); - `, __version+1) + db.Exec(fmt.Sprintf(`PRAGMA user_version=%d`, __version+1)) var version int - db.QueryRow("SELECT id FROM eng_log_version LIMIT 1").Scan(&version) + db.QueryRow("PRAGMA user_version").Scan(&version) _, err = Init(file.Name()) if err == nil {