diff --git a/modules/git/branch.go b/modules/git/branch.go index 6c709ad8..eb6b6603 100644 --- a/modules/git/branch.go +++ b/modules/git/branch.go @@ -6,6 +6,8 @@ package git import ( "encoding/base64" "fmt" + "os" + "os/exec" "strings" "unicode" @@ -17,34 +19,28 @@ import ( // TeaCreateBranch creates a new branch in the repo, tracking from another branch. func (r TeaRepo) TeaCreateBranch(localBranchName, remoteBranchName, remoteName string) error { - // save in .git/config to assign remote for future pulls localBranchRefName := git_plumbing.NewBranchReferenceName(localBranchName) - err := r.CreateBranch(&git_config.Branch{ - Name: localBranchName, - Merge: git_plumbing.NewBranchReferenceName(remoteBranchName), - Remote: remoteName, - }) - if err != nil { + if _, err := r.Reference(localBranchRefName, true); err == nil { + return git.ErrBranchExists + } else if err != nil && err != git_plumbing.ErrReferenceNotFound { return err } - // serialize the branch to .git/refs/heads - remoteBranchRefName := git_plumbing.NewRemoteReferenceName(remoteName, remoteBranchName) - remoteBranchRef, err := r.Storer.Reference(remoteBranchRefName) - if err != nil { - return err - } - localHashRef := git_plumbing.NewHashReference(localBranchRefName, remoteBranchRef.Hash()) - return r.Storer.SetReference(localHashRef) + return runGitCommand("branch", "--track", localBranchName, fmt.Sprintf("%s/%s", remoteName, remoteBranchName)) } // TeaCheckout checks out the given branch in the worktree. func (r TeaRepo) TeaCheckout(ref git_plumbing.ReferenceName) error { - tree, err := r.Worktree() - if err != nil { - return err + args := []string{"checkout"} + if ref.IsRemote() { + args = append(args, "--detach", ref.String()) + } else if ref.IsBranch() { + args = append(args, ref.Short()) + } else { + args = append(args, ref.String()) } - return tree.Checkout(&git.CheckoutOptions{Branch: ref}) + + return runGitCommand(args...) } // TeaDeleteLocalBranch removes the given branch locally @@ -293,3 +289,19 @@ func isASCII(s string) bool { } return true } + +func runGitCommand(args ...string) error { + cmd := exec.Command("git", args...) + cmd.Env = os.Environ() + output, err := cmd.CombinedOutput() + if err == nil { + return nil + } + + msg := strings.TrimSpace(string(output)) + if msg == "" { + return fmt.Errorf("git %s: %w", strings.Join(args, " "), err) + } + + return fmt.Errorf("git %s: %w: %s", strings.Join(args, " "), err, msg) +} diff --git a/modules/task/pull_checkout.go b/modules/task/pull_checkout.go index d73a2ecd..bdf98b89 100644 --- a/modules/task/pull_checkout.go +++ b/modules/task/pull_checkout.go @@ -4,14 +4,19 @@ package task import ( + "encoding/base64" "fmt" + "os" + "os/exec" + "strconv" + "strings" "code.gitea.io/sdk/gitea" "code.gitea.io/tea/modules/config" local_git "code.gitea.io/tea/modules/git" + "code.gitea.io/tea/modules/utils" "github.com/go-git/go-git/v5" - git_config "github.com/go-git/go-git/v5/config" git_plumbing "github.com/go-git/go-git/v5/plumbing" ) @@ -78,36 +83,30 @@ func doPRFetch( localRemote *git.Remote, callback func(string) (string, error), ) (string, error) { + _ = callback localRemoteName := localRemote.Config().Name localBranchName := pr.Head.Ref - // get auth & fetch remote via its configured protocol url, err := localRepo.TeaRemoteURL(localRemoteName) if err != nil { return "", err } - auth, err := local_git.GetAuthForURL(url, login.GetAccessToken(), login.SSHKey, callback) - if err != nil { - return "", err - } - fetchOpts := &git.FetchOptions{Auth: auth} + refspecs := []string{} if isRemoteDeleted(pr) { // When the head branch is already deleted, pr.Head.Ref points to // `refs/pull//head`, where the commits stay available. // This ref must be fetched explicitly, and does not allow pushing, so we use it // only in this case as fallback. localBranchName = fmt.Sprintf("pulls/%d", pr.Index) - fetchOpts.RefSpecs = []git_config.RefSpec{git_config.RefSpec(fmt.Sprintf("%s:refs/remotes/%s/%s", + refspecs = append(refspecs, fmt.Sprintf("%s:refs/remotes/%s/%s", pr.Head.Ref, localRemoteName, localBranchName, - ))} + )) } fmt.Printf("Fetching PR %v (head %s:%s) from remote '%s'\n", pr.Index, url, pr.Head.Ref, localRemoteName) - err = localRemote.Fetch(fetchOpts) - if err == git.NoErrAlreadyUpToDate { - fmt.Println(err) - } else if err != nil { + err = runGitFetch(localRemoteName, url.String(), login.GetAccessToken(), login.SSHKey, refspecs...) + if err != nil { return "", err } return localBranchName, nil @@ -160,3 +159,52 @@ func doPRCheckout( fmt.Println(info) return localRepo.TeaCheckout(checkoutRef) } + +func runGitFetch(remoteName, remoteURL, authToken, sshKey string, refspecs ...string) error { + args := []string{} + if authToken != "" && isHTTPRemote(remoteURL) { + args = append(args, "-c", "http.extraheader="+buildGitAuthHeader(authToken)) + } + args = append(args, "fetch") + args = append(args, remoteName) + args = append(args, refspecs...) + + cmd := exec.Command("git", args...) + cmd.Env = os.Environ() + if sshKey != "" && isSSHRemote(remoteURL) { + absKey, err := utils.AbsPathWithExpansion(sshKey) + if err != nil { + return err + } + cmd.Env = append(cmd.Env, "GIT_SSH_COMMAND=ssh -i "+strconv.Quote(absKey)+" -o IdentitiesOnly=yes") + } + + output, err := cmd.CombinedOutput() + if err == nil { + trimmed := string(output) + if trimmed != "" { + fmt.Print(trimmed) + } + return nil + } + + msg := string(output) + if msg == "" { + return fmt.Errorf("git fetch %s: %w", remoteName, err) + } + + return fmt.Errorf("git fetch %s: %w: %s", remoteName, err, msg) +} + +func buildGitAuthHeader(authToken string) string { + encoded := base64.StdEncoding.EncodeToString([]byte(authToken + ":")) + return "Authorization: Basic " + encoded +} + +func isHTTPRemote(remoteURL string) bool { + return strings.HasPrefix(remoteURL, "http://") || strings.HasPrefix(remoteURL, "https://") +} + +func isSSHRemote(remoteURL string) bool { + return strings.HasPrefix(remoteURL, "ssh://") +} diff --git a/tests/integration/git_branch_test.go b/tests/integration/git_branch_test.go new file mode 100644 index 00000000..71dc64c0 --- /dev/null +++ b/tests/integration/git_branch_test.go @@ -0,0 +1,99 @@ +// Copyright 2026 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package integration + +import ( + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + + teagit "code.gitea.io/tea/modules/git" + git_plumbing "github.com/go-git/go-git/v5/plumbing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTeaCheckoutRemoteReferenceKeepsWorktreeClean(t *testing.T) { + clonePath := setupGitCheckoutTestRepo(t) + t.Chdir(clonePath) + + repo, err := teagit.RepoFromPath(clonePath) + require.NoError(t, err) + + err = repo.TeaCheckout(git_plumbing.NewRemoteReferenceName("origin", "feature/test-branch")) + require.NoError(t, err) + + assert.Empty(t, gitOutput(t, clonePath, "status", "--porcelain")) + assert.Equal(t, "HEAD", gitOutput(t, clonePath, "rev-parse", "--abbrev-ref", "HEAD")) +} + +func TestTeaCreateBranchTracksRemoteBranch(t *testing.T) { + clonePath := setupGitCheckoutTestRepo(t) + t.Chdir(clonePath) + + repo, err := teagit.RepoFromPath(clonePath) + require.NoError(t, err) + + err = repo.TeaCreateBranch("pulls/123", "feature/test-branch", "origin") + require.NoError(t, err) + + err = repo.TeaCheckout(git_plumbing.NewBranchReferenceName("pulls/123")) + require.NoError(t, err) + + assert.Empty(t, gitOutput(t, clonePath, "status", "--porcelain")) + assert.Equal(t, "origin", gitOutput(t, clonePath, "config", "--get", "branch.pulls/123.remote")) + assert.Equal(t, "refs/heads/feature/test-branch", gitOutput(t, clonePath, "config", "--get", "branch.pulls/123.merge")) + assert.Equal(t, "pulls/123", gitOutput(t, clonePath, "rev-parse", "--abbrev-ref", "HEAD")) +} + +func setupGitCheckoutTestRepo(t *testing.T) string { + t.Helper() + + tmpDir := t.TempDir() + remotePath := filepath.Join(tmpDir, "remote.git") + seedPath := filepath.Join(tmpDir, "seed") + clonePath := filepath.Join(tmpDir, "clone") + + runGit(t, tmpDir, "init", "--bare", remotePath) + runGit(t, tmpDir, "init", seedPath) + runGit(t, seedPath, "config", "user.email", "test@example.com") + runGit(t, seedPath, "config", "user.name", "Test User") + + require.NoError(t, os.WriteFile(filepath.Join(seedPath, "README.md"), []byte("# Test Repo\n"), 0o644)) + runGit(t, seedPath, "add", "README.md") + runGit(t, seedPath, "commit", "-m", "Initial commit") + runGit(t, seedPath, "branch", "-M", "main") + runGit(t, seedPath, "remote", "add", "origin", remotePath) + runGit(t, seedPath, "push", "-u", "origin", "main") + + runGit(t, seedPath, "checkout", "-b", "feature/test-branch") + require.NoError(t, os.WriteFile(filepath.Join(seedPath, "feature.txt"), []byte("feature\n"), 0o644)) + runGit(t, seedPath, "add", "feature.txt") + runGit(t, seedPath, "commit", "-m", "Add feature") + runGit(t, seedPath, "push", "-u", "origin", "feature/test-branch") + + runGit(t, tmpDir, "clone", remotePath, clonePath) + return clonePath +} + +func runGit(t *testing.T, dir string, args ...string) { + t.Helper() + + cmd := exec.Command("git", args...) + cmd.Dir = dir + output, err := cmd.CombinedOutput() + require.NoErrorf(t, err, "git %s failed: %s", strings.Join(args, " "), strings.TrimSpace(string(output))) +} + +func gitOutput(t *testing.T, dir string, args ...string) string { + t.Helper() + + cmd := exec.Command("git", args...) + cmd.Dir = dir + output, err := cmd.CombinedOutput() + require.NoErrorf(t, err, "git %s failed: %s", strings.Join(args, " "), strings.TrimSpace(string(output))) + return strings.TrimSpace(string(output)) +} diff --git a/tests/integration/pulls_reply_test.go b/tests/integration/pulls_reply_test.go index 9cf35b43..2383930a 100644 --- a/tests/integration/pulls_reply_test.go +++ b/tests/integration/pulls_reply_test.go @@ -8,6 +8,7 @@ import ( "encoding/base64" "fmt" "strconv" + "strings" "testing" "time" @@ -91,6 +92,9 @@ func TestPullsReply(t *testing.T) { "--repo", repo.FullName, }) + if err != nil && strings.Contains(err.Error(), "unknown API error: 405") { + t.Skip("pull review comment replies are not supported by this integration Gitea instance") + } require.NoError(t, err) require.Eventually(t, func() bool {