// Copyright 2016-2022, Pulumi Corporation. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package httpstate import ( "context" "fmt" "runtime" "sync" "testing" "time" "github.com/stretchr/testify/assert" ) func TestTokenSource(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Flaky on Windows CI workers due to the use of timer+Sleep") } t.Parallel() ctx := context.Background() dur := 20 * time.Millisecond backend := &testTokenBackend{tokens: map[string]time.Time{}} tok0, tok0Expires := backend.NewToken(dur) ts, err := newTokenSource(ctx, tok0, tok0Expires, dur, backend.Refresh) assert.NoError(t, err) defer ts.Close() for i := 0; i < 32; i++ { tok, err := ts.GetToken(ctx) assert.NoError(t, err) assert.NoError(t, backend.VerifyToken(tok)) t.Logf("STEP: %d, TOKEN: %s", i, tok) // tok0 initially if i == 0 { assert.Equal(t, tok0, tok) } // definitely a fresh token by step 16 // allow some leeway due to time.Sleep concurrency if i > 16 { assert.NotEqual(t, tok0, tok) } time.Sleep(dur / 16) } } func TestTokenSourceWithQuicklyExpiringInitialToken(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Flaky on Windows CI workers due to the use of timer+Sleep") } t.Parallel() ctx := context.Background() dur := 20 * time.Millisecond backend := &testTokenBackend{tokens: map[string]time.Time{}} tok0, tok0Expires := backend.NewToken(dur / 10) ts, err := newTokenSource(ctx, tok0, tok0Expires, dur, backend.Refresh) assert.NoError(t, err) defer ts.Close() for i := 0; i < 8; i++ { tok, err := ts.GetToken(ctx) assert.NoError(t, err) assert.NoError(t, backend.VerifyToken(tok)) t.Logf("STEP: %d, TOKEN: %s", i, tok) time.Sleep(dur / 16) } } type testTokenBackend struct { mu sync.Mutex counter int tokens map[string]time.Time } func (ts *testTokenBackend) NewToken(duration time.Duration) (string, time.Time) { ts.mu.Lock() defer ts.mu.Unlock() return ts.newTokenInner(duration) } func (ts *testTokenBackend) Refresh( ctx context.Context, duration time.Duration, currentToken string, ) (string, time.Time, error) { ts.mu.Lock() defer ts.mu.Unlock() if err := ts.verifyTokenInner(currentToken); err != nil { return "", time.Time{}, err } tok, expires := ts.newTokenInner(duration) return tok, expires, nil } func (ts *testTokenBackend) TokenName(refreshCount int) string { ts.mu.Lock() defer ts.mu.Unlock() return ts.tokenNameInner(refreshCount) } func (ts *testTokenBackend) VerifyToken(token string) error { ts.mu.Lock() defer ts.mu.Unlock() return ts.verifyTokenInner(token) } func (ts *testTokenBackend) newTokenInner(duration time.Duration) (string, time.Time) { now := time.Now() ts.counter++ tok := ts.tokenNameInner(ts.counter) expires := now.Add(duration) ts.tokens[tok] = now.Add(duration) return tok, expires } func (ts *testTokenBackend) tokenNameInner(refreshCount int) string { return fmt.Sprintf("token-%d", ts.counter) } func (ts *testTokenBackend) verifyTokenInner(token string) error { now := time.Now() expires, gotCurrentToken := ts.tokens[token] if !gotCurrentToken { return fmt.Errorf("Unknown token: %v", token) } if now.After(expires) { return fmt.Errorf("Expired token %v (%v past expiration)", token, now.Sub(expires)) } return nil }