Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SparK ResNet and global feature interaction #80

Open
csvance opened this issue Jan 30, 2024 · 8 comments
Open

SparK ResNet and global feature interaction #80

csvance opened this issue Jan 30, 2024 · 8 comments

Comments

@csvance
Copy link

csvance commented Jan 30, 2024

Hello, thanks for the great paper.

With the ResNet version of SparK using sparse convolution and sparse batch normalization together, the flow and mixing of global semantic information is heavily restricted due to effective masking on the receptive field caused by sparse operations and lack of global channel interaction with batch norm. It seems like this information will struggle to propagate especially in more shallow networks with lower receptive field like ResNet50. In the paper it was empirically shown that ResNet50 benefited the least from SparK, failing to match the performance of supervised ResNet101. I was wonder if the authors or anyone else tried using sparse group normalization with ResNet so there would be some global interaction of feature channels to better allow the learning of high level features. Masked autoencoder pretraining has shown alot of promise for data limited tasks in medical imaging and ResNet50 is commonly used by practitioners, so understanding how to most effectively use SparK pretraining has big implications for many in the field.

@keyu-tian
Copy link
Owner

keyu-tian commented Jan 30, 2024

@csvance very insightful thinking. I've also heard before that using a 3D sparse convolutional backbone network can lead to insufficient global information interaction in 3D point cloud perception (actually the interaction only occurs within connected components).

Yeah GroupNorm, LayerNorm or some attention-like operator can alleviate this problem. It's a promising direction to explore.

@csvance
Copy link
Author

csvance commented Jan 31, 2024

Running some experiments with this on an internal dataset using the Big Transfer ResNetV2 architecture. One of the other reasons I think GroupNorm might be promising was its transfer learning performance as demonstrated in the Big Transfer paper. Even though sparse normalization counteracts some of the distribution shift, there is going to be a higher degree of feature interaction with unmasked input. Group norm could possibly be more robust against this than batch norm for pretraining -> training. If it shows promise I will do an ImageNet run and post the results here.

@csvance
Copy link
Author

csvance commented Mar 8, 2024

Hello, I know the paper says you use a batch size of 4096, but was curious how many GPU that was split between? Having some stability issues and I suspect it has to do with effective batch size for batch norm in the decoder. Previously I was using a batch size of 64 and accumulating gradient 64 times on single RTX 3090 24GB to get 4096. Now I have access to 4x A6000 48GB and am trying batch size 128 + gradient accumulation 8 to get 4096 and using sync batch norm same as SparK decoder. Hoping that having a much higher effective batch size for batch norm in decoder will be the key to stop training from diverging.

@keyu-tian
Copy link
Owner

@csvance Yeah the sync batch norm and a big enough batch size are important for BN stability. We used 32 Tesla A100s, bs=128 per GPU (so total bs will be 4096) in most of the time, and didn't use the gradient accumulation. I think bs=64 is too small for BN, and 4x128=512 can be better.

@csvance
Copy link
Author

csvance commented Mar 12, 2024

Yeah I'm definitely seeing a big difference between my new and old setup. There is still some instability with 4*128 effective batch size for sync batch norm, but things converge much better than I have seen before. It looks like BatchNorm + large batch size is crucial for the decoder here, I have tried decoder with GroupNorm and convergence is significantly worse without any improvement to stability.

Just as an experiment I'm running with an image size of 128x128 and using a batch size of 512 per GPU giving me 2048 sync BN batch size (accumulate gradient twice to get 4096 for optimizer step). Will be interesting to see if there is still issues with constant gradient explosion. Here is what the divergence looks like in the loss curve, it pretty much always happens when I reach a certain loss around 0.3 MSE or so. Doesn't matter even when I fine tuning gradient clipping, learning rate etc, it's like the loss landscape is extremely sharp / unstable without sufficient batch size for batch norm.

image

@csvance
Copy link
Author

csvance commented Mar 15, 2024

I was able to get SparK to converge with LayerNorm in the decoder instead of BatchNorm! I had forgot to enable decoupled weight decay with the optimizer I was using, which was the source of the divergences (too much weight decay relative to learning rate). Still during training there are some times where the loss spikes a bit, but its not extreme and starts to decrease again to a better minima.

I have no doubt that BatchNorm will converge faster still, but using LayerNorm in decoder could be a good option for those who do not have access to a huge number of GPU.

@keyu-tian
Copy link
Owner

@csvance Happy to hear that! and thanks for your effort. Substituting BN with LN or GN (groupnorm) is indeed a valuable try, and I guess that BN isn't always essential. We initially adopted BN just because UNet used it, but I believe LN or GN could effectively replace BN without a lot of performance drop, and yes, this could be particularly beneficial for those with limited GPU resources.

@csvance
Copy link
Author

csvance commented Mar 18, 2024

For using SparK with backscatter X-ray images, I found it was good to use a larger epsilon for tile normalization and also normalize x_hat tiles. Reason for this is there is many tiles which are mostly background since alot of X-ray are taller than they are wide and often have large segments of noisy background. This made the learned representation transfer better for downstream problems. Without the large epsilon, training is unstable when normalizing x_hat tiles at the start of training which seems to negatively impact the learned representation. I suspect normalizing x_hat is a useful inductive bias, but I havn't tried any of this with ImageNet yet.

Until now I have been working with a relatively small subset of my dataset, roughly ~100k. Going to ramp things up several order of magnitude now. Results on downstream tasks are very promising even with such few images. Downstream is already close to ImageNet21K transfer performance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants