You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

244 lines
9.6 KiB

1 month ago
  1. import os
  2. import sys
  3. import json
  4. import time
  5. from time import gmtime, strftime
  6. import torch.distributed as dist
  7. import torch
  8. from torch.utils.data import DataLoader
  9. from torch.utils.data.distributed import DistributedSampler
  10. from torch.nn.parallel import DistributedDataParallel as DDP
  11. import comet_ml
  12. # Ensure project root is in path
  13. sys.path.append('../')
  14. from config import Config
  15. from dataset import QlibDataset
  16. from model.kronos import KronosTokenizer, Kronos
  17. # Import shared utilities
  18. from utils.training_utils import (
  19. setup_ddp,
  20. cleanup_ddp,
  21. set_seed,
  22. get_model_size,
  23. format_time
  24. )
  25. def create_dataloaders(config: dict, rank: int, world_size: int):
  26. """
  27. Creates and returns distributed dataloaders for training and validation.
  28. Args:
  29. config (dict): A dictionary of configuration parameters.
  30. rank (int): The global rank of the current process.
  31. world_size (int): The total number of processes.
  32. Returns:
  33. tuple: (train_loader, val_loader, train_dataset, valid_dataset).
  34. """
  35. print(f"[Rank {rank}] Creating distributed dataloaders...")
  36. train_dataset = QlibDataset('train')
  37. valid_dataset = QlibDataset('val')
  38. print(f"[Rank {rank}] Train dataset size: {len(train_dataset)}, Validation dataset size: {len(valid_dataset)}")
  39. train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
  40. val_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, rank=rank, shuffle=False)
  41. train_loader = DataLoader(
  42. train_dataset, batch_size=config['batch_size'], sampler=train_sampler,
  43. num_workers=config.get('num_workers', 2), pin_memory=True, drop_last=True
  44. )
  45. val_loader = DataLoader(
  46. valid_dataset, batch_size=config['batch_size'], sampler=val_sampler,
  47. num_workers=config.get('num_workers', 2), pin_memory=True, drop_last=False
  48. )
  49. return train_loader, val_loader, train_dataset, valid_dataset
  50. def train_model(model, tokenizer, device, config, save_dir, logger, rank, world_size):
  51. """
  52. The main training and validation loop for the predictor.
  53. """
  54. start_time = time.time()
  55. if rank == 0:
  56. effective_bs = config['batch_size'] * world_size
  57. print(f"Effective BATCHSIZE per GPU: {config['batch_size']}, Total: {effective_bs}")
  58. train_loader, val_loader, train_dataset, valid_dataset = create_dataloaders(config, rank, world_size)
  59. optimizer = torch.optim.AdamW(
  60. model.parameters(),
  61. lr=config['predictor_learning_rate'],
  62. betas=(config['adam_beta1'], config['adam_beta2']),
  63. weight_decay=config['adam_weight_decay']
  64. )
  65. scheduler = torch.optim.lr_scheduler.OneCycleLR(
  66. optimizer, max_lr=config['predictor_learning_rate'],
  67. steps_per_epoch=len(train_loader), epochs=config['epochs'],
  68. pct_start=0.03, div_factor=10
  69. )
  70. best_val_loss = float('inf')
  71. dt_result = {}
  72. batch_idx_global = 0
  73. for epoch_idx in range(config['epochs']):
  74. epoch_start_time = time.time()
  75. model.train()
  76. train_loader.sampler.set_epoch(epoch_idx)
  77. train_dataset.set_epoch_seed(epoch_idx * 10000 + rank)
  78. valid_dataset.set_epoch_seed(0)
  79. for i, (batch_x, batch_x_stamp) in enumerate(train_loader):
  80. batch_x = batch_x.squeeze(0).to(device, non_blocking=True)
  81. batch_x_stamp = batch_x_stamp.squeeze(0).to(device, non_blocking=True)
  82. # Tokenize input data on-the-fly
  83. with torch.no_grad():
  84. token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
  85. # Prepare inputs and targets for the language model
  86. token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
  87. token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
  88. # Forward pass and loss calculation
  89. logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
  90. loss, s1_loss, s2_loss = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
  91. # Backward pass and optimization
  92. optimizer.zero_grad()
  93. loss.backward()
  94. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3.0)
  95. optimizer.step()
  96. scheduler.step()
  97. # Logging (Master Process Only)
  98. if rank == 0 and (batch_idx_global + 1) % config['log_interval'] == 0:
  99. lr = optimizer.param_groups[0]['lr']
  100. print(
  101. f"[Rank {rank}, Epoch {epoch_idx + 1}/{config['epochs']}, Step {i + 1}/{len(train_loader)}] "
  102. f"LR {lr:.6f}, Loss: {loss.item():.4f}"
  103. )
  104. if rank == 0 and logger:
  105. lr = optimizer.param_groups[0]['lr']
  106. logger.log_metric('train_predictor_loss_batch', loss.item(), step=batch_idx_global)
  107. logger.log_metric('train_S1_loss_each_batch', s1_loss.item(), step=batch_idx_global)
  108. logger.log_metric('train_S2_loss_each_batch', s2_loss.item(), step=batch_idx_global)
  109. logger.log_metric('predictor_learning_rate', lr, step=batch_idx_global)
  110. batch_idx_global += 1
  111. # --- Validation Loop ---
  112. model.eval()
  113. tot_val_loss_sum_rank = 0.0
  114. val_batches_processed_rank = 0
  115. with torch.no_grad():
  116. for batch_x, batch_x_stamp in val_loader:
  117. batch_x = batch_x.squeeze(0).to(device, non_blocking=True)
  118. batch_x_stamp = batch_x_stamp.squeeze(0).to(device, non_blocking=True)
  119. token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
  120. token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
  121. token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
  122. logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
  123. val_loss, _, _ = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
  124. tot_val_loss_sum_rank += val_loss.item()
  125. val_batches_processed_rank += 1
  126. # Reduce validation metrics
  127. val_loss_sum_tensor = torch.tensor(tot_val_loss_sum_rank, device=device)
  128. val_batches_tensor = torch.tensor(val_batches_processed_rank, device=device)
  129. dist.all_reduce(val_loss_sum_tensor, op=dist.ReduceOp.SUM)
  130. dist.all_reduce(val_batches_tensor, op=dist.ReduceOp.SUM)
  131. avg_val_loss = val_loss_sum_tensor.item() / val_batches_tensor.item() if val_batches_tensor.item() > 0 else 0
  132. # --- End of Epoch Summary & Checkpointing (Master Process Only) ---
  133. if rank == 0:
  134. print(f"\n--- Epoch {epoch_idx + 1}/{config['epochs']} Summary ---")
  135. print(f"Validation Loss: {avg_val_loss:.4f}")
  136. print(f"Time This Epoch: {format_time(time.time() - epoch_start_time)}")
  137. print(f"Total Time Elapsed: {format_time(time.time() - start_time)}\n")
  138. if logger:
  139. logger.log_metric('val_predictor_loss_epoch', avg_val_loss, epoch=epoch_idx)
  140. if avg_val_loss < best_val_loss:
  141. best_val_loss = avg_val_loss
  142. save_path = f"{save_dir}/checkpoints/best_model"
  143. model.module.save_pretrained(save_path)
  144. print(f"Best model saved to {save_path} (Val Loss: {best_val_loss:.4f})")
  145. dist.barrier()
  146. dt_result['best_val_loss'] = best_val_loss
  147. return dt_result
  148. def main(config: dict):
  149. """Main function to orchestrate the DDP training process."""
  150. rank, world_size, local_rank = setup_ddp()
  151. device = torch.device(f"cuda:{local_rank}")
  152. set_seed(config['seed'], rank)
  153. save_dir = os.path.join(config['save_path'], config['predictor_save_folder_name'])
  154. # Logger and summary setup (master process only)
  155. comet_logger, master_summary = None, {}
  156. if rank == 0:
  157. os.makedirs(os.path.join(save_dir, 'checkpoints'), exist_ok=True)
  158. master_summary = {
  159. 'start_time': strftime("%Y-%m-%dT%H-%M-%S", gmtime()),
  160. 'save_directory': save_dir,
  161. 'world_size': world_size,
  162. }
  163. if config['use_comet']:
  164. comet_logger = comet_ml.Experiment(
  165. api_key=config['comet_config']['api_key'],
  166. project_name=config['comet_config']['project_name'],
  167. workspace=config['comet_config']['workspace'],
  168. )
  169. comet_logger.add_tag(config['comet_tag'])
  170. comet_logger.set_name(config['comet_name'])
  171. comet_logger.log_parameters(config)
  172. print("Comet Logger Initialized.")
  173. dist.barrier()
  174. # Model Initialization
  175. tokenizer = KronosTokenizer.from_pretrained(config['finetuned_tokenizer_path'])
  176. tokenizer.eval().to(device)
  177. model = Kronos.from_pretrained(config['pretrained_predictor_path'])
  178. model.to(device)
  179. model = DDP(model, device_ids=[local_rank], find_unused_parameters=False)
  180. if rank == 0:
  181. print(f"Predictor Model Size: {get_model_size(model.module)}")
  182. # Start Training
  183. dt_result = train_model(
  184. model, tokenizer, device, config, save_dir, comet_logger, rank, world_size
  185. )
  186. if rank == 0:
  187. master_summary['final_result'] = dt_result
  188. with open(os.path.join(save_dir, 'summary.json'), 'w') as f:
  189. json.dump(master_summary, f, indent=4)
  190. print('Training finished. Summary file saved.')
  191. if comet_logger: comet_logger.end()
  192. cleanup_ddp()
  193. if __name__ == '__main__':
  194. # Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_predictor.py
  195. if "WORLD_SIZE" not in os.environ:
  196. raise RuntimeError("This script must be launched with `torchrun`.")
  197. config_instance = Config()
  198. main(config_instance.__dict__)