You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

131 lines
6.0 KiB

1 month ago
  1. import os
  2. class Config:
  3. """
  4. Configuration class for the entire project.
  5. """
  6. def __init__(self):
  7. # =================================================================
  8. # Data & Feature Parameters
  9. # =================================================================
  10. # TODO: Update this path to your Qlib data directory.
  11. self.qlib_data_path = "~/.qlib/qlib_data/cn_data"
  12. self.instrument = 'csi300'
  13. # Overall time range for data loading from Qlib.
  14. self.dataset_begin_time = "2011-01-01"
  15. self.dataset_end_time = '2025-06-05'
  16. # Sliding window parameters for creating samples.
  17. self.lookback_window = 90 # Number of past time steps for input.
  18. self.predict_window = 10 # Number of future time steps for prediction.
  19. self.max_context = 512 # Maximum context length for the model.
  20. # Features to be used from the raw data.
  21. self.feature_list = ['open', 'high', 'low', 'close', 'vol', 'amt']
  22. # Time-based features to be generated.
  23. self.time_feature_list = ['minute', 'hour', 'weekday', 'day', 'month']
  24. # =================================================================
  25. # Dataset Splitting & Paths
  26. # =================================================================
  27. # Note: The validation/test set starts earlier than the training/validation set ends
  28. # to account for the `lookback_window`.
  29. self.train_time_range = ["2011-01-01", "2022-12-31"]
  30. self.val_time_range = ["2022-09-01", "2024-06-30"]
  31. self.test_time_range = ["2024-04-01", "2025-06-05"]
  32. self.backtest_time_range = ["2024-07-01", "2025-06-05"]
  33. # TODO: Directory to save the processed, pickled datasets.
  34. self.dataset_path = "./data/processed_datasets"
  35. # =================================================================
  36. # Training Hyperparameters
  37. # =================================================================
  38. self.clip = 5.0 # Clipping value for normalized data to prevent outliers.
  39. self.epochs = 30
  40. self.log_interval = 100 # Log training status every N batches.
  41. self.batch_size = 50 # Batch size per GPU.
  42. # Number of samples to draw for one "epoch" of training/validation.
  43. # This is useful for large datasets where a true epoch is too long.
  44. self.n_train_iter = 2000 * self.batch_size
  45. self.n_val_iter = 400 * self.batch_size
  46. # Learning rates for different model components.
  47. self.tokenizer_learning_rate = 2e-4
  48. self.predictor_learning_rate = 4e-5
  49. # Gradient accumulation to simulate a larger batch size.
  50. self.accumulation_steps = 1
  51. # AdamW optimizer parameters.
  52. self.adam_beta1 = 0.9
  53. self.adam_beta2 = 0.95
  54. self.adam_weight_decay = 0.1
  55. # Miscellaneous
  56. self.seed = 100 # Global random seed for reproducibility.
  57. # =================================================================
  58. # Experiment Logging & Saving
  59. # =================================================================
  60. self.use_comet = True # Set to False if you don't want to use Comet ML
  61. self.comet_config = {
  62. # It is highly recommended to load secrets from environment variables
  63. # for security purposes. Example: os.getenv("COMET_API_KEY")
  64. "api_key": "YOUR_COMET_API_KEY",
  65. "project_name": "Kronos-Finetune-Demo",
  66. "workspace": "your_comet_workspace" # TODO: Change to your Comet ML workspace name
  67. }
  68. self.comet_tag = 'finetune_demo'
  69. self.comet_name = 'finetune_demo'
  70. # Base directory for saving model checkpoints and results.
  71. # Using a general 'outputs' directory is a common practice.
  72. self.save_path = "./outputs/models"
  73. self.tokenizer_save_folder_name = 'finetune_tokenizer_demo'
  74. self.predictor_save_folder_name = 'finetune_predictor_demo'
  75. self.backtest_save_folder_name = 'finetune_backtest_demo'
  76. # Path for backtesting results.
  77. self.backtest_result_path = "./outputs/backtest_results"
  78. # =================================================================
  79. # Model & Checkpoint Paths
  80. # =================================================================
  81. # TODO: Update these paths to your pretrained model locations.
  82. # These can be local paths or Hugging Face Hub model identifiers.
  83. self.pretrained_tokenizer_path = "path/to/your/Kronos-Tokenizer-base"
  84. self.pretrained_predictor_path = "path/to/your/Kronos-small"
  85. # Paths to the fine-tuned models, derived from the save_path.
  86. # These will be generated automatically during training.
  87. self.finetuned_tokenizer_path = f"{self.save_path}/{self.tokenizer_save_folder_name}/checkpoints/best_model"
  88. self.finetuned_predictor_path = f"{self.save_path}/{self.predictor_save_folder_name}/checkpoints/best_model"
  89. # =================================================================
  90. # Backtesting Parameters
  91. # =================================================================
  92. self.backtest_n_symbol_hold = 50 # Number of symbols to hold in the portfolio.
  93. self.backtest_n_symbol_drop = 5 # Number of symbols to drop from the pool.
  94. self.backtest_hold_thresh = 5 # Minimum holding period for a stock.
  95. self.inference_T = 0.6
  96. self.inference_top_p = 0.9
  97. self.inference_top_k = 0
  98. self.inference_sample_count = 5
  99. self.backtest_batch_size = 1000
  100. self.backtest_benchmark = self._set_benchmark(self.instrument)
  101. def _set_benchmark(self, instrument):
  102. dt_benchmark = {
  103. 'csi800': "SH000906",
  104. 'csi1000': "SH000852",
  105. 'csi300': "SH000300",
  106. }
  107. if instrument in dt_benchmark:
  108. return dt_benchmark[instrument]
  109. else:
  110. raise ValueError(f"Benchmark not defined for instrument: {instrument}")