The eng_log_version table will be used to figure out which migrations to run.pull/8/head
parent
378d0c711d
commit
e3d511db69
@ -0,0 +1,99 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
)
|
||||||
|
|
||||||
|
var __version = 1
|
||||||
|
|
||||||
|
type LogCtx struct {
|
||||||
|
db *sql.DB
|
||||||
|
version int
|
||||||
|
}
|
||||||
|
|
||||||
|
func Init(db_location string) (*LogCtx, error) {
|
||||||
|
db, err := sql.Open("sqlite3", db_location)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not connect to DB. %w", err)
|
||||||
|
}
|
||||||
|
if err := db.Ping(); err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not ping DB. %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optionally add the version table
|
||||||
|
if _, err := tx.Exec(`
|
||||||
|
create table if not exists eng_log_version (id integer not null);
|
||||||
|
`); err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return nil, fmt.Errorf("Could not create eng_log_version. %w", err)
|
||||||
|
}
|
||||||
|
// Optionally add the log table
|
||||||
|
if _, err := tx.Exec(`
|
||||||
|
create table if not exists log (id integer not null primary key, title string)
|
||||||
|
`); err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return nil, fmt.Errorf("Could not create log. %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var version int
|
||||||
|
// Check the version
|
||||||
|
if err := tx.QueryRow("SELECT IFNULL((SELECT id FROM eng_log_version LIMIT 1), 0)").Scan(&version); err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return nil, fmt.Errorf("Could not query for eng_log_version id. %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if version == 0 {
|
||||||
|
_, err := tx.Exec("INSERT INTO eng_log_version (id) VALUES (?)", __version)
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return nil, fmt.Errorf("Could not insert log-version. %w", err)
|
||||||
|
}
|
||||||
|
version = __version
|
||||||
|
} else if version != __version {
|
||||||
|
tx.Rollback()
|
||||||
|
return nil, errors.New(fmt.Sprintf("Wrong version. Expected %d got %d", __version, version))
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.Commit()
|
||||||
|
|
||||||
|
return &LogCtx{db, version}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *LogCtx) AddEntry(title string) int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
|
||||||
|
home_dir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
var db_path = path.Join(home_dir, ".dev-log.sql")
|
||||||
|
|
||||||
|
if _, err := os.Stat(db_path); errors.Is(err, os.ErrNotExist) {
|
||||||
|
_, err := os.Create(db_path)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = Init(db_path)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err.Error())
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,91 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func assert_string(t *testing.T, expected string, actual string) {
|
||||||
|
if actual != expected {
|
||||||
|
t.Fatalf("(%s, %s)", expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func assert_int(t *testing.T, expected int, actual int) {
|
||||||
|
if actual != expected {
|
||||||
|
t.Fatalf("(%d, %d)", expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInitBasic(t *testing.T) {
|
||||||
|
var db_location = ":memory:"
|
||||||
|
ctx, err := Init(db_location)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf(err.Error())
|
||||||
|
}
|
||||||
|
// check that the tables exist
|
||||||
|
var table_name string
|
||||||
|
|
||||||
|
ctx.db.QueryRow(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='eng_log_version';",
|
||||||
|
).Scan(&table_name)
|
||||||
|
assert_string(t, "eng_log_version", table_name)
|
||||||
|
|
||||||
|
ctx.db.QueryRow(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='log';",
|
||||||
|
).Scan(&table_name)
|
||||||
|
assert_string(t, "log", table_name)
|
||||||
|
|
||||||
|
// Check that the version stored is correct
|
||||||
|
var version int
|
||||||
|
ctx.db.QueryRow("SELECT id FROM eng_log_version").Scan(&version)
|
||||||
|
assert_int(t, __version, ctx.version)
|
||||||
|
assert_int(t, __version, version)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInitLogVersionTableExists(t *testing.T) {
|
||||||
|
var db_location = ":memory:"
|
||||||
|
db, err := sql.Open("sqlite3", db_location)
|
||||||
|
|
||||||
|
db.Exec(`
|
||||||
|
CREATE TABLE eng_log_version (id INTEGER NOT NULL);
|
||||||
|
INSERT INTO eng_log_version (id) VALUES (1);
|
||||||
|
`)
|
||||||
|
_, err = Init(db_location)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf(err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInitWrongLogVersionExists(t *testing.T) {
|
||||||
|
file, err := ioutil.TempFile("", "log_test_init_wrong.*.db")
|
||||||
|
defer os.Remove(file.Name()) // clean up
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite3", file.Name())
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Exec(`
|
||||||
|
CREATE TABLE eng_log_version (id INTEGER NOT NULL);
|
||||||
|
`)
|
||||||
|
db.Exec(`
|
||||||
|
INSERT INTO eng_log_version (id) VALUES (?);
|
||||||
|
`, __version+1)
|
||||||
|
|
||||||
|
var version int
|
||||||
|
db.QueryRow("SELECT id FROM eng_log_version LIMIT 1").Scan(&version)
|
||||||
|
|
||||||
|
_, err = Init(file.Name())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("No error when error expected")
|
||||||
|
}
|
||||||
|
assert_string(
|
||||||
|
t,
|
||||||
|
fmt.Sprintf("Wrong version. Expected %d got %d", __version, __version+1),
|
||||||
|
err.Error(),
|
||||||
|
)
|
||||||
|
}
|
Loading…
Reference in new issue