You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

359 lines
13 KiB

1 month ago
  1. import os
  2. import sys
  3. import json
  4. import time
  5. import random
  6. import numpy as np
  7. import torch
  8. import torch.nn.functional as F
  9. from torch.utils.data import DataLoader
  10. from torch.utils.data.distributed import DistributedSampler
  11. from time import gmtime, strftime
  12. import datetime
  13. import logging
  14. from logging.handlers import RotatingFileHandler
  15. import torch.distributed as dist
  16. from torch.nn.parallel import DistributedDataParallel as DDP
  17. sys.path.append("../")
  18. from model import KronosTokenizer
  19. from finetune_base_model import CustomKlineDataset
  20. from config_loader import CustomFinetuneConfig
  21. def set_seed(seed: int, rank: int = 0):
  22. actual_seed = seed
  23. random.seed(actual_seed)
  24. np.random.seed(actual_seed)
  25. torch.manual_seed(actual_seed)
  26. if torch.cuda.is_available():
  27. torch.cuda.manual_seed_all(actual_seed)
  28. torch.backends.cudnn.deterministic = True
  29. torch.backends.cudnn.benchmark = False
  30. def get_model_size(model: torch.nn.Module) -> str:
  31. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  32. if total_params >= 1e9:
  33. return f"{total_params / 1e9:.1f}B"
  34. elif total_params >= 1e6:
  35. return f"{total_params / 1e6:.1f}M"
  36. else:
  37. return f"{total_params / 1e3:.1f}K"
  38. def format_time(seconds: float) -> str:
  39. return str(datetime.timedelta(seconds=int(seconds)))
  40. def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger:
  41. os.makedirs(log_dir, exist_ok=True)
  42. logger = logging.getLogger(f"tokenizer_training_rank_{rank}")
  43. logger.setLevel(logging.INFO)
  44. if logger.handlers:
  45. return logger
  46. log_file = os.path.join(log_dir, f"tokenizer_training_rank_{rank}.log")
  47. file_handler = RotatingFileHandler(
  48. log_file,
  49. maxBytes=10*1024*1024,
  50. backupCount=5,
  51. encoding='utf-8'
  52. )
  53. file_handler.setLevel(logging.INFO)
  54. console_handler = None
  55. if rank == 0:
  56. console_handler = logging.StreamHandler()
  57. console_handler.setLevel(logging.INFO)
  58. formatter = logging.Formatter(
  59. '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  60. datefmt='%Y-%m-%d %H:%M:%S'
  61. )
  62. file_handler.setFormatter(formatter)
  63. if console_handler is not None:
  64. console_handler.setFormatter(formatter)
  65. logger.addHandler(file_handler)
  66. if console_handler is not None:
  67. logger.addHandler(console_handler)
  68. logger.info(f"=== Tokenizer Training Started ===")
  69. logger.info(f"Experiment Name: {exp_name}")
  70. logger.info(f"Log Directory: {log_dir}")
  71. logger.info(f"Rank: {rank}")
  72. logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
  73. return logger
  74. def create_dataloaders(config):
  75. if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
  76. print("Creating tokenizer training data loaders...")
  77. train_dataset = CustomKlineDataset(
  78. data_path=config.data_path,
  79. data_type="train",
  80. lookback_window=config.lookback_window,
  81. predict_window=config.predict_window,
  82. clip=config.clip,
  83. seed=config.seed,
  84. train_ratio=config.train_ratio,
  85. val_ratio=config.val_ratio,
  86. test_ratio=config.test_ratio
  87. )
  88. val_dataset = CustomKlineDataset(
  89. data_path=config.data_path,
  90. data_type="val",
  91. lookback_window=config.lookback_window,
  92. predict_window=config.predict_window,
  93. clip=config.clip,
  94. seed=config.seed + 1,
  95. train_ratio=config.train_ratio,
  96. val_ratio=config.val_ratio,
  97. test_ratio=config.test_ratio
  98. )
  99. use_ddp = dist.is_available() and dist.is_initialized()
  100. train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None
  101. val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None
  102. train_loader = DataLoader(
  103. train_dataset,
  104. batch_size=config.batch_size,
  105. shuffle=(train_sampler is None),
  106. num_workers=config.num_workers,
  107. pin_memory=True,
  108. drop_last=True,
  109. sampler=train_sampler
  110. )
  111. val_loader = DataLoader(
  112. val_dataset,
  113. batch_size=config.batch_size,
  114. shuffle=False,
  115. num_workers=config.num_workers,
  116. pin_memory=True,
  117. drop_last=False,
  118. sampler=val_sampler
  119. )
  120. if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
  121. print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
  122. return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler
  123. def train_tokenizer(model, device, config, save_dir, logger):
  124. logger.info("Starting tokenizer training...")
  125. use_ddp = dist.is_available() and dist.is_initialized()
  126. rank = dist.get_rank() if use_ddp else 0
  127. world_size = dist.get_world_size() if use_ddp else 1
  128. train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config)
  129. optimizer = torch.optim.AdamW(
  130. model.parameters(),
  131. lr=config.tokenizer_learning_rate,
  132. weight_decay=config.adam_weight_decay
  133. )
  134. scheduler = torch.optim.lr_scheduler.OneCycleLR(
  135. optimizer,
  136. max_lr=config.tokenizer_learning_rate,
  137. steps_per_epoch=len(train_loader),
  138. epochs=config.tokenizer_epochs,
  139. pct_start=0.03,
  140. div_factor=10
  141. )
  142. if use_ddp:
  143. local_rank = int(os.environ.get("LOCAL_RANK", "0"))
  144. model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
  145. best_val_loss = float("inf")
  146. batch_idx_global = 0
  147. accumulation_steps = getattr(config, 'accumulation_steps', 1)
  148. for epoch in range(config.tokenizer_epochs):
  149. epoch_start_time = time.time()
  150. model.train()
  151. train_dataset.set_epoch_seed(epoch * 10000)
  152. val_dataset.set_epoch_seed(0)
  153. if train_sampler is not None:
  154. train_sampler.set_epoch(epoch)
  155. for batch_idx, (ori_batch_x, _) in enumerate(train_loader):
  156. ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
  157. current_batch_total_loss = 0.0
  158. for j in range(accumulation_steps):
  159. start_idx = j * (ori_batch_x.shape[0] // accumulation_steps)
  160. end_idx = (j + 1) * (ori_batch_x.shape[0] // accumulation_steps)
  161. batch_x = ori_batch_x[start_idx:end_idx]
  162. zs, bsq_loss, _, _ = (model.module if use_ddp else model)(batch_x)
  163. z_pre, z = zs
  164. recon_loss_pre = F.mse_loss(z_pre, batch_x)
  165. recon_loss_all = F.mse_loss(z, batch_x)
  166. recon_loss = recon_loss_pre + recon_loss_all
  167. loss = (recon_loss + bsq_loss) / 2
  168. loss_scaled = loss / accumulation_steps
  169. current_batch_total_loss += loss.item()
  170. loss_scaled.backward()
  171. torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=2.0)
  172. optimizer.step()
  173. scheduler.step()
  174. optimizer.zero_grad()
  175. if (batch_idx_global + 1) % config.log_interval == 0:
  176. avg_loss = current_batch_total_loss / accumulation_steps
  177. lr = optimizer.param_groups[0]["lr"]
  178. log_msg = (f"[Epoch {epoch+1}/{config.tokenizer_epochs}, Step {batch_idx+1}/{len(train_loader)}] "
  179. f"LR: {lr:.6f}, Loss: {avg_loss:.4f}")
  180. logger.info(log_msg)
  181. if rank == 0:
  182. print(log_msg)
  183. detail_msg = (f" - VQ Loss: {bsq_loss.item():.4f}\n"
  184. f" - Recon Loss Pre: {recon_loss_pre.item():.4f}\n"
  185. f" - Recon Loss All: {recon_loss_all.item():.4f}")
  186. logger.info(detail_msg)
  187. if rank == 0:
  188. print(detail_msg)
  189. batch_idx_global += 1
  190. model.eval()
  191. tot_val_loss_sum_rank = 0.0
  192. val_sample_count_rank = 0
  193. with torch.no_grad():
  194. for ori_batch_x, _ in val_loader:
  195. ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
  196. zs, _, _, _ = (model.module if use_ddp else model)(ori_batch_x)
  197. _, z = zs
  198. val_loss_item = F.mse_loss(z, ori_batch_x)
  199. tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0)
  200. val_sample_count_rank += ori_batch_x.size(0)
  201. if use_ddp:
  202. tensor_sum = torch.tensor([tot_val_loss_sum_rank, val_sample_count_rank], dtype=torch.float64, device=device)
  203. dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM)
  204. tot_val_loss_all = tensor_sum[0].item()
  205. val_count_all = int(tensor_sum[1].item())
  206. avg_val_loss = (tot_val_loss_all / val_count_all) if val_count_all > 0 else 0.0
  207. else:
  208. avg_val_loss = tot_val_loss_sum_rank / val_sample_count_rank if val_sample_count_rank > 0 else 0
  209. epoch_time = time.time() - epoch_start_time
  210. epoch_summary = (f"\n--- Epoch {epoch+1}/{config.tokenizer_epochs} Summary ---\n"
  211. f"Validation Loss: {avg_val_loss:.4f}\n"
  212. f"Epoch Time: {format_time(epoch_time)}\n"
  213. f"Total Training Time: {format_time(time.time() - epoch_start_time)}\n")
  214. logger.info(epoch_summary)
  215. if rank == 0:
  216. print(epoch_summary)
  217. if avg_val_loss < best_val_loss:
  218. best_val_loss = avg_val_loss
  219. if rank == 0:
  220. model_save_path = os.path.join(save_dir, "best_model")
  221. os.makedirs(model_save_path, exist_ok=True)
  222. (model.module if use_ddp else model).save_pretrained(model_save_path)
  223. save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})"
  224. logger.info(save_msg)
  225. print(save_msg)
  226. return best_val_loss
  227. def main():
  228. import argparse
  229. parser = argparse.ArgumentParser(description='Kronos Tokenizer Fine-tuning Training')
  230. parser.add_argument('--config', type=str, default='config.yaml',
  231. help='Configuration file path (default: config.yaml)')
  232. args = parser.parse_args()
  233. config = CustomFinetuneConfig(args.config)
  234. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  235. print(f"Using device: {device}")
  236. config = CustomFinetuneConfig(args.config)
  237. os.makedirs(config.tokenizer_save_path, exist_ok=True)
  238. log_dir = os.path.join(config.base_save_path, "logs")
  239. logger = setup_logging(config.exp_name, log_dir, 0)
  240. set_seed(config.seed)
  241. # 加载预训练tokenizer
  242. if getattr(config, 'pre_trained_tokenizer', True):
  243. logger.info("Loading pretrained tokenizer...")
  244. print("Loading pretrained tokenizer...")
  245. tokenizer = KronosTokenizer.from_pretrained(config.pretrained_tokenizer_path)
  246. else:
  247. print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture")
  248. import json, os
  249. cfg_path = os.path.join(config.pretrained_tokenizer_path, 'config.json')
  250. with open(cfg_path, 'r') as f:
  251. arch = json.load(f)
  252. tokenizer = KronosTokenizer(
  253. d_in=arch.get('d_in', 6),
  254. d_model=arch.get('d_model', 256),
  255. n_heads=arch.get('n_heads', 4),
  256. ff_dim=arch.get('ff_dim', 512),
  257. n_enc_layers=arch.get('n_enc_layers', 4),
  258. n_dec_layers=arch.get('n_dec_layers', 4),
  259. ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
  260. attn_dropout_p=arch.get('attn_dropout_p', 0.0),
  261. resid_dropout_p=arch.get('resid_dropout_p', 0.0),
  262. s1_bits=arch.get('s1_bits', 10),
  263. s2_bits=arch.get('s2_bits', 10),
  264. beta=arch.get('beta', 0.05),
  265. gamma0=arch.get('gamma0', 1.0),
  266. gamma=arch.get('gamma', 1.1),
  267. zeta=arch.get('zeta', 0.05),
  268. group_size=arch.get('group_size', 4)
  269. )
  270. tokenizer = tokenizer.to(device)
  271. model_size = get_model_size(tokenizer)
  272. logger.info(f"Tokenizer parameters: {model_size}")
  273. print(f"Tokenizer parameters: {model_size}")
  274. logger.info("=== Training Configuration ===")
  275. logger.info(f"Data path: {config.data_path}")
  276. logger.info(f"Lookback window: {config.lookback_window}")
  277. logger.info(f"Predict window: {config.predict_window}")
  278. logger.info(f"Batch size: {config.batch_size}")
  279. logger.info(f"Learning rate: {config.tokenizer_learning_rate}")
  280. logger.info(f"Training epochs: {config.tokenizer_epochs}")
  281. logger.info(f"Device: {device}")
  282. logger.info(f"Distributed training: False")
  283. logger.info("Starting tokenizer fine-tuning training...")
  284. print("Starting tokenizer fine-tuning training...")
  285. best_val_loss = train_tokenizer(tokenizer, device, config, config.tokenizer_save_path, logger)
  286. final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.tokenizer_save_path}"
  287. logger.info(final_msg)
  288. print(final_msg)
  289. if __name__ == "__main__":
  290. main()