Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve WoPBS API to avoid mistakes when using ciphertexts with varying degrees, consider having a carry invariant radix LUT generation #1016

Open
cgouert opened this issue Mar 21, 2024 · 19 comments
Assignees
Labels
enhancement New feature or request

Comments

@cgouert
Copy link

cgouert commented Mar 21, 2024

Describe the bug
After a smart_mul, the WoPBS does not yield the expected answer.

To Reproduce

use std::{collections::HashMap};
use tfhe::{
    integer::{
        gen_keys_radix, wopbs::*,
    },
    shortint::parameters::{
        parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
        PARAM_MESSAGE_2_CARRY_2_KS_PBS,
    },
};

fn foo(x: u64, lut_entries: &HashMap<u64, u64>) -> u64 {
    lut_entries[&x]
}

fn main() {
    let nb_blocks: usize = 4;

    // Generate radix keys
    let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks.into());

    // Generate key for PBS (without padding)
    let wopbs_key = WopbsKey::new_wopbs_key(
        &client_key,
        &server_key,
        &WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
    );

    // Create ciphertexts 
    let mut ct = client_key.encrypt(2_u64);
    let mut ct_2 = client_key.encrypt(4_u64);

    // Generate LUTs for WoPBS
    let mut lut_1_map : HashMap<u64, u64> = HashMap::new();
    let mut lut_2_map : HashMap<u64, u64> = HashMap::new();
    for i in 0..256 {
      lut_1_map.insert(i, 2*i % 256);
      lut_2_map.insert(i, 3*i % 256);
    }
    let lut_1 = wopbs_key.generate_lut_radix(&ct, |x: u64| foo(x, &lut_1_map));
    let lut_2 = wopbs_key.generate_lut_radix(&ct, |x: u64| foo(x, &lut_2_map));
    
    // Multiply input ciphertexts: 2 * 4 = 8
    ct = server_key.smart_mul(&mut ct, &mut ct_2);

    // Apply LUT #1
    ct = wopbs_key.keyswitch_to_wopbs_params(&server_key, &ct);
    let lut_1_res = wopbs_key.wopbs(&ct, &lut_1);
    let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);
    let test: u64 = client_key.decrypt(&lut_1_res);
    println!("Lut #1 result: {:?}", &test);
    println!("Expected result: {:?}", &lut_1_map[&8]);

    // Apply LUT #2
    let lut_2_res = wopbs_key.wopbs(&ct, &lut_2);
    let lut_2_res = wopbs_key.keyswitch_to_pbs_params(&lut_2_res);
    let test: u64 = client_key.decrypt(&lut_2_res);
    println!("Lut #2 result: {:?}", &test);
    println!("Expected result: {:?}", &lut_2_map[&8]);

}

Expected behaviour
For the above code, the decrypted results do not match the expected results.

Evidence
image

Configuration(please complete the following information):

  • OS: Ubuntu 22.04

cc: @jimouris

@IceTDrinker
Copy link
Member

IceTDrinker commented Mar 21, 2024

@IceTDrinker
Copy link
Member

I see some mention of without padding in your code example ?

@IceTDrinker
Copy link
Member

ah no my bad I mixed things up 😵‍💫

@IceTDrinker
Copy link
Member

can confirm it reproes on latest main

@IceTDrinker IceTDrinker added bug Something isn't working and removed triage_required labels Mar 21, 2024
@IceTDrinker
Copy link
Member

mul_parallelized does not have the issue, looks like a bad carry management somewhere, could be keyswitching or the lut generation which does not handle this properly

@IceTDrinker
Copy link
Member

IceTDrinker commented Mar 21, 2024

This is not a bug on our end but a hard to use feature @cgouert

see updated code below, the main thing is to move the lut generation (which is adapted to the ciphertext degree) right before applying the wopbs

use std::collections::HashMap;
use tfhe::integer::gen_keys_radix;
use tfhe::integer::wopbs::*;
use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS;
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;

fn foo(x: u64, lut_entries: &HashMap<u64, u64>) -> u64 {
    lut_entries[&x]
}

