You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

17 lines
395 B

1 month ago
  1. from .kronos import KronosTokenizer, Kronos, KronosPredictor
  2. model_dict = {
  3. 'kronos_tokenizer': KronosTokenizer,
  4. 'kronos': Kronos,
  5. 'kronos_predictor': KronosPredictor
  6. }
  7. def get_model_class(model_name):
  8. if model_name in model_dict:
  9. return model_dict[model_name]
  10. else:
  11. print(f"Model {model_name} not found in model_dict")
  12. raise NotImplementedError