/// This package is in charge of connecting to the DB and migrations package db import ( "context" "database/sql" "errors" "fmt" _ "github.com/mattn/go-sqlite3" ) var migrations = []func(*sql.Tx) error{migration0, migration1} var __version = len(migrations) type DbCtx struct { Db *sql.DB Version int } func initVersion(tx *sql.Tx) (int, int, error) { var version int // Check the version if err := tx.QueryRow("PRAGMA user_version").Scan(&version); err != nil { tx.Rollback() return 0, 0, fmt.Errorf("Could not select user_version. %w", err) } if version < __version { _, err := tx.Exec(fmt.Sprintf("PRAGMA user_version=%d", __version)) if err != nil { tx.Rollback() return 0, 0, fmt.Errorf("Could not update pragma version. %w", err) } // Start from 1 so that all migrations are applied return version + 1, __version, nil } else if version > __version { tx.Rollback() return 0, 0, errors.New(fmt.Sprintf("Wrong version. Expected %d got %d", __version, version)) } // User is on the latest version, migrations do not need to be ran. return version, version, nil } func migration0(tx *sql.Tx) error { if _, err := tx.Exec(` create table if not exists diary_log ( id integer not null primary key, title text, created_at int not null default (strftime('%s','now')), version int not null default 0 ) `); err != nil { tx.Rollback() return fmt.Errorf("Could not create diary_log. %w", err) } return nil } func migration1(tx *sql.Tx) error { if _, err := tx.Exec(` create table diary_log_note ( id integer not null primary key, log_id integer, body text, created_at int not null default (strftime('%s','now')), version int not null default 0, foreign key(log_id) references diary_log(id) ) `); err != nil { tx.Rollback() return fmt.Errorf("Could not create diary_log_note. %w", err) } return nil } func initMigrations(tx *sql.Tx, start_from_version int) error { if start_from_version == __version { return nil } for migration_idx, migration := range migrations { // Version is 1 indexed, while the migration_idx is 0 indexed var migration_num = migration_idx + 1 if migration_num < start_from_version { continue } err := migration(tx) if err != nil { tx.Rollback() return fmt.Errorf("Failed migration %d. %w", migration_num, err) } } return nil } func Init(db_location string) (*DbCtx, error) { db, err := sql.Open("sqlite3", db_location) if err != nil { return nil, fmt.Errorf("Could not connect to DB. %w", err) } if err := db.Ping(); err != nil { return nil, fmt.Errorf("Could not ping DB. %w", err) } ctx := context.Background() tx, err := db.BeginTx(ctx, nil) if err != nil { return nil, err } old_version, new_version, err := initVersion(tx) if err != nil { return nil, err } err = initMigrations(tx, old_version) if err != nil { return nil, err } tx.Commit() return &DbCtx{db, new_version}, nil }