Bitwise Multiply and Add in Java

Let’s begin by looking the multiplication code. The idea is actually pretty clever. Suppose that you have n1 and n2 written in binary. Then you can think of n1 as a sum of powers of two: n2 = c30 230 + c29 229 + … + c1 21 + c0 20, where each ci is either 0 or 1. Then you can think of the product n1 n2 as

n1 n2 =

n1 (c30 230 + c29 229 + … + c1 21 + c0 20) =

n1 c30 230 + n1 c29 229 + … + n1 c1 21 + n1 c0 20

This is a bit dense, but the idea is that the product of the two numbers is given by the first number multiplied by the powers of two making up the second number, times the value of the binary digits of the second number.

The question now is whether we can compute the terms of this sum without doing any actual multiplications. In order to do so, we’re going to need to be able to read the binary digits of n2. Fortunately, we can do this using shifts. In particular, suppose we start off with n2 and then just look at the last bit. That’s c0. If we then shift the value down one position, then the last bit is c0, etc. More generally, after shifting the value of n2 down by i positions, the lowest bit will be ci. To read the very last bit, we can just bitwise AND the value with the number 1. This has a binary representation that’s zero everywhere except the last digit. Since 0 AND n = 0 for any n, this clears all the topmost bits. Moreover, since 0 AND 1 = 0 and 1 AND 1 = 1, this operation preserves the last bit of the number.

Okay – we now know that we can read the values of ci; so what? Well, the good news is that we also can compute the values of the series n1 2i in a similar fashion. In particular, consider the sequence of values n1 << 0, n1 << 1, etc. Any time you do a left bit-shift, it’s equivalent to multiplying by a power of two. This means that we now have all the components we need to compute the above sum. Here’s your original source code, commented with what’s going on:

public static void bitwiseMultiply(int n1, int n2) {
    /* This value will hold n1 * 2^i for varying values of i.  It will
     * start off holding n1 * 2^0 = n1, and after each iteration will 
     * be updated to hold the next term in the sequence.
     */
    int a = n1;

    /* This value will be used to read the individual bits out of n2.
     * We'll use the shifting trick to read the bits and will maintain
     * the invariant that after i iterations, b is equal to n2 >> i.
     */
    int b = n2;

    /* This value will hold the sum of the terms so far. */
    int result = 0;

    /* Continuously loop over more and more bits of n2 until we've
     * consumed the last of them.  Since after i iterations of the
     * loop b = n2 >> i, this only reaches zero once we've used up
     * all the bits of the original value of n2.
     */
    while (b != 0)
    {
        /* Using the bitwise AND trick, determine whether the ith 
         * bit of b is a zero or one.  If it's a zero, then the
         * current term in our sum is zero and we don't do anything.
         * Otherwise, then we should add n1 * 2^i.
         */
        if ((b & 1) != 0)
        {
            /* Recall that a = n1 * 2^i at this point, so we're adding
             * in the next term in the sum.
             */
            result = result + a;
        }

        /* To maintain that a = n1 * 2^i after i iterations, scale it
         * by a factor of two by left shifting one position.
         */
        a <<= 1;

        /* To maintain that b = n2 >> i after i iterations, shift it
         * one spot over.
         */
        b >>>= 1;
    }

    System.out.println(result);
}

Hope this helps!

Leave a Comment