Calculating π in Clojure (Salamin-Brent)

Took a shot at implementing a PI digit generator in Clojure using a ‘fast’ algorithm.
It seemed like a decent enough excercise to try and understand something about performance in Clojure.

MacBook Pro – Intel Core 2 Duo 2.26 GHz – 4GB RAM
Java(TM) SE Runtime Environment (build 1.6.0_15-b03-219)
Java HotSpot(TM) 64-Bit Server VM (build 14.1-b02-90, mixed mode)
Clojure 1.1.0-alpha-SNAPSHOT (Aug 20 2009) git commit f1f5ad40984d46bdc314090552b76471ee2b8d01

Clojure matches the performance of Java in this example.

The Clojure code :

(import 'java.lang.Math)
(import 'java.math.MathContext)
(import 'java.math.BigDecimal)
(defn sb-pi [places]
  "Calculates PI digits using the Salamin-Brent algorithm
   and Java's BigDecimal class."
  (let [digits (.intValue (+ 10 places)) ;; add some guard digits
        round-mode BigDecimal/ROUND_DOWN]
    (letfn [(big-sqrt[#^BigDecimal num]
             "Calculates square root using Newton's method."
             (letfn [(big-sqrt-int
                      [#^BigDecimal num #^BigDecimal x0 #^BigDecimal x1]
                      "aux function for calculating square root"
                      (let [#^BigDecimal x0new x1
                            #^BigDecimal x1new (-> num (.divide x0new digits round-mode))
                            #^BigDecimal xsum (+ x1new x0new)
                            #^BigDecimal x1tot (-> xsum (.divide 2M digits round-mode))]
                        (if (= x0 x1)
                          x1tot
                          (recur num x1 x1tot))))]
               (big-sqrt-int
                num 0M (BigDecimal/valueOf
                        (Math/sqrt (. num doubleValue))))))
            (sb-pi-int [#^BigDecimal a #^BigDecimal b
                        #^BigDecimal t #^BigDecimal x #^BigDecimal y]
             "aux function for calculating PI"
             (let
                 [#^BigDecimal y1 a
                  #^BigDecimal absum (+ a b)
                  #^BigDecimal a1 (-> absum (.divide 2M digits round-mode))
                  #^BigDecimal b1 (big-sqrt (* b y1))
                  #^BigDecimal ydiff (- y1 a1)
                  #^BigDecimal t1 (- t (* x ydiff ydiff))
                  #^BigDecimal x1 (* x 2M)]
               (if (== a b)
                 (let [#^BigDecimal absum1 (+ a1 b1)
                       #^BigDecimal absqrd (* absum1 absum1)
                       #^BigDecimal u (* t1 4M)]
                   (-> absqrd
                       (.divide u digits round-mode)
                       (.setScale places round-mode)))
                 (recur a1 b1 t1 x1 y1))))]
      (sb-pi-int 1M (-> 1M (.divide #^BigDecimal (big-sqrt 2M) digits round-mode))
                       (/ 1M 4M) 1M nil))))
(time (println (sb-pi (Integer/parseInt (second *command-line-args*)))))
$ time clj pi.clj 1               -->       3.403 msecs
$ time clj pi.clj 10              -->       3.956 msecs
$ time clj pi.clj 100             -->      10.630 msecs
$ time clj pi.clj 1000            -->     141.937 msecs
$ time clj pi.clj 10000           -->    3316.180 msecs

The same algorithm in Java (but using iteration instead of recursion) :

import java.math.BigDecimal;
import static java.math.BigDecimal.*;
class Pi {
  private static final BigDecimal TWO = new BigDecimal(2);
  private static final BigDecimal FOUR = new BigDecimal(4);
  private static int ROUND_MODE = ROUND_DOWN;
  public static void main(String[] args) {
    long start = System.nanoTime();
    System.out.println(pi(Integer.parseInt(args[0])));
    System.out.println("Elapsed time: " +
                       ((System.nanoTime() - start) / 1E6) + " msecs");
  }
  // Salamin-Brent Algorithm
  public static BigDecimal pi(final int digits) {
    final int SCALE = 10 + digits;
    BigDecimal a = ONE;
    BigDecimal b = ONE.divide(sqrt(TWO, SCALE), SCALE, ROUND_MODE);
    BigDecimal t = new BigDecimal(0.25);
    BigDecimal x = ONE;
    BigDecimal y;
    while (!a.equals(b)) {
      y = a;
      a = a.add(b).divide(TWO, SCALE, ROUND_MODE);
      b = sqrt(b.multiply(y), SCALE);
      t = t.subtract(x.multiply(y.subtract(a).multiply(y.subtract(a))));
      x = x.multiply(TWO);
    }
    return a.add(b)
      .multiply(a.add(b))
      .divide(t.multiply(FOUR), SCALE, ROUND_MODE)
      .setScale(digits, ROUND_MODE);
  }
  // square root method (Newton's)
  public static BigDecimal sqrt(BigDecimal A, final int SCALE) {
    BigDecimal x0 = new BigDecimal("0");
    BigDecimal x1 = new BigDecimal(Math.sqrt(A.doubleValue()));
    while (!x0.equals(x1)) {
      x0 = x1;
      x1 = A.divide(x0, SCALE, ROUND_MODE);
      x1 = x1.add(x0);
      x1 = x1.divide(TWO, SCALE, ROUND_MODE);
    }
    return x1;
  }
}
$ time java Pi 1         ---->         2.162 msecs
$ time java Pi 10        ---->         2.425 msecs
$ time java Pi 100       ---->         7.897 msecs
$ time java Pi 1000      ---->       150.610 msecs
$ time java Pi 10000     ---->      3009.705 msecs