You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
17 lines
395 B
17 lines
395 B
from .kronos import KronosTokenizer, Kronos, KronosPredictor
|
|
|
|
model_dict = {
|
|
'kronos_tokenizer': KronosTokenizer,
|
|
'kronos': Kronos,
|
|
'kronos_predictor': KronosPredictor
|
|
}
|
|
|
|
|
|
def get_model_class(model_name):
|
|
if model_name in model_dict:
|
|
return model_dict[model_name]
|
|
else:
|
|
print(f"Model {model_name} not found in model_dict")
|
|
raise NotImplementedError
|
|
|
|
|