|
|
|
/// This package is in charge of connecting to the DB and migrations
|
|
|
|
|
|
|
|
package db
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"database/sql"
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
|
|
)
|
|
|
|
|
|
|
|
var migrations = []func(*sql.Tx) error{migration0, migration1}
|
|
|
|
var __version = len(migrations)
|
|
|
|
|
|
|
|
type DbCtx struct {
|
|
|
|
Db *sql.DB
|
|
|
|
Version int
|
|
|
|
}
|
|
|
|
|
|
|
|
func initVersion(tx *sql.Tx) (int, int, error) {
|
|
|
|
var version int
|
|
|
|
// Check the version
|
|
|
|
if err := tx.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
|
|
|
|
tx.Rollback()
|
|
|
|
return 0, 0, fmt.Errorf("Could not select user_version. %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if version < __version {
|
|
|
|
_, err := tx.Exec(fmt.Sprintf("PRAGMA user_version=%d", __version))
|
|
|
|
if err != nil {
|
|
|
|
tx.Rollback()
|
|
|
|
return 0, 0, fmt.Errorf("Could not update pragma version. %w", err)
|
|
|
|
}
|
|
|
|
// Start from 1 so that all migrations are applied
|
|
|
|
return version + 1, __version, nil
|
|
|
|
} else if version > __version {
|
|
|
|
tx.Rollback()
|
|
|
|
return 0, 0, errors.New(fmt.Sprintf("Wrong version. Expected %d got %d", __version, version))
|
|
|
|
}
|
|
|
|
// User is on the latest version, migrations do not need to be ran.
|
|
|
|
return version, version, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func migration0(tx *sql.Tx) error {
|
|
|
|
if _, err := tx.Exec(`
|
|
|
|
create table if not exists diary_log (
|
|
|
|
id integer not null primary key,
|
|
|
|
title text,
|
|
|
|
created_at int not null default (strftime('%s','now')),
|
|
|
|
version int not null default 0
|
|
|
|
)
|
|
|
|
`); err != nil {
|
|
|
|
tx.Rollback()
|
|
|
|
return fmt.Errorf("Could not create diary_log. %w", err)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func migration1(tx *sql.Tx) error {
|
|
|
|
if _, err := tx.Exec(`
|
|
|
|
create table diary_log_note (
|
|
|
|
id integer not null primary key,
|
|
|
|
log_id integer,
|
|
|
|
body text,
|
|
|
|
created_at int not null default (strftime('%s','now')),
|
|
|
|
version int not null default 0,
|
|
|
|
foreign key(log_id) references diary_log(id)
|
|
|
|
)
|
|
|
|
`); err != nil {
|
|
|
|
tx.Rollback()
|
|
|
|
return fmt.Errorf("Could not create diary_log_note. %w", err)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func initMigrations(tx *sql.Tx, start_from_version int) error {
|
|
|
|
if start_from_version == __version {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
for migration_idx, migration := range migrations {
|
|
|
|
// Version is 1 indexed, while the migration_idx is 0 indexed
|
|
|
|
var migration_num = migration_idx + 1
|
|
|
|
if migration_num < start_from_version {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
err := migration(tx)
|
|
|
|
if err != nil {
|
|
|
|
tx.Rollback()
|
|
|
|
return fmt.Errorf("Failed migration %d. %w", migration_num, err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func Init(db_location string) (*DbCtx, error) {
|
|
|
|
db, err := sql.Open("sqlite3", db_location)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("Could not connect to DB. %w", err)
|
|
|
|
}
|
|
|
|
if err := db.Ping(); err != nil {
|
|
|
|
return nil, fmt.Errorf("Could not ping DB. %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
ctx := context.Background()
|
|
|
|
tx, err := db.BeginTx(ctx, nil)
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
old_version, new_version, err := initVersion(tx)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
err = initMigrations(tx, old_version)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
tx.Commit()
|
|
|
|
|
|
|
|
return &DbCtx{db, new_version}, nil
|
|
|
|
}
|