You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

267 lines
11 KiB

1 month ago
  1. import os
  2. import yaml
  3. from typing import Dict, Any
  4. class ConfigLoader:
  5. def __init__(self, config_path: str):
  6. self.config_path = config_path
  7. self.config = self._load_config()
  8. def _load_config(self) -> Dict[str, Any]:
  9. if not os.path.exists(self.config_path):
  10. raise FileNotFoundError(f"config file not found: {self.config_path}")
  11. with open(self.config_path, 'r', encoding='utf-8') as f:
  12. config = yaml.safe_load(f)
  13. config = self._resolve_dynamic_paths(config)
  14. return config
  15. def _resolve_dynamic_paths(self, config: Dict[str, Any]) -> Dict[str, Any]:
  16. exp_name = config.get('model_paths', {}).get('exp_name', '')
  17. if not exp_name:
  18. return config
  19. base_path = config.get('model_paths', {}).get('base_path', '')
  20. path_templates = {
  21. 'base_save_path': f"{base_path}/{exp_name}",
  22. 'finetuned_tokenizer': f"{base_path}/{exp_name}/tokenizer/best_model"
  23. }
  24. if 'model_paths' in config:
  25. for key, template in path_templates.items():
  26. if key in config['model_paths']:
  27. # only use template when the original value is empty string
  28. current_value = config['model_paths'][key]
  29. if current_value == "" or current_value is None:
  30. config['model_paths'][key] = template
  31. else:
  32. # if the original value is not empty, use template to replace the {exp_name} placeholder
  33. if isinstance(current_value, str) and '{exp_name}' in current_value:
  34. config['model_paths'][key] = current_value.format(exp_name=exp_name)
  35. return config
  36. def get(self, key: str, default=None):
  37. keys = key.split('.')
  38. value = self.config
  39. try:
  40. for k in keys:
  41. value = value[k]
  42. return value
  43. except (KeyError, TypeError):
  44. return default
  45. def get_data_config(self) -> Dict[str, Any]:
  46. return self.config.get('data', {})
  47. def get_training_config(self) -> Dict[str, Any]:
  48. return self.config.get('training', {})
  49. def get_model_paths(self) -> Dict[str, str]:
  50. return self.config.get('model_paths', {})
  51. def get_experiment_config(self) -> Dict[str, Any]:
  52. return self.config.get('experiment', {})
  53. def get_device_config(self) -> Dict[str, Any]:
  54. return self.config.get('device', {})
  55. def get_distributed_config(self) -> Dict[str, Any]:
  56. return self.config.get('distributed', {})
  57. def update_config(self, updates: Dict[str, Any]):
  58. def update_nested_dict(d, u):
  59. for k, v in u.items():
  60. if isinstance(v, dict):
  61. d[k] = update_nested_dict(d.get(k, {}), v)
  62. else:
  63. d[k] = v
  64. return d
  65. self.config = update_nested_dict(self.config, updates)
  66. def save_config(self, save_path: str = None):
  67. if save_path is None:
  68. save_path = self.config_path
  69. with open(save_path, 'w', encoding='utf-8') as f:
  70. yaml.dump(self.config, f, default_flow_style=False, allow_unicode=True, indent=2)
  71. def print_config(self):
  72. print("=" * 50)
  73. print("Current configuration:")
  74. print("=" * 50)
  75. yaml.dump(self.config, default_flow_style=False, allow_unicode=True, indent=2)
  76. print("=" * 50)
  77. class CustomFinetuneConfig:
  78. def __init__(self, config_path: str = None):
  79. if config_path is None:
  80. config_path = os.path.join(os.path.dirname(__file__), 'config.yaml')
  81. self.loader = ConfigLoader(config_path)
  82. self._load_all_configs()
  83. def _load_all_configs(self):
  84. data_config = self.loader.get_data_config()
  85. self.data_path = data_config.get('data_path')
  86. self.lookback_window = data_config.get('lookback_window', 512)
  87. self.predict_window = data_config.get('predict_window', 48)
  88. self.max_context = data_config.get('max_context', 512)
  89. self.clip = data_config.get('clip', 5.0)
  90. self.train_ratio = data_config.get('train_ratio', 0.9)
  91. self.val_ratio = data_config.get('val_ratio', 0.1)
  92. self.test_ratio = data_config.get('test_ratio', 0.0)
  93. # training configuration
  94. training_config = self.loader.get_training_config()
  95. # support training epochs of tokenizer and basemodel separately
  96. self.tokenizer_epochs = training_config.get('tokenizer_epochs', 30)
  97. self.basemodel_epochs = training_config.get('basemodel_epochs', 30)
  98. if 'epochs' in training_config and 'tokenizer_epochs' not in training_config:
  99. self.tokenizer_epochs = training_config.get('epochs', 30)
  100. if 'epochs' in training_config and 'basemodel_epochs' not in training_config:
  101. self.basemodel_epochs = training_config.get('epochs', 30)
  102. self.batch_size = training_config.get('batch_size', 160)
  103. self.log_interval = training_config.get('log_interval', 50)
  104. self.num_workers = training_config.get('num_workers', 6)
  105. self.seed = training_config.get('seed', 100)
  106. self.tokenizer_learning_rate = training_config.get('tokenizer_learning_rate', 2e-4)
  107. self.predictor_learning_rate = training_config.get('predictor_learning_rate', 4e-5)
  108. self.adam_beta1 = training_config.get('adam_beta1', 0.9)
  109. self.adam_beta2 = training_config.get('adam_beta2', 0.95)
  110. self.adam_weight_decay = training_config.get('adam_weight_decay', 0.1)
  111. self.accumulation_steps = training_config.get('accumulation_steps', 1)
  112. model_paths = self.loader.get_model_paths()
  113. self.exp_name = model_paths.get('exp_name', 'default_experiment')
  114. self.pretrained_tokenizer_path = model_paths.get('pretrained_tokenizer')
  115. self.pretrained_predictor_path = model_paths.get('pretrained_predictor')
  116. self.base_save_path = model_paths.get('base_save_path')
  117. self.tokenizer_save_name = model_paths.get('tokenizer_save_name', 'tokenizer')
  118. self.basemodel_save_name = model_paths.get('basemodel_save_name', 'basemodel')
  119. self.finetuned_tokenizer_path = model_paths.get('finetuned_tokenizer')
  120. experiment_config = self.loader.get_experiment_config()
  121. self.experiment_name = experiment_config.get('name', 'kronos_custom_finetune')
  122. self.experiment_description = experiment_config.get('description', '')
  123. self.use_comet = experiment_config.get('use_comet', False)
  124. self.train_tokenizer = experiment_config.get('train_tokenizer', True)
  125. self.train_basemodel = experiment_config.get('train_basemodel', True)
  126. self.skip_existing = experiment_config.get('skip_existing', False)
  127. unified_pretrained = experiment_config.get('pre_trained', None)
  128. self.pre_trained_tokenizer = experiment_config.get('pre_trained_tokenizer', unified_pretrained if unified_pretrained is not None else True)
  129. self.pre_trained_predictor = experiment_config.get('pre_trained_predictor', unified_pretrained if unified_pretrained is not None else True)
  130. device_config = self.loader.get_device_config()
  131. self.use_cuda = device_config.get('use_cuda', True)
  132. self.device_id = device_config.get('device_id', 0)
  133. distributed_config = self.loader.get_distributed_config()
  134. self.use_ddp = distributed_config.get('use_ddp', False)
  135. self.ddp_backend = distributed_config.get('backend', 'nccl')
  136. self._compute_full_paths()
  137. def _compute_full_paths(self):
  138. self.tokenizer_save_path = os.path.join(self.base_save_path, self.tokenizer_save_name)
  139. self.tokenizer_best_model_path = os.path.join(self.tokenizer_save_path, 'best_model')
  140. self.basemodel_save_path = os.path.join(self.base_save_path, self.basemodel_save_name)
  141. self.basemodel_best_model_path = os.path.join(self.basemodel_save_path, 'best_model')
  142. def get_tokenizer_config(self):
  143. return {
  144. 'data_path': self.data_path,
  145. 'lookback_window': self.lookback_window,
  146. 'predict_window': self.predict_window,
  147. 'max_context': self.max_context,
  148. 'clip': self.clip,
  149. 'train_ratio': self.train_ratio,
  150. 'val_ratio': self.val_ratio,
  151. 'test_ratio': self.test_ratio,
  152. 'epochs': self.tokenizer_epochs,
  153. 'batch_size': self.batch_size,
  154. 'log_interval': self.log_interval,
  155. 'num_workers': self.num_workers,
  156. 'seed': self.seed,
  157. 'learning_rate': self.tokenizer_learning_rate,
  158. 'adam_beta1': self.adam_beta1,
  159. 'adam_beta2': self.adam_beta2,
  160. 'adam_weight_decay': self.adam_weight_decay,
  161. 'accumulation_steps': self.accumulation_steps,
  162. 'pretrained_model_path': self.pretrained_tokenizer_path,
  163. 'save_path': self.tokenizer_save_path,
  164. 'use_comet': self.use_comet
  165. }
  166. def get_basemodel_config(self):
  167. return {
  168. 'data_path': self.data_path,
  169. 'lookback_window': self.lookback_window,
  170. 'predict_window': self.predict_window,
  171. 'max_context': self.max_context,
  172. 'clip': self.clip,
  173. 'train_ratio': self.train_ratio,
  174. 'val_ratio': self.val_ratio,
  175. 'test_ratio': self.test_ratio,
  176. 'epochs': self.basemodel_epochs,
  177. 'batch_size': self.batch_size,
  178. 'log_interval': self.log_interval,
  179. 'num_workers': self.num_workers,
  180. 'seed': self.seed,
  181. 'predictor_learning_rate': self.predictor_learning_rate,
  182. 'tokenizer_learning_rate': self.tokenizer_learning_rate,
  183. 'adam_beta1': self.adam_beta1,
  184. 'adam_beta2': self.adam_beta2,
  185. 'adam_weight_decay': self.adam_weight_decay,
  186. 'pretrained_tokenizer_path': self.finetuned_tokenizer_path,
  187. 'pretrained_predictor_path': self.pretrained_predictor_path,
  188. 'save_path': self.basemodel_save_path,
  189. 'use_comet': self.use_comet
  190. }
  191. def print_config_summary(self):
  192. print("=" * 60)
  193. print("Kronos finetuning configuration summary")
  194. print("=" * 60)
  195. print(f"Experiment name: {self.exp_name}")
  196. print(f"Data path: {self.data_path}")
  197. print(f"Lookback window: {self.lookback_window}")
  198. print(f"Predict window: {self.predict_window}")
  199. print(f"Tokenizer training epochs: {self.tokenizer_epochs}")
  200. print(f"Basemodel training epochs: {self.basemodel_epochs}")
  201. print(f"Batch size: {self.batch_size}")
  202. print(f"Tokenizer learning rate: {self.tokenizer_learning_rate}")
  203. print(f"Predictor learning rate: {self.predictor_learning_rate}")
  204. print(f"Train tokenizer: {self.train_tokenizer}")
  205. print(f"Train basemodel: {self.train_basemodel}")
  206. print(f"Skip existing: {self.skip_existing}")
  207. print(f"Use pre-trained tokenizer: {self.pre_trained_tokenizer}")
  208. print(f"Use pre-trained predictor: {self.pre_trained_predictor}")
  209. print(f"Base save path: {self.base_save_path}")
  210. print(f"Tokenizer save path: {self.tokenizer_save_path}")
  211. print(f"Basemodel save path: {self.basemodel_save_path}")
  212. print("=" * 60)