pulumi/sdk/go/common/util/retry/until_test.go

131 lines
2.8 KiB
Go

package retry
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestUntil_exhaustAttempts(t *testing.T) {
t.Parallel()
// Make the math for the test easy.
delay := time.Second
backoff := 2.0
maxDelay := 100 * time.Second
ctx := context.Background()
errTooManyTries := errors.New("too many tries")
afterRec := newAfterRecorder(time.Now())
var attempts int
_, _, err := (&Retryer{
After: afterRec.After,
}).Until(ctx, Acceptor{
Delay: &delay,
Backoff: &backoff,
MaxDelay: &maxDelay,
Accept: func(try int, delay time.Duration) (bool, interface{}, error) {
if try > 3 {
return false, nil, errTooManyTries
}
attempts++
return false, nil, nil // operation failed
},
})
assert.ErrorIs(t, err, errTooManyTries)
assert.Equal(t, []time.Duration{
1 * time.Second,
2 * time.Second,
4 * time.Second,
8 * time.Second,
}, afterRec.Sleeps)
assert.Equal(t, 4, attempts)
}
func TestUntil_contextExpired(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
ok, _, _ := (&Retryer{
After: newAfterRecorder(time.Now()).After,
}).Until(ctx, Acceptor{
Accept: func(try int, delay time.Duration) (bool, interface{}, error) {
if try > 2 {
cancel()
}
return false, nil, nil
},
})
assert.False(t, ok, "should not have succeeded")
// This is surprising behavior (no error instead of ctx.Err())
// but it's risky to change it right now.
// Ideally, the assertion here would be:
// assert.ErrorIs(t, err, context.Canceled)
}
func TestUntil_maxDelay(t *testing.T) {
t.Parallel()
// Make the math for the test easy.
delay := time.Second
backoff := 2.0
maxDelay := 10 * time.Second
ctx := context.Background()
afterRec := newAfterRecorder(time.Now())
ok, _, err := (&Retryer{
After: afterRec.After,
}).Until(ctx, Acceptor{
Delay: &delay,
Backoff: &backoff,
MaxDelay: &maxDelay,
Accept: func(try int, delay time.Duration) (bool, interface{}, error) {
// 100 tries should be enough to reach maxDelay.
if try < 100 {
return false, nil, nil
}
return true, nil, nil
},
})
assert.True(t, ok)
assert.NoError(t, err)
assert.Len(t, afterRec.Sleeps, 100)
for _, d := range afterRec.Sleeps {
assert.LessOrEqual(t, d, maxDelay)
}
}
// afterRecorder implements a time.After variant
// that records all requested sleeps,
// and advances the current time instantly.
type afterRecorder struct {
mu sync.Mutex
now time.Time
// Sleeps is the list of all sleeps requested.
Sleeps []time.Duration
}
func newAfterRecorder(now time.Time) *afterRecorder {
return &afterRecorder{
now: now,
}
}
func (r *afterRecorder) After(d time.Duration) <-chan time.Time {
r.mu.Lock()
r.Sleeps = append(r.Sleeps, d)
r.now = r.now.Add(d)
r.mu.Unlock()
afterc := make(chan time.Time, 1)
afterc <- r.now
return afterc
}