Reward-Forcing / videoalign /train_reward.py
fffiloni's picture
Migrated from GitHub
34c9227 verified
import ast
import json
import os
import pdb
import random
from dataclasses import asdict
from functools import partial
import torch
from datasets import load_dataset, concatenate_datasets
from peft import LoraConfig, get_peft_model
from transformers import AutoProcessor, HfArgumentParser
from trl import get_kbit_device_map, get_quantization_config
from videoalign.trainer import Qwen2VLRewardModelBT, VideoVLMRewardTrainer, compute_multi_attr_accuracy, PartialEmbeddingUpdateCallback
from videoalign.data import DataConfig, QWen2VLDataCollator, convert_GSB_csv_to_reward_data
from videoalign.utils import ModelConfig, PEFTLoraConfig, TrainingConfig
from videoalign.utils import load_model_from_checkpoint
def save_configs_to_json(data_config, training_args, model_config, peft_lora_config):
"""
Save all configurations to a JSON file.
"""
config_dict = {
"data_config": asdict(data_config),
"training_args": asdict(training_args),
"model_config": asdict(model_config),
"peft_lora_config": asdict(peft_lora_config),
}
# del information about local device
del config_dict["training_args"]["local_rank"]
del config_dict["training_args"]["_n_gpu"]
save_path = os.path.join(training_args.output_dir, "model_config.json")
os.makedirs(training_args.output_dir, exist_ok=True)
print(training_args.output_dir)
with open(save_path, "w") as f:
json.dump(config_dict, f, indent=4)
def find_target_linear_names(model, num_lora_modules=-1, lora_namespan_exclude=[], verbose=False):
"""
Find the target linear modules for LoRA.
"""
linear_cls = torch.nn.Linear
embedding_cls = torch.nn.Embedding
lora_module_names = []
for name, module in model.named_modules():
if any(ex_keyword in name for ex_keyword in lora_namespan_exclude):
# print(f"Excluding module: {name}")
continue
if isinstance(module, (linear_cls, embedding_cls)):
lora_module_names.append(name)
if num_lora_modules > 0:
lora_module_names = lora_module_names[-num_lora_modules:]
if verbose:
print(f"Found {len(lora_module_names)} lora modules: {lora_module_names}")
return lora_module_names
def set_requires_grad(parameters, requires_grad):
for p in parameters:
p.requires_grad = requires_grad
def create_model_and_processor(
model_config, peft_lora_config, training_args,
cache_dir=None,
):
# create model
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
use_cache=True if training_args.gradient_checkpointing else False,
)
# pdb.set_trace()
# create processor and set padding
processor = AutoProcessor.from_pretrained(model_config.model_name_or_path,
padding_side="right",
cache_dir=cache_dir)
special_token_ids = None
if model_config.use_special_tokens:
special_tokens = ["<|VQ_reward|>", "<|MQ_reward|>", "<|TA_reward|>"]
processor.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
special_token_ids = processor.tokenizer.convert_tokens_to_ids(special_tokens)
model = Qwen2VLRewardModelBT.from_pretrained(
model_config.model_name_or_path,
output_dim=model_config.output_dim,
reward_token=model_config.reward_token,
special_token_ids=special_token_ids,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2" if not training_args.disable_flash_attn2 else "sdpa",
cache_dir=cache_dir,
**model_kwargs
)
if model_config.use_special_tokens:
model.resize_token_embeddings(len(processor.tokenizer))
if training_args.bf16:
model.to(torch.bfloat16)
if training_args.fp16:
model.to(torch.float16)
# create lora and peft model
if peft_lora_config.lora_enable:
target_modules = find_target_linear_names(model,
num_lora_modules=peft_lora_config.num_lora_modules,
lora_namespan_exclude=peft_lora_config.lora_namespan_exclude)
peft_config = LoraConfig(
target_modules=target_modules,
r=peft_lora_config.lora_r,
lora_alpha=peft_lora_config.lora_alpha,
lora_dropout=peft_lora_config.lora_dropout,
task_type=peft_lora_config.lora_task_type,
use_rslora=peft_lora_config.use_rslora,
bias="none",
modules_to_save=peft_lora_config.lora_modules_to_save,
)
model = get_peft_model(model, peft_config)
else:
peft_config = None
model.config.tokenizer_padding_side = processor.tokenizer.padding_side
model.config.pad_token_id = processor.tokenizer.pad_token_id
return model, processor, peft_config
def create_dataset(data_config, meta_file=None):
if meta_file is None:
meta_file = data_config.meta_data
dataset = load_dataset('csv', data_files=meta_file)
def add_idx(example, idx):
example['metainfo_idx'] = idx
return example
dataset['train'] = dataset['train'].map(lambda example, idx: add_idx(example, idx), with_indices=True)
if not data_config.use_tied_data:
filter_func = lambda example: any(example[f"{dim}"] != "same" for dim in data_config.eval_dim)
dataset = dataset.filter(filter_func)
# convert data to reward data
convert_func = lambda example: convert_GSB_csv_to_reward_data(example, data_config.data_dir, data_config.eval_dim,
data_config.max_frame_pixels, data_config.fps, data_config.num_frames,
data_config.prompt_template_type,
sample_type=data_config.sample_type,)
dataset = dataset.map(convert_func, remove_columns=dataset['train'].column_names, load_from_cache_file=False)
dataset = dataset['train']
# pdb.set_trace()
return dataset
def train():
## ===> Step 1: Parse arguments
parser = HfArgumentParser((DataConfig, TrainingConfig, ModelConfig, PEFTLoraConfig))
data_config, training_args, model_config, peft_lora_config = parser.parse_args_into_dataclasses()
# pdb.set_trace()
# check valid (lora config)
assert not (peft_lora_config.lora_enable and model_config.freeze_llm), 'When using LoRA, the LLM should not be frozen. If you want to freeze the LLM, please disable LoRA.'
if not peft_lora_config.lora_enable:
assert not peft_lora_config.vision_lora, \
"Error: model_config.lora_enable is not enabled, but model_config.vision_lora is enabled."
else:
if peft_lora_config.lora_namespan_exclude is not None:
peft_lora_config.lora_namespan_exclude = ast.literal_eval(peft_lora_config.lora_namespan_exclude)
else:
peft_lora_config.lora_namespan_exclude = []
if not peft_lora_config.vision_lora:
peft_lora_config.lora_namespan_exclude += ["visual"]
# pdb.set_trace()
## ===> Step 2: Load model and configure
model, processor, peft_config = create_model_and_processor(
model_config=model_config,
peft_lora_config=peft_lora_config,
training_args=training_args,
)
## load model
if training_args.load_from_pretrained is not None:
model, checkpoint_step = load_model_from_checkpoint(model, training_args.load_from_pretrained, training_args.load_from_pretrained_step)
model.train()
if peft_lora_config.lora_enable:
model_to_configure = model.model
else:
model_to_configure = model
# set requires_grad for LLM
set_requires_grad(model_to_configure.model.parameters(), not model_config.freeze_llm)
if not peft_lora_config.vision_lora:
# set requires_grad for visual encoder and merger
set_requires_grad(model_to_configure.visual.parameters(), not model_config.freeze_vision_tower)
set_requires_grad(model_to_configure.visual.merger.parameters(), model_config.tune_merger)
# set requires_grad for regression head
set_requires_grad(model_to_configure.rm_head.parameters(), True)
## ===> Step 3: Load Dataset and configure
if isinstance(data_config.eval_dim, str):
data_config.eval_dim = [data_config.eval_dim]
# datasets = create_dataset(data_config)
# train_dataset = concatenate_datasets([datasets[dim] for dim in data_config.eval_dim])
train_dataset = create_dataset(data_config)
train_dataset = train_dataset.shuffle(seed=42)
if training_args.conduct_eval:
if data_config.meta_data_test is not None:
random.seed(42)
valid_dataset = create_dataset(data_config, meta_file=data_config.meta_data_test)
# indices = random.sample(range(len(valid_dataset)), 1000)
# valid_dataset = valid_dataset.select(indices)
else:
dataset = train_dataset.train_test_split(test_size=0.02)
train_dataset = dataset['train']
valid_dataset = dataset['test']
else:
valid_dataset = None
print(f"===> Selected {len(train_dataset)} samples for training.")
print(f"===> Selected {len(valid_dataset)} samples for testing.")
num_gpu = int(os.environ.get("WORLD_SIZE", 1))
data_collator = QWen2VLDataCollator(processor, add_noise=data_config.add_noise,
p_shuffle_frames=data_config.p_shuffle_frames,
p_color_jitter=data_config.p_color_jitter,)
compute_metrics = partial(compute_multi_attr_accuracy, eval_dims=data_config.eval_dim)
actual_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * num_gpu
total_steps = training_args.num_train_epochs * len(train_dataset) // actual_batch_size
if training_args.save_epochs is not None:
training_args.save_steps = round(training_args.save_epochs * len(train_dataset) / actual_batch_size)
if training_args.eval_epochs is not None:
training_args.eval_steps = round(training_args.eval_epochs * len(train_dataset) / actual_batch_size)
if training_args.logging_epochs is not None:
training_args.logging_steps = round(training_args.logging_epochs * len(train_dataset) / actual_batch_size)
if training_args.local_rank == -1 or training_args.local_rank == 0:
print(f"===> Using {num_gpu} GPUs.")
print(f"===> Total Batch Size: {actual_batch_size}")
print(f"===> Training Epochs: {training_args.num_train_epochs}")
print(f"===> Total Steps: {total_steps}")
print(f"===> Save Steps: {training_args.save_steps}")
print(f"===> Eval Steps: {training_args.eval_steps}")
print(f"===> Logging Steps: {training_args.logging_steps}")
# pdb.set_trace()
## ===> Step 4: Save configs for re-check
if training_args.local_rank == -1 or training_args.local_rank == 0:
save_configs_to_json(data_config, training_args, model_config, peft_lora_config)
print(train_dataset)
## ===> Step 5: Start Training!
special_token_ids = model.special_token_ids
callbacks = []
if special_token_ids is not None:
callbacks.append(PartialEmbeddingUpdateCallback(special_token_ids))
trainer = VideoVLMRewardTrainer(
model=model,
compute_metrics=compute_metrics,
data_collator=data_collator,
args=training_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset if training_args.conduct_eval else None,
peft_config=peft_config,
callbacks=callbacks,
loss_type=model_config.loss_type,
tokenizer=processor.tokenizer,
)
trainer.train()
if training_args.local_rank == -1 or training_args.local_rank == 0:
model_state_dict = model.state_dict()
torch.save(model_state_dict, os.path.join(training_args.output_dir, 'final_model.pth'))
model.config.save_pretrained(training_args.output_dir)
if __name__ == "__main__":
train()