fn main() {
    let nb_blocks: usize = 4;

    // Generate radix keys
    let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks.into());

    // Generate key for PBS (without padding)
    let wopbs_key = WopbsKey::new_wopbs_key(
        &client_key,
        &server_key,
        &WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
    );

    let clear_1 = 2_u64;
    let clear_2 = 4_u64;

    // Create ciphertexts
    let mut ct = client_key.encrypt(clear_1);
    let mut ct_2 = client_key.encrypt(clear_2);

    // Generate LUTs for WoPBS
    let mut lut_1_map: HashMap<u64, u64> = HashMap::new();
    let mut lut_2_map: HashMap<u64, u64> = HashMap::new();
    for i in 0..256 {
        lut_1_map.insert(i, 2 * i % 256);
        lut_2_map.insert(i, 3 * i % 256);
    }

    let f1 = |x: u64| foo(x, &lut_1_map);
    let f2 = |x: u64| foo(x, &lut_2_map);

    // Multiply input ciphertexts: 2 * 4 = 8
    ct = server_key.smart_mul(&mut ct, &mut ct_2);

    let sanity_dec: u64 = client_key.decrypt(&ct);
    let clear_prod = clear_1 * clear_2;
    assert_eq!(sanity_dec, clear_prod);

    // Lut generation just in time
    let lut_1 = wopbs_key.generate_lut_radix(&ct, f1);
    let lut_2 = wopbs_key.generate_lut_radix(&ct, f2);

    // Apply LUT #1
    ct = wopbs_key.keyswitch_to_wopbs_params(&server_key, &ct);
    let lut_1_res = wopbs_key.wopbs(&ct, &lut_1);
    let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);
    let test: u64 = client_key.decrypt(&lut_1_res);
    println!("Lut #1 result: {:?}", &test);
    println!("Expected result: {:?}", f1(clear_prod));

    // Apply LUT #2
    let lut_2_res = wopbs_key.wopbs(&ct, &lut_2);
    let lut_2_res = wopbs_key.keyswitch_to_pbs_params(&lut_2_res);
    let test: u64 = client_key.decrypt(&lut_2_res);
    println!("Lut #2 result: {:?}", &test);
    println!("Expected result: {:?}", f2(clear_prod));
}
RUSTFLAGS="-C target-cpu=native" cargo run --profile devo --features=x86_64-unix,integer,internal-keycache --example wop_smart_mul -p tfhe
   Compiling tfhe v0.6.0 (/home/***/Documents/zama/code/tfhe-rs/tfhe)
    Finished devo [optimized + debuginfo] target(s) in 8.01s
     Running `target/devo/examples/wop_smart_mul`
Lut #1 result: 16
Expected result: 16
Lut #2 result: 24
Expected result: 24

@IceTDrinker IceTDrinker removed the bug Something isn't working label Mar 21, 2024
@IceTDrinker
Copy link
Member

we'll consider adding an "apply_function" that just takes the function and manages the JIT LUT generation

@cgouert
Copy link
Author

cgouert commented Mar 21, 2024

Thanks for your help with this @IceTDrinker! Just to clarify, are you saying the smart_mul changes the degree of the ciphertexts? If this is the case, the LUT doesn't work as expected because it was generated based on the original degree of a fresh encryption?

@jimouris
Copy link

@IceTDrinker Thanks for your responses!

It seems that just putting the smart_mul before the generate_lut_radix solves the issue but what if we want to do something like:

smart_mul(..)
let lut_1_res = wopbs_key.wopbs(&ct, &lut_1);
let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);

smart_mul(..)
let lut_1_res = wopbs_key.wopbs(&ct, &lut_1);
let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);

smart_mul(..)
let lut_1_res = wopbs_key.wopbs(&ct, &lut_1);
let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);

In this case, we'd have to run generate_lut_radix three times although the LUTs will be the same?

@IceTDrinker
Copy link
Member

Thanks for your help with this @IceTDrinker! Just to clarify, are you saying the smart_mul changes the degree of the ciphertexts? If this is the case, the LUT doesn't work as expected because it was generated based on the original degree of a fresh encryption?

exactly and yes it's expected that multiplying changes the degrees of the underlying blocks :)

@IceTDrinker
Copy link
Member

@IceTDrinker Thanks for your responses!

It seems that just putting the smart_mul before the generate_lut_radix solves the issue but what if we want to do something like:

smart_mul(..)
let lut_1_res = wopbs_key.wopbs(&ct, &lut_1);
let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);

smart_mul(..)
let lut_1_res = wopbs_key.wopbs(&ct, &lut_1);
let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);

smart_mul(..)
let lut_1_res = wopbs_key.wopbs(&ct, &lut_1);
let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);

In this case, we'd have to run generate_lut_radix three times although the LUTs will be the same?

my best guess here is that it does a lazy lut evaluation filling as little coefficients as it can, as having more 0s in the LUT IIRC will result in less noise in the output (could be wrong on that, but if you take a trivial 0 lut you have a noiseless encryption of 0 so there may be some of that)

I don't think the LUT generation would show up on a performance measurement when compared to the runtime of a wopbs so yes you likely need a LUT generation each time, you could do a small wrapper to generate the lut and apply it right afterwards to never have issues with degrees

@IceTDrinker
Copy link
Member

or rather the way the LUT is organized is not invariant by the carry bits of the ciphertexts 🤔

maybe there is something to do about it but I'm really not sure, I was not part of the team who initially worked on that, still it could be interesting to investigate if carry invariant LUTs are "easy" to write for the wopbs, the lazy eval thing is likely wrong now that I think about it

