-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add simplified model manager install API to InvocationContext #6132
base: main
Are you sure you want to change the base?
Conversation
9cc1f20
to
af1b57a
Compare
I have added a migration script that tidies up the |
537a626
to
3ddd7ce
Compare
3ddd7ce
to
fa6efac
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what I was expecting the implementation to be, but it definitely wasn't as simple as this - great work.
I've requested a few changes and there's one discussion item that I'd like to marinate on before we change the public invocation API.
invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py
Outdated
Show resolved
Hide resolved
- Set `self._context=context` instead of passing it as an arg
Just a bit of typo protection in lieu of full type safety for these methods, which is difficult due to the typing of `DownloadEventHandler`.
It's inherited from the ABC.
def diffusers_load_directory(directory: Path) -> AnyModel: | ||
load_class = GenericDiffusersLoader( | ||
app_config=self._app_config, | ||
logger=self._logger, | ||
ram_cache=self._ram_cache, | ||
convert_cache=self.convert_cache, | ||
).get_hf_load_class(directory) | ||
result: AnyModel = load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype()) | ||
return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is unused - I think the logic to get the loader should be checking if it's a directory? I'm not sure how to fix this myself bc the diffusers_load_directory
function has a different type signature than the other loader function options.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Looks like I never wired up that function! In the process I discovered a long-standing bug that would prevent text encoders from being loaded generically.
I also added a unit test for loading generic diffusers models. To do this, I added the 10 MB "tiny" taesdxl model, which was the smallest loadable diffusers I could find.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should be adding models to the repo. Feels arbitrary to do it just for this one scenario and not all of them, plus 10MB increases the repo size by like 8%. Don't wanna get in the habit of doing this.
I think the best way to test model loading is to package up a models dir that has samples of various model types and configurations, then run full loading tests locally on one of our own machines or a custom CI runner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've commented out the unit test for directory loading.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please drop the commits that added the model file so it's not in the repo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay in reviewing. I've tidied a few things and tested everything, working great!
Two minor issues noted.
@psychedelicious I've addressed the remaining issues you raised. Thanks for a thorough review. |
…okeAI into lstein/feat/simple-mm2-api
Yes, mypy is having trouble tracking the return type of several methods. I haven't figured out what causes the problem and don't want to add a # type: ignore. But maybe I should 'cause I'm not ready to turn to pyright. |
We shouldn't add |
@RyanJDick Would you mind doing one last review of this PR? |
You've convinced me. I've switched to pyright! |
Looks like 43/44 files have changed since I last looked 😅 . I'll plan to spend a chunk of time on this tomorrow. |
@RyanJDick Can narrow that down to reviewing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just reviewed the invocation_context.py
API.
""" | ||
Build the migration from database version 9 to 10. | ||
|
||
This migration does the following: | ||
- Moves "core" models previously downloaded with download_with_progress_bar() into new | ||
"models/.download_cache" directory. | ||
- Renames "models/.cache" to "models/.convert_cache". | ||
- Adds `error_type` and `error_message` columns to the session queue table. | ||
- Renames the `error` column to `error_traceback`. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This docstring is outdated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
if isinstance(source, Path): | ||
return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader) | ||
else: | ||
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source)) | ||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't love that we switch behaviour based on whether source
is a Path
or a str
. It feels like a fragile distinction, especially given the popularity of using str
s to represent paths in Python. The caller should always know whether they are dealing with a path or a URL/repo name, so I think it's better to make this distinction explicit.
In this discussion we had landed on an API that didn't require this type condition. Was there a reason for moving away from that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like @psychedelicious removed load_model_from_url()
in commit b12444 . Add it back?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved the method from the high-level services.model_manager
class to the invocation context. The function is otherwise the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@psychedelicious I can’t find load_model_from_url()
in invocation_context.py
, and the commit indicates that it was deleted from the model manager and then the load_and_cache_model()
was modified to no longer call it.
I’ll take care of fixing the API if the change wasn’t deliberate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
load_model_from_url()
was this:
class ModelManagerService(ModelManagerServiceBase):
# ...
def load_model_from_url(
self,
source: str | AnyHttpUrl,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel:
model_path = self.install.download_and_cache_model(source=str(source))
return self.load.load_model_from_path(model_path=model_path, loader=loader)
It was called in one place in invocation_context.py
:
class ModelsInterface(InvocationContextInterface):
def load_and_cache_model(
self,
source: Path | str | AnyHttpUrl,
loader: Optional[Callable[[Path], dict[str, Tensor]]] = None,
) -> LoadedModel:
if isinstance(source, Path):
return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
else:
# Called here
return self._services.model_manager.load_model_from_url(source=source, loader=loader)
What I did was copy the two lines from the function body directly into ModelsInterface.load_and_cache_model()
- I don't think we should have load_model_from_url
on the main ModelManagerService
.
I think the API Ryan is suggesting is to have load_and_cache_from_path
and load_and_cache_from_url
on ModelsInterface
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For clarity, I have renamed the methods load_local_model()
and load_remote_model()
. The former accepts the Path to a file or directory, and the latter accepts either a direct download URL or a HuggingFace URL. I have fixed the documentation and updated the pull request description.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
@RyanJDick I've fixed the issues you identified. |
…okeAI into lstein/feat/simple-mm2-api
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
invocation_context.py
looks good to me ✅
I'll defer to @psychedelicious for final approval, since he has a more complete understanding of this PR than me at this point.
Summary
This three two model manager-related methods to the InvocationContext uniform API. They are accessible via
context.models.*
:load_local_model(model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None) -> LoadedModelWithoutConfig
Load the model located at the indicated path.
This will load a local model (.safetensors, .ckpt or diffusers directory) into the model manager RAM cache and return its
LoadedModelWithoutConfig
. If the optional loader argument is provided, the loader will be invoked to load the model into memory. Otherwise the method will callsafetensors.torch.load_file()
torch.load()
(with a pickle scan), orfrom_pretrained()
as appropriate to the path type.Be aware that the
LoadedModelWithoutConfig
object differs fromLoadedModel
by having noconfig
attribute.Here is an example of usage:
load_remote_model(source: str | AnyHttpUrl, loader: Optional[Callable[[Path], AnyModel]] = None) -> LoadedModelWithoutConfig
Load the model located at the indicated URL or repo_id.
This is similar to
load_local_model()
but it accepts either a HugginFace repo_id (as a string), or a URL. The model's file(s) will be downloaded tomodels/.download_cache
and then loaded, returning adownload_and_cache_model( source: str | AnyHttpUrl, access_token: Optional[str] = None, timeout: Optional[int] = 0) -> Path
Download the model file located at source to the models cache and return its Path. This will check
models/.download_cache
for the desired model file and download it from the indicated source if not already present. The local Path to the downloaded file is then returned.Other Changes
This PR performs a migration, in which it renames
models/.cache
tomodels/.convert_cache
, and migrates previously-downloaded ESRGAN, openpose, DepthAnything and Lama inpaint models from themodels/core
directory intomodels/.download_cache
.There are a number of legacy model files in
models/core
, such as GFPGAN, which are no longer used. This PR deletes them and tidies up themodels/core
directory.Related Issues / Discussions
I have systematically replaced all the calls to
download_with_progress_bar()
. This function is no longer used elsewhere and has been removed.QA Instructions
I have added unit tests for the three new calls. You may test that the
load_and_cache_model()
call is working by running the upscaler within the web app. On first try, you will see the model file being downloaded into the models.cache
directory. On subsequent tries, the model will either load from RAM (if it hasn't been displaced) or will be loaded from the filesystem.Merge Plan
Squash merge when approved.
Checklist