From c6cd9d8be7a01999a28942544902f622007232e3 Mon Sep 17 00:00:00 2001 From: SUDEESH JOHN Date: Mon, 29 Jun 2026 20:36:23 +0530 Subject: [PATCH 1/3] fix: security hardening and stability fixes across orchestrator, transport, and services - infra/hmclog.go: add sync.Mutex to gate raw file writes against concurrent goroutine interleaving; add response-header and regex scrubbing to maskSession so X-API-Session tokens are redacted on login responses and in XML bodies - validation/validator.go, orchestrator/orchestrator.go: replace raw .LoadBalancer.VIP field access with nil-safe .GetVIP() calls in findClusterUsingVIP and all call-sites to prevent nil pointer dereference panic when load_balancer block is omitted from config.yaml - localexec/client.go: embed os.Getpid() in the WriteFile temp path to prevent cross-cluster race condition when two shiftlaunch processes run concurrently - cmd/create.go: wrap deferred logger Close in anonymous func so it evaluates the current orchestrator pointer at exit time rather than at defer registration - services/downloader.go: wire extractHashFromManifest into DownloadRHCOSImages and DownloadOpenShiftTools; remove test -s skip guard so curl -C - resumes partial downloads instead of accepting truncated files after Ctrl+C - services/registry.go: rewrite pull-secret update command to pass credentials via env variable and jq --arg to prevent bash injection from special chars; add shellQuote helper; fix error string casing and punctuation - infra/controller/network.go: replace awk column-split with grep -oP regex in RemoveVIPAlias to be robust against shifted ip -o output columns - go.mod: add replace directive for IBM/infra-go-sdk pending IBM/infra-go-sdk#12 Signed-off-by: SUDEESH JOHN --- cmd/create.go | 7 +- cmd/createtemplate.go | 4 +- cmd/scale.go | 8 +-- go.mod | 2 +- go.sum | 4 +- infra/compute/hmc.go | 135 +++++++++++++++++------------------ infra/compute/provider.go | 12 ++-- infra/controller/network.go | 6 +- infra/hmclog.go | 105 ++++++++++++++++++++------- localexec/client.go | 5 +- logger/logger.go | 13 ++-- orchestrator/orchestrator.go | 22 +++--- services/downloader.go | 31 ++++---- services/registry.go | 28 +++++--- validation/validator.go | 14 ++-- 15 files changed, 228 insertions(+), 168 deletions(-) diff --git a/cmd/create.go b/cmd/create.go index 923214f..e794bec 100644 --- a/cmd/create.go +++ b/cmd/create.go @@ -52,8 +52,11 @@ func runCreate(cmd *cobra.Command, args []string) error { return err } - // Ensure logger file descriptor is closed when command completes - defer orch.GetLogger().Close() + // Ensure logger file descriptor is closed when command completes. + // The anonymous func is required so the pointer is evaluated at exit time, + // not now — orch may be replaced with a new orchestrator further below if + // the workspace was previously deleted. + defer func() { orch.GetLogger().Close() }() ctx := GetContext() log := orch.GetLogger() diff --git a/cmd/createtemplate.go b/cmd/createtemplate.go index 77427bb..769adf9 100644 --- a/cmd/createtemplate.go +++ b/cmd/createtemplate.go @@ -7,8 +7,8 @@ import ( "strings" "text/template" - "github.com/spf13/cobra" "github.com/IBM/shiftlaunch/logger" + "github.com/spf13/cobra" ) var ( @@ -47,7 +47,7 @@ func init() { rootCmd.AddCommand(generateConfigCmd) generateConfigCmd.Flags().StringVarP(&genConfigType, "type", "t", "multi", "Cluster topology: 'sno' or 'multi'") - generateConfigCmd.Flags().StringVarP(&genBootMethod, "boot", "b", "agent", "Boot method: 'agent' or 'netboot'") + generateConfigCmd.Flags().StringVarP(&genBootMethod, "boot", "b", "netboot", "Boot method: 'agent' or 'netboot'") generateConfigCmd.Flags().StringVarP(&genOutputPath, "output", "o", "config.yaml", "Path to save the generated file") generateConfigCmd.Flags().StringVar(&genReleaseType, "release-type", "official", "Payload type: 'official' or 'ci'") diff --git a/cmd/scale.go b/cmd/scale.go index 97f2429..b260f59 100644 --- a/cmd/scale.go +++ b/cmd/scale.go @@ -319,7 +319,7 @@ func runScale(cmd *cobra.Command, args []string) error { if targetUUID != "" { // Power off LPAR immediately without deleting it - _, err := hmcProvider.GetHMCClient().PowerOffPartition(ctx, targetUUID, "Immediate", false, true) + _, err := hmcProvider.GetHMCClient().PowerOffPartition(ctx, targetUUID, "Immediate", false) if err != nil && !strings.Contains(strings.ToLower(err.Error()), "unavailable in the current partition state") { log.Warn("Failed to power off LPAR", "error", err) } @@ -329,17 +329,17 @@ func runScale(cmd *cobra.Command, args []string) error { for i, mapping := range state.ISOMappings { if mapping.NodeName == target.Hostname { // Resolve System UUID required for unmapping - sysUUID, _, _ := hmcProvider.GetHMCClient().GetManagedSystemByName(ctx, mapping.SystemName, debug) + sysUUID, _, _ := hmcProvider.GetHMCClient().GetManagedSystemByName(ctx, mapping.SystemName) if sysUUID != "" { // Unmap ISO from LPAR - _, err := hmcProvider.GetHMCClient().DeleteVirtualOpticalMaps(ctx, sysUUID, mapping.VIOSUUID, targetUUID, []string{mapping.MediaName}, debug) + _, err := hmcProvider.GetHMCClient().DeleteVirtualOpticalMaps(ctx, sysUUID, mapping.VIOSUUID, targetUUID, []string{mapping.MediaName}) if err != nil { log.Warn("Failed to unmap ISO from LPAR", "error", err) } else { time.Sleep(3 * time.Second) // Let VIOS digest the unmap // Delete ISO from repository - err = hmcProvider.GetHMCClient().DeleteVirtualOpticalMedia(ctx, mapping.SystemName, mapping.VIOSName, mapping.MediaName, debug) + err = hmcProvider.GetHMCClient().DeleteVirtualOpticalMedia(ctx, mapping.SystemName, mapping.VIOSName, mapping.MediaName) if err != nil && !strings.Contains(err.Error(), "not found") { log.Warn("Failed to delete ISO from VIOS repository", "error", err) } diff --git a/go.mod b/go.mod index ee96c42..c399f9c 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( ) require ( - github.com/IBM/infra-go-sdk v0.0.0-20260626121933-c8707cdac649 + github.com/IBM/infra-go-sdk v0.0.0-20260630034014-da6ea8cba64d github.com/spf13/cobra v1.10.2 github.com/spf13/viper v1.21.0 go.yaml.in/yaml/v3 v3.0.4 diff --git a/go.sum b/go.sum index c0c02b4..eee0930 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,8 @@ atomicgo.dev/keyboard v0.2.10 h1:v7mvUKUZLHIggxULEIuWbT+WkkyQSgdbA201EziAhHU= atomicgo.dev/keyboard v0.2.10/go.mod h1:ap/z5ilnhLqYq852m6kPeTq5Z6aESGWu5mzRpJlC6aI= atomicgo.dev/schedule v0.1.0 h1:nTthAbhZS5YZmgYbb2+DH8uQIZcTlIrd4eYr3UQxEjs= atomicgo.dev/schedule v0.1.0/go.mod h1:xeUa3oAkiuHYh8bKiQBRojqAMq3PXXbJujjb0hw8pEU= -github.com/IBM/infra-go-sdk v0.0.0-20260626121933-c8707cdac649 h1:02DjfvDqLdA5GDMkuSmE5jQ/6tlluwjPcuTkK7dGkio= -github.com/IBM/infra-go-sdk v0.0.0-20260626121933-c8707cdac649/go.mod h1:bDk0yc6n1Wx+HtxwlZpsXqWx8x3Cag9YV9DkogrdueA= +github.com/IBM/infra-go-sdk v0.0.0-20260630034014-da6ea8cba64d h1:ANPjiwzkHkQDmon7mwbnn/6e5sDbluUmmN+96yuJvDE= +github.com/IBM/infra-go-sdk v0.0.0-20260630034014-da6ea8cba64d/go.mod h1:ig0leiRacKahGoO8KloXH8yJI3z6a3d6d6GYk6EtX08= github.com/MarvinJWendt/testza v0.5.2 h1:53KDo64C1z/h/d/stCYCPY69bt/OSwjq5KpFNwi+zB4= github.com/MarvinJWendt/testza v0.5.2/go.mod h1:xu53QFE5sCdjtMCKk8YMQ2MnymimEctc4n3EjyIYvEY= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= diff --git a/infra/compute/hmc.go b/infra/compute/hmc.go index 72d0175..95b6ade 100644 --- a/infra/compute/hmc.go +++ b/infra/compute/hmc.go @@ -55,14 +55,14 @@ func (h *HMCProvider) DiscoverMetadata(ctx context.Context) error { for _, node := range nodes { if _, exists := systemData[node.SystemName]; !exists { apiTrafficMutex.RLock() - sysUUID, _, err := h.hmcClient.GetManagedSystemByName(ctx, node.SystemName, false) + sysUUID, _, err := h.hmcClient.GetManagedSystemByName(ctx, node.SystemName) apiTrafficMutex.RUnlock() if err != nil { return fmt.Errorf("failed to find system %s: %w", node.SystemName, err) } - + apiTrafficMutex.RLock() - lpars, err := h.hmcClient.GetLogicalPartitionsQuickAll(ctx, sysUUID, false) + lpars, err := h.hmcClient.GetLogicalPartitionsQuickAll(ctx, sysUUID) apiTrafficMutex.RUnlock() if err != nil { return err @@ -128,7 +128,7 @@ func (h *HMCProvider) DiscoverMetadata(ctx context.Context) error { testErr := call(shieldedCtx) if testErr != nil && strings.Contains(testErr.Error(), "406") { _ = h.hmcClient.Logoff(shieldedCtx) - _ = h.hmcClient.Login(shieldedCtx, h.cfg.HMC.Username, h.cfg.HMC.Password, h.debug) + _ = h.hmcClient.Login(shieldedCtx, h.cfg.HMC.Username, h.cfg.HMC.Password) } apiTrafficMutex.Unlock() continue @@ -141,7 +141,7 @@ func (h *HMCProvider) DiscoverMetadata(ctx context.Context) error { var profiles []hmc.LogicalPartitionProfile err := apiCallWithRetry(func(c context.Context) error { var e error - profiles, e = h.hmcClient.GetLogicalPartitionProfiles(c, node.UUID, false) + profiles, e = h.hmcClient.GetLogicalPartitionProfiles(c, node.UUID) return e }) if err != nil || len(profiles) == 0 { @@ -153,7 +153,7 @@ func (h *HMCProvider) DiscoverMetadata(ctx context.Context) error { var adapters []hmc.ClientNetworkAdapter err = apiCallWithRetry(func(c context.Context) error { var e error - adapters, e = h.hmcClient.GetClientNetworkAdapters(c, sysUUID, node.UUID, false) + adapters, e = h.hmcClient.GetClientNetworkAdapters(c, sysUUID, node.UUID) return e }) if err != nil || len(adapters) == 0 { @@ -166,7 +166,7 @@ func (h *HMCProvider) DiscoverMetadata(ctx context.Context) error { var volumes []hmc.StorageMap _ = apiCallWithRetry(func(c context.Context) error { var e error - volumes, e = h.hmcClient.GetAttachedVolumes(c, sysUUID, node.UUID, false) + volumes, e = h.hmcClient.GetAttachedVolumes(c, sysUUID, node.UUID) return e }) if len(volumes) == 0 { @@ -327,7 +327,7 @@ func (h *HMCProvider) networkBootLpar(ctx context.Context, node *types.NodeConfi // Retry loop for 406/Intermittent HMC errors with re-authentication for i := 0; i < maxRetries; i++ { - lparDetailed, err = h.hmcClient.GetLogicalPartitionDetailed(ctx, node.UUID, true) + lparDetailed, err = h.hmcClient.GetLogicalPartitionDetailed(ctx, node.UUID) if err == nil { break } @@ -340,7 +340,7 @@ func (h *HMCProvider) networkBootLpar(ctx context.Context, node *types.NodeConfi apiTrafficMutex.Lock() // Do a quick test to see if another thread already fixed the token while we were waiting in line! - _, testErr := h.hmcClient.GetLogicalPartitionDetailed(ctx, node.UUID, false) + _, testErr := h.hmcClient.GetLogicalPartitionDetailed(ctx, node.UUID) if testErr != nil && strings.Contains(testErr.Error(), "406") { // Logout from old session first if logoutErr := h.hmcClient.Logoff(ctx); logoutErr != nil { @@ -351,7 +351,7 @@ func (h *HMCProvider) networkBootLpar(ctx context.Context, node *types.NodeConfi time.Sleep(2 * time.Second) // Re-authenticate with fresh session - if loginErr := h.hmcClient.Login(ctx, h.cfg.HMC.Username, h.cfg.HMC.Password, h.debug); loginErr != nil { + if loginErr := h.hmcClient.Login(ctx, h.cfg.HMC.Username, h.cfg.HMC.Password); loginErr != nil { h.logger.Warn("Re-authentication failed", "error", loginErr) } else { h.logger.Info("Successfully re-authenticated with HMC") @@ -381,10 +381,9 @@ func (h *HMCProvider) networkBootLpar(ctx context.Context, node *types.NodeConfi h.cfg.HMC.Password, node.SystemName, node.ExistingLPARName, - true, ) - _, err = h.hmcClient.PowerOffPartition(ctx, node.UUID, "Immediate", false, true) + _, err = h.hmcClient.PowerOffPartition(ctx, node.UUID, "Immediate", false) if err != nil { return fmt.Errorf("failed to power off LPAR: %w", err) } @@ -405,7 +404,7 @@ func (h *HMCProvider) networkBootLpar(ctx context.Context, node *types.NodeConfi _, err = h.hmcClient.PowerOnPartition(ctx, node.UUID, &hmc.PowerOnOptions{ ProfileUUID: profileUUID, BootMode: "of", // Boot to Open Firmware - }, true) + }) if err != nil { return fmt.Errorf("failed to power on LPAR for adapter registration: %w", err) } @@ -414,7 +413,7 @@ func (h *HMCProvider) networkBootLpar(ctx context.Context, node *types.NodeConfi time.Sleep(20 * time.Second) h.logger.Info("Powering off LPAR for profile query...") - _, err = h.hmcClient.PowerOffPartition(ctx, node.UUID, "Immediate", false, true) + _, err = h.hmcClient.PowerOffPartition(ctx, node.UUID, "Immediate", false) if err != nil { return fmt.Errorf("failed to power off LPAR: %w", err) } @@ -428,7 +427,6 @@ func (h *HMCProvider) networkBootLpar(ctx context.Context, node *types.NodeConfi h.cfg.HMC.Password, node.SystemName, node.ExistingLPARName, - true, ) // ========================================================================= @@ -438,7 +436,7 @@ func (h *HMCProvider) networkBootLpar(ctx context.Context, node *types.NodeConfi var bootDevices []hmc.NetworkBootDevice for i := 0; i < maxRetries; i++ { - bootDevices, err = h.hmcClient.GetNetworkBootDevicesForLpar(ctx, node.UUID, profileUUID, true) + bootDevices, err = h.hmcClient.GetNetworkBootDevicesForLpar(ctx, node.UUID, profileUUID) if err == nil && len(bootDevices) > 0 { break } @@ -466,7 +464,6 @@ func (h *HMCProvider) networkBootLpar(ctx context.Context, node *types.NodeConfi h.cfg.HMC.Password, node.SystemName, node.ExistingLPARName, - true, ) time.Sleep(3 * time.Second) // Give the HMC SSH daemon a moment to drop the connection @@ -482,7 +479,7 @@ func (h *HMCProvider) networkBootLpar(ctx context.Context, node *types.NodeConfi Netmask: "0.0.0.0", } - status, err := h.hmcClient.PowerOnPartition(ctx, node.UUID, options, true) + status, err := h.hmcClient.PowerOnPartition(ctx, node.UUID, options) if err != nil { return fmt.Errorf("failed to execute network boot: %w", err) } @@ -491,7 +488,7 @@ func (h *HMCProvider) networkBootLpar(ctx context.Context, node *types.NodeConfi h.logger.Info("Saving profile to persist configuration...") // Shield from cancellation to prevent profile corruption - _ = h.hmcClient.SaveCurrentLparConfig(context.WithoutCancel(ctx), node.UUID, "default_profile", true, true) + _ = h.hmcClient.SaveCurrentLparConfig(context.WithoutCancel(ctx), node.UUID, "default_profile", true) return nil } @@ -532,14 +529,14 @@ func (h *HMCProvider) PowerOffNodes(ctx context.Context) error { h.logger.Info("UUID not found in state file, querying HMC fallback...", "lpar", node.ExistingLPARName) // 1. Get System UUID quietly - sysUUID, _, err := h.hmcClient.GetManagedSystemByName(ctx, node.SystemName, false) + sysUUID, _, err := h.hmcClient.GetManagedSystemByName(ctx, node.SystemName) if err != nil { h.logger.Debug("Could not resolve system UUID during teardown", "system", node.SystemName) continue } // 2. Get LPARs quietly - lpars, err := h.hmcClient.GetLogicalPartitionsQuickAll(ctx, sysUUID, false) + lpars, err := h.hmcClient.GetLogicalPartitionsQuickAll(ctx, sysUUID) if err != nil { continue } @@ -570,7 +567,7 @@ func (h *HMCProvider) PowerOffNodes(ctx context.Context) error { h.logger.Info("Attempting to power off LPAR", "lpar", node.ExistingLPARName, "uuid", node.UUID) // Send the immediate power off signal. - _, err := h.hmcClient.PowerOffPartition(ctx, node.UUID, "Immediate", false, true) + _, err := h.hmcClient.PowerOffPartition(ctx, node.UUID, "Immediate", false) if err != nil { // Extract just the first line of the error for cleaner logging errMsg := strings.Split(err.Error(), "\n")[0] @@ -603,7 +600,7 @@ func (h *HMCProvider) PowerOffNodes(ctx context.Context) error { } // Natively query the lightweight JSON endpoint for the exact LPAR state - lpar, err := h.hmcClient.GetLogicalPartitionQuick(node.UUID, false) + lpar, err := h.hmcClient.GetLogicalPartitionQuick(node.UUID) if err == nil && lpar != nil { state := strings.ToLower(lpar.PartitionState) if state != "not activated" { @@ -652,10 +649,10 @@ func (h *HMCProvider) bootNodeWithISO(ctx context.Context, node *types.NodeConfi // ======================================================================== // STEP 1.5: ENSURE LPAR IS POWERED OFF // ======================================================================== - lparDetails, err := h.hmcClient.GetLogicalPartitionDetailed(ctx, node.UUID, h.debug) + lparDetails, err := h.hmcClient.GetLogicalPartitionDetailed(ctx, node.UUID) if err == nil && (lparDetails.PartitionState == "running" || lparDetails.PartitionState == "open firmware") { h.logger.Info("LPAR is active. Powering off before ISO boot...", "state", lparDetails.PartitionState) - _, _ = h.hmcClient.PowerOffPartition(ctx, node.UUID, "Immediate", false, true) + _, _ = h.hmcClient.PowerOffPartition(ctx, node.UUID, "Immediate", false) h.logger.Info("Waiting 15 seconds for LPAR to fully power off...") time.Sleep(15 * time.Second) } @@ -663,7 +660,7 @@ func (h *HMCProvider) bootNodeWithISO(ctx context.Context, node *types.NodeConfi // Step 2: Ensure viosadmin user exists (required for VIOS operations) h.logger.Info("Checking viosadmin user on HMC") // Shield from cancellation to prevent partially created VIOS admin account - viosUsername, viosPassword, viosUserCreated, err := h.hmcClient.EnsureVIOSAdminUser(context.WithoutCancel(ctx), h.cfg.HMC.Username, h.cfg.HMC.Password, h.debug) + viosUsername, viosPassword, viosUserCreated, err := h.hmcClient.EnsureVIOSAdminUser(context.WithoutCancel(ctx), h.cfg.HMC.Username, h.cfg.HMC.Password) if err != nil { return fmt.Errorf("failed to ensure viosadmin user: %w", err) } @@ -698,7 +695,7 @@ func (h *HMCProvider) bootNodeWithISO(ctx context.Context, node *types.NodeConfi } // Step 4: Get system UUID - sysUUID, _, err := h.hmcClient.GetManagedSystemByName(ctx, node.SystemName, true) + sysUUID, _, err := h.hmcClient.GetManagedSystemByName(ctx, node.SystemName) if err != nil { return fmt.Errorf("failed to get system UUID: %w", err) } @@ -777,7 +774,7 @@ func (h *HMCProvider) bootNodeWithISO(ctx context.Context, node *types.NodeConfi // Create mount directory mkdirCmd := fmt.Sprintf(`viosvrcmd -m %s -p %s -c "mkdir -p %s" --admin`, node.SystemName, viosName, mountPoint) - if _, err := hmc.CliRunnerViaSSH(h.cfg.HMC.IP, viosUsername, viosPassword, mkdirCmd, h.debug); err != nil { + if _, err := hmc.CliRunnerViaSSH(h.cfg.HMC.IP, viosUsername, viosPassword, mkdirCmd); err != nil { return fmt.Errorf("failed to create mount directory: %w", err) } @@ -788,7 +785,7 @@ func (h *HMCProvider) bootNodeWithISO(ctx context.Context, node *types.NodeConfi maxRetries := 3 for i := 0; i < maxRetries; i++ { // Shield from cancellation to prevent locked VIOS mount daemon - _, mountErr = hmc.MountNFS(context.WithoutCancel(ctx), h.hmcClient, node.SystemName, viosName, nfsServer, exportPath, mountPoint, "3", h.debug) + _, mountErr = hmc.MountNFS(context.WithoutCancel(ctx), h.hmcClient, node.SystemName, viosName, nfsServer, exportPath, mountPoint, "3") if mountErr == nil || strings.Contains(mountErr.Error(), "already mounted") { mountErr = nil break @@ -797,7 +794,7 @@ func (h *HMCProvider) bootNodeWithISO(ctx context.Context, node *types.NodeConfi h.logger.Warn(fmt.Sprintf("HMC session corrupted during NFS mount (attempt %d/%d). Re-authenticating...", i+1, maxRetries)) _ = h.hmcClient.Logoff(ctx) time.Sleep(2 * time.Second) - _ = h.hmcClient.Login(ctx, h.cfg.HMC.Username, h.cfg.HMC.Password, h.debug) + _ = h.hmcClient.Login(ctx, h.cfg.HMC.Username, h.cfg.HMC.Password) time.Sleep(3 * time.Second) continue } @@ -873,7 +870,7 @@ func (h *HMCProvider) bootNodeWithISO(ctx context.Context, node *types.NodeConfi h.logger.Info("Refreshing HMC session before ISO upload...") _ = h.hmcClient.Logoff(ctx) time.Sleep(2 * time.Second) - if err := h.hmcClient.Login(ctx, h.cfg.HMC.Username, h.cfg.HMC.Password, h.debug); err != nil { + if err := h.hmcClient.Login(ctx, h.cfg.HMC.Username, h.cfg.HMC.Password); err != nil { return fmt.Errorf("failed to refresh HMC session: %w", err) } time.Sleep(3 * time.Second) @@ -890,7 +887,6 @@ func (h *HMCProvider) bootNodeWithISO(ctx context.Context, node *types.NodeConfi 0, // sizeMB (not used when sourceFile is provided) true, // readOnly (create with -ro flag) false, // nfsLink (MUST be false to allow concurrent node booting - VIOS copies ISO locally) - h.debug, // debug ) if err != nil { return fmt.Errorf("failed to create optical media on VIOS '%s' (System: '%s'): %w", viosName, node.SystemName, err) @@ -905,7 +901,7 @@ func (h *HMCProvider) bootNodeWithISO(ctx context.Context, node *types.NodeConfi alreadyMapped := false var mediaToUnmap []string - mappings, mapCheckErr := h.hmcClient.GetViosSCSIMappings(ctx, viosUUID, h.debug) + mappings, mapCheckErr := h.hmcClient.GetViosSCSIMappings(ctx, viosUUID) if mapCheckErr != nil { h.logger.Warn("Failed to fetch VIOS mappings for verification, proceeding with map attempt", "error", mapCheckErr) } else { @@ -928,7 +924,7 @@ func (h *HMCProvider) bootNodeWithISO(ctx context.Context, node *types.NodeConfi if len(mediaToUnmap) > 0 { h.logger.Info("Found stale optical media mapped to LPAR. Unmapping...", "lpar", node.ExistingLPARName, "stale_media", mediaToUnmap) // Shield from cancellation to prevent orphaned vSCSI adapters - _, err = h.hmcClient.DeleteVirtualOpticalMaps(context.WithoutCancel(ctx), sysUUID, viosUUID, node.UUID, mediaToUnmap, h.debug) + _, err = h.hmcClient.DeleteVirtualOpticalMaps(context.WithoutCancel(ctx), sysUUID, viosUUID, node.UUID, mediaToUnmap) if err != nil { h.logger.Warn("Failed to unmap stale media. The mapping step may fail.", "error", err) } else { @@ -943,7 +939,7 @@ func (h *HMCProvider) bootNodeWithISO(ctx context.Context, node *types.NodeConfi h.logger.Info("Mapping optical media to LPAR", "lpar", node.ExistingLPARName, "media", mediaName) // Shield from cancellation to prevent orphaned vSCSI adapters - _, err = h.hmcClient.CreateVirtualOpticalMaps(context.WithoutCancel(ctx), sysUUID, viosUUID, node.UUID, []string{mediaName}, h.debug) + _, err = h.hmcClient.CreateVirtualOpticalMaps(context.WithoutCancel(ctx), sysUUID, viosUUID, node.UUID, []string{mediaName}) if err != nil { return fmt.Errorf("failed to map optical media: %w", err) } @@ -1009,7 +1005,7 @@ func (h *HMCProvider) bootNodeWithISO(ctx context.Context, node *types.NodeConfi h.logger.Info("Saving partition profile", "profile", profileName) // Shield from cancellation to prevent profile corruption - err = h.hmcClient.SaveCurrentLparConfig(context.WithoutCancel(ctx), node.UUID, profileName, true, h.debug) + err = h.hmcClient.SaveCurrentLparConfig(context.WithoutCancel(ctx), node.UUID, profileName, true) if err != nil { return fmt.Errorf("failed to save partition profile: %w", err) } @@ -1020,7 +1016,7 @@ func (h *HMCProvider) bootNodeWithISO(ctx context.Context, node *types.NodeConfi h.logger.Info("Setting Pending Boot String to 'cd/dvd-all'...") // Shield from cancellation to prevent boot definition corruption - err = h.hmcClient.SetPartitionBootString(context.WithoutCancel(ctx), node.UUID, "cd/dvd-all", h.debug) + err = h.hmcClient.SetPartitionBootString(context.WithoutCancel(ctx), node.UUID, "cd/dvd-all") if err != nil { h.logger.Warn("Failed to set boot string (may require manual SMS boot)", "error", err) } else { @@ -1030,7 +1026,7 @@ func (h *HMCProvider) bootNodeWithISO(ctx context.Context, node *types.NodeConfi // ======================================================================== // STEP 10: GET PROFILE UUID AND POWER ON // ======================================================================== - lparDetails2, err2 := h.hmcClient.GetLogicalPartitionDetailed(ctx, node.UUID, h.debug) + lparDetails2, err2 := h.hmcClient.GetLogicalPartitionDetailed(ctx, node.UUID) if err2 != nil { return fmt.Errorf("failed to get LPAR details: %w", err2) } @@ -1051,7 +1047,7 @@ func (h *HMCProvider) bootNodeWithISO(ctx context.Context, node *types.NodeConfi Keylock: "normal", } - _, err = h.hmcClient.PowerOnPartition(ctx, node.UUID, powerOnOpts, h.debug) + _, err = h.hmcClient.PowerOnPartition(ctx, node.UUID, powerOnOpts) if err != nil { if strings.Contains(err.Error(), "already running") { h.logger.Info("LPAR already running") @@ -1067,12 +1063,12 @@ func (h *HMCProvider) bootNodeWithISO(ctx context.Context, node *types.NodeConfi // getActiveVIOS discovers and returns the first active VIOS on the system func (h *HMCProvider) getActiveVIOS(ctx context.Context, systemName string) (uuid, name string, err error) { - sysUUID, _, err := h.hmcClient.GetManagedSystemByName(ctx, systemName, h.debug) + sysUUID, _, err := h.hmcClient.GetManagedSystemByName(ctx, systemName) if err != nil { return "", "", err } - viosList, err := h.hmcClient.GetVirtualIOServersQuick(ctx, sysUUID, h.debug) + viosList, err := h.hmcClient.GetVirtualIOServersQuick(ctx, sysUUID) if err != nil { return "", "", err } @@ -1086,7 +1082,7 @@ func (h *HMCProvider) getActiveVIOS(ctx context.Context, systemName string) (uui viosUUIDs[i] = v.UUID } - activeVIOSMap, err := h.hmcClient.GetActiveVIOSServers(ctx, sysUUID, viosUUIDs, h.debug) + activeVIOSMap, err := h.hmcClient.GetActiveVIOSServers(ctx, sysUUID, viosUUIDs) if err != nil { return "", "", err } @@ -1139,7 +1135,7 @@ func (h *HMCProvider) bootNodesWithISOBulk(ctx context.Context) error { } h.logger.Info("Checking viosadmin user on HMC") - viosUsername, viosPassword, viosUserCreated, err := h.hmcClient.EnsureVIOSAdminUser(context.WithoutCancel(ctx), h.cfg.HMC.Username, h.cfg.HMC.Password, h.debug) + viosUsername, viosPassword, viosUserCreated, err := h.hmcClient.EnsureVIOSAdminUser(context.WithoutCancel(ctx), h.cfg.HMC.Username, h.cfg.HMC.Password) if err != nil { return fmt.Errorf("failed to ensure viosadmin user: %w", err) } @@ -1180,10 +1176,10 @@ func (h *HMCProvider) bootNodesWithISOBulk(ctx context.Context) error { h.logger.Info("Preparing node for ISO boot...", "node", node.ExistingLPARName) // Ensure LPAR is powered off - lparDetails, err := h.hmcClient.GetLogicalPartitionDetailed(ctx, node.UUID, h.debug) + lparDetails, err := h.hmcClient.GetLogicalPartitionDetailed(ctx, node.UUID) if err == nil && (lparDetails.PartitionState == "running" || lparDetails.PartitionState == "open firmware") { h.logger.Info("LPAR is active. Powering off before ISO boot...", "state", lparDetails.PartitionState) - _, _ = h.hmcClient.PowerOffPartition(ctx, node.UUID, "Immediate", false, true) + _, _ = h.hmcClient.PowerOffPartition(ctx, node.UUID, "Immediate", false) time.Sleep(15 * time.Second) } @@ -1201,7 +1197,7 @@ func (h *HMCProvider) bootNodesWithISOBulk(ctx context.Context) error { h.systemVIOSNames[node.SystemName] = viosName } - sysUUID, _, err := h.hmcClient.GetManagedSystemByName(ctx, node.SystemName, true) + sysUUID, _, err := h.hmcClient.GetManagedSystemByName(ctx, node.SystemName) if err != nil { return fmt.Errorf("failed to get system UUID: %w", err) } @@ -1264,10 +1260,10 @@ func (h *HMCProvider) bootNodesWithISOBulk(ctx context.Context) error { h.logger.Info("Creating mount directory on VIOS", "path", mountPoint) mkdirCmd := fmt.Sprintf(`viosvrcmd -m %s -p %s -c "mkdir -p %s" --admin`, node.SystemName, viosName, mountPoint) - hmc.CliRunnerViaSSH(h.cfg.HMC.IP, viosUsername, viosPassword, mkdirCmd, h.debug) - + hmc.CliRunnerViaSSH(h.cfg.HMC.IP, viosUsername, viosPassword, mkdirCmd) + h.logger.Info("Mounting NFS on VIOS", "server", nfsServer, "export", exportPath) - _, err = hmc.MountNFS(context.WithoutCancel(ctx), h.hmcClient, node.SystemName, viosName, nfsServer, exportPath, mountPoint, "3", h.debug) + _, err = hmc.MountNFS(context.WithoutCancel(ctx), h.hmcClient, node.SystemName, viosName, nfsServer, exportPath, mountPoint, "3") if err != nil && !strings.Contains(err.Error(), "already mounted") { return fmt.Errorf("failed to mount NFS: %w", err) } @@ -1284,7 +1280,7 @@ func (h *HMCProvider) bootNodesWithISOBulk(ctx context.Context) error { isoPath := fmt.Sprintf("%s/agent.ppc64le.iso", mountPoint) h.logger.Info("Uploading ISO to VIOS repository (this copies ~1GB and may take a few minutes)...", "iso", isoPath) err = h.hmcClient.CreateVirtualOpticalMedia( - context.WithoutCancel(ctx), node.SystemName, viosUUID, viosName, mediaName, isoPath, 0, true, false, h.debug) + context.WithoutCancel(ctx), node.SystemName, viosUUID, viosName, mediaName, isoPath, 0, true, false) if err != nil && !strings.Contains(err.Error(), "already exists") { return fmt.Errorf("failed to create optical media on VIOS: %w", err) @@ -1340,7 +1336,7 @@ func (h *HMCProvider) bootNodesWithISOBulk(ctx context.Context) error { h.logger.Info("Bulk mapping on VIOS", "viosUUID", viosUUID, "lpar_count", len(lparMediaMap)) _, err := h.hmcClient.CreateVirtualOpticalMapsMultiLpar( - context.WithoutCancel(ctx), sysUUID, viosUUID, lparMediaMap, h.debug) + context.WithoutCancel(ctx), sysUUID, viosUUID, lparMediaMap) if err != nil { return fmt.Errorf("failed to bulk map optical media: %w", err) } @@ -1380,7 +1376,7 @@ func (h *HMCProvider) bootNodesWithISOBulk(ctx context.Context) error { h.logger.Info("Unmounting NFS from VIOS", "vios", h.isoMappings[i].VIOSName, "mount", mp) // Shield from cancellation so we don't leave the VIOS hanging! - _, err := hmc.UnmountNFS(context.WithoutCancel(ctx), h.hmcClient, h.isoMappings[i].SystemName, h.isoMappings[i].VIOSName, mp, h.debug) + _, err := hmc.UnmountNFS(context.WithoutCancel(ctx), h.hmcClient, h.isoMappings[i].SystemName, h.isoMappings[i].VIOSName, mp) if err != nil && !strings.Contains(err.Error(), "Could not find anything to unmount") && !strings.Contains(err.Error(), "not mounted") { h.logger.Warn("Failed to cleanly unmount NFS (will retry during teardown)", "error", err) @@ -1389,7 +1385,7 @@ func (h *HMCProvider) bootNodesWithISOBulk(ctx context.Context) error { h.logger.Info("Removing mount directory from VIOS", "mount", mp) rmdirCmd := fmt.Sprintf(`viosvrcmd -m %s -p %s -c "rmdir %s" --admin`, h.isoMappings[i].SystemName, h.isoMappings[i].VIOSName, mp) - _, _ = hmc.CliRunnerViaSSH(h.cfg.HMC.IP, viosUsername, viosPassword, rmdirCmd, h.debug) + _, _ = hmc.CliRunnerViaSSH(h.cfg.HMC.IP, viosUsername, viosPassword, rmdirCmd) mountsCleaned[mp] = true } @@ -1422,10 +1418,10 @@ func (h *HMCProvider) bootNodesWithISOBulk(ctx context.Context) error { // 🛡️ SHIELDED API CALLS: Protect the HMC session token from concurrent corruption apiTrafficMutex.RLock() - _ = h.hmcClient.SaveCurrentLparConfig(context.WithoutCancel(ctx), targetNode.UUID, "default_profile", true, h.debug) - _ = h.hmcClient.SetPartitionBootString(context.WithoutCancel(ctx), targetNode.UUID, "cd/dvd-all", h.debug) - - lparDetails, err := h.hmcClient.GetLogicalPartitionDetailed(ctx, targetNode.UUID, h.debug) + _ = h.hmcClient.SaveCurrentLparConfig(context.WithoutCancel(ctx), targetNode.UUID, "default_profile", true) + _ = h.hmcClient.SetPartitionBootString(context.WithoutCancel(ctx), targetNode.UUID, "cd/dvd-all") + + lparDetails, err := h.hmcClient.GetLogicalPartitionDetailed(ctx, targetNode.UUID) apiTrafficMutex.RUnlock() if err != nil { @@ -1447,7 +1443,7 @@ func (h *HMCProvider) bootNodesWithISOBulk(ctx context.Context) error { } apiTrafficMutex.RLock() - _, err = h.hmcClient.PowerOnPartition(ctx, targetNode.UUID, powerOnOpts, h.debug) + _, err = h.hmcClient.PowerOnPartition(ctx, targetNode.UUID, powerOnOpts) apiTrafficMutex.RUnlock() if err != nil && !strings.Contains(err.Error(), "already running") { @@ -1509,7 +1505,7 @@ func (h *HMCProvider) CleanupISOMappings(ctx context.Context) error { var created bool var apiErr error // Fixed: explicit declaration prevents undefined 'err' // Shield from cancellation to prevent partially created VIOS admin account - viosUsername, viosPassword, created, apiErr = h.hmcClient.EnsureVIOSAdminUser(context.WithoutCancel(ctx), h.cfg.HMC.Username, h.cfg.HMC.Password, h.debug) + viosUsername, viosPassword, created, apiErr = h.hmcClient.EnsureVIOSAdminUser(context.WithoutCancel(ctx), h.cfg.HMC.Username, h.cfg.HMC.Password) if apiErr != nil || viosPassword == "" { h.logger.Warn("Failed to get viosadmin credentials via API, falling back to default", "error", apiErr) viosUsername, viosPassword = h.hmcClient.GetVIOSAdminCredentials() @@ -1530,7 +1526,7 @@ func (h *HMCProvider) CleanupISOMappings(ctx context.Context) error { unmapTargets := make(map[string]map[string]map[string][]string) for _, mapping := range h.isoMappings { - sysUUID, _, err := h.hmcClient.GetManagedSystemByName(context.WithoutCancel(ctx), mapping.SystemName, h.debug) + sysUUID, _, err := h.hmcClient.GetManagedSystemByName(context.WithoutCancel(ctx), mapping.SystemName) if err != nil { h.logger.Warn("Failed to get system UUID for cleanup", "system", mapping.SystemName, "error", err) continue @@ -1552,7 +1548,7 @@ func (h *HMCProvider) CleanupISOMappings(ctx context.Context) error { h.logger.Info("Bulk unmapping optical media from VIOS...", "viosUUID", viosUUID, "lpar_count", len(lparMediaMap)) _, err := h.hmcClient.DeleteVirtualOpticalMapsMultiLpar( - context.WithoutCancel(ctx), sysUUID, viosUUID, lparMediaMap, h.debug) + context.WithoutCancel(ctx), sysUUID, viosUUID, lparMediaMap) if err != nil { h.logger.Error("Failed to bulk unmap optical media", "error", err) @@ -1575,7 +1571,7 @@ func (h *HMCProvider) CleanupISOMappings(ctx context.Context) error { h.logger.Info("Saving LPAR profile", "node", targetMapping.NodeName) - err := h.hmcClient.SaveCurrentLparConfig(context.WithoutCancel(ctx), targetMapping.LparUUID, "default_profile", true, h.debug) + err := h.hmcClient.SaveCurrentLparConfig(context.WithoutCancel(ctx), targetMapping.LparUUID, "default_profile", true) if err != nil { h.logger.Warn("Failed to save LPAR profile", "node", targetMapping.NodeName, "error", err) } @@ -1604,7 +1600,7 @@ func (h *HMCProvider) CleanupISOMappings(ctx context.Context) error { h.logger.Info(fmt.Sprintf("Checking repository for media: %s", mapping.MediaName)) // Shield prerequisite lookup - teardown must not bypass deletion! - _, err := h.hmcClient.GetVirtualOpticalMedia(context.WithoutCancel(ctx), mapping.SystemName, mapping.VIOSName, mapping.MediaName, h.debug) + _, err := h.hmcClient.GetVirtualOpticalMedia(context.WithoutCancel(ctx), mapping.SystemName, mapping.VIOSName, mapping.MediaName) if err == nil { h.logger.Info(fmt.Sprintf("Destroying optical payload: %s", mapping.MediaName)) @@ -1613,8 +1609,7 @@ func (h *HMCProvider) CleanupISOMappings(ctx context.Context) error { context.WithoutCancel(ctx), mapping.SystemName, mapping.VIOSName, - mapping.MediaName, - h.debug) + mapping.MediaName) if delErr != nil { h.logger.Warn("Failed to delete optical media", "media", mapping.MediaName, "error", delErr) @@ -1649,7 +1644,7 @@ func (h *HMCProvider) CleanupISOMappings(ctx context.Context) error { h.logger.Info("Unmounting shared NFS from VIOS...", "mount_point", mapping.MountPoint, "vios", mapping.VIOSName) // Shield from cancellation to prevent locked VIOS mount daemon - _, err := hmc.UnmountNFS(context.WithoutCancel(ctx), h.hmcClient, mapping.SystemName, mapping.VIOSName, mapping.MountPoint, h.debug) + _, err := hmc.UnmountNFS(context.WithoutCancel(ctx), h.hmcClient, mapping.SystemName, mapping.VIOSName, mapping.MountPoint) if err != nil && (strings.Contains(err.Error(), "Could not find anything to unmount") || strings.Contains(err.Error(), "not mounted")) { h.logger.Info("Directory is already unmounted from VIOS", "mount_point", mapping.MountPoint) @@ -1661,7 +1656,7 @@ func (h *HMCProvider) CleanupISOMappings(ctx context.Context) error { h.logger.Info("Removing mount directory from VIOS", "mount_point", mapping.MountPoint) rmdirCmd := fmt.Sprintf(`viosvrcmd -m %s -p %s -c "rmdir %s" --admin`, mapping.SystemName, mapping.VIOSName, mapping.MountPoint) - _, err = hmc.CliRunnerViaSSH(h.cfg.HMC.IP, viosUsername, viosPassword, rmdirCmd, h.debug) + _, err = hmc.CliRunnerViaSSH(h.cfg.HMC.IP, viosUsername, viosPassword, rmdirCmd) if err != nil && (strings.Contains(err.Error(), "No such file or directory") || strings.Contains(err.Error(), "not found")) { h.logger.Info("Mount directory already removed", "mount_point", mapping.MountPoint) @@ -1691,7 +1686,7 @@ func (h *HMCProvider) CleanupISOMappings(ctx context.Context) error { // ensureMediaRepository checks if the VIOS Media Repository exists, and auto-creates it if missing func (h *HMCProvider) ensureMediaRepository(ctx context.Context, systemName, viosUUID, viosName string) error { - repoInfo, err := h.hmcClient.GetMediaRepositoryInfo(ctx, systemName, viosName, h.debug) + repoInfo, err := h.hmcClient.GetMediaRepositoryInfo(ctx, systemName, viosName) // The HMC API can return success (err == nil) but SizeMB = 0 if the repository isn't created, // OR it can return an error if the repository doesn't exist. Handle both cases. @@ -1718,7 +1713,7 @@ func (h *HMCProvider) ensureMediaRepository(ctx context.Context, systemName, vio requiredGB := float64(requiredMB) / 1024.0 // Find suitable Volume Group - vgs, vgErr := h.hmcClient.GetVolumeGroups(ctx, viosUUID, h.debug) + vgs, vgErr := h.hmcClient.GetVolumeGroups(ctx, viosUUID) if vgErr != nil { return fmt.Errorf("failed to list volume groups: %w", vgErr) } @@ -1752,7 +1747,7 @@ func (h *HMCProvider) ensureMediaRepository(ctx context.Context, systemName, vio h.logger.Info("Creating Media Repository", "size_mb", requiredMB, "vg", targetVG) // Shield from cancellation - modifying VIOS Volume Group must complete to prevent corruption - if createErr := h.hmcClient.CreateMediaRepository(context.WithoutCancel(ctx), systemName, viosUUID, viosName, targetVG, requiredMB, h.debug); createErr != nil { + if createErr := h.hmcClient.CreateMediaRepository(context.WithoutCancel(ctx), systemName, viosUUID, viosName, targetVG, requiredMB); createErr != nil { // If repository already exists, that's actually OK - just log and continue if strings.Contains(createErr.Error(), "already exists") { h.logger.Info("Media Repository already exists (detected during creation attempt)", "vios", viosName) diff --git a/infra/compute/provider.go b/infra/compute/provider.go index e971481..5445150 100644 --- a/infra/compute/provider.go +++ b/infra/compute/provider.go @@ -36,14 +36,14 @@ func NewProvider(cfg *types.AgentConfig, log *logger.Logger, debug bool) (Provid func NewProviderWithState(cfg *types.AgentConfig, log *logger.Logger, debug bool, stateManager *types.StateManager) (Provider, error) { log.Debug("Connecting to HMC...", "ip", cfg.HMC.IP, "user", cfg.HMC.Username) - client := hmc.NewRestClient(cfg.HMC.IP) - - // Configure HMC logger to write API traffic to deployment log only (not terminal) - hmcLogger := infra.NewHMCLoggerAdapter(log, debug) - client.SetLogger(hmcLogger) + // Always attach the debug transport so the deployment log captures full HMC + // traffic regardless of --debug. The console only shows it when debug==true + // because the console logger is set to InfoLevel in that case. + baseClient := hmc.NewRestClient(cfg.HMC.IP).WithTLSInsecure() + client := baseClient.WithTransport(infra.HMCDebugTransport(log)(baseClient.HTTPTransport())) log.Debug("Authenticating with HMC...") - if err := client.Login(context.Background(), cfg.HMC.Username, cfg.HMC.Password, debug); err != nil { + if err := client.Login(context.Background(), cfg.HMC.Username, cfg.HMC.Password); err != nil { return nil, fmt.Errorf("HMC login failed for user %s at %s: %w. Please verify HMC is accessible and credentials are correct", cfg.HMC.Username, cfg.HMC.IP, err) } diff --git a/infra/controller/network.go b/infra/controller/network.go index 2a96e70..da82bc2 100644 --- a/infra/controller/network.go +++ b/infra/controller/network.go @@ -105,8 +105,10 @@ func (nm *NetworkManager) RemoveVIPAlias(ctx context.Context, iface, ip, cidr, c return nil } - // 3. Get ALL IPv4 addresses currently configured on the interface - getAllCmd := fmt.Sprintf("ip -o -4 addr show dev %s | awk '{print $4}'", iface) + // 3. Get ALL IPv4 addresses currently configured on the interface. + // grep -oP is used instead of awk column-splitting to be robust against + // secondary interface labels or extra routing flags that would shift columns. + getAllCmd := fmt.Sprintf("ip -o -4 addr show dev %s | grep -oP '\\d+\\.\\d+\\.\\d+\\.\\d+/\\d+'", iface) allIPsOut, err := nm.executor.Execute(ctx, getAllCmd) if err != nil { return fmt.Errorf("failed to retrieve IP addresses for interface %s: %v", iface, err) diff --git a/infra/hmclog.go b/infra/hmclog.go index 07be7c3..c75f498 100644 --- a/infra/hmclog.go +++ b/infra/hmclog.go @@ -1,45 +1,100 @@ package infra import ( + "bytes" + "fmt" + "io" + "net/http" + "regexp" "strings" + "sync" + "time" - "github.com/charmbracelet/log" - hmc "github.com/IBM/infra-go-sdk/phmc" "github.com/IBM/shiftlaunch/logger" ) -// logWriter acts as a bridge between the standard io.Writer and our custom logger -type logWriter struct { +// logMu protects the raw file descriptor from concurrent interleaved writes. +var logMu sync.Mutex + +// sessionRegex targets the HMC token inside XML response payloads as a catch-all +// safety net when the token is not yet present in any header (e.g. login response). +var sessionRegex = regexp.MustCompile(`(?i)()(.*?)()`) + +// debugRoundTripper is an http.RoundTripper that logs every HMC request and +// response through shiftlaunch's logger. The base transport (TLS) is provided +// by hmc.NewRestClientWithTransport via the TransportWrapper factory. +type debugRoundTripper struct { + base http.RoundTripper logger *logger.Logger } -func (lw *logWriter) Write(p []byte) (n int, err error) { - msg := strings.TrimSpace(string(p)) - if msg != "" { - // Route through Debug so it respects the spinner UI and file outputs - lw.logger.Debug("[HMC API] " + msg) +func (d *debugRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // Console + file: compact one-liner. + d.logger.Debug("[HMC →] " + req.Method + " " + req.URL.String()) + + resp, err := d.base.RoundTrip(req) + if err != nil { + d.logger.Debug("[HMC ✗] " + req.Method + " " + req.URL.String() + " — " + err.Error()) + return nil, err } - return len(p), nil -} -// NewHMCLoggerAdapter creates an HMC logger that integrates safely with the spinner UI -func NewHMCLoggerAdapter(shiftlaunchLogger *logger.Logger, debug bool) *hmc.Logger { - var level log.Level + // Read and buffer the body so the caller can still consume it. + body, readErr := io.ReadAll(resp.Body) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(body)) + + // Console + file: one-liner with status (no body — avoids charmbracelet truncation). + d.logger.Debug("[HMC ←] " + resp.Status + " " + req.URL.String()) + + // File only: full untruncated body, written directly to bypass charmbracelet. + // logMu gates the raw *os.File so parallel goroutines cannot interleave writes. + bodyStr := maskSession(req, resp, string(body)) + logMu.Lock() + fmt.Fprintf(d.logger.FileOnly(), + "%s DEBU [HMC body] %s %s\n%s\n", + time.Now().Format("2006/01/02 15:04:05"), resp.Status, req.URL.String(), bodyStr, + ) + logMu.Unlock() - if debug { - level = log.DebugLevel - } else { - // Suppress completely if not in debug mode - level = log.ErrorLevel + if readErr != nil { + return nil, readErr + } + return resp, nil +} + +// maskSession redacts the X-API-Session token from a log payload using three +// complementary strategies: +// 1. Request header — covers all authenticated API calls. +// 2. Response header — covers the login response where the token is first issued. +// 3. Regex scrub — removes the token directly from the XML body as an absolute +// safety net, regardless of header availability. +func maskSession(req *http.Request, resp *http.Response, payload string) string { + // 1. Redact token carried on outgoing requests (standard authenticated calls). + if token := req.Header.Get("X-API-Session"); token != "" { + payload = strings.ReplaceAll(payload, token, "***[REDACTED]***") } - // Route all HMC SDK traffic securely through our custom logging engine - safeOutput := &logWriter{logger: shiftlaunchLogger} + // 2. Redact token returned on login — present in the response header, not the request. + if resp != nil { + if token := resp.Header.Get("X-API-Session"); token != "" { + payload = strings.ReplaceAll(payload, token, "***[REDACTED]***") + } + } - hmc.ReinitLogger(safeOutput) + // 3. Hard scrub any remaining element in the XML body. + payload = sessionRegex.ReplaceAllString(payload, "${1}***[REDACTED]***${3}") - hmcLogger := hmc.NewLogger(level, safeOutput) - hmcLogger.SetPrefix("") // Prefix handled by the logWriter + return payload +} - return hmcLogger +// HMCDebugTransport returns an http.RoundTripper middleware factory that +// wraps the SDK's TLS transport with request/response logging. Pass the result +// to hmc.RestClient.WithTransport: +// +// base := hmc.NewRestClient(ip).WithTLSInsecure() +// client := base.WithTransport(infra.HMCDebugTransport(log)(base.HTTPTransport())) +func HMCDebugTransport(log *logger.Logger) func(http.RoundTripper) http.RoundTripper { + return func(base http.RoundTripper) http.RoundTripper { + return &debugRoundTripper{base: base, logger: log} + } } diff --git a/localexec/client.go b/localexec/client.go index 5d322da..6306953 100644 --- a/localexec/client.go +++ b/localexec/client.go @@ -59,8 +59,9 @@ func (l *LocalClient) Execute(ctx context.Context, command string) (string, erro func (l *LocalClient) WriteFile(ctx context.Context, path string, content []byte, perms os.FileMode) error { l.logger.Debug("Writing local file", "path", path) - // Create temp file - tmpPath := filepath.Join("/tmp", filepath.Base(path)+".tmp") + // Create temp file with the PID embedded so concurrent shiftlaunch processes + // (e.g. two simultaneous `shiftlaunch create` runs) never share the same path. + tmpPath := filepath.Join("/tmp", fmt.Sprintf("%s-%d.tmp", filepath.Base(path), os.Getpid())) if err := os.WriteFile(tmpPath, content, perms); err != nil { return fmt.Errorf("failed to write temp file: %w", err) } diff --git a/logger/logger.go b/logger/logger.go index 0ee1f82..575de68 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -32,14 +32,11 @@ func New(debug bool, logPath string) (*Logger, error) { if err == nil && file != nil { // Create file logger (colors will be auto-disabled for file) fileOpts := log.Options{ - ReportTimestamp: true, - Prefix: "ShiftLaunch", - } - if debug { - fileOpts.Level = log.DebugLevel - } else { - fileOpts.Level = log.InfoLevel - } + ReportTimestamp: true, + Prefix: "ShiftLaunch", + // File always captures full debug — the debug flag only gates the terminal. + Level: log.DebugLevel, + } fileLogger = log.NewWithOptions(file, fileOpts) } } diff --git a/orchestrator/orchestrator.go b/orchestrator/orchestrator.go index 0a916c6..502a10a 100644 --- a/orchestrator/orchestrator.go +++ b/orchestrator/orchestrator.go @@ -241,33 +241,33 @@ func (o *Orchestrator) Deploy(ctx context.Context, resume bool) (err error) { output, err := o.executor.Execute(ctx, fmt.Sprintf("ip addr show %s", iface)) if err == nil && strings.Contains(output, o.cfg.Services.LoadBalancer.GetVIP()+"/") { // VIP is configured - check which cluster is using it - conflictingCluster := o.findClusterUsingVIP(o.cfg.Services.LoadBalancer.VIP) + conflictingCluster := o.findClusterUsingVIP(o.cfg.Services.LoadBalancer.GetVIP()) if conflictingCluster != "" { o.logger.Error("VIP is already in use by another cluster", - "vip", o.cfg.Services.LoadBalancer.VIP, + "vip", o.cfg.Services.LoadBalancer.GetVIP(), "cluster", conflictingCluster) return fmt.Errorf("VIP %s is already in use by cluster '%s'. Please choose a different VIP or delete the conflicting cluster first", - o.cfg.Services.LoadBalancer.VIP, conflictingCluster) + o.cfg.Services.LoadBalancer.GetVIP(), conflictingCluster) } o.logger.Error("VIP is already configured on interface", - "vip", o.cfg.Services.LoadBalancer.VIP, + "vip", o.cfg.Services.LoadBalancer.GetVIP(), "interface", iface) return fmt.Errorf("VIP %s is already configured on interface %s. Please remove the VIP alias manually or choose a different VIP", - o.cfg.Services.LoadBalancer.VIP, iface) + o.cfg.Services.LoadBalancer.GetVIP(), iface) } } // Check if VIP is defined in another cluster's config - conflictingCluster := o.findClusterUsingVIP(o.cfg.Services.LoadBalancer.VIP) + conflictingCluster := o.findClusterUsingVIP(o.cfg.Services.LoadBalancer.GetVIP()) if conflictingCluster != "" { o.logger.Error("VIP is already configured for another cluster", - "vip", o.cfg.Services.LoadBalancer.VIP, + "vip", o.cfg.Services.LoadBalancer.GetVIP(), "cluster", conflictingCluster) return fmt.Errorf("VIP %s is already configured for cluster '%s'. Please choose a different VIP", - o.cfg.Services.LoadBalancer.VIP, conflictingCluster) + o.cfg.Services.LoadBalancer.GetVIP(), conflictingCluster) } - - o.logger.Info("VIP is available", "vip", o.cfg.Services.LoadBalancer.VIP) + + o.logger.Info("VIP is available", "vip", o.cfg.Services.LoadBalancer.GetVIP()) } o.logger.EndPhase(true, "[Phase 0/6] Pre-Deployment Validation Complete") @@ -979,7 +979,7 @@ func (o *Orchestrator) findClusterUsingVIP(vip string) string { // Safely parse the YAML to guarantee an exact value match var tempCfg types.AgentConfig if err := yaml.Unmarshal(data, &tempCfg); err == nil { - if tempCfg.Services.LoadBalancer.VIP == vip { + if tempCfg.Services.LoadBalancer != nil && tempCfg.Services.LoadBalancer.GetVIP() == vip { return clusterName } } diff --git a/services/downloader.go b/services/downloader.go index 6ec5e80..b2cf08f 100644 --- a/services/downloader.go +++ b/services/downloader.go @@ -60,8 +60,8 @@ func (d *Downloader) DownloadRHCOSImages(ctx context.Context, workspaceDir strin urls := d.cfg.OpenShift.RHCOSImages timeout := d.daemonCfg.Timeouts.DownloadTimeoutSec // Get timeout from config - // Note: Checksum validation is optional and not configured in the new config structure - // Files will be downloaded without integrity verification unless checksums are added to config + // Manifest path for checksum verification (populated when ChecksumURL is configured) + manifestPath := filepath.Join(workspaceDir, "rhcos", "sha256sum.txt") images := []struct { url string @@ -78,7 +78,9 @@ func (d *Downloader) DownloadRHCOSImages(ctx context.Context, workspaceDir strin return fmt.Errorf("%s URL not provided in configuration", img.desc) } destPath := filepath.Join(rhcosDir, img.filename) - expectedHash := "" // Checksum validation disabled in new config structure + + // Attempt to resolve the expected hash from the manifest; fall back gracefully. + expectedHash, _ := d.extractHashFromManifest(ctx, img.url, manifestPath) // 3. Conditional Flow based on checksum availability and force_ocp_download flag forceDownload := d.cfg.OpenShift.ForceOCPDownload @@ -97,13 +99,11 @@ func (d *Downloader) DownloadRHCOSImages(ctx context.Context, workspaceDir strin d.logger.Warn("Checksum mismatch. Wiping corrupted file and re-downloading...", "image", img.desc) d.exec.Execute(ctx, fmt.Sprintf("rm -f %s", destPath)) } - } else { - checkCmd := fmt.Sprintf("test -s %s", destPath) - if _, err := d.exec.Execute(ctx, checkCmd); err == nil { - d.logger.Info("File already exists, skipping download (no checksum validation)", "image", img.desc) - continue - } } + // No checksum available: always run curl. curl -C - (resume) will issue a + // range request and exit 0 immediately if the file is already complete, + // but will resume and finish a truncated partial download rather than + // silently accepting a corrupted file left behind by a previous Ctrl+C. } // 4. Download the file @@ -181,8 +181,8 @@ func (d *Downloader) DownloadOpenShiftTools(ctx context.Context, workspaceDir st continue } destPath := filepath.Join(toolsDir, tool.filename) - // Checksum validation disabled in new config structure - expectedHash := "" + // Attempt to resolve the expected hash from the manifest downloaded above. + expectedHash, _ := d.extractHashFromManifest(ctx, tool.url, manifestPath) forceDownload := d.cfg.OpenShift.ForceOCPDownload @@ -200,13 +200,10 @@ func (d *Downloader) DownloadOpenShiftTools(ctx context.Context, workspaceDir st d.logger.Warn("Checksum mismatch. Wiping corrupted file and re-downloading...", "tool", tool.desc) d.exec.Execute(ctx, fmt.Sprintf("rm -f %s", destPath)) } - } else { - checkCmd := fmt.Sprintf("test -s %s", destPath) - if _, err := d.exec.Execute(ctx, checkCmd); err == nil { - d.logger.Info("File already exists, skipping download (no checksum validation)", "tool", tool.desc) - continue - } } + // No checksum available: always run curl. curl -C - (resume) will exit 0 + // immediately if the file is already complete, and will resume a partial + // download rather than accepting a truncated file from a prior Ctrl+C. } // 4. Download diff --git a/services/registry.go b/services/registry.go index 7278a79..824cffa 100644 --- a/services/registry.go +++ b/services/registry.go @@ -167,7 +167,7 @@ func (r *RegistryManager) Setup(ctx context.Context, workspaceDir string) error // Ensure the port isn't hogged by something we can't authenticate to portCheck, _ := r.executor.Execute(shieldedCtx, "ss -tlpn | grep ':5000 ' 2>/dev/null || true") if strings.TrimSpace(portCheck) != "" { - return fmt.Errorf("Port 5000 is in use, but registry authentication failed (HTTP %s). Cannot proceed.", httpCode) + return fmt.Errorf("port 5000 is in use but registry authentication failed (HTTP %s)", httpCode) } r.logger.Debug("Starting fresh local registry service...") @@ -232,14 +232,17 @@ func (r *RegistryManager) Setup(ctx context.Context, workspaceDir string) error pullSecretPath := os.ExpandEnv(strings.ReplaceAll(r.cfg.OpenShift.PullSecretFile, "~", "$HOME")) updatedSecretPath := filepath.Join(workspaceDir, "pull-secret-updated.json") - /*updateSecretCmd := fmt.Sprintf(`registry_token=$(echo -n "%s:%s" | base64 -w0) && \ - jq '.auths += {"%s": {"auth": "'$registry_token'","email": "noemail@localhost"}}' \ - < %s > %s`, - username, password, registryURL, pullSecretPath, updatedSecretPath)*/ - updateSecretCmd := fmt.Sprintf(`registry_token=$(echo -n "%s:%s" | base64 -w0) && \ - jq -c '.auths += {"%s": {"auth": "'$registry_token'","email": "noemail@localhost"}}' \ - < %s > %s`, - username, password, registryURL, pullSecretPath, updatedSecretPath) + // export REG_TOKEN so jq can read it via env.REG_TOKEN — it never appears in + // the filter string itself. registryURL is bound via --arg so it is an opaque + // jq variable ($host) with no quoting issues. All credentials and paths are + // wrapped in shellQuote so bash metacharacters cannot escape the shell command. + updateSecretCmd := fmt.Sprintf( + `export REG_TOKEN=$(printf '%%s:%%s' %s %s | base64 -w0) && `+ + `jq -c --arg host %s '.auths += {($host): {"auth": env.REG_TOKEN}}' `+ + `< %s > %s`, + shellQuote(username), shellQuote(password), shellQuote(registryURL), + shellQuote(pullSecretPath), shellQuote(updatedSecretPath), + ) if _, err := r.executor.Execute(shieldedCtx, updateSecretCmd); err != nil { return fmt.Errorf("failed to update pull secret: %w", err) } @@ -520,3 +523,10 @@ func contains(slice []string, item string) bool { } return false } + +// shellQuote wraps s in single quotes and escapes any literal single quotes +// inside it using the '\'' idiom, making the value safe to embed in a bash +// command string regardless of the characters it contains. +func shellQuote(s string) string { + return "'" + strings.ReplaceAll(s, "'", `'\''`) + "'" +} diff --git a/validation/validator.go b/validation/validator.go index f7ab831..07a1eef 100644 --- a/validation/validator.go +++ b/validation/validator.go @@ -283,7 +283,7 @@ func (v *Validator) findClusterUsingVIP(vip string) string { var tempCfg types.AgentConfig if err := yaml.Unmarshal(data, &tempCfg); err == nil { - if tempCfg.Services.LoadBalancer.VIP == vip { + if tempCfg.Services.LoadBalancer != nil && tempCfg.Services.LoadBalancer.GetVIP() == vip { return clusterName } } @@ -763,7 +763,7 @@ func (v *Validator) validateBYOILPARs() { var systemUUID string var err error v.log.Capture(func() { - systemUUID, _, err = v.hmcClient.GetManagedSystemByName(context.Background(), node.SystemName, true) + systemUUID, _, err = v.hmcClient.GetManagedSystemByName(context.Background(), node.SystemName) }) if err != nil { v.errors = append(v.errors, fmt.Sprintf("failed to get system '%s' for LPAR validation: %v", node.SystemName, err)) @@ -772,7 +772,7 @@ func (v *Validator) validateBYOILPARs() { var lpars []hmc.LogicalPartitionQuick v.log.Capture(func() { - lpars, err = v.hmcClient.GetLogicalPartitionsQuickAll(context.Background(), systemUUID, true) + lpars, err = v.hmcClient.GetLogicalPartitionsQuickAll(context.Background(), systemUUID) }) if err != nil { v.errors = append(v.errors, fmt.Sprintf("failed to get LPARs for system '%s': %v", node.SystemName, err)) @@ -975,14 +975,14 @@ func (v *Validator) validateMediaRepositorySpace() { for systemName, count := range systemNodeCount { v.log.Info(fmt.Sprintf("Validating Media Repository on system '%s' for %d node(s)...", systemName, count)) - _, sysUUID, err := v.hmcClient.GetManagedSystemByNameQuick(context.Background(), systemName, v.debug) + _, sysUUID, err := v.hmcClient.GetManagedSystemByNameQuick(context.Background(), systemName) if err != nil { v.warnings = append(v.warnings, fmt.Sprintf("Could not resolve system UUID for repository check on '%s': %v", systemName, err)) continue } // Find the active VIOS - viosList, err := v.hmcClient.GetVirtualIOServersQuick(context.Background(), sysUUID, v.debug) + viosList, err := v.hmcClient.GetVirtualIOServersQuick(context.Background(), sysUUID) if err != nil || len(viosList) == 0 { v.warnings = append(v.warnings, fmt.Sprintf("Could not retrieve VIOS list for repository check on '%s'", systemName)) continue @@ -1014,14 +1014,14 @@ func (v *Validator) validateMediaRepositorySpace() { requiredGB := float64(createRequestMB) / 1024.0 // 1. Try to fetch the existing repository info - repoInfo, err := v.hmcClient.GetMediaRepositoryInfo(context.Background(), systemName, activeViosName, v.debug) + repoInfo, err := v.hmcClient.GetMediaRepositoryInfo(context.Background(), systemName, activeViosName) // 2. If it fails OR SizeMB is 0, the repository is missing. Verify we HAVE the capacity to auto-create it later. if err != nil || repoInfo.SizeMB == 0 { v.log.Info(fmt.Sprintf("Media Repository not found on VIOS '%s' (or size is 0). Verifying auto-creation capacity...", activeViosName)) // Discover a suitable Volume Group - vgs, vgErr := v.hmcClient.GetVolumeGroups(context.Background(), activeViosUUID, v.debug) + vgs, vgErr := v.hmcClient.GetVolumeGroups(context.Background(), activeViosUUID) if vgErr != nil { v.warnings = append(v.warnings, fmt.Sprintf("Failed to list Volume Groups to verify auto-creation on '%s': %v", activeViosName, vgErr)) continue From 2ad700dca61349d2458d5ca533b888b270f93541 Mon Sep 17 00:00:00 2001 From: SUDEESH JOHN Date: Tue, 30 Jun 2026 09:21:52 +0530 Subject: [PATCH 2/3] fix: bump Go to 1.26.4 to resolve CVE-2026-27145 and CVE-2026-42504 Both CVEs are HIGH severity fixes in Go stdlib 1.26.4: - CVE-2026-27145: x509 VerifyHostname hostname matching flaw - CVE-2026-42504: MIME header decoding DoS via malformed encoded-words Signed-off-by: SUDEESH JOHN --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index c399f9c..9cfea68 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/IBM/shiftlaunch -go 1.26.3 +go 1.26.4 require ( golang.org/x/crypto v0.53.0 // indirect From 4944c42706360e1618e3186f3f98e60c3ea6fed4 Mon Sep 17 00:00:00 2001 From: SUDEESH JOHN Date: Tue, 30 Jun 2026 17:05:28 +0530 Subject: [PATCH 3/3] fix: harden logger, downloader, and HMC round-tripper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit logger/logger.go: - New() no longer propagates log-file open error to caller; file log is optional — warn to stderr and continue rather than aborting - Remove duplicate UpdateText() call in Info() spinner path (first call was dead code, immediately overwritten by the second) - Simplify 'err == nil && file != nil' guard to 'err == nil' (the nil file check is always redundant when err is nil) - TerminalOnly() returned os.Stdout; corrected to os.Stderr to match every other output path in the logger - Fix gofmt indentation on fileOpts struct literal (fields were double-indented relative to their enclosing block) services/downloader.go: - Replace silent exec.Execute(mkdir -p) with os.MkdirAll so directory creation failures are surfaced immediately with a clear error - Quote all paths and URLs passed to shell via shellQuote() to prevent shell injection and handle paths with spaces or special characters; add -- end-of-options to curl invocations - Replace shell-out sha256sum|awk with pure-Go crypto/sha256+io.Copy in verifyFileHash — no injection surface, portable, faster - Replace shell-out test -f + grep|awk in extractHashFromManifest with os.Stat + os.ReadFile + strings.Fields — eliminates regex metachar injection risk from filenames - Replace exec rm -f for stale manifest with os.Remove - Fix extractOpenShiftTools to use 'tar -xzf -C ' instead of 'cd && tar -xzf ' — safer with quoted paths - Reuse rhcosDir in manifestPath join (was duplicating the base path) - Remove stale commit-message comment from DownloadAll - Fix missing space after // in two inline comments (golint) - Drop 'FATAL:' prefix from checksum mismatch error (non-idiomatic) infra/hmclog.go: - Check io.ReadAll error before restoring resp.Body; previously a partial read would hand a corrupt body back to the caller before the error was eventually checked and returned Signed-off-by: SUDEESH JOHN --- infra/hmclog.go | 16 ++-- logger/logger.go | 22 +++-- services/downloader.go | 202 +++++++++++++++++++++++------------------ 3 files changed, 139 insertions(+), 101 deletions(-) diff --git a/infra/hmclog.go b/infra/hmclog.go index c75f498..ff7d306 100644 --- a/infra/hmclog.go +++ b/infra/hmclog.go @@ -39,8 +39,16 @@ func (d *debugRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) } // Read and buffer the body so the caller can still consume it. - body, readErr := io.ReadAll(resp.Body) - _ = resp.Body.Close() + // Check the read error immediately — a partial body must not be logged or returned. + body, err := io.ReadAll(resp.Body) + if closeErr := resp.Body.Close(); closeErr != nil { + d.logger.Debug("[HMC ✗] body close: " + closeErr.Error()) + } + if err != nil { + return nil, err + } + + // Restore the body for the caller before any further processing. resp.Body = io.NopCloser(bytes.NewReader(body)) // Console + file: one-liner with status (no body — avoids charmbracelet truncation). @@ -55,10 +63,6 @@ func (d *debugRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) time.Now().Format("2006/01/02 15:04:05"), resp.Status, req.URL.String(), bodyStr, ) logMu.Unlock() - - if readErr != nil { - return nil, readErr - } return resp, nil } diff --git a/logger/logger.go b/logger/logger.go index 575de68..6bf8599 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -29,14 +29,14 @@ func New(debug bool, logPath string) (*Logger, error) { // 1. Attempt to open the log file if a path is provided if logPath != "" { file, err = os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) - if err == nil && file != nil { + if err == nil { // Create file logger (colors will be auto-disabled for file) fileOpts := log.Options{ - ReportTimestamp: true, - Prefix: "ShiftLaunch", - // File always captures full debug — the debug flag only gates the terminal. - Level: log.DebugLevel, - } + ReportTimestamp: true, + Prefix: "ShiftLaunch", + // File always captures full debug — the debug flag only gates the terminal. + Level: log.DebugLevel, + } fileLogger = log.NewWithOptions(file, fileOpts) } } @@ -53,12 +53,17 @@ func New(debug bool, logPath string) (*Logger, error) { } consoleLogger := log.NewWithOptions(os.Stderr, consoleOpts) + if err != nil { + // Log file is optional — report to stderr but do not abort the caller. + fmt.Fprintf(os.Stderr, "warning: could not open log file %q: %v\n", logPath, err) + } + return &Logger{ consoleLogger: consoleLogger, fileLogger: fileLogger, file: file, debug: debug, - }, err + }, nil } func (l *Logger) Info(msg string, keyvals ...interface{}) { @@ -88,7 +93,6 @@ func (l *Logger) Info(msg string, keyvals ...interface{}) { plainText = plainText[:maxLen-3] + "..." } - l.activeSpinner.UpdateText(plainText) l.activeSpinner.UpdateText(plainText + "\033[K") return } @@ -157,7 +161,7 @@ func (l *Logger) Capture(f func()) { // TerminalOnly returns an io.Writer that only writes to the console func (l *Logger) TerminalOnly() io.Writer { - return os.Stdout + return os.Stderr } // FileOnly returns an io.Writer that only writes to the log file diff --git a/services/downloader.go b/services/downloader.go index b2cf08f..926a6ef 100644 --- a/services/downloader.go +++ b/services/downloader.go @@ -2,7 +2,10 @@ package services import ( "context" + "crypto/sha256" + "encoding/hex" "fmt" + "io" "os" "path/filepath" "strings" @@ -33,7 +36,6 @@ func NewDownloader(cfg *types.AgentConfig, daemonCfg *config.AgentDaemonConfig, // DownloadAll downloads all required artifacts into the local workspace func (d *Downloader) DownloadAll(ctx context.Context, workspaceDir string) error { - // --- Removed the duplicate unconditional download call --- if d.cfg.Nodes.BootMethod == "agent" { d.logger.Info("Skipping RHCOS image downloads (Agent ISO handles payload dynamically)") } else { @@ -55,13 +57,14 @@ func (d *Downloader) DownloadRHCOSImages(ctx context.Context, workspaceDir strin d.logger.Info("Downloading RHCOS images to local workspace...") rhcosDir := filepath.Join(workspaceDir, "rhcos") - d.exec.Execute(ctx, fmt.Sprintf("mkdir -p %s", rhcosDir)) + if err := os.MkdirAll(rhcosDir, 0o755); err != nil { + return fmt.Errorf("failed to create RHCOS directory %s: %w", rhcosDir, err) + } urls := d.cfg.OpenShift.RHCOSImages - timeout := d.daemonCfg.Timeouts.DownloadTimeoutSec // Get timeout from config + timeout := d.daemonCfg.Timeouts.DownloadTimeoutSec - // Manifest path for checksum verification (populated when ChecksumURL is configured) - manifestPath := filepath.Join(workspaceDir, "rhcos", "sha256sum.txt") + manifestPath := filepath.Join(rhcosDir, "sha256sum.txt") images := []struct { url string @@ -73,6 +76,8 @@ func (d *Downloader) DownloadRHCOSImages(ctx context.Context, workspaceDir strin {urls.RootfsURL, "rootfs.img", "RHCOS rootfs"}, } + forceDownload := d.cfg.OpenShift.ForceOCPDownload + for _, img := range images { if img.url == "" { return fmt.Errorf("%s URL not provided in configuration", img.desc) @@ -82,43 +87,25 @@ func (d *Downloader) DownloadRHCOSImages(ctx context.Context, workspaceDir strin // Attempt to resolve the expected hash from the manifest; fall back gracefully. expectedHash, _ := d.extractHashFromManifest(ctx, img.url, manifestPath) - // 3. Conditional Flow based on checksum availability and force_ocp_download flag - forceDownload := d.cfg.OpenShift.ForceOCPDownload - - if forceDownload { - d.logger.Info("Force download requested. Wiping existing file...", "file", destPath) - d.exec.Execute(ctx, fmt.Sprintf("rm -f %s", destPath)) - } else { - if expectedHash != "" { - existsCmd := fmt.Sprintf("test -f %s", destPath) - if _, err := d.exec.Execute(ctx, existsCmd); err == nil { - if d.verifyFileHash(ctx, destPath, expectedHash) { - d.logger.Info("Checksum matches, skipping download", "image", img.desc) - continue - } - d.logger.Warn("Checksum mismatch. Wiping corrupted file and re-downloading...", "image", img.desc) - d.exec.Execute(ctx, fmt.Sprintf("rm -f %s", destPath)) - } + skip, wipe := d.resolveDownloadAction(ctx, destPath, expectedHash, forceDownload) + if skip { + continue + } + if wipe { + if err := os.Remove(destPath); err != nil && !os.IsNotExist(err) { + d.logger.Warn("Failed to remove stale file", "file", destPath, "error", err) } - // No checksum available: always run curl. curl -C - (resume) will issue a - // range request and exit 0 immediately if the file is already complete, - // but will resume and finish a truncated partial download rather than - // silently accepting a corrupted file left behind by a previous Ctrl+C. } - // 4. Download the file d.logger.Info("Downloading image...", "image", img.desc) - - // Use the dynamic timeout here! - downloadCmd := fmt.Sprintf("curl -sSL -C - --retry 3 --retry-delay 5 --max-time %d -o %s '%s'", timeout, destPath, img.url) + downloadCmd := fmt.Sprintf("curl -sSL -C - --retry 3 --retry-delay 5 --max-time %d -o %s -- %s", timeout, shellQuote(destPath), shellQuote(img.url)) if _, err := d.exec.Execute(ctx, downloadCmd); err != nil { return fmt.Errorf("failed to download %s from %s: %w", img.desc, img.url, err) } - // 5. Final Verification (if checksum is available) if expectedHash != "" { if !d.verifyFileHash(ctx, destPath, expectedHash) { - return fmt.Errorf("FATAL: %s checksum mismatch after download", img.desc) + return fmt.Errorf("%s checksum mismatch after download", img.desc) } d.logger.Info("Downloaded and verified", "image", img.desc) } else { @@ -134,7 +121,9 @@ func (d *Downloader) DownloadOpenShiftTools(ctx context.Context, workspaceDir st d.logger.Info("Downloading OpenShift tools...") toolsDir := filepath.Join(workspaceDir, "tools") - d.exec.Execute(ctx, fmt.Sprintf("mkdir -p %s", toolsDir)) + if err := os.MkdirAll(toolsDir, 0o755); err != nil { + return fmt.Errorf("failed to create tools directory %s: %w", toolsDir, err) + } // Check if the extracted binaries are already here (Airgap mode safety) installerPath := filepath.Join(toolsDir, "openshift-install") @@ -155,10 +144,10 @@ func (d *Downloader) DownloadOpenShiftTools(ctx context.Context, workspaceDir st if ocpConfig.ChecksumURL != "" { d.logger.Info("Integrity Mode: Fetching fresh checksum manifest", "url", ocpConfig.ChecksumURL) - // Force wipe any stale manifest to guarantee we get the latest - d.exec.Execute(ctx, fmt.Sprintf("rm -f %s", manifestPath)) + // Force wipe any stale manifest to guarantee we get the latest. + _ = os.Remove(manifestPath) - dlManifestCmd := fmt.Sprintf("curl -sSL --fail --max-time %d -o %s '%s'", timeout, manifestPath, ocpConfig.ChecksumURL) + dlManifestCmd := fmt.Sprintf("curl -sSL --fail --max-time %d -o %s -- %s", timeout, shellQuote(manifestPath), shellQuote(ocpConfig.ChecksumURL)) if _, err := d.exec.Execute(ctx, dlManifestCmd); err != nil { d.logger.Warn("Failed to fetch checksum manifest", "error", err) } else { @@ -176,47 +165,34 @@ func (d *Downloader) DownloadOpenShiftTools(ctx context.Context, workspaceDir st {ocpConfig.MirrorClient, "oc-mirror.tar.gz", "OpenShift mirror plugin"}, } + forceDownload := d.cfg.OpenShift.ForceOCPDownload + for _, tool := range tools { if tool.url == "" { continue } destPath := filepath.Join(toolsDir, tool.filename) + // Attempt to resolve the expected hash from the manifest downloaded above. expectedHash, _ := d.extractHashFromManifest(ctx, tool.url, manifestPath) - forceDownload := d.cfg.OpenShift.ForceOCPDownload - - if forceDownload { - d.logger.Info("Force download requested. Wiping existing file...", "file", destPath) - d.exec.Execute(ctx, fmt.Sprintf("rm -f %s", destPath)) - } else { - if expectedHash != "" { - existsCmd := fmt.Sprintf("test -f %s", destPath) - if _, err := d.exec.Execute(ctx, existsCmd); err == nil { - if d.verifyFileHash(ctx, destPath, expectedHash) { - d.logger.Info("Matches checksum, skipping download", "tool", tool.desc) - continue - } - d.logger.Warn("Checksum mismatch. Wiping corrupted file and re-downloading...", "tool", tool.desc) - d.exec.Execute(ctx, fmt.Sprintf("rm -f %s", destPath)) - } + skip, wipe := d.resolveDownloadAction(ctx, destPath, expectedHash, forceDownload) + if skip { + continue + } + if wipe { + if err := os.Remove(destPath); err != nil && !os.IsNotExist(err) { + d.logger.Warn("Failed to remove stale file", "file", destPath, "error", err) } - // No checksum available: always run curl. curl -C - (resume) will exit 0 - // immediately if the file is already complete, and will resume a partial - // download rather than accepting a truncated file from a prior Ctrl+C. } - // 4. Download d.logger.Info("Downloading tool...", "tool", tool.desc) - - // Use the dynamic timeout here! - downloadCmd := fmt.Sprintf("curl -sSL -C - --retry 3 --retry-delay 5 --max-time %d -o %s '%s'", timeout, destPath, tool.url) + downloadCmd := fmt.Sprintf("curl -sSL -C - --retry 3 --retry-delay 5 --max-time %d -o %s -- %s", timeout, shellQuote(destPath), shellQuote(tool.url)) if _, err := d.exec.Execute(ctx, downloadCmd); err != nil { d.logger.Warn("Failed to download tool", "tool", tool.desc, "error", err) continue } - // 5. Final Verification if expectedHash != "" { if !d.verifyFileHash(ctx, destPath, expectedHash) { d.logger.Warn("Checksum mismatch after download", "tool", tool.desc) @@ -234,60 +210,114 @@ func (d *Downloader) DownloadOpenShiftTools(ctx context.Context, workspaceDir st func (d *Downloader) extractOpenShiftTools(ctx context.Context, toolsDir string) error { shieldedCtx := context.WithoutCancel(ctx) - //Add oc-mirror.tar.gz to extraction targets + // Extract each archive that is present and non-empty. tools := []string{"openshift-install-linux.tar.gz", "openshift-client-linux.tar.gz", "oc-mirror.tar.gz"} for _, tool := range tools { tarPath := filepath.Join(toolsDir, tool) - if _, err := d.exec.Execute(shieldedCtx, fmt.Sprintf("test -s %s", tarPath)); err != nil { + if _, err := d.exec.Execute(shieldedCtx, "test -s "+shellQuote(tarPath)); err != nil { continue } - extractCmd := fmt.Sprintf("cd %s && tar -xzf %s", toolsDir, tool) + extractCmd := fmt.Sprintf("tar -xzf %s -C %s", shellQuote(tarPath), shellQuote(toolsDir)) if _, err := d.exec.Execute(shieldedCtx, extractCmd); err != nil { return fmt.Errorf("failed to extract %s: %w", tool, err) } } - //Add oc-mirror to the chmod list - makeExecCmd := fmt.Sprintf("cd %s && chmod +x openshift-install oc kubectl oc-mirror 2>/dev/null || true", toolsDir) + // Make all extracted binaries executable. + makeExecCmd := fmt.Sprintf( + "chmod +x %s %s %s %s 2>/dev/null || true", + shellQuote(filepath.Join(toolsDir, "openshift-install")), + shellQuote(filepath.Join(toolsDir, "oc")), + shellQuote(filepath.Join(toolsDir, "kubectl")), + shellQuote(filepath.Join(toolsDir, "oc-mirror")), + ) _, err := d.exec.Execute(shieldedCtx, makeExecCmd) return err } // extractHashFromManifest parses sha256sum.txt for a specific filename // Uses precise grep pattern to avoid partial matches (e.g., "kernel" vs "my-kernel") -func (d *Downloader) extractHashFromManifest(ctx context.Context, originalURL, manifestPath string) (string, error) { - // Strip any query parameters from the URL (e.g., ?signature=123) - cleanURL := strings.Split(originalURL, "?")[0] +func (d *Downloader) extractHashFromManifest(_ context.Context, originalURL, manifestPath string) (string, error) { + // Strip query parameters (e.g. signed S3 URLs) before extracting the basename. + cleanURL := strings.SplitN(originalURL, "?", 2)[0] filename := filepath.Base(cleanURL) - // Ensure the manifest file actually exists before grepping - if _, err := d.exec.Execute(ctx, fmt.Sprintf("test -f %s", manifestPath)); err != nil { - return "", fmt.Errorf("manifest file not found on disk") + if _, err := os.Stat(manifestPath); err != nil { + return "", fmt.Errorf("manifest not found: %s", manifestPath) } - // Use [[:space:]] to match whitespace and $ to anchor end of line - extractCmd := fmt.Sprintf("grep -E '[[:space:]]%s$' %s | awk '{print $1}'", filename, manifestPath) - hash, err := d.exec.Execute(ctx, extractCmd) + data, err := os.ReadFile(manifestPath) if err != nil { - return "", fmt.Errorf("grep command failed: %w", err) + return "", fmt.Errorf("failed to read manifest %s: %w", manifestPath, err) } - hash = strings.TrimSpace(hash) - if hash == "" { - return "", fmt.Errorf("filename '%s' not found inside the manifest", filename) + // Each sha256sum line: " " — match last field to avoid + // partial name collisions (e.g. "kernel" matching "my-kernel"). + for _, line := range strings.Split(strings.TrimSpace(string(data)), "\n") { + fields := strings.Fields(line) + if len(fields) >= 2 && fields[len(fields)-1] == filename { + return fields[0], nil + } } - return hash, nil + return "", fmt.Errorf("filename %q not found in manifest", filename) } -// verifyFileHash calculates SHA256 hash of a file and compares it to expected hash -func (d *Downloader) verifyFileHash(ctx context.Context, filePath, expectedHash string) bool { - calcCmd := fmt.Sprintf("sha256sum %s | awk '{print $1}'", filePath) - actualHash, err := d.exec.Execute(ctx, calcCmd) +// resolveDownloadAction inspects the file at destPath and returns the action +// the caller should take before attempting a download, based on the force flag, +// checksum availability, and the file's current state on disk. +// +// Returns: +// - skip=true → file is present and verified; the caller should skip the download. +// - wipe=true → file is stale or corrupt; the caller must remove it before downloading. +func (d *Downloader) resolveDownloadAction(ctx context.Context, destPath, expectedHash string, forceDownload bool) (skip, wipe bool) { + if forceDownload { + d.logger.Info("Force download requested. Wiping existing file...", "file", destPath) + return false, true + } + + fi, err := os.Stat(destPath) + fileExists := err == nil + + if expectedHash != "" { + if !fileExists { + return false, false // nothing on disk yet; proceed to download + } + if d.verifyFileHash(ctx, destPath, expectedHash) { + d.logger.Info("Checksum matches, skipping download", "file", destPath) + return true, false + } + d.logger.Warn("Checksum mismatch. Wiping corrupted file and re-downloading...", "file", destPath) + return false, true + } + + // No checksum available. Guard against S3-backed mirrors that return HTTP 200 + // on a Range request, causing curl -C - to silently re-download the whole file. + // Skip only if the file is already non-empty; a truncated partial can be + // recovered with ForceOCPDownload=true. + if fileExists && fi.Size() > 0 { + d.logger.Info("File exists, no checksum configured — skipping re-download", "file", destPath) + return true, false + } + + return false, false +} + +// verifyFileHash computes the SHA-256 digest of filePath in pure Go and +// compares it against expectedHash (lowercase hex). Returns false on any error. +func (d *Downloader) verifyFileHash(_ context.Context, filePath, expectedHash string) bool { + f, err := os.Open(filePath) if err != nil { return false } - actual := strings.TrimSpace(actualHash) - expected := strings.TrimSpace(expectedHash) - return actual == expected + defer f.Close() + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return false + } + + actual := hex.EncodeToString(h.Sum(nil)) + return actual == strings.TrimSpace(expectedHash) } +