Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MMBench supports Siglip #524

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 17 additions & 6 deletions xtuner/tools/mmbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from torch.utils.data import Dataset
from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, CLIPImageProcessor,
CLIPVisionModel, GenerationConfig)
CLIPVisionModel, GenerationConfig,
SiglipImageProcessor, SiglipVisionModel)

from xtuner.dataset.utils import decode_base64_to_image, expand2square
from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal
Expand Down Expand Up @@ -330,10 +331,21 @@ def main():
'Please specify the `--visual-encoder`!')
visual_encoder_path = args.visual_encoder
with LoadWoInit():
visual_encoder = CLIPVisionModel.from_pretrained(
visual_encoder_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
image_processor = CLIPImageProcessor.from_pretrained(
visual_encoder_path)
if 'clip' in visual_encoder_path.lower():
visual_encoder = CLIPVisionModel.from_pretrained(
visual_encoder_path,
torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
image_processor = CLIPImageProcessor.from_pretrained(
visual_encoder_path)
elif 'siglip' in visual_encoder_path.lower():
visual_encoder = SiglipVisionModel.from_pretrained(
visual_encoder_path,
torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
image_processor = SiglipImageProcessor.from_pretrained(
visual_encoder_path)
else:
raise f'Visual encoders not supported : {visual_encoder_path}'

master_print(f'Load visual_encoder from {visual_encoder_path}')

# load adapter
Expand Down Expand Up @@ -506,5 +518,4 @@ def main():


if __name__ == '__main__':

main()