You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
OS: Windows 11 22631.3296
Python: 3.11.8
PyTorch: 2.2.1 (installed in conda env)
CUDA: 12.1 (installed in conda env)
NV Driver: 551.76
Gemma Model: 7b-it
I was trying to run the inference. Before I started, I have used 6GB memory and had 26GB free.
I obseved that when the code runs to the load_weights function, the memory usage went up to 98% of my total 32GB RAM, lasted for about a minute and then dropped to normal. In that time, I haven't called the to(device) function in the next line.
Form the Task Manager, at the time of high usage, I see the python.exe took about 28GB Working set, while the active private working set was about 14GB. And at that time, the page file of Windows was involved to keep the system working.
However, the 7B-it model (16bit float) should not exceed 16GB size. Allocating 28GB of memory in this process is pointless.
Remember what I said above, the memory usage eventually dropped to normal without calling to(device)? This just showed that it doesn't require that much memory.
Sorry, I don't know how Python or PyTorch manage memory. But I'm wondering if it's possible to improve this line for smoothing memory usage spikes?
The text was updated successfully, but these errors were encountered:
Maybe one workaround could be loading weights layer by layer in sequence and gc weights immediately after the weight of a particular layer gets loaded. I think in this way, it will have less peak memory usage.
@michaelmoynihan you have investigated this before, do you have any insights?
OS: Windows 11 22631.3296
Python: 3.11.8
PyTorch: 2.2.1 (installed in conda env)
CUDA: 12.1 (installed in conda env)
NV Driver: 551.76
Gemma Model: 7b-it
I was trying to run the inference. Before I started, I have used 6GB memory and had 26GB free.
I obseved that when the code runs to the
load_weights
function, the memory usage went up to 98% of my total 32GB RAM, lasted for about a minute and then dropped to normal. In that time, I haven't called theto(device)
function in the next line.Form the Task Manager, at the time of high usage, I see the python.exe took about 28GB Working set, while the active private working set was about 14GB. And at that time, the page file of Windows was involved to keep the system working.
However, the 7B-it model (16bit float) should not exceed 16GB size. Allocating 28GB of memory in this process is pointless.
Remember what I said above, the memory usage eventually dropped to normal without calling
to(device)
? This just showed that it doesn't require that much memory.Sorry, I don't know how Python or PyTorch manage memory. But I'm wondering if it's possible to improve this line for smoothing memory usage spikes?
The text was updated successfully, but these errors were encountered: