-
Notifications
You must be signed in to change notification settings - Fork 232
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
606 additions
and
141 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from mmengine import Config | ||
|
||
from xtuner.model.base import BaseTune | ||
from xtuner.registry import BUILDER | ||
|
||
|
||
class AutoModel(): | ||
|
||
@classmethod | ||
def from_config(cls, config: str): | ||
config = Config.fromfile(config) | ||
model: BaseTune = BUILDER.build(config.model) | ||
return model | ||
|
||
@classmethod | ||
def from_pretrained(cls, config: str, checkpoint: str): | ||
config = Config.fromfile(config) | ||
model: BaseTune = BUILDER.build(config.model) | ||
model.load_checkpoint(checkpoint) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from abc import abstractclassmethod, abstractmethod | ||
|
||
from mmengine.model import BaseModel | ||
|
||
from xtuner.types import HybridChatMessages, HybridChatTemplate | ||
|
||
|
||
class BaseTune(BaseModel): | ||
|
||
def __init__(): | ||
super().__init__() | ||
|
||
def init_weights(self): | ||
"""Parent class method. | ||
To avoid overwriting the loaded weights, overload it to an empty | ||
function. | ||
""" | ||
pass | ||
|
||
def avoid_override_weights(self): | ||
self._is_init = True | ||
|
||
@abstractmethod | ||
@property | ||
def chat_template(self) -> HybridChatTemplate: | ||
pass | ||
|
||
@abstractmethod | ||
@property | ||
def llm(self): | ||
pass | ||
|
||
@abstractmethod | ||
@property | ||
def tokenizer(self): | ||
pass | ||
|
||
@abstractmethod | ||
def gradient_checkpointing_enable(self): | ||
pass | ||
|
||
def forward(self, data, data_samples=None, mode='loss'): | ||
"""Overload parent class method, only support training.""" | ||
|
||
if mode == 'loss': | ||
return self.compute_loss(data) | ||
else: | ||
raise NotImplementedError( | ||
f"{type(self)}'s forward is only supported for use during " | ||
'training. If you want to get predictions or chat, please ' | ||
"directly use `llm`'s forward.") | ||
|
||
@abstractmethod | ||
def chat(self, messages: HybridChatMessages, sample_params, streamer): | ||
pass | ||
|
||
@abstractmethod | ||
def save_checkpoint(self, *args, **kwargs): | ||
pass | ||
|
||
@abstractmethod | ||
def load_checkpoint(self, *args, **kwargs) -> 'BaseTune': | ||
pass | ||
|
||
def __getattr__(self, name: str): | ||
try: | ||
return super().__getattr__(name) | ||
except AttributeError: | ||
return getattr(self.llm, name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .base import EncoderWrapper | ||
from .llava import LlavaEncoderWrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from abc import abstractclassmethod, abstractmethod | ||
from typing import List, Union | ||
|
||
import torch | ||
from PIL import Image | ||
from torch import nn | ||
|
||
_ImageType = Union[str, Image.Image] | ||
|
||
|
||
class EncoderWrapper(nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
@abstractmethod | ||
@property | ||
def encoder(self): | ||
pass | ||
|
||
@abstractmethod | ||
@property | ||
def projector(self): | ||
pass | ||
|
||
@abstractmethod | ||
def post_init_proj(self, llm): | ||
pass | ||
|
||
@abstractmethod | ||
def preprocess(self, image: _ImageType) -> torch.Tensor: | ||
pass | ||
|
||
@abstractmethod | ||
def batch_infer(images: List[_ImageType]) -> List[torch.Tensor]: | ||
pass | ||
|
||
@abstractmethod | ||
def gradient_checkpointing_enable(self): | ||
pass | ||
|
||
@abstractclassmethod | ||
def save_checkpoint(self, *args, **kwargs): | ||
pass | ||
|
||
@abstractclassmethod | ||
def load_checkpoint(self, *args, **kwargs) -> 'EncoderWrapper': | ||
pass | ||
|
||
@abstractclassmethod | ||
def only_build_processor(self, *args, **kwargs): | ||
pass |
Oops, something went wrong.