diff --git a/log.go b/log.go new file mode 100644 index 0000000..f21c80f --- /dev/null +++ b/log.go @@ -0,0 +1,99 @@ +package main + +import ( + "context" + "database/sql" + "errors" + "fmt" + _ "github.com/mattn/go-sqlite3" + "log" + "os" + "path" +) + +var __version = 1 + +type LogCtx struct { + db *sql.DB + version int +} + +func Init(db_location string) (*LogCtx, error) { + db, err := sql.Open("sqlite3", db_location) + if err != nil { + return nil, fmt.Errorf("Could not connect to DB. %w", err) + } + if err := db.Ping(); err != nil { + return nil, fmt.Errorf("Could not ping DB. %w", err) + } + + ctx := context.Background() + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + + // Optionally add the version table + if _, err := tx.Exec(` + create table if not exists eng_log_version (id integer not null); + `); err != nil { + tx.Rollback() + return nil, fmt.Errorf("Could not create eng_log_version. %w", err) + } + // Optionally add the log table + if _, err := tx.Exec(` + create table if not exists log (id integer not null primary key, title string) + `); err != nil { + tx.Rollback() + return nil, fmt.Errorf("Could not create log. %w", err) + } + + var version int + // Check the version + if err := tx.QueryRow("SELECT IFNULL((SELECT id FROM eng_log_version LIMIT 1), 0)").Scan(&version); err != nil { + tx.Rollback() + return nil, fmt.Errorf("Could not query for eng_log_version id. %w", err) + } + + if version == 0 { + _, err := tx.Exec("INSERT INTO eng_log_version (id) VALUES (?)", __version) + if err != nil { + tx.Rollback() + return nil, fmt.Errorf("Could not insert log-version. %w", err) + } + version = __version + } else if version != __version { + tx.Rollback() + return nil, errors.New(fmt.Sprintf("Wrong version. Expected %d got %d", __version, version)) + } + + tx.Commit() + + return &LogCtx{db, version}, nil +} + +func (db *LogCtx) AddEntry(title string) int { + return 1 +} + +func main() { + + home_dir, err := os.UserHomeDir() + if err != nil { + log.Fatal(err.Error()) + } + + var db_path = path.Join(home_dir, ".dev-log.sql") + + if _, err := os.Stat(db_path); errors.Is(err, os.ErrNotExist) { + _, err := os.Create(db_path) + if err != nil { + log.Fatal(err.Error()) + } + } + + _, err = Init(db_path) + if err != nil { + log.Fatal(err.Error()) + } +} diff --git a/log_test.go b/log_test.go new file mode 100644 index 0000000..811663d --- /dev/null +++ b/log_test.go @@ -0,0 +1,91 @@ +package main + +import ( + "database/sql" + "fmt" + "io/ioutil" + "log" + "os" + "testing" +) + +func assert_string(t *testing.T, expected string, actual string) { + if actual != expected { + t.Fatalf("(%s, %s)", expected, actual) + } +} +func assert_int(t *testing.T, expected int, actual int) { + if actual != expected { + t.Fatalf("(%d, %d)", expected, actual) + } +} + +func TestInitBasic(t *testing.T) { + var db_location = ":memory:" + ctx, err := Init(db_location) + if err != nil { + t.Fatalf(err.Error()) + } + // check that the tables exist + var table_name string + + ctx.db.QueryRow( + "SELECT name FROM sqlite_master WHERE type='table' AND name='eng_log_version';", + ).Scan(&table_name) + assert_string(t, "eng_log_version", table_name) + + ctx.db.QueryRow( + "SELECT name FROM sqlite_master WHERE type='table' AND name='log';", + ).Scan(&table_name) + assert_string(t, "log", table_name) + + // Check that the version stored is correct + var version int + ctx.db.QueryRow("SELECT id FROM eng_log_version").Scan(&version) + assert_int(t, __version, ctx.version) + assert_int(t, __version, version) +} + +func TestInitLogVersionTableExists(t *testing.T) { + var db_location = ":memory:" + db, err := sql.Open("sqlite3", db_location) + + db.Exec(` + CREATE TABLE eng_log_version (id INTEGER NOT NULL); + INSERT INTO eng_log_version (id) VALUES (1); + `) + _, err = Init(db_location) + if err != nil { + t.Fatalf(err.Error()) + } +} + +func TestInitWrongLogVersionExists(t *testing.T) { + file, err := ioutil.TempFile("", "log_test_init_wrong.*.db") + defer os.Remove(file.Name()) // clean up + + db, err := sql.Open("sqlite3", file.Name()) + if err != nil { + log.Fatal(err.Error()) + } + + db.Exec(` + CREATE TABLE eng_log_version (id INTEGER NOT NULL); + `) + db.Exec(` + INSERT INTO eng_log_version (id) VALUES (?); + `, __version+1) + + var version int + db.QueryRow("SELECT id FROM eng_log_version LIMIT 1").Scan(&version) + + _, err = Init(file.Name()) + if err == nil { + t.Fatal("No error when error expected") + } + assert_string( + t, + fmt.Sprintf("Wrong version. Expected %d got %d", __version, __version+1), + err.Error(), + ) +}