mirror of
https://github.com/fatedier/frp.git
synced 2026-03-21 08:23:29 +08:00
auth/oidc: fix eager token fetch at startup, add validation and e2e tests (#5234)
This commit is contained in:
@@ -30,6 +30,7 @@ import (
|
||||
"golang.org/x/oauth2/clientcredentials"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/config/v1/validation"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
)
|
||||
|
||||
@@ -88,6 +89,40 @@ func (s *nonCachingTokenSource) Token() (*oauth2.Token, error) {
|
||||
return s.cfg.Token(s.ctx)
|
||||
}
|
||||
|
||||
// oidcTokenSource wraps a caching oauth2.TokenSource and, on the first
|
||||
// successful Token() call, checks whether the provider returns an expiry.
|
||||
// If not, it permanently switches to nonCachingTokenSource so that a fresh
|
||||
// token is fetched every time. This avoids an eager network call at
|
||||
// construction time, letting the login retry loop handle transient IdP
|
||||
// outages.
|
||||
type oidcTokenSource struct {
|
||||
mu sync.Mutex
|
||||
initialized bool
|
||||
source oauth2.TokenSource
|
||||
fallbackCfg *clientcredentials.Config
|
||||
fallbackCtx context.Context
|
||||
}
|
||||
|
||||
func (s *oidcTokenSource) Token() (*oauth2.Token, error) {
|
||||
s.mu.Lock()
|
||||
if !s.initialized {
|
||||
token, err := s.source.Token()
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
if token.Expiry.IsZero() {
|
||||
s.source = &nonCachingTokenSource{cfg: s.fallbackCfg, ctx: s.fallbackCtx}
|
||||
}
|
||||
s.initialized = true
|
||||
s.mu.Unlock()
|
||||
return token, nil
|
||||
}
|
||||
source := s.source
|
||||
s.mu.Unlock()
|
||||
return source.Token()
|
||||
}
|
||||
|
||||
type OidcAuthProvider struct {
|
||||
additionalAuthScopes []v1.AuthScope
|
||||
|
||||
@@ -95,6 +130,10 @@ type OidcAuthProvider struct {
|
||||
}
|
||||
|
||||
func NewOidcAuthSetter(additionalAuthScopes []v1.AuthScope, cfg v1.AuthOIDCClientConfig) (*OidcAuthProvider, error) {
|
||||
if err := validation.ValidateOIDCClientCredentialsConfig(&cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
eps := make(map[string][]string)
|
||||
for k, v := range cfg.AdditionalEndpointParams {
|
||||
eps[k] = []string{v}
|
||||
@@ -127,24 +166,22 @@ func NewOidcAuthSetter(additionalAuthScopes []v1.AuthScope, cfg v1.AuthOIDCClien
|
||||
// Create a persistent TokenSource that caches the token and refreshes
|
||||
// it before expiry. This avoids making a new HTTP request to the OIDC
|
||||
// provider on every heartbeat/ping.
|
||||
tokenSource := tokenGenerator.TokenSource(ctx)
|
||||
|
||||
// Fetch the initial token to check if the provider returns an expiry.
|
||||
// If Expiry is the zero value (provider omitted expires_in), the cached
|
||||
// TokenSource would treat the token as valid forever and never refresh it,
|
||||
// even after the JWT's exp claim passes. In that case, fall back to
|
||||
// fetching a fresh token on every request.
|
||||
initialToken, err := tokenSource.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to obtain initial OIDC token: %w", err)
|
||||
}
|
||||
if initialToken.Expiry.IsZero() {
|
||||
tokenSource = &nonCachingTokenSource{cfg: tokenGenerator, ctx: ctx}
|
||||
}
|
||||
//
|
||||
// We wrap it in an oidcTokenSource so that the first Token() call
|
||||
// (deferred to SetLogin inside the login retry loop) probes whether the
|
||||
// provider returns expires_in. If not, it switches to a non-caching
|
||||
// source. This avoids an eager network call at construction time, which
|
||||
// would prevent loopLoginUntilSuccess from retrying on transient IdP
|
||||
// outages.
|
||||
cachingSource := tokenGenerator.TokenSource(ctx)
|
||||
|
||||
return &OidcAuthProvider{
|
||||
additionalAuthScopes: additionalAuthScopes,
|
||||
tokenSource: tokenSource,
|
||||
tokenSource: &oidcTokenSource{
|
||||
source: cachingSource,
|
||||
fallbackCfg: tokenGenerator,
|
||||
fallbackCtx: ctx,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -91,8 +91,10 @@ func TestOidcAuthProviderFallsBackWhenNoExpiry(t *testing.T) {
|
||||
)
|
||||
r.NoError(err)
|
||||
|
||||
// Constructor fetches the initial token (1 request).
|
||||
// Each subsequent call should also fetch a fresh token since there is no expiry.
|
||||
// Constructor no longer fetches a token eagerly.
|
||||
// The first SetLogin triggers the adaptive probe.
|
||||
r.Equal(int32(0), requestCount.Load())
|
||||
|
||||
loginMsg := &msg.Login{}
|
||||
err = provider.SetLogin(loginMsg)
|
||||
r.NoError(err)
|
||||
@@ -105,8 +107,8 @@ func TestOidcAuthProviderFallsBackWhenNoExpiry(t *testing.T) {
|
||||
r.Equal("fresh-test-token", pingMsg.PrivilegeKey)
|
||||
}
|
||||
|
||||
// 1 initial (constructor) + 1 login + 3 pings = 5 requests
|
||||
r.Equal(int32(5), requestCount.Load(), "each call should fetch a fresh token when expires_in is missing")
|
||||
// 1 probe (login) + 3 pings = 4 requests (probe doubles as the login token fetch)
|
||||
r.Equal(int32(4), requestCount.Load(), "each call should fetch a fresh token when expires_in is missing")
|
||||
}
|
||||
|
||||
func TestOidcAuthProviderCachesToken(t *testing.T) {
|
||||
@@ -134,10 +136,10 @@ func TestOidcAuthProviderCachesToken(t *testing.T) {
|
||||
)
|
||||
r.NoError(err)
|
||||
|
||||
// Constructor eagerly fetches the initial token (1 request).
|
||||
r.Equal(int32(1), requestCount.Load())
|
||||
// Constructor no longer fetches eagerly; first SetLogin triggers the probe.
|
||||
r.Equal(int32(0), requestCount.Load())
|
||||
|
||||
// SetLogin should reuse the cached token
|
||||
// SetLogin triggers the adaptive probe and caches the token.
|
||||
loginMsg := &msg.Login{}
|
||||
err = provider.SetLogin(loginMsg)
|
||||
r.NoError(err)
|
||||
@@ -153,3 +155,99 @@ func TestOidcAuthProviderCachesToken(t *testing.T) {
|
||||
}
|
||||
r.Equal(int32(1), requestCount.Load(), "token endpoint should only be called once; cached token should be reused")
|
||||
}
|
||||
|
||||
func TestOidcAuthProviderRetriesOnInitialFailure(t *testing.T) {
|
||||
r := require.New(t)
|
||||
|
||||
var requestCount atomic.Int32
|
||||
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
n := requestCount.Add(1)
|
||||
// The oauth2 library retries once internally, so we need two
|
||||
// consecutive failures to surface an error to the caller.
|
||||
if n <= 2 {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": "temporarily_unavailable",
|
||||
"error_description": "service is starting up",
|
||||
})
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{ //nolint:gosec // test-only dummy token response
|
||||
"access_token": "retry-test-token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
})
|
||||
}))
|
||||
defer tokenServer.Close()
|
||||
|
||||
// Constructor succeeds even though the IdP is "down".
|
||||
provider, err := auth.NewOidcAuthSetter(
|
||||
[]v1.AuthScope{v1.AuthScopeHeartBeats},
|
||||
v1.AuthOIDCClientConfig{
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
TokenEndpointURL: tokenServer.URL,
|
||||
},
|
||||
)
|
||||
r.NoError(err)
|
||||
r.Equal(int32(0), requestCount.Load())
|
||||
|
||||
// First SetLogin hits the IdP, which returns an error (after internal retry).
|
||||
loginMsg := &msg.Login{}
|
||||
err = provider.SetLogin(loginMsg)
|
||||
r.Error(err)
|
||||
r.Equal(int32(2), requestCount.Load())
|
||||
|
||||
// Second SetLogin retries and succeeds.
|
||||
err = provider.SetLogin(loginMsg)
|
||||
r.NoError(err)
|
||||
r.Equal("retry-test-token", loginMsg.PrivilegeKey)
|
||||
r.Equal(int32(3), requestCount.Load())
|
||||
|
||||
// Subsequent calls use cached token.
|
||||
pingMsg := &msg.Ping{}
|
||||
err = provider.SetPing(pingMsg)
|
||||
r.NoError(err)
|
||||
r.Equal("retry-test-token", pingMsg.PrivilegeKey)
|
||||
r.Equal(int32(3), requestCount.Load())
|
||||
}
|
||||
|
||||
func TestNewOidcAuthSetterRejectsInvalidStaticConfig(t *testing.T) {
|
||||
r := require.New(t)
|
||||
tokenServer := httptest.NewServer(http.NotFoundHandler())
|
||||
defer tokenServer.Close()
|
||||
|
||||
_, err := auth.NewOidcAuthSetter(nil, v1.AuthOIDCClientConfig{
|
||||
ClientID: "test-client",
|
||||
TokenEndpointURL: "://bad",
|
||||
})
|
||||
r.Error(err)
|
||||
r.Contains(err.Error(), "auth.oidc.tokenEndpointURL")
|
||||
|
||||
_, err = auth.NewOidcAuthSetter(nil, v1.AuthOIDCClientConfig{
|
||||
TokenEndpointURL: tokenServer.URL,
|
||||
})
|
||||
r.Error(err)
|
||||
r.Contains(err.Error(), "auth.oidc.clientID is required")
|
||||
|
||||
_, err = auth.NewOidcAuthSetter(nil, v1.AuthOIDCClientConfig{
|
||||
ClientID: "test-client",
|
||||
TokenEndpointURL: tokenServer.URL,
|
||||
AdditionalEndpointParams: map[string]string{
|
||||
"scope": "profile",
|
||||
},
|
||||
})
|
||||
r.Error(err)
|
||||
r.Contains(err.Error(), "auth.oidc.additionalEndpointParams.scope is not allowed; use auth.oidc.scope instead")
|
||||
|
||||
_, err = auth.NewOidcAuthSetter(nil, v1.AuthOIDCClientConfig{
|
||||
ClientID: "test-client",
|
||||
TokenEndpointURL: tokenServer.URL,
|
||||
Audience: "api",
|
||||
AdditionalEndpointParams: map[string]string{"audience": "override"},
|
||||
})
|
||||
r.Error(err)
|
||||
r.Contains(err.Error(), "cannot specify both auth.oidc.audience and auth.oidc.additionalEndpointParams.audience")
|
||||
}
|
||||
|
||||
@@ -88,6 +88,11 @@ func (v *ConfigValidator) validateAuthConfig(c *v1.AuthClientConfig) (Warning, e
|
||||
if err := v.validateOIDCConfig(&c.OIDC); err != nil {
|
||||
errs = AppendError(errs, err)
|
||||
}
|
||||
if c.Method == v1.AuthMethodOIDC && c.OIDC.TokenSource == nil {
|
||||
if err := ValidateOIDCClientCredentialsConfig(&c.OIDC); err != nil {
|
||||
errs = AppendError(errs, err)
|
||||
}
|
||||
}
|
||||
return nil, errs
|
||||
}
|
||||
|
||||
|
||||
57
pkg/config/v1/validation/oidc.go
Normal file
57
pkg/config/v1/validation/oidc.go
Normal file
@@ -0,0 +1,57 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package validation
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
)
|
||||
|
||||
func ValidateOIDCClientCredentialsConfig(c *v1.AuthOIDCClientConfig) error {
|
||||
var errs []string
|
||||
|
||||
if c.ClientID == "" {
|
||||
errs = append(errs, "auth.oidc.clientID is required")
|
||||
}
|
||||
|
||||
if c.TokenEndpointURL == "" {
|
||||
errs = append(errs, "auth.oidc.tokenEndpointURL is required")
|
||||
} else {
|
||||
tokenURL, err := url.Parse(c.TokenEndpointURL)
|
||||
if err != nil || !tokenURL.IsAbs() || tokenURL.Host == "" {
|
||||
errs = append(errs, "auth.oidc.tokenEndpointURL must be an absolute http or https URL")
|
||||
} else if tokenURL.Scheme != "http" && tokenURL.Scheme != "https" {
|
||||
errs = append(errs, "auth.oidc.tokenEndpointURL must use http or https")
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := c.AdditionalEndpointParams["scope"]; ok {
|
||||
errs = append(errs, "auth.oidc.additionalEndpointParams.scope is not allowed; use auth.oidc.scope instead")
|
||||
}
|
||||
|
||||
if c.Audience != "" {
|
||||
if _, ok := c.AdditionalEndpointParams["audience"]; ok {
|
||||
errs = append(errs, "cannot specify both auth.oidc.audience and auth.oidc.additionalEndpointParams.audience")
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return errors.New(strings.Join(errs, "; "))
|
||||
}
|
||||
78
pkg/config/v1/validation/oidc_test.go
Normal file
78
pkg/config/v1/validation/oidc_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package validation
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
)
|
||||
|
||||
func TestValidateOIDCClientCredentialsConfig(t *testing.T) {
|
||||
tokenServer := httptest.NewServer(http.NotFoundHandler())
|
||||
defer tokenServer.Close()
|
||||
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
require.NoError(t, ValidateOIDCClientCredentialsConfig(&v1.AuthOIDCClientConfig{
|
||||
ClientID: "test-client",
|
||||
TokenEndpointURL: tokenServer.URL,
|
||||
AdditionalEndpointParams: map[string]string{
|
||||
"resource": "api",
|
||||
},
|
||||
}))
|
||||
})
|
||||
|
||||
t.Run("invalid token endpoint url", func(t *testing.T) {
|
||||
err := ValidateOIDCClientCredentialsConfig(&v1.AuthOIDCClientConfig{
|
||||
ClientID: "test-client",
|
||||
TokenEndpointURL: "://bad",
|
||||
})
|
||||
require.ErrorContains(t, err, "auth.oidc.tokenEndpointURL")
|
||||
})
|
||||
|
||||
t.Run("missing client id", func(t *testing.T) {
|
||||
err := ValidateOIDCClientCredentialsConfig(&v1.AuthOIDCClientConfig{
|
||||
TokenEndpointURL: tokenServer.URL,
|
||||
})
|
||||
require.ErrorContains(t, err, "auth.oidc.clientID is required")
|
||||
})
|
||||
|
||||
t.Run("scope endpoint param is not allowed", func(t *testing.T) {
|
||||
err := ValidateOIDCClientCredentialsConfig(&v1.AuthOIDCClientConfig{
|
||||
ClientID: "test-client",
|
||||
TokenEndpointURL: tokenServer.URL,
|
||||
AdditionalEndpointParams: map[string]string{
|
||||
"scope": "email",
|
||||
},
|
||||
})
|
||||
require.ErrorContains(t, err, "auth.oidc.additionalEndpointParams.scope is not allowed; use auth.oidc.scope instead")
|
||||
})
|
||||
|
||||
t.Run("audience conflict", func(t *testing.T) {
|
||||
err := ValidateOIDCClientCredentialsConfig(&v1.AuthOIDCClientConfig{
|
||||
ClientID: "test-client",
|
||||
TokenEndpointURL: tokenServer.URL,
|
||||
Audience: "api",
|
||||
AdditionalEndpointParams: map[string]string{
|
||||
"audience": "override",
|
||||
},
|
||||
})
|
||||
require.ErrorContains(t, err, "cannot specify both auth.oidc.audience and auth.oidc.additionalEndpointParams.audience")
|
||||
})
|
||||
}
|
||||
@@ -100,7 +100,11 @@ func (s *Server) Run() error {
|
||||
}
|
||||
|
||||
func (s *Server) Close() error {
|
||||
return s.hs.Close()
|
||||
err := s.hs.Close()
|
||||
if s.ln != nil {
|
||||
_ = s.ln.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type RouterRegisterHelper struct {
|
||||
|
||||
@@ -131,6 +131,9 @@ func (c *Controller) handlePacket(buf []byte) {
|
||||
}
|
||||
|
||||
func (c *Controller) Stop() error {
|
||||
if c.tun == nil {
|
||||
return nil
|
||||
}
|
||||
return c.tun.Close()
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user