package main import ( "os" "reflect" "testing" "golang.org/x/tools/go/vcs" ) var ( root = &vcs.RepoRoot{ VCS: vcs.ByCmd("git"), Repo: "https://github.com/bazeltest/rules_go", Root: "github.com/bazeltest/rules_go", } ) func TestMain(m *testing.M) { // Replace vcs.RepoRootForImportPath to disable any network calls. repoRootForImportPath = func(_ string, _ bool) (*vcs.RepoRoot, error) { return root, nil } os.Exit(m.Run()) } func TestGetRepoRoot(t *testing.T) { for _, tc := range []struct { label string remote string cmd string importpath string r *vcs.RepoRoot }{ { label: "all", remote: "https://github.com/bazeltest/rules_go", cmd: "git", importpath: "github.com/bazeltest/rules_go", r: root, }, { label: "different remote", remote: "https://example.com/rules_go", cmd: "git", importpath: "github.com/bazeltest/rules_go", r: &vcs.RepoRoot{ VCS: vcs.ByCmd("git"), Repo: "https://example.com/rules_go", Root: "github.com/bazeltest/rules_go", }, }, { label: "only importpath", importpath: "github.com/bazeltest/rules_go", r: root, }, } { r, err := getRepoRoot(tc.remote, tc.cmd, tc.importpath) if err != nil { t.Errorf("[%s] %v", tc.label, err) } if !reflect.DeepEqual(r, tc.r) { t.Errorf("[%s] Expected %+v, got %+v", tc.label, tc.r, r) } } } func TestGetRepoRoot_error(t *testing.T) { for _, tc := range []struct { label string remote string cmd string importpath string }{ { label: "importpath as remote", remote: "github.com/bazeltest/rules_go", }, { label: "missing vcs", remote: "https://github.com/bazeltest/rules_go", importpath: "github.com/bazeltest/rules_go", }, { label: "missing remote", cmd: "git", importpath: "github.com/bazeltest/rules_go", }, } { r, err := getRepoRoot(tc.remote, tc.cmd, tc.importpath) if err == nil { t.Errorf("[%s] expected error. Got %+v", tc.label, r) } } }