Wouter Coekaerts's square root puzzle presents a difficult challenge. We need to figure out the number he's thinking of without seeing the number. The only question we can ask is: "is it this number?" (In fact, it's a bit worse than that: when we ask that question, he just writes the answer on the console if we're right.) There are a few possible toeholds here where we could get started. I'll discuss some I rejected before I get to my actual solution.
Wouter gives us the source code for the
answer() method, so we know
how he's testing the answer. We also know how he generated the answer: a BigInteger of approximately 10000 random bits generated from a SecureRandom.
The easiest approach would be to just look at the answer. However the answer (actually the square of the answer) is stored in a private field. We can't get at it without using reflection, JNI, sun.misc.Unsafe, or something else of that nature. The challenge doesn't permit those types of attacks. What about accessing it as an inner class? Inner classes can access private fields of their enclosing classes, so could we trick the JVM into thinking our code is an inner class of square.Square? You can't do this, because inner classes are just a fiction created by the compiler. It might look like you can access private fields of your enclosing class, but the compiler actually adds hidden accessors for those fields which it calls under the covers. This approach won't work.
What if we could influence the answer? Can we control what SecureRandom does? We can do this by providing a different source of entropy from the command line or by modifying /dev/random. However I'm not sure that these can be done with a security manager in place or without root access, and I think that they violate the "no security vulnerabilities" spirit of the challenge.
One promising approach relies on the lack of a final modifier on BigInteger. What if we don't pass in a BigInteger for the answer, but pass in our own subclass of BigInteger? For example, could we override
BigInteger.equals() to always return true? This looked like a promising line of attack, and I think it might actually work in some implementations. Unfortunately, the way Wouter wrote
answer() he doesn't directly invoke any methods on the candidate object we control and
OpenJDK's implementation of BigInteger.divide() doesn't call any methods on the divisor object either, so there are no opportunities there (however
GNU Classpath's Biginteger.divide() does call methods on the divisor). BigInteger probably should have been final, but unfortunately this oversight doesn't seem to be sufficient to solve this puzzle. (If
answer() had been implemented as
root.equals(n.divide(root)) instead of
n.divide(root).equals(root)this solution would have worked well, because our own
equals() would be invoked.)
If we can't subclass BigInteger, can we just replace it? We can write our own implementation of BigInteger which always returns true for equals, but we'd need to replace the built-in one. This would require modifying the bootstrap classpath with
-Xbootclasspath/p:, which I think violates the rules, or at least the spirit, of the challenge.
What about a simple brute force approach? I was actually tempted to submit this. We know the upper bound of the answer (2
10000), so a simple for loop ought to suffice. Assuming we can test one answer per nanosecond (highly optimistic!),
wolframalpha estimates that it would take 6.322×10
2996 years, or 4.6×10
2986 times the age of the universe. The solution is provably correct, if somewhat impractical.
When I was looking at the implementation of
BigInteger.divide() to evaluate the feasibility of the subclassing attack, I noticed that there are a few special cases at the top of the method. Consider a/b. If a = b, the result of integral division will always be 1. If a < b, the result of integral division is always 0. If a > b, the code is much more complicated. This suggested that a timing attack might be feasible: we should be able to determine if our candidate is less than Wouter's recorded number by how long it takes for the
answer() function to complete.
Note that this attack is dependent on a number of things. We know how Wouter is testing the answer:
n.divide(root).equals(root). If he'd done the more efficient
n.equals(root.times(root)), instead, this attack wouldn't work (or would be much more difficult). It also relies on the JDK not doing unnecessary work for division when the dividend is less than the divisor. In Java 6 the JDK still calculates the remainder for this case (and discards it) making the attack more difficult but still feasible. In Java 7 the timing difference is more pronounced.
Timing the function is also quite challenging. For one thing, JVMs use JIT compilers which may compile code multiple times, with increasing levels of optimization, as they discover what's important. Therefore it's important to give the JIT plenty of opportunity to compile the code before we start trying to time it. I did most of my initial testing using -Xint, which disables the JIT, to eliminate this variable until I was confident that the solution worked. Garbage collection can also interfere with the timing, but it can only increase the time. If we only look at minimum times, and do enough measurements, it shouldn't matter. Other programs running on the machine can interfere with our timing, as can power management features. I tried running the solution on my MacBook Pro at first, but the timing was all over the place. In the end I used an isolated Linux x86-64 Xeon X5690 (Westmere) machine and Hotspot Java 1.7.0_09_b05 to test the solution. It was much more reliable.
Once the JIT is warmed up so we can get reliable timing, the next step is to calculate a base line. I do this by generating my own, random, 20000 bit number (the square of a 10000 bit number is a 20000 bit number) and testing how long it takes to divide that by number slightly lower than it. I repeat this a million times and record the minimum time it took to run my own
answer() function, which is identical to Wouter's. Now that we have the baseline for a 'slow' divide.
Someone with a better statistical background that me could probably come up with a robust way of distinguishing between 'fast' and 'slow' division. I just used a heuristic and ran the test many times. If the divide time is >= 75% of the fastest baseline time, I consider it slow. If it's <= 50% of the fastest baseline time, I consider it fast. If it's somewhere in between I run the measurement again until I get a conclusive result. I short circuit for fast times, since I assume that these must be fast due to the number being tested, but not for slow times, since they could be slow for any number of the reasons discussed above.
Then I start to identify bits. I start with the highest possible bit and work my way down. Whenever I discover that the divide is fast with a particular bit set, and slow with the same bit clear, I know that we've identified the next bit in n. When the first stage of my program finishes (it takes about 80 minutes), it should have identified the highest number for which
answer() is 'slow'. In other words, we know n-1.
Obviously determining n from n-1 is trivial, but Wouter doesn't actually want us to find n: he wants the square root of n. Unfortunately, there's no built-in library function to do this. I could have implemented
Newton's method to solve the square root, but fortunately someone else already did this for me:
BigIntegerMath in Google's Guava has a very efficient implementation that handles a lot of corner cases I probably wouldn't have bothered with. Wouter wanted the solution in a single class, and I didn't want to rely on libraries not included with the base JDK, so I copied the relevant code from Guava. Since this is an Apache licensed project, that's conveniently legal.
My program runs in about 80 minutes on Java 7 on an Intel Xeon Linux server. The actual time depends on the ratio of 0 bits to 1 bits in n. 0 bits can be identified much faster than 1 bits. It may or may not work on other machines without tweaking some of the timing parameters. There's no way to tell for sure from within the program if it got the right answer, since Wouter just prints something to the console, but I do check to see if it's feasible. Most integers have irrational square roots, but we know that n is a
perfect square. Therefore I test that the number I identified is also a perfect square. If it's not, I repeat the whole process again. A more sophisticated solution might track a confidence level for each bit and try remeasuring those, first, but it's always found the solution on the first attempt.
You can find the
source code at github.
Thanks to my colleagues at Two Sigma who provided useful insights into this problem. Andrew Berman and Yaron Gvili suggested attacking SecureRandom, either directly or via /dev/random and /dev/urandom. Yaron also discussed the brute force approach. Isaac Dooley steered me away from mean and towards min and also explored ClassLoader based solutions. Trammell Hudson encouraged me to continue with the timing attack even when initial results were frustratingly ambiguous. Another colleague who shall remain anonymous suggested I prove that P=NP, and then implement a polynomial-time solution.