import os import sys import json import time import random import numpy as np import torch import torch.nn.functional as F from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from time import gmtime, strftime import datetime import logging from logging.handlers import RotatingFileHandler import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP sys.path.append("../") from model import KronosTokenizer from finetune_base_model import CustomKlineDataset from config_loader import CustomFinetuneConfig def set_seed(seed: int, rank: int = 0): actual_seed = seed random.seed(actual_seed) np.random.seed(actual_seed) torch.manual_seed(actual_seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(actual_seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def get_model_size(model: torch.nn.Module) -> str: total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) if total_params >= 1e9: return f"{total_params / 1e9:.1f}B" elif total_params >= 1e6: return f"{total_params / 1e6:.1f}M" else: return f"{total_params / 1e3:.1f}K" def format_time(seconds: float) -> str: return str(datetime.timedelta(seconds=int(seconds))) def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger: os.makedirs(log_dir, exist_ok=True) logger = logging.getLogger(f"tokenizer_training_rank_{rank}") logger.setLevel(logging.INFO) if logger.handlers: return logger log_file = os.path.join(log_dir, f"tokenizer_training_rank_{rank}.log") file_handler = RotatingFileHandler( log_file, maxBytes=10*1024*1024, backupCount=5, encoding='utf-8' ) file_handler.setLevel(logging.INFO) console_handler = None if rank == 0: console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) file_handler.setFormatter(formatter) if console_handler is not None: console_handler.setFormatter(formatter) logger.addHandler(file_handler) if console_handler is not None: logger.addHandler(console_handler) logger.info(f"=== Tokenizer Training Started ===") logger.info(f"Experiment Name: {exp_name}") logger.info(f"Log Directory: {log_dir}") logger.info(f"Rank: {rank}") logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") return logger def create_dataloaders(config): if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0: print("Creating tokenizer training data loaders...") train_dataset = CustomKlineDataset( data_path=config.data_path, data_type="train", lookback_window=config.lookback_window, predict_window=config.predict_window, clip=config.clip, seed=config.seed, train_ratio=config.train_ratio, val_ratio=config.val_ratio, test_ratio=config.test_ratio ) val_dataset = CustomKlineDataset( data_path=config.data_path, data_type="val", lookback_window=config.lookback_window, predict_window=config.predict_window, clip=config.clip, seed=config.seed + 1, train_ratio=config.train_ratio, val_ratio=config.val_ratio, test_ratio=config.test_ratio ) use_ddp = dist.is_available() and dist.is_initialized() train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None train_loader = DataLoader( train_dataset, batch_size=config.batch_size, shuffle=(train_sampler is None), num_workers=config.num_workers, pin_memory=True, drop_last=True, sampler=train_sampler ) val_loader = DataLoader( val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=True, drop_last=False, sampler=val_sampler ) if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0: print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}") return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler def train_tokenizer(model, device, config, save_dir, logger): logger.info("Starting tokenizer training...") use_ddp = dist.is_available() and dist.is_initialized() rank = dist.get_rank() if use_ddp else 0 world_size = dist.get_world_size() if use_ddp else 1 train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config) optimizer = torch.optim.AdamW( model.parameters(), lr=config.tokenizer_learning_rate, weight_decay=config.adam_weight_decay ) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=config.tokenizer_learning_rate, steps_per_epoch=len(train_loader), epochs=config.tokenizer_epochs, pct_start=0.03, div_factor=10 ) if use_ddp: local_rank = int(os.environ.get("LOCAL_RANK", "0")) model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False) best_val_loss = float("inf") batch_idx_global = 0 accumulation_steps = getattr(config, 'accumulation_steps', 1) for epoch in range(config.tokenizer_epochs): epoch_start_time = time.time() model.train() train_dataset.set_epoch_seed(epoch * 10000) val_dataset.set_epoch_seed(0) if train_sampler is not None: train_sampler.set_epoch(epoch) for batch_idx, (ori_batch_x, _) in enumerate(train_loader): ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True) current_batch_total_loss = 0.0 for j in range(accumulation_steps): start_idx = j * (ori_batch_x.shape[0] // accumulation_steps) end_idx = (j + 1) * (ori_batch_x.shape[0] // accumulation_steps) batch_x = ori_batch_x[start_idx:end_idx] zs, bsq_loss, _, _ = (model.module if use_ddp else model)(batch_x) z_pre, z = zs recon_loss_pre = F.mse_loss(z_pre, batch_x) recon_loss_all = F.mse_loss(z, batch_x) recon_loss = recon_loss_pre + recon_loss_all loss = (recon_loss + bsq_loss) / 2 loss_scaled = loss / accumulation_steps current_batch_total_loss += loss.item() loss_scaled.backward() torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=2.0) optimizer.step() scheduler.step() optimizer.zero_grad() if (batch_idx_global + 1) % config.log_interval == 0: avg_loss = current_batch_total_loss / accumulation_steps lr = optimizer.param_groups[0]["lr"] log_msg = (f"[Epoch {epoch+1}/{config.tokenizer_epochs}, Step {batch_idx+1}/{len(train_loader)}] " f"LR: {lr:.6f}, Loss: {avg_loss:.4f}") logger.info(log_msg) if rank == 0: print(log_msg) detail_msg = (f" - VQ Loss: {bsq_loss.item():.4f}\n" f" - Recon Loss Pre: {recon_loss_pre.item():.4f}\n" f" - Recon Loss All: {recon_loss_all.item():.4f}") logger.info(detail_msg) if rank == 0: print(detail_msg) batch_idx_global += 1 model.eval() tot_val_loss_sum_rank = 0.0 val_sample_count_rank = 0 with torch.no_grad(): for ori_batch_x, _ in val_loader: ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True) zs, _, _, _ = (model.module if use_ddp else model)(ori_batch_x) _, z = zs val_loss_item = F.mse_loss(z, ori_batch_x) tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0) val_sample_count_rank += ori_batch_x.size(0) if use_ddp: tensor_sum = torch.tensor([tot_val_loss_sum_rank, val_sample_count_rank], dtype=torch.float64, device=device) dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM) tot_val_loss_all = tensor_sum[0].item() val_count_all = int(tensor_sum[1].item()) avg_val_loss = (tot_val_loss_all / val_count_all) if val_count_all > 0 else 0.0 else: avg_val_loss = tot_val_loss_sum_rank / val_sample_count_rank if val_sample_count_rank > 0 else 0 epoch_time = time.time() - epoch_start_time epoch_summary = (f"\n--- Epoch {epoch+1}/{config.tokenizer_epochs} Summary ---\n" f"Validation Loss: {avg_val_loss:.4f}\n" f"Epoch Time: {format_time(epoch_time)}\n" f"Total Training Time: {format_time(time.time() - epoch_start_time)}\n") logger.info(epoch_summary) if rank == 0: print(epoch_summary) if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss if rank == 0: model_save_path = os.path.join(save_dir, "best_model") os.makedirs(model_save_path, exist_ok=True) (model.module if use_ddp else model).save_pretrained(model_save_path) save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})" logger.info(save_msg) print(save_msg) return best_val_loss def main(): import argparse parser = argparse.ArgumentParser(description='Kronos Tokenizer Fine-tuning Training') parser.add_argument('--config', type=str, default='config.yaml', help='Configuration file path (default: config.yaml)') args = parser.parse_args() config = CustomFinetuneConfig(args.config) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") config = CustomFinetuneConfig(args.config) os.makedirs(config.tokenizer_save_path, exist_ok=True) log_dir = os.path.join(config.base_save_path, "logs") logger = setup_logging(config.exp_name, log_dir, 0) set_seed(config.seed) # 加载预训练tokenizer if getattr(config, 'pre_trained_tokenizer', True): logger.info("Loading pretrained tokenizer...") print("Loading pretrained tokenizer...") tokenizer = KronosTokenizer.from_pretrained(config.pretrained_tokenizer_path) else: print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture") import json, os cfg_path = os.path.join(config.pretrained_tokenizer_path, 'config.json') with open(cfg_path, 'r') as f: arch = json.load(f) tokenizer = KronosTokenizer( d_in=arch.get('d_in', 6), d_model=arch.get('d_model', 256), n_heads=arch.get('n_heads', 4), ff_dim=arch.get('ff_dim', 512), n_enc_layers=arch.get('n_enc_layers', 4), n_dec_layers=arch.get('n_dec_layers', 4), ffn_dropout_p=arch.get('ffn_dropout_p', 0.0), attn_dropout_p=arch.get('attn_dropout_p', 0.0), resid_dropout_p=arch.get('resid_dropout_p', 0.0), s1_bits=arch.get('s1_bits', 10), s2_bits=arch.get('s2_bits', 10), beta=arch.get('beta', 0.05), gamma0=arch.get('gamma0', 1.0), gamma=arch.get('gamma', 1.1), zeta=arch.get('zeta', 0.05), group_size=arch.get('group_size', 4) ) tokenizer = tokenizer.to(device) model_size = get_model_size(tokenizer) logger.info(f"Tokenizer parameters: {model_size}") print(f"Tokenizer parameters: {model_size}") logger.info("=== Training Configuration ===") logger.info(f"Data path: {config.data_path}") logger.info(f"Lookback window: {config.lookback_window}") logger.info(f"Predict window: {config.predict_window}") logger.info(f"Batch size: {config.batch_size}") logger.info(f"Learning rate: {config.tokenizer_learning_rate}") logger.info(f"Training epochs: {config.tokenizer_epochs}") logger.info(f"Device: {device}") logger.info(f"Distributed training: False") logger.info("Starting tokenizer fine-tuning training...") print("Starting tokenizer fine-tuning training...") best_val_loss = train_tokenizer(tokenizer, device, config, config.tokenizer_save_path, logger) final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.tokenizer_save_path}" logger.info(final_msg) print(final_msg) if __name__ == "__main__": main()