diff --git a/cmd/cheat/main.go b/cmd/cheat/main.go index 30321d9..c608d73 100755 --- a/cmd/cheat/main.go +++ b/cmd/cheat/main.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/docopt/docopt-go" + "github.com/mitchellh/go-homedir" "github.com/cheat/cheat/internal/cheatpath" "github.com/cheat/cheat/internal/config" @@ -34,6 +35,13 @@ func main() { os.Exit(0) } + // get the user's home directory + home, err := homedir.Dir() + if err != nil { + fmt.Fprintf(os.Stderr, "failed to get user home directory: %v\n", err) + os.Exit(1) + } + // read the envvars into a map of strings envvars := map[string]string{} for _, e := range os.Environ() { @@ -42,7 +50,7 @@ func main() { } // load the os-specifc paths at which the config file may be located - confpaths, err := config.Paths(runtime.GOOS, envvars) + confpaths, err := config.Paths(runtime.GOOS, home, envvars) if err != nil { fmt.Fprintf(os.Stderr, "failed to load config: %v\n", err) os.Exit(1) diff --git a/internal/config/paths.go b/internal/config/paths.go index b5752c4..f333c1a 100644 --- a/internal/config/paths.go +++ b/internal/config/paths.go @@ -9,13 +9,11 @@ import ( // Paths returns config file paths that are appropriate for the operating // system -func Paths(sys string, envvars map[string]string) ([]string, error) { - - // get the user's home directory - home, err := homedir.Dir() - if err != nil { - return []string{}, fmt.Errorf("failed to get user home directory: %v", err) - } +func Paths( + sys string, + home string, + envvars map[string]string, +) ([]string, error) { // if `CHEAT_CONFIG_PATH` is set, expand ~ and return it if confpath, ok := envvars["CHEAT_CONFIG_PATH"]; ok { diff --git a/internal/config/paths_test.go b/internal/config/paths_test.go index 0838c3e..4cb2cb9 100644 --- a/internal/config/paths_test.go +++ b/internal/config/paths_test.go @@ -11,9 +11,11 @@ import ( // *nix platforms func TestValidatePathsNix(t *testing.T) { + // mock the user's home directory + home := "/home/foo" + // mock some envvars envvars := map[string]string{ - "HOME": "/home/foo", "XDG_CONFIG_HOME": "/home/bar", } @@ -27,7 +29,7 @@ func TestValidatePathsNix(t *testing.T) { // test each *nix os for _, os := range oses { // get the paths for the platform - paths, err := Paths(os, envvars) + paths, err := Paths(os, home, envvars) if err != nil { t.Errorf("paths returned an error: %v", err) } @@ -54,10 +56,11 @@ func TestValidatePathsNix(t *testing.T) { // on *nix platforms when `XDG_CONFIG_HOME is not set func TestValidatePathsNixNoXDG(t *testing.T) { + // mock the user's home directory + home := "/home/foo" + // mock some envvars - envvars := map[string]string{ - "HOME": "/home/foo", - } + envvars := map[string]string{} // specify the platforms to test oses := []string{ @@ -69,7 +72,7 @@ func TestValidatePathsNixNoXDG(t *testing.T) { // test each *nix os for _, os := range oses { // get the paths for the platform - paths, err := Paths(os, envvars) + paths, err := Paths(os, home, envvars) if err != nil { t.Errorf("paths returned an error: %v", err) } @@ -95,6 +98,9 @@ func TestValidatePathsNixNoXDG(t *testing.T) { // on Windows platforms func TestValidatePathsWindows(t *testing.T) { + // mock the user's home directory + home := "not-used-on-windows" + // mock some envvars envvars := map[string]string{ "APPDATA": "/apps", @@ -102,7 +108,7 @@ func TestValidatePathsWindows(t *testing.T) { } // get the paths for the platform - paths, err := Paths("windows", envvars) + paths, err := Paths("windows", home, envvars) if err != nil { t.Errorf("paths returned an error: %v", err) } @@ -126,7 +132,7 @@ func TestValidatePathsWindows(t *testing.T) { // TestValidatePathsUnsupported asserts that an error is returned on // unsupported platforms func TestValidatePathsUnsupported(t *testing.T) { - _, err := Paths("unsupported", map[string]string{}) + _, err := Paths("unsupported", "", map[string]string{}) if err == nil { t.Errorf("failed to return error on unsupported platform") } @@ -136,15 +142,17 @@ func TestValidatePathsUnsupported(t *testing.T) { // returned when `CHEAT_CONFIG_PATH` is explicitly specified. func TestValidatePathsCheatConfigPath(t *testing.T) { + // mock the user's home directory + home := "/home/foo" + // mock some envvars envvars := map[string]string{ - "HOME": "/home/foo", "XDG_CONFIG_HOME": "/home/bar", "CHEAT_CONFIG_PATH": "/home/baz/conf.yml", } // get the paths for the platform - paths, err := Paths("linux", envvars) + paths, err := Paths("linux", home, envvars) if err != nil { t.Errorf("paths returned an error: %v", err) }