You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

267 lines
11 KiB

import os
import yaml
from typing import Dict, Any
class ConfigLoader:
def __init__(self, config_path: str):
self.config_path = config_path
self.config = self._load_config()
def _load_config(self) -> Dict[str, Any]:
if not os.path.exists(self.config_path):
raise FileNotFoundError(f"config file not found: {self.config_path}")
with open(self.config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
config = self._resolve_dynamic_paths(config)
return config
def _resolve_dynamic_paths(self, config: Dict[str, Any]) -> Dict[str, Any]:
exp_name = config.get('model_paths', {}).get('exp_name', '')
if not exp_name:
return config
base_path = config.get('model_paths', {}).get('base_path', '')
path_templates = {
'base_save_path': f"{base_path}/{exp_name}",
'finetuned_tokenizer': f"{base_path}/{exp_name}/tokenizer/best_model"
}
if 'model_paths' in config:
for key, template in path_templates.items():
if key in config['model_paths']:
# only use template when the original value is empty string
current_value = config['model_paths'][key]
if current_value == "" or current_value is None:
config['model_paths'][key] = template
else:
# if the original value is not empty, use template to replace the {exp_name} placeholder
if isinstance(current_value, str) and '{exp_name}' in current_value:
config['model_paths'][key] = current_value.format(exp_name=exp_name)
return config
def get(self, key: str, default=None):
keys = key.split('.')
value = self.config
try:
for k in keys:
value = value[k]
return value
except (KeyError, TypeError):
return default
def get_data_config(self) -> Dict[str, Any]:
return self.config.get('data', {})
def get_training_config(self) -> Dict[str, Any]:
return self.config.get('training', {})
def get_model_paths(self) -> Dict[str, str]:
return self.config.get('model_paths', {})
def get_experiment_config(self) -> Dict[str, Any]:
return self.config.get('experiment', {})
def get_device_config(self) -> Dict[str, Any]:
return self.config.get('device', {})
def get_distributed_config(self) -> Dict[str, Any]:
return self.config.get('distributed', {})
def update_config(self, updates: Dict[str, Any]):
def update_nested_dict(d, u):
for k, v in u.items():
if isinstance(v, dict):
d[k] = update_nested_dict(d.get(k, {}), v)
else:
d[k] = v
return d
self.config = update_nested_dict(self.config, updates)
def save_config(self, save_path: str = None):
if save_path is None:
save_path = self.config_path
with open(save_path, 'w', encoding='utf-8') as f:
yaml.dump(self.config, f, default_flow_style=False, allow_unicode=True, indent=2)
def print_config(self):
print("=" * 50)
print("Current configuration:")
print("=" * 50)
yaml.dump(self.config, default_flow_style=False, allow_unicode=True, indent=2)
print("=" * 50)
class CustomFinetuneConfig:
def __init__(self, config_path: str = None):
if config_path is None:
config_path = os.path.join(os.path.dirname(__file__), 'config.yaml')
self.loader = ConfigLoader(config_path)
self._load_all_configs()
def _load_all_configs(self):
data_config = self.loader.get_data_config()
self.data_path = data_config.get('data_path')
self.lookback_window = data_config.get('lookback_window', 512)
self.predict_window = data_config.get('predict_window', 48)
self.max_context = data_config.get('max_context', 512)
self.clip = data_config.get('clip', 5.0)
self.train_ratio = data_config.get('train_ratio', 0.9)
self.val_ratio = data_config.get('val_ratio', 0.1)
self.test_ratio = data_config.get('test_ratio', 0.0)
# training configuration
training_config = self.loader.get_training_config()
# support training epochs of tokenizer and basemodel separately
self.tokenizer_epochs = training_config.get('tokenizer_epochs', 30)
self.basemodel_epochs = training_config.get('basemodel_epochs', 30)
if 'epochs' in training_config and 'tokenizer_epochs' not in training_config:
self.tokenizer_epochs = training_config.get('epochs', 30)
if 'epochs' in training_config and 'basemodel_epochs' not in training_config:
self.basemodel_epochs = training_config.get('epochs', 30)
self.batch_size = training_config.get('batch_size', 160)
self.log_interval = training_config.get('log_interval', 50)
self.num_workers = training_config.get('num_workers', 6)
self.seed = training_config.get('seed', 100)
self.tokenizer_learning_rate = training_config.get('tokenizer_learning_rate', 2e-4)
self.predictor_learning_rate = training_config.get('predictor_learning_rate', 4e-5)
self.adam_beta1 = training_config.get('adam_beta1', 0.9)
self.adam_beta2 = training_config.get('adam_beta2', 0.95)
self.adam_weight_decay = training_config.get('adam_weight_decay', 0.1)
self.accumulation_steps = training_config.get('accumulation_steps', 1)
model_paths = self.loader.get_model_paths()
self.exp_name = model_paths.get('exp_name', 'default_experiment')
self.pretrained_tokenizer_path = model_paths.get('pretrained_tokenizer')
self.pretrained_predictor_path = model_paths.get('pretrained_predictor')
self.base_save_path = model_paths.get('base_save_path')
self.tokenizer_save_name = model_paths.get('tokenizer_save_name', 'tokenizer')
self.basemodel_save_name = model_paths.get('basemodel_save_name', 'basemodel')
self.finetuned_tokenizer_path = model_paths.get('finetuned_tokenizer')
experiment_config = self.loader.get_experiment_config()
self.experiment_name = experiment_config.get('name', 'kronos_custom_finetune')
self.experiment_description = experiment_config.get('description', '')
self.use_comet = experiment_config.get('use_comet', False)
self.train_tokenizer = experiment_config.get('train_tokenizer', True)
self.train_basemodel = experiment_config.get('train_basemodel', True)
self.skip_existing = experiment_config.get('skip_existing', False)
unified_pretrained = experiment_config.get('pre_trained', None)
self.pre_trained_tokenizer = experiment_config.get('pre_trained_tokenizer', unified_pretrained if unified_pretrained is not None else True)
self.pre_trained_predictor = experiment_config.get('pre_trained_predictor', unified_pretrained if unified_pretrained is not None else True)
device_config = self.loader.get_device_config()
self.use_cuda = device_config.get('use_cuda', True)
self.device_id = device_config.get('device_id', 0)
distributed_config = self.loader.get_distributed_config()
self.use_ddp = distributed_config.get('use_ddp', False)
self.ddp_backend = distributed_config.get('backend', 'nccl')
self._compute_full_paths()
def _compute_full_paths(self):
self.tokenizer_save_path = os.path.join(self.base_save_path, self.tokenizer_save_name)
self.tokenizer_best_model_path = os.path.join(self.tokenizer_save_path, 'best_model')
self.basemodel_save_path = os.path.join(self.base_save_path, self.basemodel_save_name)
self.basemodel_best_model_path = os.path.join(self.basemodel_save_path, 'best_model')
def get_tokenizer_config(self):
return {
'data_path': self.data_path,
'lookback_window': self.lookback_window,
'predict_window': self.predict_window,
'max_context': self.max_context,
'clip': self.clip,
'train_ratio': self.train_ratio,
'val_ratio': self.val_ratio,
'test_ratio': self.test_ratio,
'epochs': self.tokenizer_epochs,
'batch_size': self.batch_size,
'log_interval': self.log_interval,
'num_workers': self.num_workers,
'seed': self.seed,
'learning_rate': self.tokenizer_learning_rate,
'adam_beta1': self.adam_beta1,
'adam_beta2': self.adam_beta2,
'adam_weight_decay': self.adam_weight_decay,
'accumulation_steps': self.accumulation_steps,
'pretrained_model_path': self.pretrained_tokenizer_path,
'save_path': self.tokenizer_save_path,
'use_comet': self.use_comet
}
def get_basemodel_config(self):
return {
'data_path': self.data_path,
'lookback_window': self.lookback_window,
'predict_window': self.predict_window,
'max_context': self.max_context,
'clip': self.clip,
'train_ratio': self.train_ratio,
'val_ratio': self.val_ratio,
'test_ratio': self.test_ratio,
'epochs': self.basemodel_epochs,
'batch_size': self.batch_size,
'log_interval': self.log_interval,
'num_workers': self.num_workers,
'seed': self.seed,
'predictor_learning_rate': self.predictor_learning_rate,
'tokenizer_learning_rate': self.tokenizer_learning_rate,
'adam_beta1': self.adam_beta1,
'adam_beta2': self.adam_beta2,
'adam_weight_decay': self.adam_weight_decay,
'pretrained_tokenizer_path': self.finetuned_tokenizer_path,
'pretrained_predictor_path': self.pretrained_predictor_path,
'save_path': self.basemodel_save_path,
'use_comet': self.use_comet
}
def print_config_summary(self):
print("=" * 60)
print("Kronos finetuning configuration summary")
print("=" * 60)
print(f"Experiment name: {self.exp_name}")
print(f"Data path: {self.data_path}")
print(f"Lookback window: {self.lookback_window}")
print(f"Predict window: {self.predict_window}")
print(f"Tokenizer training epochs: {self.tokenizer_epochs}")
print(f"Basemodel training epochs: {self.basemodel_epochs}")
print(f"Batch size: {self.batch_size}")
print(f"Tokenizer learning rate: {self.tokenizer_learning_rate}")
print(f"Predictor learning rate: {self.predictor_learning_rate}")
print(f"Train tokenizer: {self.train_tokenizer}")
print(f"Train basemodel: {self.train_basemodel}")
print(f"Skip existing: {self.skip_existing}")
print(f"Use pre-trained tokenizer: {self.pre_trained_tokenizer}")
print(f"Use pre-trained predictor: {self.pre_trained_predictor}")
print(f"Base save path: {self.base_save_path}")
print(f"Tokenizer save path: {self.tokenizer_save_path}")
print(f"Basemodel save path: {self.basemodel_save_path}")
print("=" * 60)