diff --git a/cmd/login/oauth_refresh.go b/cmd/login/oauth_refresh.go index ec4dc27..4369cdb 100644 --- a/cmd/login/oauth_refresh.go +++ b/cmd/login/oauth_refresh.go @@ -17,7 +17,7 @@ import ( var CmdLoginOAuthRefresh = cli.Command{ Name: "oauth-refresh", Usage: "Refresh an OAuth token", - Description: "Manually refresh an expired OAuth token. Usually only used when troubleshooting authentication.", + Description: "Manually refresh an expired OAuth token. If the refresh token is also expired, opens a browser for re-authentication.", ArgsUsage: "[]", Action: runLoginOAuthRefresh, } @@ -48,12 +48,21 @@ func runLoginOAuthRefresh(_ context.Context, cmd *cli.Command) error { return fmt.Errorf("login '%s' does not have a refresh token. It may have been created using a different authentication method", loginName) } - // Refresh the token + // Try to refresh the token err := auth.RefreshAccessToken(login) - if err != nil { - return fmt.Errorf("failed to refresh token: %s", err) + if err == nil { + fmt.Printf("Successfully refreshed OAuth token for %s\n", loginName) + return nil } - fmt.Printf("Successfully refreshed OAuth token for %s\n", loginName) + // Refresh failed - fall back to browser-based re-authentication + fmt.Printf("Token refresh failed: %s\n", err) + fmt.Println("Opening browser for re-authentication...") + + if err := auth.ReauthenticateLogin(login); err != nil { + return fmt.Errorf("re-authentication failed: %s", err) + } + + fmt.Printf("Successfully re-authenticated %s\n", loginName) return nil } diff --git a/go.mod b/go.mod index 08a18af..6d78822 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/urfave/cli/v3 v3.6.2 golang.org/x/crypto v0.47.0 golang.org/x/oauth2 v0.34.0 + golang.org/x/sys v0.40.0 golang.org/x/term v0.39.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -89,7 +90,6 @@ require ( github.com/yuin/goldmark-emoji v1.0.5 // indirect golang.org/x/net v0.48.0 // indirect golang.org/x/sync v0.19.0 // indirect - golang.org/x/sys v0.40.0 // indirect golang.org/x/text v0.33.0 // indirect golang.org/x/tools v0.40.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect diff --git a/go.sum b/go.sum index e015202..0fa89d6 100644 --- a/go.sum +++ b/go.sum @@ -71,14 +71,10 @@ github.com/charmbracelet/x/termios v0.1.1 h1:o3Q2bT8eqzGnGPOYheoYS8eEleT5ZVNYNy8 github.com/charmbracelet/x/termios v0.1.1/go.mod h1:rB7fnv1TgOPOyyKRJ9o+AsTU/vK5WHJ2ivHeut/Pcwo= github.com/charmbracelet/x/xpty v0.1.2 h1:Pqmu4TEJ8KeA9uSkISKMU3f+C1F6OGBn8ABuGlqCbtI= github.com/charmbracelet/x/xpty v0.1.2/go.mod h1:XK2Z0id5rtLWcpeNiMYBccNNBrP2IJnzHI0Lq13Xzq4= -github.com/clipperhouse/displaywidth v0.3.1 h1:k07iN9gD32177o1y4O1jQMzbLdCrsGJh+blirVYybsk= -github.com/clipperhouse/displaywidth v0.3.1/go.mod h1:tgLJKKyaDOCadywag3agw4snxS5kYEuYR6Y9+qWDDYM= github.com/clipperhouse/displaywidth v0.6.2 h1:ZDpTkFfpHOKte4RG5O/BOyf3ysnvFswpyYrV7z2uAKo= github.com/clipperhouse/displaywidth v0.6.2/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o= github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= -github.com/clipperhouse/uax29/v2 v2.2.0 h1:ChwIKnQN3kcZteTXMgb1wztSgaU+ZemkgWdohwgs8tY= -github.com/clipperhouse/uax29/v2 v2.2.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM= github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4= github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0= @@ -169,12 +165,8 @@ github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 h1:zrbMGy9YXpIeTnGj github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6/go.mod h1:rEKTHC9roVVicUIfZK7DYrdIoM0EOr8mK1Hj5s3JjH0= github.com/olekukonko/errors v1.1.0 h1:RNuGIh15QdDenh+hNvKrJkmxxjV4hcS50Db478Ou5sM= github.com/olekukonko/errors v1.1.0/go.mod h1:ppzxA5jBKcO1vIpCXQ9ZqgDh8iwODz6OXIGKU8r5m4Y= -github.com/olekukonko/ll v0.1.2 h1:lkg/k/9mlsy0SxO5aC+WEpbdT5K83ddnNhAepz7TQc0= -github.com/olekukonko/ll v0.1.2/go.mod h1:b52bVQRRPObe+yyBl0TxNfhesL0nedD4Cht0/zx55Ew= github.com/olekukonko/ll v0.1.4-0.20260115111900-9e59c2286df0 h1:jrYnow5+hy3WRDCBypUFvVKNSPPCdqgSXIE9eJDD8LM= github.com/olekukonko/ll v0.1.4-0.20260115111900-9e59c2286df0/go.mod h1:b52bVQRRPObe+yyBl0TxNfhesL0nedD4Cht0/zx55Ew= -github.com/olekukonko/tablewriter v1.1.1 h1:b3reP6GCfrHwmKkYwNRFh2rxidGHcT6cgxj/sHiDDx0= -github.com/olekukonko/tablewriter v1.1.1/go.mod h1:De/bIcTF+gpBDB3Alv3fEsZA+9unTsSzAg/ZGADCtn4= github.com/olekukonko/tablewriter v1.1.3 h1:VSHhghXxrP0JHl+0NnKid7WoEmd9/urKRJLysb70nnA= github.com/olekukonko/tablewriter v1.1.3/go.mod h1:9VU0knjhmMkXjnMKrZ3+L2JhhtsQ/L38BbL3CRNE8tM= github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= @@ -211,8 +203,6 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/urfave/cli-docs/v3 v3.1.0 h1:Sa5xm19IpE5gpm6tZzXdfjdFxn67PnEsE4dpXF7vsKw= github.com/urfave/cli-docs/v3 v3.1.0/go.mod h1:59d+5Hz1h6GSGJ10cvcEkbIe3j233t4XDqI72UIx7to= -github.com/urfave/cli/v3 v3.6.1 h1:j8Qq8NyUawj/7rTYdBGrxcH7A/j7/G8Q5LhWEW4G3Mo= -github.com/urfave/cli/v3 v3.6.1/go.mod h1:ysVLtOEmg2tOy6PknnYVhDoouyC/6N42TMeoMzskhso= github.com/urfave/cli/v3 v3.6.2 h1:lQuqiPrZ1cIz8hz+HcrG0TNZFxU70dPZ3Yl+pSrH9A8= github.com/urfave/cli/v3 v3.6.2/go.mod h1:ysVLtOEmg2tOy6PknnYVhDoouyC/6N42TMeoMzskhso= github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= @@ -244,8 +234,6 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= -golang.org/x/oauth2 v0.33.0 h1:4Q+qn+E5z8gPRJfmRy7C2gGG3T4jIprK6aSYgTXGRpo= -golang.org/x/oauth2 v0.33.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/modules/auth/oauth.go b/modules/auth/oauth.go index 8584261..867a475 100644 --- a/modules/auth/oauth.go +++ b/modules/auth/oauth.go @@ -71,12 +71,24 @@ func OAuthLoginWithOptions(name, giteaURL string, insecure bool) error { // OAuthLoginWithFullOptions performs an OAuth2 PKCE login flow with full options control func OAuthLoginWithFullOptions(opts OAuthOptions) error { - // Normalize URL - serverURL, err := utils.NormalizeURL(opts.URL) + serverURL, token, err := performBrowserOAuthFlow(opts) if err != nil { - return fmt.Errorf("unable to parse URL: %s", err) + return err } + return createLoginFromToken(opts.Name, serverURL, token, opts.Insecure) +} + +// performBrowserOAuthFlow performs the browser-based OAuth2 PKCE flow and returns the token. +// This is the shared implementation used by both new logins and re-authentication. +func performBrowserOAuthFlow(opts OAuthOptions) (serverURL string, token *oauth2.Token, err error) { + // Normalize URL + normalizedURL, err := utils.NormalizeURL(opts.URL) + if err != nil { + return "", nil, fmt.Errorf("unable to parse URL: %s", err) + } + serverURL = normalizedURL.String() + // Set defaults if needed if opts.ClientID == "" { opts.ClientID = config.DefaultClientID @@ -107,7 +119,7 @@ func OAuthLoginWithFullOptions(opts OAuthOptions) error { // Generate code verifier (random string) codeVerifier, err := generateCodeVerifier(codeVerifierLength) if err != nil { - return fmt.Errorf("failed to generate code verifier: %s", err) + return "", nil, fmt.Errorf("failed to generate code verifier: %s", err) } // Generate code challenge (SHA256 hash of code verifier) @@ -118,8 +130,8 @@ func OAuthLoginWithFullOptions(opts OAuthOptions) error { ctx = context.WithValue(ctx, oauth2.HTTPClient, createHTTPClient(opts.Insecure)) // Configure the OAuth2 endpoints - authURL := fmt.Sprintf("%s/login/oauth/authorize", serverURL) - tokenURL := fmt.Sprintf("%s/login/oauth/access_token", serverURL) + authURL := fmt.Sprintf("%s/login/oauth/authorize", normalizedURL) + tokenURL := fmt.Sprintf("%s/login/oauth/access_token", normalizedURL) oauth2Config := &oauth2.Config{ ClientID: opts.ClientID, @@ -141,7 +153,7 @@ func OAuthLoginWithFullOptions(opts OAuthOptions) error { // Generate state parameter to protect against CSRF state, err := generateCodeVerifier(32) if err != nil { - return fmt.Errorf("failed to generate state: %s", err) + return "", nil, fmt.Errorf("failed to generate state: %s", err) } // Get the authorization URL @@ -156,7 +168,7 @@ func OAuthLoginWithFullOptions(opts OAuthOptions) error { strings.Contains(err.Error(), "redirect") { fmt.Println("\nError: Redirect URL not registered in Gitea") fmt.Println("\nTo fix this, you need to register the redirect URL in Gitea:") - fmt.Printf("1. Go to your Gitea instance: %s\n", serverURL) + fmt.Printf("1. Go to your Gitea instance: %s\n", normalizedURL) fmt.Println("2. Sign in and go to Settings > Applications") fmt.Println("3. Register a new OAuth2 application with:") fmt.Printf(" - Application Name: tea-cli (or any name)\n") @@ -165,22 +177,21 @@ func OAuthLoginWithFullOptions(opts OAuthOptions) error { fmt.Printf(" tea login add --oauth --client-id YOUR_CLIENT_ID --redirect-url %s\n", opts.RedirectURL) fmt.Println("\nAlternatively, you can use a token-based login: tea login add") } - return fmt.Errorf("authorization failed: %s", err) + return "", nil, fmt.Errorf("authorization failed: %s", err) } // Verify state to prevent CSRF attacks if state != receivedState { - return fmt.Errorf("state mismatch, possible CSRF attack") + return "", nil, fmt.Errorf("state mismatch, possible CSRF attack") } // Exchange authorization code for token - token, err := oauth2Config.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", codeVerifier)) + token, err = oauth2Config.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", codeVerifier)) if err != nil { - return fmt.Errorf("token exchange failed: %s", err) + return "", nil, fmt.Errorf("token exchange failed: %s", err) } - // Create login with token data - return createLoginFromToken(opts.Name, serverURL.String(), token, opts.Insecure) + return serverURL, token, nil } // createHTTPClient creates an HTTP client with optional insecure setting @@ -417,3 +428,33 @@ func createLoginFromToken(name, serverURL string, token *oauth2.Token, insecure func RefreshAccessToken(login *config.Login) error { return login.RefreshOAuthToken() } + +// ReauthenticateLogin performs a full browser-based OAuth flow to get new tokens +// for an existing login. This is used when the refresh token is expired or invalid. +func ReauthenticateLogin(login *config.Login) error { + opts := OAuthOptions{ + Name: login.Name, + URL: login.URL, + Insecure: login.Insecure, + ClientID: config.DefaultClientID, + RedirectURL: fmt.Sprintf("http://%s:%d", redirectHost, redirectPort), + Port: redirectPort, + } + + _, token, err := performBrowserOAuthFlow(opts) + if err != nil { + return err + } + + // Update the existing login with new token data + login.Token = token.AccessToken + if token.RefreshToken != "" { + login.RefreshToken = token.RefreshToken + } + if !token.Expiry.IsZero() { + login.TokenExpiry = token.Expiry.Unix() + } + + // Save updated login + return config.SaveLoginTokens(login) +} diff --git a/modules/config/config.go b/modules/config/config.go index 6402697..f958b57 100644 --- a/modules/config/config.go +++ b/modules/config/config.go @@ -98,8 +98,33 @@ func loadConfig() (err error) { return } -// saveConfig save config to file -func saveConfig() error { +// reloadConfigFromDisk re-reads the config file from disk, bypassing the sync.Once. +// This is used after acquiring a lock to ensure we have the latest config state. +// The caller must hold the config lock. +func reloadConfigFromDisk() error { + ymlPath := GetConfigPath() + exist, _ := utils.FileExist(ymlPath) + if !exist { + // No config file yet, start with empty config + config = LocalConfig{} + return nil + } + + bs, err := os.ReadFile(ymlPath) + if err != nil { + return fmt.Errorf("failed to read config file %s: %w", ymlPath, err) + } + + if err := yaml.Unmarshal(bs, &config); err != nil { + return fmt.Errorf("failed to parse config file %s: %w", ymlPath, err) + } + + return nil +} + +// saveConfigUnsafe saves config to file without acquiring a lock. +// Caller must hold the config lock. +func saveConfigUnsafe() error { ymlPath := GetConfigPath() bs, err := yaml.Marshal(config) if err != nil { diff --git a/modules/config/lock.go b/modules/config/lock.go new file mode 100644 index 0000000..e8160cc --- /dev/null +++ b/modules/config/lock.go @@ -0,0 +1,97 @@ +// Copyright 2026 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package config + +import ( + "fmt" + "os" + "sync" + "time" +) + +const ( + // LockTimeout is the default timeout for acquiring the config file lock. + LockTimeout = 5 * time.Second + + // mutexPollInterval is how often to retry acquiring the in-process mutex. + mutexPollInterval = 10 * time.Millisecond + + // fileLockPollInterval is how often to retry acquiring the file lock. + fileLockPollInterval = 50 * time.Millisecond +) + +// configMutex protects in-process concurrent access to the config. +var configMutex sync.Mutex + +// acquireConfigLock acquires both the in-process mutex and a file lock. +// Returns an unlock function that must be called to release both locks. +// The timeout applies to acquiring the file lock; the mutex acquisition +// uses the same timeout via a TryLock loop. +func acquireConfigLock(lockPath string, timeout time.Duration) (unlock func() error, err error) { + // Try to acquire mutex with timeout + deadline := time.Now().Add(timeout) + for { + if configMutex.TryLock() { + break + } + if time.Now().After(deadline) { + return nil, fmt.Errorf("timeout waiting for config mutex") + } + time.Sleep(mutexPollInterval) + } + + // Mutex acquired, now try file lock + file, err := os.OpenFile(lockPath, os.O_CREATE|os.O_RDWR, 0o600) + if err != nil { + configMutex.Unlock() + return nil, fmt.Errorf("failed to open lock file: %w", err) + } + + // Try to acquire file lock with remaining timeout + remaining := max(time.Until(deadline), 0) + + if err := lockFile(file, remaining); err != nil { + file.Close() + configMutex.Unlock() + return nil, fmt.Errorf("failed to acquire file lock: %w", err) + } + + // Return unlock function + return func() error { + unlockErr := unlockFile(file) + closeErr := file.Close() + configMutex.Unlock() + if unlockErr != nil { + return unlockErr + } + return closeErr + }, nil +} + +// getConfigLockPath returns the path to the lock file for the config. +func getConfigLockPath() string { + return GetConfigPath() + ".lock" +} + +// withConfigLock executes the given function while holding the config lock. +// It acquires the lock, reloads the config from disk, executes fn, and releases the lock. +func withConfigLock(fn func() error) (retErr error) { + lockPath := getConfigLockPath() + unlock, err := acquireConfigLock(lockPath, LockTimeout) + if err != nil { + return fmt.Errorf("failed to acquire config lock: %w", err) + } + defer func() { + if unlockErr := unlock(); unlockErr != nil && retErr == nil { + retErr = fmt.Errorf("failed to release config lock: %w", unlockErr) + } + }() + + // Reload config from disk to get latest state + if err := reloadConfigFromDisk(); err != nil { + return err + } + + return fn() +} diff --git a/modules/config/lock_test.go b/modules/config/lock_test.go new file mode 100644 index 0000000..28e9323 --- /dev/null +++ b/modules/config/lock_test.go @@ -0,0 +1,182 @@ +// Copyright 2026 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package config + +import ( + "fmt" + "os" + "path/filepath" + "sync" + "testing" + "time" +) + +func TestConfigLock_BasicLockUnlock(t *testing.T) { + // Create a temp directory for test + tmpDir, err := os.MkdirTemp("", "tea-lock-test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + lockPath := filepath.Join(tmpDir, "config.yml.lock") + + // Should be able to acquire lock + unlock, err := acquireConfigLock(lockPath, 5*time.Second) + if err != nil { + t.Fatalf("failed to acquire lock: %v", err) + } + + // Should be able to release lock + err = unlock() + if err != nil { + t.Fatalf("failed to release lock: %v", err) + } +} + +func TestConfigLock_MutexProtection(t *testing.T) { + // Create a temp directory for test + tmpDir, err := os.MkdirTemp("", "tea-lock-test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + lockPath := filepath.Join(tmpDir, "config.yml.lock") + + // Acquire lock + unlock, err := acquireConfigLock(lockPath, 5*time.Second) + if err != nil { + t.Fatalf("failed to acquire lock: %v", err) + } + + // Try to acquire again from same process - should block/timeout due to mutex + done := make(chan bool) + go func() { + _, err := acquireConfigLock(lockPath, 100*time.Millisecond) + done <- (err != nil) // Should timeout/fail + }() + + select { + case failed := <-done: + if !failed { + t.Error("second lock acquisition should have failed due to mutex") + } + case <-time.After(2 * time.Second): + t.Error("test timed out") + } + + if err := unlock(); err != nil { + t.Errorf("failed to unlock: %v", err) + } +} + +func TestReloadConfigFromDisk(t *testing.T) { + // Save original config state + originalConfig := config + + // Create a temp config file + tmpDir, err := os.MkdirTemp("", "tea-reload-test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // We can't easily change GetConfigPath, so we test that reloadConfigFromDisk + // handles a missing file gracefully (returns nil and resets config) + config = LocalConfig{Logins: []Login{{Name: "test"}}} + + // Call reload - since the actual config path likely exists or doesn't, + // we just verify it doesn't panic and returns without error or with expected error + err = reloadConfigFromDisk() + // The function should either succeed or return an error, not panic + if err != nil { + // This is acceptable - config file might not exist in test environment + t.Logf("reloadConfigFromDisk returned error (expected in test env): %v", err) + } + + // Restore original config + config = originalConfig +} + +func TestWithConfigLock(t *testing.T) { + executed := false + err := withConfigLock(func() error { + executed = true + return nil + }) + if err != nil { + t.Errorf("withConfigLock returned error: %v", err) + } + if !executed { + t.Error("function was not executed") + } +} + +func TestWithConfigLock_PropagatesError(t *testing.T) { + expectedErr := fmt.Errorf("test error") + err := withConfigLock(func() error { + return expectedErr + }) + + if err != expectedErr { + t.Errorf("expected error %v, got %v", expectedErr, err) + } +} + +func TestDoubleCheckedLocking_SimulatedRefresh(t *testing.T) { + // This test simulates the double-checked locking pattern + // by having multiple goroutines try to "refresh" simultaneously + + var ( + refreshCount int + mu sync.Mutex + ) + + // Simulate what RefreshOAuthToken does with double-check + simulatedRefresh := func(tokenExpiry *int64) error { + // First check (without lock) + if *tokenExpiry > time.Now().Unix() { + return nil // Token still valid + } + + return withConfigLock(func() error { + // Double-check after acquiring lock + if *tokenExpiry > time.Now().Unix() { + return nil // Another goroutine refreshed it + } + + // Simulate refresh + mu.Lock() + refreshCount++ + mu.Unlock() + + time.Sleep(50 * time.Millisecond) // Simulate API call + *tokenExpiry = time.Now().Add(1 * time.Hour).Unix() + return nil + }) + } + + // Start with expired token + tokenExpiry := time.Now().Add(-1 * time.Hour).Unix() + + // Launch multiple goroutines trying to refresh + var wg sync.WaitGroup + for range 5 { + wg.Add(1) + go func() { + defer wg.Done() + if err := simulatedRefresh(&tokenExpiry); err != nil { + t.Errorf("refresh failed: %v", err) + } + }() + } + + wg.Wait() + + // Should only have refreshed once due to double-checked locking + if refreshCount != 1 { + t.Errorf("expected 1 refresh, got %d", refreshCount) + } +} diff --git a/modules/config/lock_unix.go b/modules/config/lock_unix.go new file mode 100644 index 0000000..cb64859 --- /dev/null +++ b/modules/config/lock_unix.go @@ -0,0 +1,39 @@ +// Copyright 2026 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +//go:build unix + +package config + +import ( + "fmt" + "os" + "syscall" + "time" +) + +// lockFile acquires an exclusive lock on the file using flock. +// It polls with non-blocking flock until timeout. +func lockFile(file *os.File, timeout time.Duration) error { + deadline := time.Now().Add(timeout) + + for { + err := syscall.Flock(int(file.Fd()), syscall.LOCK_EX|syscall.LOCK_NB) + if err == nil { + return nil + } + if err != syscall.EWOULDBLOCK { + return fmt.Errorf("flock failed: %w", err) + } + + if time.Now().After(deadline) { + return fmt.Errorf("timeout waiting for file lock") + } + time.Sleep(fileLockPollInterval) + } +} + +// unlockFile releases the lock on the file. +func unlockFile(file *os.File) error { + return syscall.Flock(int(file.Fd()), syscall.LOCK_UN) +} diff --git a/modules/config/lock_unix_test.go b/modules/config/lock_unix_test.go new file mode 100644 index 0000000..f8ab2f8 --- /dev/null +++ b/modules/config/lock_unix_test.go @@ -0,0 +1,82 @@ +// Copyright 2026 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +//go:build unix + +package config + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "testing" + "time" +) + +func TestConfigLock_CrossProcess(t *testing.T) { + // Create a temp directory for test + tmpDir, err := os.MkdirTemp("", "tea-lock-test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + lockPath := filepath.Join(tmpDir, "config.yml.lock") + + // Acquire lock in main process + unlock, err := acquireConfigLock(lockPath, 5*time.Second) + if err != nil { + t.Fatalf("failed to acquire lock: %v", err) + } + defer unlock() + + // Spawn a subprocess that tries to acquire the same lock + // The subprocess should fail to acquire within timeout + script := fmt.Sprintf(` +package main + +import ( + "os" + "syscall" +) + +func main() { + file, err := os.OpenFile(%q, os.O_CREATE|os.O_RDWR, 0o600) + if err != nil { + os.Exit(2) + } + defer file.Close() + + // Try non-blocking lock + err = syscall.Flock(int(file.Fd()), syscall.LOCK_EX|syscall.LOCK_NB) + if err != nil { + // Lock is held - expected behavior + os.Exit(0) + } + // Lock was acquired - unexpected + syscall.Flock(int(file.Fd()), syscall.LOCK_UN) + os.Exit(1) +} +`, lockPath) + + // Write and run the test script + scriptPath := filepath.Join(tmpDir, "locktest.go") + if err := os.WriteFile(scriptPath, []byte(script), 0o600); err != nil { + t.Fatalf("failed to write test script: %v", err) + } + + cmd := exec.Command("go", "run", scriptPath) + if err := cmd.Run(); err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + if exitErr.ExitCode() == 1 { + t.Error("subprocess acquired lock when it should have been held") + } else if exitErr.ExitCode() == 2 { + t.Errorf("subprocess failed to open lock file: %v", err) + } + } else { + t.Errorf("subprocess execution failed: %v", err) + } + } + // Exit code 0 means lock was properly held - success +} diff --git a/modules/config/lock_windows.go b/modules/config/lock_windows.go new file mode 100644 index 0000000..acf7387 --- /dev/null +++ b/modules/config/lock_windows.go @@ -0,0 +1,48 @@ +// Copyright 2026 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +//go:build windows + +package config + +import ( + "fmt" + "os" + "time" + + "golang.org/x/sys/windows" +) + +// lockFile acquires an exclusive lock on the file using LockFileEx. +// It polls with non-blocking LockFileEx until timeout. +func lockFile(file *os.File, timeout time.Duration) error { + deadline := time.Now().Add(timeout) + handle := windows.Handle(file.Fd()) + + // LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY + const flags = windows.LOCKFILE_EXCLUSIVE_LOCK | windows.LOCKFILE_FAIL_IMMEDIATELY + + for { + // Lock the first byte (advisory lock) + var overlapped windows.Overlapped + err := windows.LockFileEx(handle, flags, 0, 1, 0, &overlapped) + if err == nil { + return nil + } + if err != windows.ERROR_LOCK_VIOLATION { + return fmt.Errorf("LockFileEx failed: %w", err) + } + + if time.Now().After(deadline) { + return fmt.Errorf("timeout waiting for file lock") + } + time.Sleep(fileLockPollInterval) + } +} + +// unlockFile releases the lock on the file. +func unlockFile(file *os.File) error { + handle := windows.Handle(file.Fd()) + var overlapped windows.Overlapped + return windows.UnlockFileEx(handle, 0, 1, 0, &overlapped) +} diff --git a/modules/config/login.go b/modules/config/login.go index db8522f..dd5af07 100644 --- a/modules/config/login.go +++ b/modules/config/login.go @@ -84,24 +84,22 @@ func GetDefaultLogin() (*Login, error) { // SetDefaultLogin set the default login by name (case insensitive) func SetDefaultLogin(name string) error { - if err := loadConfig(); err != nil { - return err - } - - loginExist := false - for i := range config.Logins { - config.Logins[i].Default = false - if strings.ToLower(config.Logins[i].Name) == strings.ToLower(name) { - config.Logins[i].Default = true - loginExist = true + return withConfigLock(func() error { + loginExist := false + for i := range config.Logins { + config.Logins[i].Default = false + if strings.EqualFold(config.Logins[i].Name, name) { + config.Logins[i].Default = true + loginExist = true + } } - } - if !loginExist { - return fmt.Errorf("login '%s' not found", name) - } + if !loginExist { + return fmt.Errorf("login '%s' not found", name) + } - return saveConfig() + return saveConfigUnsafe() + }) } // GetLoginByName get login by name (case insensitive) @@ -112,7 +110,7 @@ func GetLoginByName(name string) *Login { } for _, l := range config.Logins { - if strings.ToLower(l.Name) == strings.ToLower(name) { + if strings.EqualFold(l.Name, name) { return &l } } @@ -165,64 +163,56 @@ func GetLoginsByHost(host string) []*Login { // DeleteLogin delete a login by name from config func DeleteLogin(name string) error { - idx := -1 - for i, l := range config.Logins { - if l.Name == name { - idx = i - break + return withConfigLock(func() error { + idx := -1 + for i, l := range config.Logins { + if strings.EqualFold(l.Name, name) { + idx = i + break + } + } + if idx == -1 { + return fmt.Errorf("can not delete login '%s', does not exist", name) } - } - if idx == -1 { - return fmt.Errorf("can not delete login '%s', does not exist", name) - } - config.Logins = append(config.Logins[:idx], config.Logins[idx+1:]...) + config.Logins = append(config.Logins[:idx], config.Logins[idx+1:]...) - return saveConfig() + return saveConfigUnsafe() + }) } // AddLogin save a login to config func AddLogin(login *Login) error { - if err := loadConfig(); err != nil { - return err - } - - // Check for duplicate login names - for _, existing := range config.Logins { - if strings.EqualFold(existing.Name, login.Name) { - return fmt.Errorf("login name '%s' already exists", login.Name) + return withConfigLock(func() error { + // Check for duplicate login names + for _, existing := range config.Logins { + if strings.EqualFold(existing.Name, login.Name) { + return fmt.Errorf("login name '%s' already exists", login.Name) + } } - } - // save login to global var - config.Logins = append(config.Logins, *login) + // save login to global var + config.Logins = append(config.Logins, *login) - // save login to config file - return saveConfig() + // save login to config file + return saveConfigUnsafe() + }) } -// UpdateLogin updates an existing login in the config -func UpdateLogin(login *Login) error { - if err := loadConfig(); err != nil { - return err - } - - // Find and update the login - found := false - for i, l := range config.Logins { - if l.Name == login.Name { - config.Logins[i] = *login - found = true - break +// SaveLoginTokens updates the token fields for an existing login. +// This is used after browser-based re-authentication to save new tokens. +func SaveLoginTokens(login *Login) error { + return withConfigLock(func() error { + for i, l := range config.Logins { + if strings.EqualFold(l.Name, login.Name) { + config.Logins[i].Token = login.Token + config.Logins[i].RefreshToken = login.RefreshToken + config.Logins[i].TokenExpiry = login.TokenExpiry + return saveConfigUnsafe() + } } - } - - if !found { return fmt.Errorf("login %s not found", login.Name) - } - - // Save updated config - return saveConfig() + }) } // RefreshOAuthTokenIfNeeded refreshes the OAuth token if it's expired or near expiry. @@ -240,22 +230,65 @@ func (l *Login) RefreshOAuthTokenIfNeeded() error { // RefreshOAuthToken refreshes the OAuth access token using the refresh token. // It updates the login with new token information and saves it to config. +// Uses double-checked locking to avoid unnecessary refresh calls when multiple +// processes race to refresh the same token. func (l *Login) RefreshOAuthToken() error { if l.RefreshToken == "" { return fmt.Errorf("no refresh token available") } - // Create a Token object with current values + return withConfigLock(func() error { + // Double-check: after acquiring lock, re-read config and check if + // another process already refreshed the token + for i, login := range config.Logins { + if login.Name == l.Name { + // Check if token was refreshed by another process + if login.TokenExpiry != l.TokenExpiry && login.TokenExpiry > 0 { + expiryTime := time.Unix(login.TokenExpiry, 0) + if time.Now().Add(TokenRefreshThreshold).Before(expiryTime) { + // Token was refreshed by another process, update our copy + l.Token = login.Token + l.RefreshToken = login.RefreshToken + l.TokenExpiry = login.TokenExpiry + return nil + } + } + + // Still need to refresh - proceed with OAuth call + newToken, err := doOAuthRefresh(l) + if err != nil { + return err + } + + // Update login with new token information + l.Token = newToken.AccessToken + if newToken.RefreshToken != "" { + l.RefreshToken = newToken.RefreshToken + } + if !newToken.Expiry.IsZero() { + l.TokenExpiry = newToken.Expiry.Unix() + } + + // Update in config slice and save + config.Logins[i] = *l + return saveConfigUnsafe() + } + } + + return fmt.Errorf("login %s not found", l.Name) + }) +} + +// doOAuthRefresh performs the actual OAuth token refresh API call. +func doOAuthRefresh(l *Login) (*oauth2.Token, error) { currentToken := &oauth2.Token{ AccessToken: l.Token, RefreshToken: l.RefreshToken, Expiry: time.Unix(l.TokenExpiry, 0), } - // Set up the OAuth2 config ctx := context.Background() - // Create HTTP client, respecting the login's TLS settings httpClient := &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: l.Insecure}, @@ -263,7 +296,6 @@ func (l *Login) RefreshOAuthToken() error { } ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) - // Configure the OAuth2 endpoints oauth2Config := &oauth2.Config{ ClientID: DefaultClientID, Endpoint: oauth2.Endpoint{ @@ -271,25 +303,12 @@ func (l *Login) RefreshOAuthToken() error { }, } - // Refresh the token newToken, err := oauth2Config.TokenSource(ctx, currentToken).Token() if err != nil { - return fmt.Errorf("failed to refresh token: %w", err) + return nil, fmt.Errorf("failed to refresh token: %w", err) } - // Update login with new token information - l.Token = newToken.AccessToken - - if newToken.RefreshToken != "" { - l.RefreshToken = newToken.RefreshToken - } - - if !newToken.Expiry.IsZero() { - l.TokenExpiry = newToken.Expiry.Unix() - } - - // Save updated login to config - return UpdateLogin(l) + return newToken, nil } // Client returns a client to operate Gitea API. You may provide additional modifiers diff --git a/modules/task/pull_create.go b/modules/task/pull_create.go index 424854d..33e2f5a 100644 --- a/modules/task/pull_create.go +++ b/modules/task/pull_create.go @@ -157,7 +157,8 @@ func GetDefaultPRTitle(header string) string { // CreateAgitFlowPull creates a agit flow PR in the given repo and prints the result func CreateAgitFlowPull(ctx *context.TeaContext, remote, head, base, topic string, opts *gitea.CreateIssueOption, - callback func(string) (string, error)) (err error) { + callback func(string) (string, error), +) (err error) { // default is default branch if len(base) == 0 { base, err = GetDefaultPRBase(ctx.Login, ctx.Owner, ctx.Repo)