Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

quantize_embeddings + KeyedJaggedTensor+ vbe cannot work #2513

Open
yjjinjie opened this issue Apr 18, 2024 · 0 comments
Open

quantize_embeddings + KeyedJaggedTensor+ vbe cannot work #2513

yjjinjie opened this issue Apr 18, 2024 · 0 comments

Comments

@yjjinjie
Copy link



import torch
from torchrec import KeyedJaggedTensor
from torchrec import EmbeddingBagConfig,EmbeddingConfig
from torchrec import EmbeddingBagCollection,EmbeddingCollection


kt2 = KeyedJaggedTensor(
    keys=['user_id', 'item_id', 'id_3', 'id_4', 'id_5', 'raw_1', 'raw_4', 'combo_1', 'lookup_2', 'lookup_3', 'lookup_4', 'match_2', 'match_3', 'match_4', 'click_50_seq__item_id', 'click_50_seq__id_3', 'click_50_seq__raw_1'], 
    values=torch.tensor([573174,   5073,   3562,      3,     18,     13,     11,     49,     26,
             4,      2,      2,      4,      2,      4, 736847, 849333, 997432,
        640218,   9926,   9926,      0,      0,      0,      0,  59926,  59926,
             0,      0,      0,      0,   2835,    769,   1265,   8232,   6399,
           114,   7487,   2876,    953,   7840,   7538,   7998,   7852,   3528,
          1475,   7620,   6110,    572,    735,   4405,   5655,   6736,   2173,
          3421,   2311,   7122,   2159,   4535,   2162,   4657,   3151,   4522,
          1075,    306,   8968,   2056,   2256,   3919,   8624,   5372,   6018,
          3861,   4114,   3984,   2287,   1481,   4757,   1189,   2518,    913,
          9421,   3093,   5911,   9704,   8168,   9410,    728,   2451,    243,
          5187,   5836,   8830,   4894,    614,   7705,   9258,   3518,   4434,
             4,      2,      4,      2,      4,      2,      3,      2,      2,
             3,      3,      3,      4,      4,      3,      0,      4,      0,
             2,      2,      3,      4,      4,      0,      2,      2,      4,
             0,      3,      2,      2,      3,      0,      4,      0,      4,
             4,      4,      2,      2,      3,      4,      2,      4,      3,
             4,      2,      4,      2,      2,      2,      2,      0,      3,
             4,      4,      3,      2,      4,      4,      4,      4,      3,
             2,      3,      4,      2,      4,      0,      4,      4,      4,
             4,      0,      0,      2,      1,      1,      0,      3,      4,
             4,      2,      4,      1,      1,      4,      2,      2,      4,
             0,      4,      4,      4,      4,      4,      1,      4,      2,
             0,      0,      0,      2,      4,      4,      2,      4,      2,
             4,      4,      1,      1,      4,      1,      4,      4,      1,
             0,      4,      4,      4,      3,      0,      0,      2,      4,
             2,      2,      4,      4,      4,      2,      2,      4,      2,
             3]),
    lengths=torch.tensor([ 1,  1,  1,  1,  0,  0,  1,  2,  2,  1,  1,  4,  2,  2,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1, 24, 44, 24, 44, 24, 44], dtype=torch.int64),
    stride_per_key_per_rank=[[1], [2], [2], [2], [2], [2], [1], [2], [2], [2], [2], [2], [2], [2], [2], [2], [2]],
    inverse_indices=(['user_id', 'item_id', 'id_3', 'id_4', 'id_5', 'raw_1', 'raw_4', 'combo_1', 'lookup_2', 'lookup_3', 
                      'lookup_4', 'match_2', 'match_3', 'match_4', 'click_50_seq__item_id', 'click_50_seq__id_3', 
                      'click_50_seq__raw_1'], 
                     torch.tensor([[0, 0], [0, 1],[0, 1], [0, 1], [0, 1], [0, 1],[0, 0], [0, 1], [0, 1], [0, 1],
                                   [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1]])
    )
)

eb_configs2=[
    EmbeddingBagConfig(num_embeddings=1000000, embedding_dim=16, name='user_id_emb', feature_names=['user_id'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=10000, embedding_dim=16, name='item_id_emb', feature_names=['item_id'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=5, embedding_dim=8, name='id_3_emb', feature_names=['id_3'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None, need_pos=False, ),
EmbeddingBagConfig(num_embeddings=100, embedding_dim=16, name='id_4_emb', feature_names=['id_4', 'id_5'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=5, embedding_dim=16, name='raw_1_emb', feature_names=['raw_1'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=5, embedding_dim=16, name='raw_4_emb', feature_names=['raw_4'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=1000000, embedding_dim=16, name='combo_1_emb', feature_names=['combo_1'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=10000, embedding_dim=8, name='lookup_2_emb', feature_names=['lookup_2'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,   need_pos=False, ),
EmbeddingBagConfig(num_embeddings=1000, embedding_dim=8, name='lookup_3_emb', feature_names=['lookup_3'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=5, embedding_dim=16, name='lookup_4_emb', feature_names=['lookup_4'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=100000, embedding_dim=16, name='match_2_emb', feature_names=['match_2'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,   need_pos=False, ),
EmbeddingBagConfig(num_embeddings=10000, embedding_dim=8, name='match_3_emb', feature_names=['match_3'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,   need_pos=False, ),
EmbeddingBagConfig(num_embeddings=5, embedding_dim=16, name='match_4_emb', feature_names=['match_4'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),

]
ebc = EmbeddingBagCollection(eb_configs2)

print(ebc(kt2))
from torchrec.inference.modules import quantize_embeddings


import torch
import torch.nn as nn

class EmbeddingGroupImpl(nn.Module):
    def __init__(self,ebc):
        super().__init__()
        self.ebc=ebc
    
    def forward(
        self,
        sparse_feature
    ):
        self.ebc(sparse_feature)

a=EmbeddingGroupImpl(ebc=ebc)
a.forward(kt2)

quant_model = quantize_embeddings(a, dtype=torch.qint8, inplace=True)
print(quant_model(kt2))
Traceback (most recent call last):
  File "/larec/tzrec/tests/test_per2.py", line 89, in <module>
    print(quant_model(kt2))
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/larec/tzrec/tests/test_per2.py", line 83, in forward
    self.ebc(sparse_feature)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torchrec/quant/embedding_modules.py", line 487, in forward
    else emb_op.forward(
  File "/opt/conda/lib/python3.10/site-packages/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py", line 764, in forward
    torch.ops.fbgemm.bounds_check_indices(
  File "/opt/conda/lib/python3.10/site-packages/torch/_ops.py", line 758, in __call__
    return self._op(*args, **(kwargs or {}))
RuntimeError: offsets size 27 is not equal to B (1) * T (14) + 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant