-
Notifications
You must be signed in to change notification settings - Fork 31
/
convert_to_hf.py
24 lines (20 loc) · 985 Bytes
/
convert_to_hf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from contrastors.models.huggingface import NomicBertForPreTraining, NomicBertConfig
from contrastors.models.biencoder import BiEncoder, BiEncoderConfig
from argparse import ArgumentParser
def parse_args():
parser = ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--private", action="store_true")
parser.add_argument("--biencoder", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
if args.biencoder:
config = BiEncoderConfig.from_pretrained(args.ckpt_path)
model = BiEncoder.from_pretrained(args.ckpt_path, config=config)
model = model.trunk
else:
config = NomicBertConfig.from_pretrained(args.ckpt_path)
model = NomicBertForPreTraining.from_pretrained(args.ckpt_path, config=config)
model.push_to_hub(args.model_name, private=args.private)