diff --git a/cmd/list.go b/cmd/list.go index aada6e1..3746bcd 100644 --- a/cmd/list.go +++ b/cmd/list.go @@ -8,6 +8,7 @@ import ( "git.nagee.dev/isthisnagee/diary/model" "github.com/spf13/cobra" + "github.com/spf13/viper" ) // listCmd represents the list command @@ -20,11 +21,22 @@ var listCmd = &cobra.Command{ $ diary list today `, Run: func(cmd *cobra.Command, args []string) { - results := App.Db.GetDiaryEntries(model.GetDiaryEntriesQuery{}) + var num_entries *int64 = new(int64) + + q := model.GetDiaryEntriesQuery{} + + *num_entries = viper.GetInt64("listNumEntries") + if (*num_entries > 0) { + q.NumEntries = num_entries + } + + results := App.Db.GetDiaryEntries(q) PrintEntries(results) }, } func init() { + listCmd.PersistentFlags().Int64P("num-entries", "n", 20, "The number of entries to list") + viper.BindPFlag("listNumEntries", listCmd.PersistentFlags().Lookup("num-entries")) rootCmd.AddCommand(listCmd) } diff --git a/cmd/root.go b/cmd/root.go index 8d80033..8fccc0e 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -8,10 +8,9 @@ import ( "os" "github.com/spf13/cobra" + "github.com/spf13/viper" ) -var cfgFile string - // rootCmd represents the base command when called without any subcommands var rootCmd = &cobra.Command{ Use: "diary", @@ -37,6 +36,7 @@ func Execute() { func init() { cobra.OnInitialize(InitApp) - rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.diary.toml)") + rootCmd.PersistentFlags().String("config", "", "config file (default is $HOME/.diary.toml)") + viper.BindPFlag("config", rootCmd.Flags().Lookup("config")) rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") } diff --git a/cmd/today.go b/cmd/today.go index 79e1765..bea83bf 100644 --- a/cmd/today.go +++ b/cmd/today.go @@ -7,6 +7,7 @@ package cmd import ( "time" + "github.com/spf13/viper" "git.nagee.dev/isthisnagee/diary/model" "github.com/spf13/cobra" ) @@ -23,6 +24,7 @@ var todayCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { var created_after_ts *int64 = new(int64) var created_before_ts *int64 = new(int64) + var num_entries *int64 = new(int64) var now = time.Now() var startOfToday = time.Date( @@ -39,9 +41,13 @@ var todayCmd = &cobra.Command{ var endOfToday = startOfToday.AddDate(0, 0, 1) *created_before_ts = endOfToday.Unix() + if viper.GetInt64("listNumEntries") > 0 { + *num_entries = viper.GetInt64("listNumEntries") + } results := App.Db.GetDiaryEntries(model.GetDiaryEntriesQuery{ CreatedBeforeTs: created_before_ts, CreatedAfterTs: created_after_ts, + NumEntries: num_entries, }) PrintEntries(results) diff --git a/cmd/util.go b/cmd/util.go index e1789a7..85f910d 100644 --- a/cmd/util.go +++ b/cmd/util.go @@ -8,7 +8,6 @@ import ( "github.com/fatih/color" "github.com/spf13/cobra" "github.com/spf13/viper" - "log" "os" "os/exec" "path" @@ -35,13 +34,14 @@ func initConfig() Cfg { home, err := os.UserHomeDir() cobra.CheckErr(err) + cfgFile := viper.GetString("config") if cfgFile != "" { // Use config file from the flag. viper.SetConfigFile(cfgFile) } else { viper.AddConfigPath(path.Join(home, ".config", "diary")) viper.SetConfigType("toml") - viper.SetConfigName("diary.toml") + viper.SetConfigName("diary") } if viper.Get("db_path") == nil { @@ -51,8 +51,13 @@ func initConfig() Cfg { viper.AutomaticEnv() // read in environment variables that match - err = viper.ReadInConfig() - cobra.CheckErr(err) + if err := viper.ReadInConfig(); err != nil { + if _, ok := err.(viper.ConfigFileNotFoundError); ok { + // Config file not found. That's OK + } else { + cobra.CheckErr(err) + } + } return Cfg{viper.GetString("db_path")} } diff --git a/model/entry.go b/model/entry.go index 862eb37..bd778da 100644 --- a/model/entry.go +++ b/model/entry.go @@ -5,7 +5,6 @@ package model import ( "database/sql" "git.nagee.dev/isthisnagee/diary/db" - "log" "strings" ) @@ -105,6 +104,8 @@ type GetDiaryEntriesQuery struct { CreatedBeforeTs *int64 /// Inclusive CreatedAfterTs *int64 + + NumEntries *int64 } func (app *App) GetDiaryEntries(q GetDiaryEntriesQuery) []*DiaryEntry { @@ -121,6 +122,10 @@ func (app *App) GetDiaryEntries(q GetDiaryEntriesQuery) []*DiaryEntry { whereParams = append(whereParams, *q.CreatedAfterTs) } query += " ORDER BY created_at desc, id desc" + if q.NumEntries != nil { + query += " LIMIT ?" + whereParams = append(whereParams, *q.NumEntries) + } rows, err := app.Db.Query( query, diff --git a/model/entry_test.go b/model/entry_test.go index 39f5f6b..5acb3e2 100644 --- a/model/entry_test.go +++ b/model/entry_test.go @@ -3,28 +3,33 @@ package model import ( "git.nagee.dev/isthisnagee/diary/db" "testing" + "runtime/debug" ) func assert_string(t *testing.T, expected string, actual string) { if actual != expected { + t.Log(string(debug.Stack())) t.Fatalf("(%v, %v)", expected, actual) } } func assert_int(t *testing.T, expected int64, actual int64) { if actual != expected { + t.Log(string(debug.Stack())) t.Fatalf("(%v, %v)", expected, actual) } } func assert_bool(t *testing.T, expected bool, actual bool) { if actual != expected { + t.Log(string(debug.Stack())) t.Fatalf("(%v, %v)", expected, actual) } } func assert_exists(t *testing.T, actual interface{}) { if actual == nil { + t.Log(string(debug.Stack())) t.Fatalf("Unexpected nil: %s", actual) } } @@ -98,5 +103,35 @@ func DeleteDiaryEntryNotFound(t *testing.T) { t.Fatalf("Expected NotFoundError, got %s", err_type) } + teardown(app) +} + +func TestGetDiaryEntries(t *testing.T) { + var app = setup() + var result_1 = app.NewDiaryEntry("Met with Nagee @ 1PM") + var result_2 = app.NewDiaryEntry("Met with Nagee @ 2PM") + + // no numEntries + entries := app.GetDiaryEntries( + GetDiaryEntriesQuery{}, + ) + + assert_int(t, int64(len(entries)), 2) + assert_int(t, result_2.Id, entries[0].Id) + assert_int(t, result_1.Id, entries[1].Id) + + + var numEntries = new(int64) + *numEntries = 1 + entries = app.GetDiaryEntries( + GetDiaryEntriesQuery{ + NumEntries: numEntries, + }, + ) + + assert_int(t, int64(len(entries)), 1) + assert_int(t, result_2.Id, entries[0].Id) + + teardown(app) }