diff --git a/modules/auth/oauth.go b/modules/auth/oauth.go index 96b5168..6311248 100644 --- a/modules/auth/oauth.go +++ b/modules/auth/oauth.go @@ -10,6 +10,7 @@ import ( "crypto/tls" "encoding/base64" "fmt" + "net" "net/http" "net/url" "strconv" @@ -39,7 +40,7 @@ const ( authTimeout = 60 * time.Second // local server settings to receive the callback - redirectPort = 3333 + redirectPort = 0 redirectHost = "127.0.0.1" ) @@ -220,6 +221,7 @@ func startLocalServerAndOpenBrowser(authURL, expectedState string, opts OAuthOpt codeChan := make(chan string, 1) stateChan := make(chan string, 1) errChan := make(chan error, 1) + portChan := make(chan int, 1) // Parse the redirect URL to get the path parsedURL, err := url.Parse(opts.RedirectURL) @@ -245,12 +247,9 @@ func startLocalServerAndOpenBrowser(authURL, expectedState string, opts OAuthOpt if parsedPort := parsedURL.Port(); parsedPort != "" { port, _ = strconv.Atoi(parsedPort) } - if port == 0 { - port = redirectPort - } } - // Server address with explicit port + // Server address with port (may be dynamic if port=0) serverAddr := fmt.Sprintf("%s:%d", hostname, port) // Start local server @@ -297,10 +296,36 @@ func startLocalServerAndOpenBrowser(authURL, expectedState string, opts OAuthOpt }), } + // Listener for getting the actual port when using port 0 + listener, err := net.Listen("tcp", serverAddr) + if err != nil { + return "", "", fmt.Errorf("failed to start local server: %s", err) + } + + // Get the actual port if we used port 0 + if port == 0 { + addr := listener.Addr().(*net.TCPAddr) + port = addr.Port + portChan <- port + + // Update redirect URL with actual port + parsedURL.Host = fmt.Sprintf("%s:%d", hostname, port) + opts.RedirectURL = parsedURL.String() + + // Update the auth URL with the new redirect URL + authURLParsed, err := url.Parse(authURL) + if err == nil { + query := authURLParsed.Query() + query.Set("redirect_uri", opts.RedirectURL) + authURLParsed.RawQuery = query.Encode() + authURL = authURLParsed.String() + } + } + // Start server in a goroutine go func() { - fmt.Printf("Starting local server on %s...\n", server.Addr) - if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + fmt.Printf("Starting local server on %s:%d...\n", hostname, port) + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { errChan <- err } }()