@IceTDrinker
Copy link
Member

I'm going to keep this issue open as an "enhancement" issue to see if there is something to be done for the LUT or the API to limit the error prone-ness

@IceTDrinker IceTDrinker changed the title WoPBS error for radix ciphertexts after multiplication Improve WoPBS API to avoid mistakes when using ciphertexts with varying degress, consider having a carry invariant radix LUT generation Mar 21, 2024
@IceTDrinker IceTDrinker added the enhancement New feature or request label Mar 21, 2024
@IceTDrinker IceTDrinker changed the title Improve WoPBS API to avoid mistakes when using ciphertexts with varying degress, consider having a carry invariant radix LUT generation Improve WoPBS API to avoid mistakes when using ciphertexts with varying degrees, consider having a carry invariant radix LUT generation Mar 21, 2024
@IceTDrinker
Copy link
Member

The solution will likely to be to have an API taking a function and building the LUT just in time

@IceTDrinker IceTDrinker self-assigned this May 21, 2024
@Juul-Mc-Goa
Copy link

Stumbled on the same issue recently.

Generating the lut manually solved the problem. The only thing changed in this function is the computation of vec_deg_basis. In my fix, all the degrees are assumed to be maximal (ie equal to the modulus).

Note that this quickfix can probably be improved:

    pub fn generate_lut_radix<F, T>(wopbs_key: &tfhe::shortint::WopbsKey, ct: &T, f: F) -> IntegerWopbsLUT
    where
        F: Fn(u64) -> u64,
        T: IntegerCiphertext,
    {
        let mut total_bit = 0;
        let block_nb = ct.blocks().len();
        let mut modulus = 1;

        //This contains the basis of each block depending on the degree
        let mut vec_deg_basis = vec![];

        for (i, _deg) in ct.moduli().iter().zip(ct.blocks().iter()) {
            modulus *= i;
            let b = f64::log2(*i as f64).ceil() as u64;
            vec_deg_basis.push(b);
            total_bit += b;
        }

        let lut_size = if 1 << total_bit < wopbs_key.param.polynomial_size.0 as u64 {
            wopbs_key.param.polynomial_size.0
        } else {
            1 << total_bit
        };
        let mut lut = IntegerWopbsLUT::new(PlaintextCount(lut_size), CiphertextCount(block_nb));

        let basis = ct.moduli()[0];
        let delta: u64 = (1 << 63)
            / (wopbs_key.param.message_modulus.0 * wopbs_key.param.carry_modulus.0) as u64;

        for lut_index_val in 0..(1 << total_bit) {
            let encoded_with_deg_val = encode_mix_radix(lut_index_val, &vec_deg_basis, basis);
            let decoded_val = decode_radix(&encoded_with_deg_val, basis);
            let f_val = f(decoded_val % modulus) % modulus;
            let encoded_f_val = encode_radix(f_val, basis, block_nb as u64);
            for (lut_number, radix_encoded_val) in encoded_f_val.iter().enumerate().take(block_nb) {
                lut[lut_number][lut_index_val as usize] = radix_encoded_val * delta;
            }
        }
        lut
    }

Remarks

  1. This issue appears because generate_lut_radix is tied to a specific ciphertext. Thus evaluating the lut at a different ciphertext is currently unsupported.
  2. I guess one could change the signature of generate_lut_radix to take as input not a ciphertext, but an argument like maximum_index: u64. Then the function would generate a lut that can compute the value at any index between 0 and maximum_index. The for lut_index_val in ... loop would then iterate in the range 0..maximum_index, and the computation of vec_deg_basis would probably be removed.

@Juul-Mc-Goa
Copy link

I assume the computation of vec_deg_basis is a feature introduced to optimize the lut computation on the one ciphertext used at generate_lut_radix. This optimization is legitimate, but breaks when trying to compute over other ciphertexts.

One solution would be to mimic Rust's FnOnce/Fn pattern:

  • have an IntegerWopbsLUTOnce type where the current optimization is kept,
  • have an IntegerWopbsLUT type without the optimization, so that it can be applied to several ciphertexts.

@IceTDrinker
Copy link
Member

I think the reasoning when this was done was to have the smallest amount of data in the LUT and extract the smallest amount of bits when possible instead of extracting all of them, we’ll most likely go the route of the apply function and put warnings on this API

@IceTDrinker
Copy link
Member

The FnOnce/Fn idea is interesting but how do you enforce running the LUT only on the ciphertext it was designed for ? On the other hand any ciphertext with the same degree profile is compatible with that LUT

@Juul-Mc-Goa
Copy link

I guess you can either:

  • not enforce anything, and put a big warning sign in the documentation saying: "this works only for some degree profiles"
  • clone and store the ct argument fed in the generate_lut_radix method, then reuse it when evaluating the lut

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants