1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
| from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, LoggerHook, ParamSchedulerHook) from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR from peft import LoraConfig from torch.optim import AdamW from transformers import AutoTokenizer
from xtuner.dataset import InternVL_V1_5_Dataset from xtuner.dataset.collate_fns import default_collate_fn from xtuner.dataset.samplers import LengthGroupedSampler from xtuner.engine.hooks import DatasetInfoHook from xtuner.engine.runner import TrainLoop from xtuner.model import InternVL_V1_5 from xtuner.utils import PROMPT_TEMPLATE
path = '/root/project/joke/model/InternVL2-2B'
data_root = '/root/project/joke/datasets/CLoT_cn_2000/' data_path = data_root + 'ex_cn.json' image_folder = data_root prompt_template = PROMPT_TEMPLATE.internlm2_chat max_length = 6656
batch_size = 4 # per_device accumulative_counts = 4 dataloader_num_workers = 4 max_epochs = 6 optim_type = AdamW
lr = 2e-5 betas = (0.9, 0.999) weight_decay = 0.05 max_norm = 1 # grad clip warmup_ratio = 0.03
save_steps = 1000 save_total_limit = 1 # Maximum checkpoints to keep (-1 means unlimited)
model = dict( type=InternVL_V1_5, model_path=path, freeze_llm=True, freeze_visual_encoder=True, quantization_llm=True, # or False quantization_vit=False, # or True and uncomment visual_encoder_lora # comment the following lines if you don't want to use Lora in llm llm_lora=dict( type=LoraConfig, r=128, lora_alpha=256, lora_dropout=0.05, target_modules=None, task_type='CAUSAL_LM'), # uncomment the following lines if you don't want to use Lora in visual encoder # noqa # visual_encoder_lora=dict( # type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.05, # target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2']) )
llava_dataset = dict( type=InternVL_V1_5_Dataset, model_path=path, data_paths=data_path, image_folders=image_folder, template=prompt_template, max_length=max_length)
train_dataloader = dict( batch_size=batch_size, num_workers=dataloader_num_workers, dataset=llava_dataset, sampler=dict( type=LengthGroupedSampler, length_property='modality_length', per_device_batch_size=batch_size * accumulative_counts), collate_fn=dict(type=default_collate_fn))
optim_wrapper = dict( type=AmpOptimWrapper, optimizer=dict( type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), accumulative_counts=accumulative_counts, loss_scale='dynamic', dtype='float16')
param_scheduler = [ dict( type=LinearLR, start_factor=1e-5, by_epoch=True, begin=0, end=warmup_ratio * max_epochs, convert_to_iter_based=True), dict( type=CosineAnnealingLR, eta_min=0.0, by_epoch=True, begin=warmup_ratio * max_epochs, end=max_epochs, convert_to_iter_based=True) ]
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
tokenizer = dict( type=AutoTokenizer.from_pretrained, pretrained_model_name_or_path=path, trust_remote_code=True)
custom_hooks = [ dict(type=DatasetInfoHook, tokenizer=tokenizer), ]
default_hooks = dict( # record the time of every iteration. timer=dict(type=IterTimerHook), # print log every 10 iterations. logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), # enable the parameter scheduler. param_scheduler=dict(type=ParamSchedulerHook), # save checkpoint per `save_steps`. checkpoint=dict( type=CheckpointHook, save_optimizer=False, by_epoch=False, interval=save_steps, max_keep_ckpts=save_total_limit), # set sampler seed in distributed evrionment. sampler_seed=dict(type=DistSamplerSeedHook), )
env_cfg = dict( # whether to enable cudnn benchmark cudnn_benchmark=False, # set multi process parameters mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), # set distributed parameters dist_cfg=dict(backend='nccl'), )
visualizer = None
log_level = 'INFO'
load_from = None
resume = False
randomness = dict(seed=None, deterministic=False)
log_processor = dict(by_epoch=False)
|