From 5f0aacde52062569accea27c37adc551a2c2d185 Mon Sep 17 00:00:00 2001 From: Joshua Krstic Date: Thu, 16 Oct 2025 22:25:59 +0000 Subject: [PATCH] Refactor container runner to not depend on file system or use a real timer Add test for the wellknownfilewriter --- launcher/container_runner.go | 113 ++++++--- launcher/container_runner_test.go | 278 ++++++++++----------- launcher/launcher/clock/fake_timer.go | 36 +++ launcher/launcher/clock/real_timer_test.go | 26 ++ launcher/launcher/clock/timer.go | 41 +++ 5 files changed, 304 insertions(+), 190 deletions(-) create mode 100644 launcher/launcher/clock/fake_timer.go create mode 100644 launcher/launcher/clock/real_timer_test.go create mode 100644 launcher/launcher/clock/timer.go diff --git a/launcher/container_runner.go b/launcher/container_runner.go index b96741bec..2ad4093d7 100644 --- a/launcher/container_runner.go +++ b/launcher/container_runner.go @@ -31,6 +31,7 @@ import ( "github.com/google/go-tpm-tools/launcher/internal/healthmonitoring/nodeproblemdetector" "github.com/google/go-tpm-tools/launcher/internal/logging" "github.com/google/go-tpm-tools/launcher/internal/signaturediscovery" + "github.com/google/go-tpm-tools/launcher/launcher/clock" "github.com/google/go-tpm-tools/launcher/launcherfile" "github.com/google/go-tpm-tools/launcher/registryauth" "github.com/google/go-tpm-tools/launcher/spec" @@ -269,13 +270,14 @@ func enableMonitoring(enabled spec.MonitoringType, logger logging.Logger) error if enabled != spec.None { logger.Info("Health Monitoring is enabled by the VM operator") - if enabled == spec.All { + switch enabled { + case spec.All: logger.Info("All health monitoring metrics enabled") if err := nodeproblemdetector.EnableAllConfig(); err != nil { logger.Error("Failed to enable full monitoring config: %v", err) return err } - } else if enabled == spec.MemoryOnly { + case spec.MemoryOnly: logger.Info("memory/bytes_used enabled") } @@ -418,10 +420,29 @@ func (r *ContainerRunner) measureMemoryMonitor() error { return nil } +type wellKnownFileLocationWriter struct{} + +func (d wellKnownFileLocationWriter) Write(p []byte) (n int, err error) { + // Write to a temp file first. + tmpTokenPath := path.Join(launcherfile.HostTmpPath, tokenFileTmp) + if err = os.WriteFile(tmpTokenPath, p, 0644); err != nil { + return 0, fmt.Errorf("failed to write a tmp token file: %v", err) + } + + // Rename the temp file to the token file (to avoid race conditions). + if err = os.Rename(tmpTokenPath, path.Join(launcherfile.HostTmpPath, launcherfile.AttestationVerifierTokenFilename)); err != nil { + return 0, fmt.Errorf("failed to rename the token file: %v", err) + } + + return len(p), nil +} + +var _ io.Writer = (*wellKnownFileLocationWriter)(nil) + // Retrieves the default OIDC token from the attestation service, and returns how long // to wait before attemping to refresh it. // The token file will be written to a tmp file and then renamed. -func (r *ContainerRunner) refreshToken(ctx context.Context) (time.Duration, error) { +func (r *ContainerRunner) refreshToken(ctx context.Context, writer io.Writer) (time.Duration, error) { if err := r.attestAgent.Refresh(ctx); err != nil { return 0, fmt.Errorf("failed to refresh attestation agent: %v", err) } @@ -444,15 +465,9 @@ func (r *ContainerRunner) refreshToken(ctx context.Context) (time.Duration, erro return 0, errors.New("token is expired") } - // Write to a temp file first. - tmpTokenPath := path.Join(launcherfile.HostTmpPath, tokenFileTmp) - if err = os.WriteFile(tmpTokenPath, token, 0644); err != nil { - return 0, fmt.Errorf("failed to write a tmp token file: %v", err) - } - - // Rename the temp file to the token file (to avoid race conditions). - if err = os.Rename(tmpTokenPath, path.Join(launcherfile.HostTmpPath, launcherfile.AttestationVerifierTokenFilename)); err != nil { - return 0, fmt.Errorf("failed to rename the token file: %v", err) + _, err = writer.Write(token) + if err != nil { + return 0, fmt.Errorf("failed to write token: %v", err) } // Print out the claims in the jwt payload @@ -467,56 +482,67 @@ func (r *ContainerRunner) refreshToken(ctx context.Context) (time.Duration, erro return getNextRefreshFromExpiration(time.Until(claims.ExpiresAt.Time), rand.Float64()), nil } -// ctx must be a cancellable context. -func (r *ContainerRunner) fetchAndWriteToken(ctx context.Context) error { - return r.fetchAndWriteTokenWithRetry(ctx, defaultRetryPolicy) -} - // ctx must be a cancellable context. // retry specifies the refresher goroutine's retry policy. -func (r *ContainerRunner) fetchAndWriteTokenWithRetry(ctx context.Context, - retry func() *backoff.ExponentialBackOff) error { - if err := os.MkdirAll(launcherfile.HostTmpPath, 0755); err != nil { - return err - } - duration, err := r.refreshToken(ctx) - if err != nil { - return err - } +func (r *ContainerRunner) startTokenRefresher(ctx context.Context, retry func() *backoff.ExponentialBackOff, + newTimer func(d time.Duration) clock.Timer, tokenWriter io.Writer) <-chan error { + r.logger.Info("Starting token refresh goroutine") + + initComplete := make(chan error, 1) - // Set a timer to refresh the token before it expires. - timer := time.NewTimer(duration) go func() { + r.logger.Info("token refresher goroutine started") + defer close(initComplete) + + isInitialized := false // A flag to ensure we only send the initialization signal once. + signalDone := func(err error) { + if !isInitialized { + initComplete <- err // Signal to the calling function that the first refresh is done. + isInitialized = true + } + } + + // Start with a timer that fires immediately to get the first token. + timer := newTimer(0) + defer timer.Stop() + for { select { case <-ctx.Done(): - timer.Stop() - r.logger.Info("token refreshing stopped") + r.logger.Info("Token refreshing stopped") return - case <-timer.C: - r.logger.Info("refreshing attestation verifier OIDC token") + case <-timer.C(): + r.logger.Info("Refreshing attestation verifier OIDC token") var duration time.Duration - // Refresh token with default retry policy. + err := backoff.RetryNotify( func() error { - duration, err = r.refreshToken(ctx) - return err + var refreshErr error + duration, refreshErr = r.refreshToken(ctx, tokenWriter) + return refreshErr }, retry(), func(err error, t time.Duration) { - r.logger.Error(fmt.Sprintf("failed to refresh attestation service token at time %v: %v", t, err)) + r.logger.Error(fmt.Sprintf("failed to refresh token at time %v: %v", t, err)) }) + + // After the first attempt signal to the calling function that the refresh is done. + signalDone(err) + if err != nil { - r.logger.Error(fmt.Sprintf("failed all attempts to refresh attestation service token, stopping refresher: %v", err)) + // If all retry attempts fail, stop the refresher. + r.logger.Error(fmt.Sprintf("failed all attempts to get/refresh token, stopping refresher: %v", err)) return } + // On success, reset the timer for the next refresh. + r.logger.Info("Resetting token refresh timer") timer.Reset(duration) } } }() - return nil + return initComplete } // getNextRefreshFromExpiration returns the Duration for the next run of the @@ -583,7 +609,15 @@ func (r *ContainerRunner) Run(ctx context.Context) error { // Only refresh token if agent has a default GCA client (not ITA use case). if r.launchSpec.ITARegion == "" { - if err := r.fetchAndWriteToken(ctx); err != nil { + // Create the well known token file location. + if err := os.MkdirAll(launcherfile.HostTmpPath, 0755); err != nil { + return err + } + r.logger.Info("Created directory", "path", launcherfile.HostTmpPath) + + errchan := r.startTokenRefresher(ctx, defaultRetryPolicy, clock.NewRealTimer, wellKnownFileLocationWriter{}) + err := <-errchan + if err != nil { return fmt.Errorf("failed to fetch and write OIDC token: %v", err) } } @@ -607,7 +641,6 @@ func (r *ContainerRunner) Run(ctx context.Context) error { attestClients.GCA = gcaClient } - teeServer, err := teeserver.New(ctx, path.Join(launcherfile.HostTmpPath, teeServerSocket), r.attestAgent, r.logger, r.launchSpec, attestClients) if err != nil { return fmt.Errorf("failed to create the TEE server: %v", err) diff --git a/launcher/container_runner_test.go b/launcher/container_runner_test.go index 90dd9ce60..fbf688c7e 100644 --- a/launcher/container_runner_test.go +++ b/launcher/container_runner_test.go @@ -7,10 +7,10 @@ import ( "crypto/rsa" "errors" "fmt" + "io" "os" "path" "strconv" - "sync" "testing" "time" @@ -24,6 +24,7 @@ import ( "github.com/google/go-tpm-tools/cel" "github.com/google/go-tpm-tools/launcher/agent" "github.com/google/go-tpm-tools/launcher/internal/logging" + "github.com/google/go-tpm-tools/launcher/launcher/clock" "github.com/google/go-tpm-tools/launcher/launcherfile" "github.com/google/go-tpm-tools/launcher/spec" "github.com/google/go-tpm-tools/verifier" @@ -43,10 +44,6 @@ type fakeAttestationAgent struct { attestFunc func(context.Context, agent.AttestAgentOpts) ([]byte, error) sigsCache []string sigsFetcherFunc func(context.Context) []string - - // attMu sits on top of attempts field and protects attempts. - attMu sync.Mutex - attempts int } func (f *fakeAttestationAgent) MeasureEvent(event cel.Content) error { @@ -81,11 +78,47 @@ func (f *fakeAttestationAgent) Close() error { return nil } +type fakeClient struct{} + +var _ verifier.Client = (*fakeClient)(nil) + +func (f *fakeClient) CreateChallenge(_ context.Context) (*verifier.Challenge, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (f *fakeClient) VerifyAttestation(_ context.Context, _ verifier.VerifyAttestationRequest) (*verifier.VerifyAttestationResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (f *fakeClient) VerifyConfidentialSpace(_ context.Context, _ verifier.VerifyAttestationRequest) (*verifier.VerifyAttestationResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + type fakeClaims struct { jwt.RegisteredClaims Signatures []string } +type fakeWriter struct { + writes [][]byte +} + +func (f *fakeWriter) Write(p []byte) (n int, err error) { + f.writes = append(f.writes, p) + return 0, nil +} + +func (f *fakeWriter) pop() ([]byte, error) { + if len(f.writes) <= 0 { + return nil, fmt.Errorf("No more writes to return") + } + p := f.writes[0] + f.writes = f.writes[1:] + return p, nil +} + +var _ io.Writer = (*fakeWriter)(nil) + func createJWT(t *testing.T, ttl time.Duration) []byte { return createJWTWithID(t, "test token", ttl) } @@ -164,19 +197,21 @@ func TestRefreshToken(t *testing.T) { t.Fatalf("Error creating host token path directory: %v", err) } - refreshTime, err := runner.refreshToken(ctx) + fw := fakeWriter{} + refreshTime, err := runner.refreshToken(ctx, &fw) if err != nil { t.Fatalf("refreshToken returned with error: %v", err) } - filepath := path.Join(launcherfile.HostTmpPath, launcherfile.AttestationVerifierTokenFilename) - data, err := os.ReadFile(filepath) + // TODO - check file path was created + //filepath := path.Join(launcherfile.HostTmpPath, launcherfile.AttestationVerifierTokenFilename) + data, err := fw.pop() if err != nil { - t.Fatalf("Failed to read from %s: %v", filepath, err) + t.Fatalf("Failed to get a written token %v", err) } if !bytes.Equal(data, expectedToken) { - t.Errorf("Initial token written to file does not match expected token: got %v, want %v", data, expectedToken) + t.Errorf("Initial token written to file does not match expected token: got %v\nwant %v", string(data), string(expectedToken)) } // Expect refreshTime to be no greater than expectedTTL. @@ -209,10 +244,12 @@ func TestRefreshTokenWithSignedContainerCacheEnabled(t *testing.T) { t.Fatalf("Error creating host token path directory: %v", err) } - _, err := runner.refreshToken(ctx) + fw := &fakeWriter{} + _, err := runner.refreshToken(ctx, fw) if err != nil { t.Fatalf("refreshToken returned with error: %v", err) } + fw.pop() // Don't need to check first token // Simulate adding signatures. newCache := []string{"old sigs cache", "new sigs cache"} @@ -221,16 +258,15 @@ func TestRefreshTokenWithSignedContainerCacheEnabled(t *testing.T) { } // Refresh token again to get the updated token. - _, err = runner.refreshToken(ctx) + _, err = runner.refreshToken(ctx, fw) if err != nil { t.Fatalf("refreshToken returned with error: %v", err) } // Read the token to check if claims contain the updated signatures. - filepath := path.Join(launcherfile.HostTmpPath, launcherfile.AttestationVerifierTokenFilename) - token, err := os.ReadFile(filepath) + token, err := fw.pop() if err != nil { - t.Fatalf("Failed to read from %s: %v", filepath, err) + t.Fatalf("Failed to get a written token %v", err) } gotClaims := &fakeClaims{} @@ -278,7 +314,8 @@ func TestRefreshTokenError(t *testing.T) { logger: logging.SimpleLogger(), } - if _, err := runner.refreshToken(context.Background()); err == nil { + fw := fakeWriter{} + if _, err := runner.refreshToken(context.Background(), &fw); err == nil { t.Error("refreshToken succeeded, expected error.") } @@ -286,7 +323,7 @@ func TestRefreshTokenError(t *testing.T) { } } -func TestFetchAndWriteTokenSucceeds(t *testing.T) { +func TestStartTokenRefresherSucceeds(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -301,69 +338,34 @@ func TestFetchAndWriteTokenSucceeds(t *testing.T) { logger: logging.SimpleLogger(), } - if err := runner.fetchAndWriteToken(ctx); err != nil { - t.Fatalf("fetchAndWriteToken failed: %v", err) - } - - filepath := path.Join(launcherfile.HostTmpPath, launcherfile.AttestationVerifierTokenFilename) - data, err := os.ReadFile(filepath) - if err != nil { - t.Fatalf("Failed to read from %s: %v", filepath, err) - } - - if !bytes.Equal(data, expectedToken) { - t.Errorf("Token written to file does not match expected token: got %v, want %v", data, expectedToken) + fakeTimer := clock.NewFakeTimer() + fakeTimerFactory := func(_ time.Duration) clock.Timer { + return &fakeTimer } -} - -func TestTokenIsNotChangedIfRefreshFails(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - expectedToken := createJWT(t, 5*time.Second) - ttl := 5 * time.Second + fw := fakeWriter{} + errchan := runner.startTokenRefresher(ctx, defaultRetryPolicy, fakeTimerFactory, &fw) - attestAgent := &fakeAttestationAgent{} - attestAgent.attestFunc = func(context.Context, agent.AttestAgentOpts) ([]byte, error) { - attestAgent.attMu.Lock() - defer func() { - attestAgent.attempts = attestAgent.attempts + 1 - attestAgent.attMu.Unlock() - }() - if attestAgent.attempts%2 == 0 { - return expectedToken, nil - } - return nil, errors.New("attest unsuccessful") - } + // Hit timer to trigger the goroutine logic + fakeTimer.OutChan <- time.Now() + // Block on timer resetting, signaling a token has been written + <-fakeTimer.ResetChan - runner := ContainerRunner{ - attestAgent: attestAgent, - logger: logging.SimpleLogger(), - } - - if err := runner.fetchAndWriteToken(ctx); err != nil { - t.Fatalf("fetchAndWriteToken failed: %v", err) - } - - filepath := path.Join(launcherfile.HostTmpPath, launcherfile.AttestationVerifierTokenFilename) - data, err := os.ReadFile(filepath) + // TODO - test the directory actually gets made + //filepath := path.Join(launcherfile.HostTmpPath, launcherfile.AttestationVerifierTokenFilename) + data, err := fw.pop() if err != nil { - t.Fatalf("Failed to read from %s: %v", filepath, err) + t.Fatalf("Failed to get a written token %v", err) } if !bytes.Equal(data, expectedToken) { - t.Errorf("Initial token written to file does not match expected token: got %v, want %v", data, expectedToken) + t.Errorf("Token written to file does not match expected token: got %v\nwant %v", string(data), string(expectedToken)) } - time.Sleep(ttl) - - data, err = os.ReadFile(filepath) + // Ensure the goroutine reported no error after the first sync. + err = <-errchan if err != nil { - t.Fatalf("Failed to read from %s: %v", filepath, err) - } - - if !bytes.Equal(data, expectedToken) { - t.Errorf("Expected token to remain the same after unsuccessful refresh attempt: got %v", data) + t.Fatalf("startTokenRefresher failed: %v", err) } } @@ -398,9 +400,8 @@ func testRetryPolicyWithNTries(t *testing.T, numTries int, expectRefresh bool) { defer cancel() expectedInitialToken := createJWTWithID(t, "initial token"+strNum, 5*time.Second) - expectedRefreshToken := createJWTWithID(t, "refresh token"+strNum, 100*time.Second) + expectedRefreshedToken := createJWTWithID(t, "refresh token"+strNum, 100*time.Second) // Wait the initial token's 5s plus a second per retry (MaxInterval). - ttl := time.Duration(numTries)*time.Second + 5*time.Second retry := -1 attestFunc := func(context.Context, agent.AttestAgentOpts) ([]byte, error) { retry++ @@ -409,7 +410,7 @@ func testRetryPolicyWithNTries(t *testing.T, numTries int, expectRefresh bool) { return expectedInitialToken, nil } if retry == numTries { - return expectedRefreshToken, nil + return expectedRefreshedToken, nil } return nil, errors.New("attest unsuccessful") } @@ -417,102 +418,59 @@ func testRetryPolicyWithNTries(t *testing.T, numTries int, expectRefresh bool) { attestAgent: &fakeAttestationAgent{attestFunc: attestFunc}, logger: logging.SimpleLogger(), } - if err := runner.fetchAndWriteTokenWithRetry(ctx, testRetryPolicyThreeTimes); err != nil { - t.Fatalf("fetchAndWriteTokenWithRetry failed: %v", err) - } - filepath := path.Join(launcherfile.HostTmpPath, launcherfile.AttestationVerifierTokenFilename) - data, err := os.ReadFile(filepath) - if err != nil { - t.Fatalf("failed to read from %s: %v", filepath, err) - } - - if !bytes.Equal(data, expectedInitialToken) { - gotClaims := extractJWTClaims(t, data) - wantClaims := extractJWTClaims(t, expectedInitialToken) - t.Errorf("initial token written to file does not match expected token: got ID %v, want ID %v", gotClaims.ID, wantClaims.ID) - } - time.Sleep(ttl) - data, err = os.ReadFile(filepath) - if err != nil { - t.Fatalf("failed to read from %s: %v", filepath, err) + fakeTimer := clock.NewFakeTimer() + fakeTimerFactory := func(_ time.Duration) clock.Timer { + return &fakeTimer } + fw := fakeWriter{} - // No refresh: the token should match initial token. - if !expectRefresh && !bytes.Equal(data, expectedInitialToken) { - gotClaims := extractJWTClaims(t, data) - wantClaims := extractJWTClaims(t, expectedInitialToken) - t.Errorf("token refresher should fail and received token should be the initial token: got ID %v, want ID %v", gotClaims.ID, wantClaims.ID) - } + errchan := runner.startTokenRefresher(ctx, testRetryPolicyThreeTimes, fakeTimerFactory, &fw) - // Should Refresh: the token should match refreshed token. - if expectRefresh && !bytes.Equal(data, expectedRefreshToken) { - gotClaims := extractJWTClaims(t, data) - wantClaims := extractJWTClaims(t, expectedRefreshToken) - t.Errorf("refreshed token did not match expected token: got ID %v, want ID %v", gotClaims.ID, wantClaims.ID) - } -} - -func TestFetchAndWriteTokenWithTokenRefresh(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - expectedToken := createJWT(t, 5*time.Second) - expectedRefreshedToken := createJWT(t, 10*time.Second) - - ttl := 5 * time.Second - - attestAgent := &fakeAttestationAgent{} - attestAgent.attestFunc = func(context.Context, agent.AttestAgentOpts) ([]byte, error) { - attestAgent.attMu.Lock() - defer func() { - attestAgent.attempts = attestAgent.attempts + 1 - attestAgent.attMu.Unlock() - }() - if attestAgent.attempts%2 == 0 { - return expectedToken, nil + // Trigger timer twice, always draining the response chan + for range 2 { + // Hit timer to trigger the goroutine logic + fakeTimer.OutChan <- time.Now() + // Block on timer action. + select { + case <-fakeTimer.ResetChan: // token was written successfully + t.Log("Got a reset signal from the refresh goroutine") + case <-fakeTimer.StopChan: // refresh failed completely + if expectRefresh { + t.Log("Got a stop signal from the refresh goroutine") + t.Fatalf("Failed to get a refreshed token from the goroutine") + } } - return expectedRefreshedToken, nil - } - runner := ContainerRunner{ - attestAgent: attestAgent, - logger: logging.SimpleLogger(), - } - if err := runner.fetchAndWriteToken(ctx); err != nil { - t.Fatalf("fetchAndWriteToken failed: %v", err) } - filepath := path.Join(launcherfile.HostTmpPath, launcherfile.AttestationVerifierTokenFilename) - data, err := os.ReadFile(filepath) + data, err := fw.pop() if err != nil { - t.Fatalf("Failed to read from %s: %v", filepath, err) + t.Fatalf("Failed to get first written token %v", err) } - if !bytes.Equal(data, expectedToken) { - t.Errorf("Initial token written to file does not match expected token: got %v, want %v", data, expectedToken) + if !bytes.Equal(data, expectedInitialToken) { + gotClaims := extractJWTClaims(t, data) + wantClaims := extractJWTClaims(t, expectedInitialToken) + t.Errorf("Initial token written to file does not match expected token: got ID %v, want ID %v", gotClaims.ID, wantClaims.ID) } - // Check that token has not been refreshed yet. - data, err = os.ReadFile(filepath) - if err != nil { - t.Fatalf("Failed to read from %s: %v", filepath, err) + data, err = fw.pop() + if err != nil && expectRefresh { + t.Fatalf("Failed to get second written token %v", err) } - if !bytes.Equal(data, expectedToken) { - t.Errorf("Token unexpectedly refreshed: got %v, want %v", data, expectedRefreshedToken) + // Should Refresh: the token should match refreshed token. + if expectRefresh && !bytes.Equal(data, expectedRefreshedToken) { + gotClaims := extractJWTClaims(t, data) + wantClaims := extractJWTClaims(t, expectedRefreshedToken) + t.Errorf("refreshed token did not match expected token: got ID %v, want ID %v", gotClaims.ID, wantClaims.ID) } - time.Sleep(ttl) - - // Check that token has changed. - data, err = os.ReadFile(filepath) + // Ensure the goroutine reported no error after the first sync. + err = <-errchan if err != nil { - t.Fatalf("Failed to read from %s: %v", filepath, err) - } - - if !bytes.Equal(data, expectedRefreshedToken) { - t.Errorf("Refreshed token written to file does not match expected token: got %v, want %v", data, expectedRefreshedToken) + t.Fatalf("startTokenRefresher failed: %v", err) } } @@ -694,6 +652,26 @@ func TestPullImageWithRetries(t *testing.T) { } } +func TestWellKnownFileLocationWriter(t *testing.T) { + if err := os.MkdirAll(launcherfile.HostTmpPath, 0755); err != nil { + t.Fatalf("Error creating host token path directory: %v", err) + } + + fw := &wellKnownFileLocationWriter{} + testData := []byte("testing") + fw.Write(testData) + + filepath := path.Join(launcherfile.HostTmpPath, launcherfile.AttestationVerifierTokenFilename) + data, err := os.ReadFile(filepath) + if err != nil { + t.Fatalf("Failed to read from %s: %v", filepath, err) + } + + if !bytes.Equal(data, testData) { + t.Errorf("Initial token written to file does not match expected token: got %v, want %v", data, expectedToken) + } +} + // This ensures fakeContainer implements containerd.Container interface. var _ containerd.Container = &fakeContainer{} diff --git a/launcher/launcher/clock/fake_timer.go b/launcher/launcher/clock/fake_timer.go new file mode 100644 index 000000000..b748fbfb7 --- /dev/null +++ b/launcher/launcher/clock/fake_timer.go @@ -0,0 +1,36 @@ +package clock + +import "time" + +// FakeTimer is a fake implementation of Timer for testing purposes. +type FakeTimer struct { + OutChan chan time.Time + ResetChan chan int + StopChan chan int +} + +// NewFakeTimer creates a new instance of FakeTimer. +func NewFakeTimer() FakeTimer { + return FakeTimer{ + OutChan: make(chan time.Time, 1), + ResetChan: make(chan int, 1), + StopChan: make(chan int, 1), + } +} + +// C returns the channel on which the timer will send the current time when it fires. +func (f *FakeTimer) C() <-chan time.Time { + return f.OutChan +} + +// Reset notifies the reset channel to signal that that reset was called. +func (f *FakeTimer) Reset(_ time.Duration) bool { + f.ResetChan <- 1 + return true +} + +// Stop notifies the stop channel to signal that stop was called. +func (f *FakeTimer) Stop() bool { + f.StopChan <- 1 + return true +} diff --git a/launcher/launcher/clock/real_timer_test.go b/launcher/launcher/clock/real_timer_test.go new file mode 100644 index 000000000..caa7a8322 --- /dev/null +++ b/launcher/launcher/clock/real_timer_test.go @@ -0,0 +1,26 @@ +package clock + +import ( + "fmt" + "testing" + "time" +) + +func TestRealTimerFires(_ *testing.T) { + timer := NewRealTimer(100 * time.Millisecond) + time.Sleep(125 * time.Millisecond) + <-timer.C() + + fmt.Printf("%v\n", timer.Reset(100*time.Millisecond)) + time.Sleep(125 * time.Millisecond) + <-timer.C() +} + +func TestRealTimerFiresInstantly(_ *testing.T) { + timer := NewRealTimer(0) + <-timer.C() + + timer.Reset(100 * time.Millisecond) + time.Sleep(125 * time.Millisecond) + <-timer.C() +} diff --git a/launcher/launcher/clock/timer.go b/launcher/launcher/clock/timer.go new file mode 100644 index 000000000..7e6f3dcd9 --- /dev/null +++ b/launcher/launcher/clock/timer.go @@ -0,0 +1,41 @@ +// Package clock contains time-related utilities +package clock + +import "time" + +// Timer is an interface that wraps the basic methods of time.Timer. +type Timer interface { + C() <-chan time.Time + Stop() bool + Reset(time.Duration) bool +} + +// Ensure conformance on Timer +var _ Timer = (*RealTimer)(nil) + +// RealTimer is an implementation of Timer that uses a real time.Timer. +type RealTimer struct { + inner *time.Timer +} + +// NewRealTimer creates and returns a reference to a RealTimer. +func NewRealTimer(d time.Duration) Timer { + return &RealTimer{ + inner: time.NewTimer(d), + } +} + +// C returns the channel on which the timer will send the current time when it fires. +func (t *RealTimer) C() <-chan time.Time { + return t.inner.C +} + +// Stop stops the timer and returns true if the timer was active. +func (t *RealTimer) Stop() bool { + return t.inner.Stop() +} + +// Reset resets the timer to the specified duration and returns true if the timer was active. +func (t *RealTimer) Reset(d time.Duration) bool { + return t.inner.Reset(d) +}