Spaces:
Paused
Paused
| import ast | |
| import json | |
| import os | |
| import pdb | |
| import random | |
| from dataclasses import asdict | |
| from functools import partial | |
| import torch | |
| from datasets import load_dataset, concatenate_datasets | |
| from peft import LoraConfig, get_peft_model | |
| from transformers import AutoProcessor, HfArgumentParser | |
| from trl import get_kbit_device_map, get_quantization_config | |
| from videoalign.trainer import Qwen2VLRewardModelBT, VideoVLMRewardTrainer, compute_multi_attr_accuracy, PartialEmbeddingUpdateCallback | |
| from videoalign.data import DataConfig, QWen2VLDataCollator, convert_GSB_csv_to_reward_data | |
| from videoalign.utils import ModelConfig, PEFTLoraConfig, TrainingConfig | |
| from videoalign.utils import load_model_from_checkpoint | |
| def save_configs_to_json(data_config, training_args, model_config, peft_lora_config): | |
| """ | |
| Save all configurations to a JSON file. | |
| """ | |
| config_dict = { | |
| "data_config": asdict(data_config), | |
| "training_args": asdict(training_args), | |
| "model_config": asdict(model_config), | |
| "peft_lora_config": asdict(peft_lora_config), | |
| } | |
| # del information about local device | |
| del config_dict["training_args"]["local_rank"] | |
| del config_dict["training_args"]["_n_gpu"] | |
| save_path = os.path.join(training_args.output_dir, "model_config.json") | |
| os.makedirs(training_args.output_dir, exist_ok=True) | |
| print(training_args.output_dir) | |
| with open(save_path, "w") as f: | |
| json.dump(config_dict, f, indent=4) | |
| def find_target_linear_names(model, num_lora_modules=-1, lora_namespan_exclude=[], verbose=False): | |
| """ | |
| Find the target linear modules for LoRA. | |
| """ | |
| linear_cls = torch.nn.Linear | |
| embedding_cls = torch.nn.Embedding | |
| lora_module_names = [] | |
| for name, module in model.named_modules(): | |
| if any(ex_keyword in name for ex_keyword in lora_namespan_exclude): | |
| # print(f"Excluding module: {name}") | |
| continue | |
| if isinstance(module, (linear_cls, embedding_cls)): | |
| lora_module_names.append(name) | |
| if num_lora_modules > 0: | |
| lora_module_names = lora_module_names[-num_lora_modules:] | |
| if verbose: | |
| print(f"Found {len(lora_module_names)} lora modules: {lora_module_names}") | |
| return lora_module_names | |
| def set_requires_grad(parameters, requires_grad): | |
| for p in parameters: | |
| p.requires_grad = requires_grad | |
| def create_model_and_processor( | |
| model_config, peft_lora_config, training_args, | |
| cache_dir=None, | |
| ): | |
| # create model | |
| torch_dtype = ( | |
| model_config.torch_dtype | |
| if model_config.torch_dtype in ["auto", None] | |
| else getattr(torch, model_config.torch_dtype) | |
| ) | |
| quantization_config = get_quantization_config(model_config) | |
| model_kwargs = dict( | |
| revision=model_config.model_revision, | |
| device_map=get_kbit_device_map() if quantization_config is not None else None, | |
| quantization_config=quantization_config, | |
| use_cache=True if training_args.gradient_checkpointing else False, | |
| ) | |
| # pdb.set_trace() | |
| # create processor and set padding | |
| processor = AutoProcessor.from_pretrained(model_config.model_name_or_path, | |
| padding_side="right", | |
| cache_dir=cache_dir) | |
| special_token_ids = None | |
| if model_config.use_special_tokens: | |
| special_tokens = ["<|VQ_reward|>", "<|MQ_reward|>", "<|TA_reward|>"] | |
| processor.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) | |
| special_token_ids = processor.tokenizer.convert_tokens_to_ids(special_tokens) | |
| model = Qwen2VLRewardModelBT.from_pretrained( | |
| model_config.model_name_or_path, | |
| output_dim=model_config.output_dim, | |
| reward_token=model_config.reward_token, | |
| special_token_ids=special_token_ids, | |
| torch_dtype=torch_dtype, | |
| attn_implementation="flash_attention_2" if not training_args.disable_flash_attn2 else "sdpa", | |
| cache_dir=cache_dir, | |
| **model_kwargs | |
| ) | |
| if model_config.use_special_tokens: | |
| model.resize_token_embeddings(len(processor.tokenizer)) | |
| if training_args.bf16: | |
| model.to(torch.bfloat16) | |
| if training_args.fp16: | |
| model.to(torch.float16) | |
| # create lora and peft model | |
| if peft_lora_config.lora_enable: | |
| target_modules = find_target_linear_names(model, | |
| num_lora_modules=peft_lora_config.num_lora_modules, | |
| lora_namespan_exclude=peft_lora_config.lora_namespan_exclude) | |
| peft_config = LoraConfig( | |
| target_modules=target_modules, | |
| r=peft_lora_config.lora_r, | |
| lora_alpha=peft_lora_config.lora_alpha, | |
| lora_dropout=peft_lora_config.lora_dropout, | |
| task_type=peft_lora_config.lora_task_type, | |
| use_rslora=peft_lora_config.use_rslora, | |
| bias="none", | |
| modules_to_save=peft_lora_config.lora_modules_to_save, | |
| ) | |
| model = get_peft_model(model, peft_config) | |
| else: | |
| peft_config = None | |
| model.config.tokenizer_padding_side = processor.tokenizer.padding_side | |
| model.config.pad_token_id = processor.tokenizer.pad_token_id | |
| return model, processor, peft_config | |
| def create_dataset(data_config, meta_file=None): | |
| if meta_file is None: | |
| meta_file = data_config.meta_data | |
| dataset = load_dataset('csv', data_files=meta_file) | |
| def add_idx(example, idx): | |
| example['metainfo_idx'] = idx | |
| return example | |
| dataset['train'] = dataset['train'].map(lambda example, idx: add_idx(example, idx), with_indices=True) | |
| if not data_config.use_tied_data: | |
| filter_func = lambda example: any(example[f"{dim}"] != "same" for dim in data_config.eval_dim) | |
| dataset = dataset.filter(filter_func) | |
| # convert data to reward data | |
| convert_func = lambda example: convert_GSB_csv_to_reward_data(example, data_config.data_dir, data_config.eval_dim, | |
| data_config.max_frame_pixels, data_config.fps, data_config.num_frames, | |
| data_config.prompt_template_type, | |
| sample_type=data_config.sample_type,) | |
| dataset = dataset.map(convert_func, remove_columns=dataset['train'].column_names, load_from_cache_file=False) | |
| dataset = dataset['train'] | |
| # pdb.set_trace() | |
| return dataset | |
| def train(): | |
| ## ===> Step 1: Parse arguments | |
| parser = HfArgumentParser((DataConfig, TrainingConfig, ModelConfig, PEFTLoraConfig)) | |
| data_config, training_args, model_config, peft_lora_config = parser.parse_args_into_dataclasses() | |
| # pdb.set_trace() | |
| # check valid (lora config) | |
| assert not (peft_lora_config.lora_enable and model_config.freeze_llm), 'When using LoRA, the LLM should not be frozen. If you want to freeze the LLM, please disable LoRA.' | |
| if not peft_lora_config.lora_enable: | |
| assert not peft_lora_config.vision_lora, \ | |
| "Error: model_config.lora_enable is not enabled, but model_config.vision_lora is enabled." | |
| else: | |
| if peft_lora_config.lora_namespan_exclude is not None: | |
| peft_lora_config.lora_namespan_exclude = ast.literal_eval(peft_lora_config.lora_namespan_exclude) | |
| else: | |
| peft_lora_config.lora_namespan_exclude = [] | |
| if not peft_lora_config.vision_lora: | |
| peft_lora_config.lora_namespan_exclude += ["visual"] | |
| # pdb.set_trace() | |
| ## ===> Step 2: Load model and configure | |
| model, processor, peft_config = create_model_and_processor( | |
| model_config=model_config, | |
| peft_lora_config=peft_lora_config, | |
| training_args=training_args, | |
| ) | |
| ## load model | |
| if training_args.load_from_pretrained is not None: | |
| model, checkpoint_step = load_model_from_checkpoint(model, training_args.load_from_pretrained, training_args.load_from_pretrained_step) | |
| model.train() | |
| if peft_lora_config.lora_enable: | |
| model_to_configure = model.model | |
| else: | |
| model_to_configure = model | |
| # set requires_grad for LLM | |
| set_requires_grad(model_to_configure.model.parameters(), not model_config.freeze_llm) | |
| if not peft_lora_config.vision_lora: | |
| # set requires_grad for visual encoder and merger | |
| set_requires_grad(model_to_configure.visual.parameters(), not model_config.freeze_vision_tower) | |
| set_requires_grad(model_to_configure.visual.merger.parameters(), model_config.tune_merger) | |
| # set requires_grad for regression head | |
| set_requires_grad(model_to_configure.rm_head.parameters(), True) | |
| ## ===> Step 3: Load Dataset and configure | |
| if isinstance(data_config.eval_dim, str): | |
| data_config.eval_dim = [data_config.eval_dim] | |
| # datasets = create_dataset(data_config) | |
| # train_dataset = concatenate_datasets([datasets[dim] for dim in data_config.eval_dim]) | |
| train_dataset = create_dataset(data_config) | |
| train_dataset = train_dataset.shuffle(seed=42) | |
| if training_args.conduct_eval: | |
| if data_config.meta_data_test is not None: | |
| random.seed(42) | |
| valid_dataset = create_dataset(data_config, meta_file=data_config.meta_data_test) | |
| # indices = random.sample(range(len(valid_dataset)), 1000) | |
| # valid_dataset = valid_dataset.select(indices) | |
| else: | |
| dataset = train_dataset.train_test_split(test_size=0.02) | |
| train_dataset = dataset['train'] | |
| valid_dataset = dataset['test'] | |
| else: | |
| valid_dataset = None | |
| print(f"===> Selected {len(train_dataset)} samples for training.") | |
| print(f"===> Selected {len(valid_dataset)} samples for testing.") | |
| num_gpu = int(os.environ.get("WORLD_SIZE", 1)) | |
| data_collator = QWen2VLDataCollator(processor, add_noise=data_config.add_noise, | |
| p_shuffle_frames=data_config.p_shuffle_frames, | |
| p_color_jitter=data_config.p_color_jitter,) | |
| compute_metrics = partial(compute_multi_attr_accuracy, eval_dims=data_config.eval_dim) | |
| actual_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * num_gpu | |
| total_steps = training_args.num_train_epochs * len(train_dataset) // actual_batch_size | |
| if training_args.save_epochs is not None: | |
| training_args.save_steps = round(training_args.save_epochs * len(train_dataset) / actual_batch_size) | |
| if training_args.eval_epochs is not None: | |
| training_args.eval_steps = round(training_args.eval_epochs * len(train_dataset) / actual_batch_size) | |
| if training_args.logging_epochs is not None: | |
| training_args.logging_steps = round(training_args.logging_epochs * len(train_dataset) / actual_batch_size) | |
| if training_args.local_rank == -1 or training_args.local_rank == 0: | |
| print(f"===> Using {num_gpu} GPUs.") | |
| print(f"===> Total Batch Size: {actual_batch_size}") | |
| print(f"===> Training Epochs: {training_args.num_train_epochs}") | |
| print(f"===> Total Steps: {total_steps}") | |
| print(f"===> Save Steps: {training_args.save_steps}") | |
| print(f"===> Eval Steps: {training_args.eval_steps}") | |
| print(f"===> Logging Steps: {training_args.logging_steps}") | |
| # pdb.set_trace() | |
| ## ===> Step 4: Save configs for re-check | |
| if training_args.local_rank == -1 or training_args.local_rank == 0: | |
| save_configs_to_json(data_config, training_args, model_config, peft_lora_config) | |
| print(train_dataset) | |
| ## ===> Step 5: Start Training! | |
| special_token_ids = model.special_token_ids | |
| callbacks = [] | |
| if special_token_ids is not None: | |
| callbacks.append(PartialEmbeddingUpdateCallback(special_token_ids)) | |
| trainer = VideoVLMRewardTrainer( | |
| model=model, | |
| compute_metrics=compute_metrics, | |
| data_collator=data_collator, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=valid_dataset if training_args.conduct_eval else None, | |
| peft_config=peft_config, | |
| callbacks=callbacks, | |
| loss_type=model_config.loss_type, | |
| tokenizer=processor.tokenizer, | |
| ) | |
| trainer.train() | |
| if training_args.local_rank == -1 or training_args.local_rank == 0: | |
| model_state_dict = model.state_dict() | |
| torch.save(model_state_dict, os.path.join(training_args.output_dir, 'final_model.pth')) | |
| model.config.save_pretrained(training_args.output_dir) | |
| if __name__ == "__main__": | |
| train() |