SFEREWQW commited on
Commit
1a1dbc7
·
verified ·
1 Parent(s): 2de723a

Update pdf_extract_kit/utils/config_loader.py

Browse files
pdf_extract_kit/utils/config_loader.py CHANGED
@@ -1,47 +1,37 @@
1
  import yaml
2
  import warnings
3
- from pdf_extract_kit.registry.registry import TASK_REGISTRY, MODEL_REGISTRY
 
4
 
5
 
6
  def load_config(config_path):
7
  if config_path is None:
8
- warnings.warn(
9
- ("Configuration path is None. Please provide a valid configuration file path. ")
10
- )
11
  return None
12
-
13
  with open(config_path, 'r') as file:
14
  config = yaml.safe_load(file)
15
  return config
16
 
17
 
18
- # def initialize_task_and_model(config):
19
- # task_name = config['task']
20
- # model_name = config['model']
21
- # model_config = config['model_config']
22
-
23
- # TaskClass = TASK_REGISTRY.get(task_name)
24
- # ModelClass = MODEL_REGISTRY.get(model_name)
25
-
26
- # model_instance = ModelClass(model_config)
27
- # task_instance = TaskClass(model_instance)
28
-
29
- # return task_instance
30
-
31
  def initialize_tasks_and_models(config):
32
-
33
  task_instances = {}
34
- for task_name in config['tasks']:
35
 
36
- model_name = config['tasks'][task_name]['model']
37
- model_config = config['tasks'][task_name]['model_config']
 
 
 
38
 
39
- TaskClass = TASK_REGISTRY.get(task_name)
40
- ModelClass = MODEL_REGISTRY.get(model_name)
 
 
41
 
42
- model_instance = ModelClass(model_config)
43
- task_instance = TaskClass(model_instance)
 
44
 
45
  task_instances[task_name] = task_instance
46
 
47
- return task_instances
 
1
  import yaml
2
  import warnings
3
+
4
+ from pdf_extract_kit.registry.registry import TASK_REGISTRY
5
 
6
 
7
  def load_config(config_path):
8
  if config_path is None:
9
+ warnings.warn("Configuration path is None. Please provide a valid configuration file path.")
 
 
10
  return None
11
+
12
  with open(config_path, 'r') as file:
13
  config = yaml.safe_load(file)
14
  return config
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def initialize_tasks_and_models(config):
 
18
  task_instances = {}
 
19
 
20
+ for task_name, task_config in config['tasks'].items():
21
+ # config 中读取 type 字段作为注册器查找键
22
+ task_type = task_config.get("type")
23
+ if task_type is None:
24
+ raise ValueError(f"Task '{task_name}' missing required 'type' field in config.")
25
 
26
+ # 查找注册的类
27
+ TaskClass = TASK_REGISTRY.get(task_type)
28
+ if TaskClass is None:
29
+ raise ValueError(f"Task type '{task_type}' not found in TASK_REGISTRY.")
30
 
31
+ # 其他字段作为初始化参数
32
+ kwargs = {k: v for k, v in task_config.items() if k != "type"}
33
+ task_instance = TaskClass(**kwargs)
34
 
35
  task_instances[task_name] = task_instance
36
 
37
+ return task_instances