Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to add custom loss using uni-core package #183

Open
DrakeAmmer opened this issue Nov 18, 2023 · 1 comment
Open

How to add custom loss using uni-core package #183

DrakeAmmer opened this issue Nov 18, 2023 · 1 comment

Comments

@DrakeAmmer
Copy link

DrakeAmmer commented Nov 18, 2023

Thank you for providing this great pre-train model. We want to extend it to a downstream task based on your pre-trained model.
Is there an easy way to add a custom head with a different loss using the uni-core package? It would be very helpful if you can provide an example script of how to achieve this.

So it seems I have to define a UnicoreLoss with a forward method take into model, sample arguments and return loss, sample_size, logging_output.
And need to define a UnicoreTask class to define the training process. However, I am not able to find where I can define the input data and the custom head, can you point out where I can define the input data format, i.e. the sample argument in the loss class, and also implement a customized head?

It would be very helpful if there is a minimum usage case to use the Uni-core package. Thanks!

@Naplessss
Copy link
Contributor

it seems I have to define a UnicoreLoss with a forward method take into model, sample arguments and return loss, sample_size, logging_output.
Yes, you can add a custom head for downstreams by follow this.
input data and the custom head
you should define input in your custom task, custom head in your custom model, a typical reference examples is unimol based conformation generation:
task:https://github.com/dptech-corp/Uni-Mol/blob/main/unimol/unimol/tasks/unimol_conf_gen.py
model: https://github.com/dptech-corp/Uni-Mol/blob/main/unimol/unimol/models/conf_gen.py
loss:https://github.com/dptech-corp/Uni-Mol/blob/main/unimol/unimol/losses/conf_gen.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants