Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binary classification .predict raises a ValueError: could not broadcast input array from shape (2,2) into shape (1,2) #1563

Open
gaborfodor720818 opened this issue Feb 2, 2024 · 3 comments

Comments

@gaborfodor720818
Copy link

Hi,

I have tried the original simpletransformers sample code on my local python with the latest simpletransformers version 0.65.1.
I have trained the model using bert model type and "bert-base-uncased". Training was done, but the prediction raises always an error.

The code: predictions, raw_outputs = model.predict(to_predict=["Example sentence belonging to class 1"])
Error:
predictions, raw_outputs = model.predict(to_predict=["Example sentence belonging to class 1"])
File "C:\Work\PythonProjects\TenderAI\venv\lib\site-packages\simpletransformers\classification\classification_model.py", line 2217, in predict
preds[start_index:end_index] = logits.detach().cpu().numpy()
ValueError: could not broadcast input array from shape (2,2) into shape (1,2)

@ThilinaRajapakse
Copy link
Owner

  1. Does this only happen with the latest version?
  2. Does it only happen with the model that you trained with an earlier version when trying to use it with the current version?

I tried to reproduce it with just bert-base-uncased but it's working fine for me.

@gaborfodor720818
Copy link
Author

  1. Does this only happen with the latest version?
  2. Does it only happen with the model that you trained with an earlier version when trying to use it with the current version?

I tried to reproduce it with just bert-base-uncased but it's working fine for me.

Strange, now it works. I have tried on a linux machine (previously it was W11) with a clean setup. Simpletransformers is the latest version again 0.65.1. ( I have never tried with older version.)
Thanks for testing.

@gaborfodor720818
Copy link
Author

I am not sure it is a bug, but I have found the reason. If I set "use_cached_eval_features": True I get the error message. I have also this "evaluate_during_training": True, because I do the evaluation during the training: model.train_model(train_df, eval_df=eval_df).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants