mirror of
https://gitea.com/gitea/tea.git
synced 2026-02-22 06:13:32 +01:00
Add locking to ensure safe concurrent access to config file (#881)
Reviewed-on: https://gitea.com/gitea/tea/pulls/881 Co-authored-by: techknowlogick <techknowlogick@gitea.com> Co-committed-by: techknowlogick <techknowlogick@gitea.com>
This commit is contained in:
committed by
techknowlogick
parent
0d5bf60632
commit
ae9eb4f2c0
@@ -71,12 +71,24 @@ func OAuthLoginWithOptions(name, giteaURL string, insecure bool) error {
|
||||
|
||||
// OAuthLoginWithFullOptions performs an OAuth2 PKCE login flow with full options control
|
||||
func OAuthLoginWithFullOptions(opts OAuthOptions) error {
|
||||
// Normalize URL
|
||||
serverURL, err := utils.NormalizeURL(opts.URL)
|
||||
serverURL, token, err := performBrowserOAuthFlow(opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse URL: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return createLoginFromToken(opts.Name, serverURL, token, opts.Insecure)
|
||||
}
|
||||
|
||||
// performBrowserOAuthFlow performs the browser-based OAuth2 PKCE flow and returns the token.
|
||||
// This is the shared implementation used by both new logins and re-authentication.
|
||||
func performBrowserOAuthFlow(opts OAuthOptions) (serverURL string, token *oauth2.Token, err error) {
|
||||
// Normalize URL
|
||||
normalizedURL, err := utils.NormalizeURL(opts.URL)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("unable to parse URL: %s", err)
|
||||
}
|
||||
serverURL = normalizedURL.String()
|
||||
|
||||
// Set defaults if needed
|
||||
if opts.ClientID == "" {
|
||||
opts.ClientID = config.DefaultClientID
|
||||
@@ -107,7 +119,7 @@ func OAuthLoginWithFullOptions(opts OAuthOptions) error {
|
||||
// Generate code verifier (random string)
|
||||
codeVerifier, err := generateCodeVerifier(codeVerifierLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate code verifier: %s", err)
|
||||
return "", nil, fmt.Errorf("failed to generate code verifier: %s", err)
|
||||
}
|
||||
|
||||
// Generate code challenge (SHA256 hash of code verifier)
|
||||
@@ -118,8 +130,8 @@ func OAuthLoginWithFullOptions(opts OAuthOptions) error {
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, createHTTPClient(opts.Insecure))
|
||||
|
||||
// Configure the OAuth2 endpoints
|
||||
authURL := fmt.Sprintf("%s/login/oauth/authorize", serverURL)
|
||||
tokenURL := fmt.Sprintf("%s/login/oauth/access_token", serverURL)
|
||||
authURL := fmt.Sprintf("%s/login/oauth/authorize", normalizedURL)
|
||||
tokenURL := fmt.Sprintf("%s/login/oauth/access_token", normalizedURL)
|
||||
|
||||
oauth2Config := &oauth2.Config{
|
||||
ClientID: opts.ClientID,
|
||||
@@ -141,7 +153,7 @@ func OAuthLoginWithFullOptions(opts OAuthOptions) error {
|
||||
// Generate state parameter to protect against CSRF
|
||||
state, err := generateCodeVerifier(32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate state: %s", err)
|
||||
return "", nil, fmt.Errorf("failed to generate state: %s", err)
|
||||
}
|
||||
|
||||
// Get the authorization URL
|
||||
@@ -156,7 +168,7 @@ func OAuthLoginWithFullOptions(opts OAuthOptions) error {
|
||||
strings.Contains(err.Error(), "redirect") {
|
||||
fmt.Println("\nError: Redirect URL not registered in Gitea")
|
||||
fmt.Println("\nTo fix this, you need to register the redirect URL in Gitea:")
|
||||
fmt.Printf("1. Go to your Gitea instance: %s\n", serverURL)
|
||||
fmt.Printf("1. Go to your Gitea instance: %s\n", normalizedURL)
|
||||
fmt.Println("2. Sign in and go to Settings > Applications")
|
||||
fmt.Println("3. Register a new OAuth2 application with:")
|
||||
fmt.Printf(" - Application Name: tea-cli (or any name)\n")
|
||||
@@ -165,22 +177,21 @@ func OAuthLoginWithFullOptions(opts OAuthOptions) error {
|
||||
fmt.Printf(" tea login add --oauth --client-id YOUR_CLIENT_ID --redirect-url %s\n", opts.RedirectURL)
|
||||
fmt.Println("\nAlternatively, you can use a token-based login: tea login add")
|
||||
}
|
||||
return fmt.Errorf("authorization failed: %s", err)
|
||||
return "", nil, fmt.Errorf("authorization failed: %s", err)
|
||||
}
|
||||
|
||||
// Verify state to prevent CSRF attacks
|
||||
if state != receivedState {
|
||||
return fmt.Errorf("state mismatch, possible CSRF attack")
|
||||
return "", nil, fmt.Errorf("state mismatch, possible CSRF attack")
|
||||
}
|
||||
|
||||
// Exchange authorization code for token
|
||||
token, err := oauth2Config.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", codeVerifier))
|
||||
token, err = oauth2Config.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", codeVerifier))
|
||||
if err != nil {
|
||||
return fmt.Errorf("token exchange failed: %s", err)
|
||||
return "", nil, fmt.Errorf("token exchange failed: %s", err)
|
||||
}
|
||||
|
||||
// Create login with token data
|
||||
return createLoginFromToken(opts.Name, serverURL.String(), token, opts.Insecure)
|
||||
return serverURL, token, nil
|
||||
}
|
||||
|
||||
// createHTTPClient creates an HTTP client with optional insecure setting
|
||||
@@ -417,3 +428,33 @@ func createLoginFromToken(name, serverURL string, token *oauth2.Token, insecure
|
||||
func RefreshAccessToken(login *config.Login) error {
|
||||
return login.RefreshOAuthToken()
|
||||
}
|
||||
|
||||
// ReauthenticateLogin performs a full browser-based OAuth flow to get new tokens
|
||||
// for an existing login. This is used when the refresh token is expired or invalid.
|
||||
func ReauthenticateLogin(login *config.Login) error {
|
||||
opts := OAuthOptions{
|
||||
Name: login.Name,
|
||||
URL: login.URL,
|
||||
Insecure: login.Insecure,
|
||||
ClientID: config.DefaultClientID,
|
||||
RedirectURL: fmt.Sprintf("http://%s:%d", redirectHost, redirectPort),
|
||||
Port: redirectPort,
|
||||
}
|
||||
|
||||
_, token, err := performBrowserOAuthFlow(opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the existing login with new token data
|
||||
login.Token = token.AccessToken
|
||||
if token.RefreshToken != "" {
|
||||
login.RefreshToken = token.RefreshToken
|
||||
}
|
||||
if !token.Expiry.IsZero() {
|
||||
login.TokenExpiry = token.Expiry.Unix()
|
||||
}
|
||||
|
||||
// Save updated login
|
||||
return config.SaveLoginTokens(login)
|
||||
}
|
||||
|
||||
@@ -98,8 +98,33 @@ func loadConfig() (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// saveConfig save config to file
|
||||
func saveConfig() error {
|
||||
// reloadConfigFromDisk re-reads the config file from disk, bypassing the sync.Once.
|
||||
// This is used after acquiring a lock to ensure we have the latest config state.
|
||||
// The caller must hold the config lock.
|
||||
func reloadConfigFromDisk() error {
|
||||
ymlPath := GetConfigPath()
|
||||
exist, _ := utils.FileExist(ymlPath)
|
||||
if !exist {
|
||||
// No config file yet, start with empty config
|
||||
config = LocalConfig{}
|
||||
return nil
|
||||
}
|
||||
|
||||
bs, err := os.ReadFile(ymlPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config file %s: %w", ymlPath, err)
|
||||
}
|
||||
|
||||
if err := yaml.Unmarshal(bs, &config); err != nil {
|
||||
return fmt.Errorf("failed to parse config file %s: %w", ymlPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveConfigUnsafe saves config to file without acquiring a lock.
|
||||
// Caller must hold the config lock.
|
||||
func saveConfigUnsafe() error {
|
||||
ymlPath := GetConfigPath()
|
||||
bs, err := yaml.Marshal(config)
|
||||
if err != nil {
|
||||
|
||||
97
modules/config/lock.go
Normal file
97
modules/config/lock.go
Normal file
@@ -0,0 +1,97 @@
|
||||
// Copyright 2026 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// LockTimeout is the default timeout for acquiring the config file lock.
|
||||
LockTimeout = 5 * time.Second
|
||||
|
||||
// mutexPollInterval is how often to retry acquiring the in-process mutex.
|
||||
mutexPollInterval = 10 * time.Millisecond
|
||||
|
||||
// fileLockPollInterval is how often to retry acquiring the file lock.
|
||||
fileLockPollInterval = 50 * time.Millisecond
|
||||
)
|
||||
|
||||
// configMutex protects in-process concurrent access to the config.
|
||||
var configMutex sync.Mutex
|
||||
|
||||
// acquireConfigLock acquires both the in-process mutex and a file lock.
|
||||
// Returns an unlock function that must be called to release both locks.
|
||||
// The timeout applies to acquiring the file lock; the mutex acquisition
|
||||
// uses the same timeout via a TryLock loop.
|
||||
func acquireConfigLock(lockPath string, timeout time.Duration) (unlock func() error, err error) {
|
||||
// Try to acquire mutex with timeout
|
||||
deadline := time.Now().Add(timeout)
|
||||
for {
|
||||
if configMutex.TryLock() {
|
||||
break
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
return nil, fmt.Errorf("timeout waiting for config mutex")
|
||||
}
|
||||
time.Sleep(mutexPollInterval)
|
||||
}
|
||||
|
||||
// Mutex acquired, now try file lock
|
||||
file, err := os.OpenFile(lockPath, os.O_CREATE|os.O_RDWR, 0o600)
|
||||
if err != nil {
|
||||
configMutex.Unlock()
|
||||
return nil, fmt.Errorf("failed to open lock file: %w", err)
|
||||
}
|
||||
|
||||
// Try to acquire file lock with remaining timeout
|
||||
remaining := max(time.Until(deadline), 0)
|
||||
|
||||
if err := lockFile(file, remaining); err != nil {
|
||||
file.Close()
|
||||
configMutex.Unlock()
|
||||
return nil, fmt.Errorf("failed to acquire file lock: %w", err)
|
||||
}
|
||||
|
||||
// Return unlock function
|
||||
return func() error {
|
||||
unlockErr := unlockFile(file)
|
||||
closeErr := file.Close()
|
||||
configMutex.Unlock()
|
||||
if unlockErr != nil {
|
||||
return unlockErr
|
||||
}
|
||||
return closeErr
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getConfigLockPath returns the path to the lock file for the config.
|
||||
func getConfigLockPath() string {
|
||||
return GetConfigPath() + ".lock"
|
||||
}
|
||||
|
||||
// withConfigLock executes the given function while holding the config lock.
|
||||
// It acquires the lock, reloads the config from disk, executes fn, and releases the lock.
|
||||
func withConfigLock(fn func() error) (retErr error) {
|
||||
lockPath := getConfigLockPath()
|
||||
unlock, err := acquireConfigLock(lockPath, LockTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to acquire config lock: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if unlockErr := unlock(); unlockErr != nil && retErr == nil {
|
||||
retErr = fmt.Errorf("failed to release config lock: %w", unlockErr)
|
||||
}
|
||||
}()
|
||||
|
||||
// Reload config from disk to get latest state
|
||||
if err := reloadConfigFromDisk(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fn()
|
||||
}
|
||||
182
modules/config/lock_test.go
Normal file
182
modules/config/lock_test.go
Normal file
@@ -0,0 +1,182 @@
|
||||
// Copyright 2026 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestConfigLock_BasicLockUnlock(t *testing.T) {
|
||||
// Create a temp directory for test
|
||||
tmpDir, err := os.MkdirTemp("", "tea-lock-test")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
lockPath := filepath.Join(tmpDir, "config.yml.lock")
|
||||
|
||||
// Should be able to acquire lock
|
||||
unlock, err := acquireConfigLock(lockPath, 5*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to acquire lock: %v", err)
|
||||
}
|
||||
|
||||
// Should be able to release lock
|
||||
err = unlock()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to release lock: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigLock_MutexProtection(t *testing.T) {
|
||||
// Create a temp directory for test
|
||||
tmpDir, err := os.MkdirTemp("", "tea-lock-test")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
lockPath := filepath.Join(tmpDir, "config.yml.lock")
|
||||
|
||||
// Acquire lock
|
||||
unlock, err := acquireConfigLock(lockPath, 5*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to acquire lock: %v", err)
|
||||
}
|
||||
|
||||
// Try to acquire again from same process - should block/timeout due to mutex
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
_, err := acquireConfigLock(lockPath, 100*time.Millisecond)
|
||||
done <- (err != nil) // Should timeout/fail
|
||||
}()
|
||||
|
||||
select {
|
||||
case failed := <-done:
|
||||
if !failed {
|
||||
t.Error("second lock acquisition should have failed due to mutex")
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("test timed out")
|
||||
}
|
||||
|
||||
if err := unlock(); err != nil {
|
||||
t.Errorf("failed to unlock: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReloadConfigFromDisk(t *testing.T) {
|
||||
// Save original config state
|
||||
originalConfig := config
|
||||
|
||||
// Create a temp config file
|
||||
tmpDir, err := os.MkdirTemp("", "tea-reload-test")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// We can't easily change GetConfigPath, so we test that reloadConfigFromDisk
|
||||
// handles a missing file gracefully (returns nil and resets config)
|
||||
config = LocalConfig{Logins: []Login{{Name: "test"}}}
|
||||
|
||||
// Call reload - since the actual config path likely exists or doesn't,
|
||||
// we just verify it doesn't panic and returns without error or with expected error
|
||||
err = reloadConfigFromDisk()
|
||||
// The function should either succeed or return an error, not panic
|
||||
if err != nil {
|
||||
// This is acceptable - config file might not exist in test environment
|
||||
t.Logf("reloadConfigFromDisk returned error (expected in test env): %v", err)
|
||||
}
|
||||
|
||||
// Restore original config
|
||||
config = originalConfig
|
||||
}
|
||||
|
||||
func TestWithConfigLock(t *testing.T) {
|
||||
executed := false
|
||||
err := withConfigLock(func() error {
|
||||
executed = true
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("withConfigLock returned error: %v", err)
|
||||
}
|
||||
if !executed {
|
||||
t.Error("function was not executed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithConfigLock_PropagatesError(t *testing.T) {
|
||||
expectedErr := fmt.Errorf("test error")
|
||||
err := withConfigLock(func() error {
|
||||
return expectedErr
|
||||
})
|
||||
|
||||
if err != expectedErr {
|
||||
t.Errorf("expected error %v, got %v", expectedErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoubleCheckedLocking_SimulatedRefresh(t *testing.T) {
|
||||
// This test simulates the double-checked locking pattern
|
||||
// by having multiple goroutines try to "refresh" simultaneously
|
||||
|
||||
var (
|
||||
refreshCount int
|
||||
mu sync.Mutex
|
||||
)
|
||||
|
||||
// Simulate what RefreshOAuthToken does with double-check
|
||||
simulatedRefresh := func(tokenExpiry *int64) error {
|
||||
// First check (without lock)
|
||||
if *tokenExpiry > time.Now().Unix() {
|
||||
return nil // Token still valid
|
||||
}
|
||||
|
||||
return withConfigLock(func() error {
|
||||
// Double-check after acquiring lock
|
||||
if *tokenExpiry > time.Now().Unix() {
|
||||
return nil // Another goroutine refreshed it
|
||||
}
|
||||
|
||||
// Simulate refresh
|
||||
mu.Lock()
|
||||
refreshCount++
|
||||
mu.Unlock()
|
||||
|
||||
time.Sleep(50 * time.Millisecond) // Simulate API call
|
||||
*tokenExpiry = time.Now().Add(1 * time.Hour).Unix()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Start with expired token
|
||||
tokenExpiry := time.Now().Add(-1 * time.Hour).Unix()
|
||||
|
||||
// Launch multiple goroutines trying to refresh
|
||||
var wg sync.WaitGroup
|
||||
for range 5 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := simulatedRefresh(&tokenExpiry); err != nil {
|
||||
t.Errorf("refresh failed: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should only have refreshed once due to double-checked locking
|
||||
if refreshCount != 1 {
|
||||
t.Errorf("expected 1 refresh, got %d", refreshCount)
|
||||
}
|
||||
}
|
||||
39
modules/config/lock_unix.go
Normal file
39
modules/config/lock_unix.go
Normal file
@@ -0,0 +1,39 @@
|
||||
// Copyright 2026 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
//go:build unix
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// lockFile acquires an exclusive lock on the file using flock.
|
||||
// It polls with non-blocking flock until timeout.
|
||||
func lockFile(file *os.File, timeout time.Duration) error {
|
||||
deadline := time.Now().Add(timeout)
|
||||
|
||||
for {
|
||||
err := syscall.Flock(int(file.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if err != syscall.EWOULDBLOCK {
|
||||
return fmt.Errorf("flock failed: %w", err)
|
||||
}
|
||||
|
||||
if time.Now().After(deadline) {
|
||||
return fmt.Errorf("timeout waiting for file lock")
|
||||
}
|
||||
time.Sleep(fileLockPollInterval)
|
||||
}
|
||||
}
|
||||
|
||||
// unlockFile releases the lock on the file.
|
||||
func unlockFile(file *os.File) error {
|
||||
return syscall.Flock(int(file.Fd()), syscall.LOCK_UN)
|
||||
}
|
||||
82
modules/config/lock_unix_test.go
Normal file
82
modules/config/lock_unix_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
// Copyright 2026 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
//go:build unix
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestConfigLock_CrossProcess(t *testing.T) {
|
||||
// Create a temp directory for test
|
||||
tmpDir, err := os.MkdirTemp("", "tea-lock-test")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
lockPath := filepath.Join(tmpDir, "config.yml.lock")
|
||||
|
||||
// Acquire lock in main process
|
||||
unlock, err := acquireConfigLock(lockPath, 5*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to acquire lock: %v", err)
|
||||
}
|
||||
defer unlock()
|
||||
|
||||
// Spawn a subprocess that tries to acquire the same lock
|
||||
// The subprocess should fail to acquire within timeout
|
||||
script := fmt.Sprintf(`
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func main() {
|
||||
file, err := os.OpenFile(%q, os.O_CREATE|os.O_RDWR, 0o600)
|
||||
if err != nil {
|
||||
os.Exit(2)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Try non-blocking lock
|
||||
err = syscall.Flock(int(file.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
|
||||
if err != nil {
|
||||
// Lock is held - expected behavior
|
||||
os.Exit(0)
|
||||
}
|
||||
// Lock was acquired - unexpected
|
||||
syscall.Flock(int(file.Fd()), syscall.LOCK_UN)
|
||||
os.Exit(1)
|
||||
}
|
||||
`, lockPath)
|
||||
|
||||
// Write and run the test script
|
||||
scriptPath := filepath.Join(tmpDir, "locktest.go")
|
||||
if err := os.WriteFile(scriptPath, []byte(script), 0o600); err != nil {
|
||||
t.Fatalf("failed to write test script: %v", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("go", "run", scriptPath)
|
||||
if err := cmd.Run(); err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
if exitErr.ExitCode() == 1 {
|
||||
t.Error("subprocess acquired lock when it should have been held")
|
||||
} else if exitErr.ExitCode() == 2 {
|
||||
t.Errorf("subprocess failed to open lock file: %v", err)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("subprocess execution failed: %v", err)
|
||||
}
|
||||
}
|
||||
// Exit code 0 means lock was properly held - success
|
||||
}
|
||||
48
modules/config/lock_windows.go
Normal file
48
modules/config/lock_windows.go
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright 2026 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
//go:build windows
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// lockFile acquires an exclusive lock on the file using LockFileEx.
|
||||
// It polls with non-blocking LockFileEx until timeout.
|
||||
func lockFile(file *os.File, timeout time.Duration) error {
|
||||
deadline := time.Now().Add(timeout)
|
||||
handle := windows.Handle(file.Fd())
|
||||
|
||||
// LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY
|
||||
const flags = windows.LOCKFILE_EXCLUSIVE_LOCK | windows.LOCKFILE_FAIL_IMMEDIATELY
|
||||
|
||||
for {
|
||||
// Lock the first byte (advisory lock)
|
||||
var overlapped windows.Overlapped
|
||||
err := windows.LockFileEx(handle, flags, 0, 1, 0, &overlapped)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if err != windows.ERROR_LOCK_VIOLATION {
|
||||
return fmt.Errorf("LockFileEx failed: %w", err)
|
||||
}
|
||||
|
||||
if time.Now().After(deadline) {
|
||||
return fmt.Errorf("timeout waiting for file lock")
|
||||
}
|
||||
time.Sleep(fileLockPollInterval)
|
||||
}
|
||||
}
|
||||
|
||||
// unlockFile releases the lock on the file.
|
||||
func unlockFile(file *os.File) error {
|
||||
handle := windows.Handle(file.Fd())
|
||||
var overlapped windows.Overlapped
|
||||
return windows.UnlockFileEx(handle, 0, 1, 0, &overlapped)
|
||||
}
|
||||
@@ -84,24 +84,22 @@ func GetDefaultLogin() (*Login, error) {
|
||||
|
||||
// SetDefaultLogin set the default login by name (case insensitive)
|
||||
func SetDefaultLogin(name string) error {
|
||||
if err := loadConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
loginExist := false
|
||||
for i := range config.Logins {
|
||||
config.Logins[i].Default = false
|
||||
if strings.ToLower(config.Logins[i].Name) == strings.ToLower(name) {
|
||||
config.Logins[i].Default = true
|
||||
loginExist = true
|
||||
return withConfigLock(func() error {
|
||||
loginExist := false
|
||||
for i := range config.Logins {
|
||||
config.Logins[i].Default = false
|
||||
if strings.EqualFold(config.Logins[i].Name, name) {
|
||||
config.Logins[i].Default = true
|
||||
loginExist = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !loginExist {
|
||||
return fmt.Errorf("login '%s' not found", name)
|
||||
}
|
||||
if !loginExist {
|
||||
return fmt.Errorf("login '%s' not found", name)
|
||||
}
|
||||
|
||||
return saveConfig()
|
||||
return saveConfigUnsafe()
|
||||
})
|
||||
}
|
||||
|
||||
// GetLoginByName get login by name (case insensitive)
|
||||
@@ -112,7 +110,7 @@ func GetLoginByName(name string) *Login {
|
||||
}
|
||||
|
||||
for _, l := range config.Logins {
|
||||
if strings.ToLower(l.Name) == strings.ToLower(name) {
|
||||
if strings.EqualFold(l.Name, name) {
|
||||
return &l
|
||||
}
|
||||
}
|
||||
@@ -165,64 +163,56 @@ func GetLoginsByHost(host string) []*Login {
|
||||
|
||||
// DeleteLogin delete a login by name from config
|
||||
func DeleteLogin(name string) error {
|
||||
idx := -1
|
||||
for i, l := range config.Logins {
|
||||
if l.Name == name {
|
||||
idx = i
|
||||
break
|
||||
return withConfigLock(func() error {
|
||||
idx := -1
|
||||
for i, l := range config.Logins {
|
||||
if strings.EqualFold(l.Name, name) {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx == -1 {
|
||||
return fmt.Errorf("can not delete login '%s', does not exist", name)
|
||||
}
|
||||
}
|
||||
if idx == -1 {
|
||||
return fmt.Errorf("can not delete login '%s', does not exist", name)
|
||||
}
|
||||
|
||||
config.Logins = append(config.Logins[:idx], config.Logins[idx+1:]...)
|
||||
config.Logins = append(config.Logins[:idx], config.Logins[idx+1:]...)
|
||||
|
||||
return saveConfig()
|
||||
return saveConfigUnsafe()
|
||||
})
|
||||
}
|
||||
|
||||
// AddLogin save a login to config
|
||||
func AddLogin(login *Login) error {
|
||||
if err := loadConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for duplicate login names
|
||||
for _, existing := range config.Logins {
|
||||
if strings.EqualFold(existing.Name, login.Name) {
|
||||
return fmt.Errorf("login name '%s' already exists", login.Name)
|
||||
return withConfigLock(func() error {
|
||||
// Check for duplicate login names
|
||||
for _, existing := range config.Logins {
|
||||
if strings.EqualFold(existing.Name, login.Name) {
|
||||
return fmt.Errorf("login name '%s' already exists", login.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// save login to global var
|
||||
config.Logins = append(config.Logins, *login)
|
||||
// save login to global var
|
||||
config.Logins = append(config.Logins, *login)
|
||||
|
||||
// save login to config file
|
||||
return saveConfig()
|
||||
// save login to config file
|
||||
return saveConfigUnsafe()
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateLogin updates an existing login in the config
|
||||
func UpdateLogin(login *Login) error {
|
||||
if err := loadConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Find and update the login
|
||||
found := false
|
||||
for i, l := range config.Logins {
|
||||
if l.Name == login.Name {
|
||||
config.Logins[i] = *login
|
||||
found = true
|
||||
break
|
||||
// SaveLoginTokens updates the token fields for an existing login.
|
||||
// This is used after browser-based re-authentication to save new tokens.
|
||||
func SaveLoginTokens(login *Login) error {
|
||||
return withConfigLock(func() error {
|
||||
for i, l := range config.Logins {
|
||||
if strings.EqualFold(l.Name, login.Name) {
|
||||
config.Logins[i].Token = login.Token
|
||||
config.Logins[i].RefreshToken = login.RefreshToken
|
||||
config.Logins[i].TokenExpiry = login.TokenExpiry
|
||||
return saveConfigUnsafe()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return fmt.Errorf("login %s not found", login.Name)
|
||||
}
|
||||
|
||||
// Save updated config
|
||||
return saveConfig()
|
||||
})
|
||||
}
|
||||
|
||||
// RefreshOAuthTokenIfNeeded refreshes the OAuth token if it's expired or near expiry.
|
||||
@@ -240,22 +230,65 @@ func (l *Login) RefreshOAuthTokenIfNeeded() error {
|
||||
|
||||
// RefreshOAuthToken refreshes the OAuth access token using the refresh token.
|
||||
// It updates the login with new token information and saves it to config.
|
||||
// Uses double-checked locking to avoid unnecessary refresh calls when multiple
|
||||
// processes race to refresh the same token.
|
||||
func (l *Login) RefreshOAuthToken() error {
|
||||
if l.RefreshToken == "" {
|
||||
return fmt.Errorf("no refresh token available")
|
||||
}
|
||||
|
||||
// Create a Token object with current values
|
||||
return withConfigLock(func() error {
|
||||
// Double-check: after acquiring lock, re-read config and check if
|
||||
// another process already refreshed the token
|
||||
for i, login := range config.Logins {
|
||||
if login.Name == l.Name {
|
||||
// Check if token was refreshed by another process
|
||||
if login.TokenExpiry != l.TokenExpiry && login.TokenExpiry > 0 {
|
||||
expiryTime := time.Unix(login.TokenExpiry, 0)
|
||||
if time.Now().Add(TokenRefreshThreshold).Before(expiryTime) {
|
||||
// Token was refreshed by another process, update our copy
|
||||
l.Token = login.Token
|
||||
l.RefreshToken = login.RefreshToken
|
||||
l.TokenExpiry = login.TokenExpiry
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Still need to refresh - proceed with OAuth call
|
||||
newToken, err := doOAuthRefresh(l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update login with new token information
|
||||
l.Token = newToken.AccessToken
|
||||
if newToken.RefreshToken != "" {
|
||||
l.RefreshToken = newToken.RefreshToken
|
||||
}
|
||||
if !newToken.Expiry.IsZero() {
|
||||
l.TokenExpiry = newToken.Expiry.Unix()
|
||||
}
|
||||
|
||||
// Update in config slice and save
|
||||
config.Logins[i] = *l
|
||||
return saveConfigUnsafe()
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("login %s not found", l.Name)
|
||||
})
|
||||
}
|
||||
|
||||
// doOAuthRefresh performs the actual OAuth token refresh API call.
|
||||
func doOAuthRefresh(l *Login) (*oauth2.Token, error) {
|
||||
currentToken := &oauth2.Token{
|
||||
AccessToken: l.Token,
|
||||
RefreshToken: l.RefreshToken,
|
||||
Expiry: time.Unix(l.TokenExpiry, 0),
|
||||
}
|
||||
|
||||
// Set up the OAuth2 config
|
||||
ctx := context.Background()
|
||||
|
||||
// Create HTTP client, respecting the login's TLS settings
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: l.Insecure},
|
||||
@@ -263,7 +296,6 @@ func (l *Login) RefreshOAuthToken() error {
|
||||
}
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
||||
|
||||
// Configure the OAuth2 endpoints
|
||||
oauth2Config := &oauth2.Config{
|
||||
ClientID: DefaultClientID,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
@@ -271,25 +303,12 @@ func (l *Login) RefreshOAuthToken() error {
|
||||
},
|
||||
}
|
||||
|
||||
// Refresh the token
|
||||
newToken, err := oauth2Config.TokenSource(ctx, currentToken).Token()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to refresh token: %w", err)
|
||||
return nil, fmt.Errorf("failed to refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Update login with new token information
|
||||
l.Token = newToken.AccessToken
|
||||
|
||||
if newToken.RefreshToken != "" {
|
||||
l.RefreshToken = newToken.RefreshToken
|
||||
}
|
||||
|
||||
if !newToken.Expiry.IsZero() {
|
||||
l.TokenExpiry = newToken.Expiry.Unix()
|
||||
}
|
||||
|
||||
// Save updated login to config
|
||||
return UpdateLogin(l)
|
||||
return newToken, nil
|
||||
}
|
||||
|
||||
// Client returns a client to operate Gitea API. You may provide additional modifiers
|
||||
|
||||
@@ -157,7 +157,8 @@ func GetDefaultPRTitle(header string) string {
|
||||
// CreateAgitFlowPull creates a agit flow PR in the given repo and prints the result
|
||||
func CreateAgitFlowPull(ctx *context.TeaContext, remote, head, base, topic string,
|
||||
opts *gitea.CreateIssueOption,
|
||||
callback func(string) (string, error)) (err error) {
|
||||
callback func(string) (string, error),
|
||||
) (err error) {
|
||||
// default is default branch
|
||||
if len(base) == 0 {
|
||||
base, err = GetDefaultPRBase(ctx.Login, ctx.Owner, ctx.Repo)
|
||||
|
||||
Reference in New Issue
Block a user