跳转至

高精度計算

太長不看版:結尾自取模板……

定義

高精度計算(Arbitrary-Precision Arithmetic),也被稱作大整數(bignum)計算,運用了一些算法結構來支持更大整數間的運算(數字大小超過語言內建整型)。

引入

高精度問題包含很多小的細節,實現上也有很多講究。

所以今天就來一起實現一個簡單的計算器吧。

任務

輸入:一個形如 a <op> b 的表達式。

  • ab 分別是長度不超過 \(1000\) 的十進制非負整數;
  • <op> 是一個字符(+-*/),表示運算。
  • 整數與運算符之間由一個空格分隔。

輸出:運算結果。

  • 對於 +-* 運算,輸出一行表示結果;
  • 對於 / 運算,輸出兩行分別表示商和餘數。
  • 保證結果均為非負整數。

存儲

在平常的實現中,高精度數字利用字符串表示,每一個字符表示數字的一個十進制位。因此可以説,高精度數值計算實際上是一種特別的字符串處理。

讀入字符串時,數字最高位在字符串首(下標小的位置)。但是習慣上,下標最小的位置存放的是數字的 最低位,即存儲反轉的字符串。這麼做的原因在於,數字的長度可能發生變化,但我們希望同樣權值位始終保持對齊(例如,希望所有的個位都在下標 [0],所有的十位都在下標 [1]……);同時,加、減、乘的運算一般都從個位開始進行(回想小學的豎式運算),這都給了「反轉存儲」以充分的理由。

此後我們將一直沿用這一約定。定義一個常數 LEN = 1004 表示程序所容納的最大長度。

由此不難寫出讀入高精度數字的代碼:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
void clear(int a[]) {
  for (int i = 0; i < LEN; ++i) a[i] = 0;
}

void read(int a[]) {
  static char s[LEN + 1];
  scanf("%s", s);

  clear(a);

  int len = strlen(s);
  // 如上所述,反轉
  for (int i = 0; i < len; ++i) a[len - i - 1] = s[i] - '0';
  // s[i] - '0' 就是 s[i] 所表示的數碼
  // 有些同學可能更習慣用 ord(s[i]) - ord('0') 的方式理解
}

輸出也按照存儲的逆序輸出。由於不希望輸出前導零,故這裏從最高位開始向下尋找第一個非零位,從此處開始輸出;終止條件 i >= 1 而不是 i >= 0 是因為當整個數字等於 \(0\) 時仍希望輸出一個字符 0

1
2
3
4
5
6
7
void print(int a[]) {
  int i;
  for (i = LEN - 1; i >= 1; --i)
    if (a[i] != 0) break;
  for (; i >= 0; --i) putchar(a[i] + '0');
  putchar('\n');
}

拼起來就是一個完整的復讀機程序咯。

copycat.cpp
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include <cstdio>
#include <cstring>

static const int LEN = 1004;

int a[LEN], b[LEN];

void clear(int a[]) {
  for (int i = 0; i < LEN; ++i) a[i] = 0;
}

void read(int a[]) {
  static char s[LEN + 1];
  scanf("%s", s);

  clear(a);

  int len = strlen(s);
  for (int i = 0; i < len; ++i) a[len - i - 1] = s[i] - '0';
}

void print(int a[]) {
  int i;
  for (i = LEN - 1; i >= 1; --i)
    if (a[i] != 0) break;
  for (; i >= 0; --i) putchar(a[i] + '0');
  putchar('\n');
}

int main() {
  read(a);
  print(a);

  return 0;
}

四則運算

四則運算中難度也各不相同。最簡單的是高精度加減法,其次是高精度—單精度(普通的 int)乘法和高精度—高精度乘法,最後是高精度—高精度除法。

我們將按這個順序分別實現所有要求的功能。

加法

高精度加法,其實就是豎式加法啦。

也就是從最低位開始,將兩個加數對應位置上的數碼相加,並判斷是否達到或超過 \(10\)。如果達到,那麼處理進位:將更高一位的結果上增加 \(1\),當前位的結果減少 \(10\)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
void add(int a[], int b[], int c[]) {
  clear(c);

  // 高精度實現中,一般令數組的最大長度 LEN 比可能的輸入大一些
  // 然後略去末尾的幾次循環,這樣一來可以省去不少邊界情況的處理
  // 因為實際輸入不會超過 1000 位,故在此循環到 LEN - 1 = 1003 已經足夠
  for (int i = 0; i < LEN - 1; ++i) {
    // 將相應位上的數碼相加
    c[i] += a[i] + b[i];
    if (c[i] >= 10) {
      // 進位
      c[i + 1] += 1;
      c[i] -= 10;
    }
  }
}

試着和上一部分結合,可以得到一個加法計算器。

adder.cpp
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include <cstdio>
#include <cstring>

static const int LEN = 1004;

int a[LEN], b[LEN], c[LEN];

void clear(int a[]) {
  for (int i = 0; i < LEN; ++i) a[i] = 0;
}

void read(int a[]) {
  static char s[LEN + 1];
  scanf("%s", s);

  clear(a);

  int len = strlen(s);
  for (int i = 0; i < len; ++i) a[len - i - 1] = s[i] - '0';
}

void print(int a[]) {
  int i;
  for (i = LEN - 1; i >= 1; --i)
    if (a[i] != 0) break;
  for (; i >= 0; --i) putchar(a[i] + '0');
  putchar('\n');
}

void add(int a[], int b[], int c[]) {
  clear(c);

  for (int i = 0; i < LEN - 1; ++i) {
    c[i] += a[i] + b[i];
    if (c[i] >= 10) {
      c[i + 1] += 1;
      c[i] -= 10;
    }
  }
}

int main() {
  read(a);
  read(b);

  add(a, b, c);
  print(c);

  return 0;
}

減法

高精度減法,也就是豎式減法啦。

從個位起逐位相減,遇到負的情況則向上一位借 \(1\)。整體思路與加法完全一致。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
void sub(int a[], int b[], int c[]) {
  clear(c);

  for (int i = 0; i < LEN - 1; ++i) {
    // 逐位相減
    c[i] += a[i] - b[i];
    if (c[i] < 0) {
      // 借位
      c[i + 1] -= 1;
      c[i] += 10;
    }
  }
}

將上一個程序中的 add() 替換成 sub(),就有了一個減法計算器。

subtractor.cpp
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include <cstdio>
#include <cstring>

static const int LEN = 1004;

int a[LEN], b[LEN], c[LEN];

void clear(int a[]) {
  for (int i = 0; i < LEN; ++i) a[i] = 0;
}

void read(int a[]) {
  static char s[LEN + 1];
  scanf("%s", s);

  clear(a);

  int len = strlen(s);
  for (int i = 0; i < len; ++i) a[len - i - 1] = s[i] - '0';
}

void print(int a[]) {
  int i;
  for (i = LEN - 1; i >= 1; --i)
    if (a[i] != 0) break;
  for (; i >= 0; --i) putchar(a[i] + '0');
  putchar('\n');
}

void sub(int a[], int b[], int c[]) {
  clear(c);

  for (int i = 0; i < LEN - 1; ++i) {
    c[i] += a[i] - b[i];
    if (c[i] < 0) {
      c[i + 1] -= 1;
      c[i] += 10;
    }
  }
}

int main() {
  read(a);
  read(b);

  sub(a, b, c);
  print(c);

  return 0;
}

試一試,輸入 1 2——輸出 /9999999,誒這個 OI Wiki 怎麼給了我一份假的代碼啊……

事實上,上面的代碼只能處理減數 \(a\) 大於等於被減數 \(b\) 的情況。處理被減數比減數小,即 \(a<b\) 時的情況很簡單。

\(a-b=-(b-a)\)

要計算 \(b-a\) 的值,因為有 \(b>a\),可以調用以上代碼中的 sub 函數,寫法為 sub(b,a,c)。要得到 \(a-b\) 的值,在得數前加上負號即可。

乘法

高精度—單精度

高精度乘法,也就是豎……等會兒等會兒!

先考慮一個簡單的情況:乘數中的一個是普通的 int 類型。有沒有簡單的處理方法呢?

一個直觀的思路是直接將 \(a\) 每一位上的數字乘以 \(b\)。從數值上來説,這個方法是正確的,但它並不符合十進制表示法,因此需要將它重新整理成正常的樣子。

重整的方式,也是從個位開始逐位向上處理進位。但是這裏的進位可能非常大,甚至遠大於 \(9\),因為每一位被乘上之後都可能達到 \(9b\) 的數量級。所以這裏的進位不能再簡單地進行 \(-10\) 運算,而是要通過除以 \(10\) 的商以及餘數計算。詳見代碼註釋,也可以參考下圖展示的一個計算高精度數 \(1337\) 乘以單精度數 \(42\) 的過程。

當然,也是出於這個原因,這個方法需要特別關注乘數 \(b\) 的範圍。若它和 \(10^9\)(或相應整型的取值上界)屬於同一數量級,那麼需要慎用高精度—單精度乘法。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
void mul_short(int a[], int b, int c[]) {
  clear(c);

  for (int i = 0; i < LEN - 1; ++i) {
    // 直接把 a 的第 i 位數碼乘以乘數,加入結果
    c[i] += a[i] * b;

    if (c[i] >= 10) {
      // 處理進位
      // c[i] / 10 即除法的商數成為進位的增量值
      c[i + 1] += c[i] / 10;
      // 而 c[i] % 10 即除法的餘數成為在當前位留下的值
      c[i] %= 10;
    }
  }
}

高精度—高精度

如果兩個乘數都是高精度,那麼豎式乘法又可以大顯身手了。

回想豎式乘法的每一步,實際上是計算了若干 \(a \times b_i \times 10^i\) 的和。例如計算 \(1337 \times 42\),計算的就是 \(1337 \times 2 \times 10^0 + 1337 \times 4 \times 10^1\)

於是可以將 \(b\) 分解為它的所有數碼,其中每個數碼都是單精度數,將它們分別與 \(a\) 相乘,再向左移動到各自的位置上相加即得答案。當然,最後也需要用與上例相同的方式處理進位。

注意這個過程與豎式乘法不盡相同,我們的算法在每一步乘的過程中並不進位,而是將所有的結果保留在對應的位置上,到最後再統一處理進位,但這不會影響結果。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
void mul(int a[], int b[], int c[]) {
  clear(c);

  for (int i = 0; i < LEN - 1; ++i) {
    // 這裏直接計算結果中的從低到高第 i 位,且一併處理了進位
    // 第 i 次循環為 c[i] 加上了所有滿足 p + q = i 的 a[p] 與 b[q] 的乘積之和
    // 這樣做的效果和直接進行上圖的運算最後求和是一樣的,只是更加簡短的一種實現方式
    for (int j = 0; j <= i; ++j) c[i] += a[j] * b[i - j];

    if (c[i] >= 10) {
      c[i + 1] += c[i] / 10;
      c[i] %= 10;
    }
  }
}

除法

高精度除法的一種實現方式就是豎式長除法。

豎式長除法實際上可以看作一個逐次減法的過程。例如上圖中商數十位的計算可以這樣理解:將 \(45\) 減去三次 \(12\) 後變得小於 \(12\),不能再減,故此位為 \(3\)

為了減少冗餘運算,我們提前得到被除數的長度 \(l_a\) 與除數的長度 \(l_b\),從下標 \(l_a - l_b\) 開始,從高位到低位來計算商。這和手工計算時將第一次乘法的最高位與被除數最高位對齊的做法是一樣的。

參考程序實現了一個函數 greater_eq() 用於判斷被除數以下標 last_dg 為最低位,是否可以再減去除數而保持非負。此後對於商的每一位,不斷調用 greater_eq(),並在成立的時候用高精度減法從餘數中減去除數,也即模擬了豎式除法的過程。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
// 被除數 a 以下標 last_dg 為最低位,是否可以再減去除數 b 而保持非負
// len 是除數 b 的長度,避免反覆計算
bool greater_eq(int a[], int b[], int last_dg, int len) {
  // 有可能被除數剩餘的部分比除數長,這個情況下最多多出 1 位,故如此判斷即可
  if (a[last_dg + len] != 0) return true;
  // 從高位到低位,逐位比較
  for (int i = len - 1; i >= 0; --i) {
    if (a[last_dg + i] > b[i]) return true;
    if (a[last_dg + i] < b[i]) return false;
  }
  // 相等的情形下也是可行的
  return true;
}

void div(int a[], int b[], int c[], int d[]) {
  clear(c);
  clear(d);

  int la, lb;
  for (la = LEN - 1; la > 0; --la)
    if (a[la - 1] != 0) break;
  for (lb = LEN - 1; lb > 0; --lb)
    if (b[lb - 1] != 0) break;
  if (lb == 0) {
    puts("> <");
    return;
  }  // 除數不能為零

  // c 是商
  // d 是被除數的剩餘部分,算法結束後自然成為餘數
  for (int i = 0; i < la; ++i) d[i] = a[i];
  for (int i = la - lb; i >= 0; --i) {
    // 計算商的第 i 位
    while (greater_eq(d, b, i, lb)) {
      // 若可以減,則減
      // 這一段是一個高精度減法
      for (int j = 0; j < lb; ++j) {
        d[i + j] -= b[j];
        if (d[i + j] < 0) {
          d[i + j + 1] -= 1;
          d[i + j] += 10;
        }
      }
      // 使商的這一位增加 1
      c[i] += 1;
      // 返回循環開頭,重新檢查
    }
  }
}

入門篇完成!

將上面介紹的四則運算的實現結合,即可完成開頭提到的計算器程序。

calculator.cpp
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#include <cstdio>
#include <cstring>

static const int LEN = 1004;

int a[LEN], b[LEN], c[LEN], d[LEN];

void clear(int a[]) {
  for (int i = 0; i < LEN; ++i) a[i] = 0;
}

void read(int a[]) {
  static char s[LEN + 1];
  scanf("%s", s);

  clear(a);

  int len = strlen(s);
  for (int i = 0; i < len; ++i) a[len - i - 1] = s[i] - '0';
}

void print(int a[]) {
  int i;
  for (i = LEN - 1; i >= 1; --i)
    if (a[i] != 0) break;
  for (; i >= 0; --i) putchar(a[i] + '0');
  putchar('\n');
}

void add(int a[], int b[], int c[]) {
  clear(c);

  for (int i = 0; i < LEN - 1; ++i) {
    c[i] += a[i] + b[i];
    if (c[i] >= 10) {
      c[i + 1] += 1;
      c[i] -= 10;
    }
  }
}

void sub(int a[], int b[], int c[]) {
  clear(c);

  for (int i = 0; i < LEN - 1; ++i) {
    c[i] += a[i] - b[i];
    if (c[i] < 0) {
      c[i + 1] -= 1;
      c[i] += 10;
    }
  }
}

void mul(int a[], int b[], int c[]) {
  clear(c);

  for (int i = 0; i < LEN - 1; ++i) {
    for (int j = 0; j <= i; ++j) c[i] += a[j] * b[i - j];

    if (c[i] >= 10) {
      c[i + 1] += c[i] / 10;
      c[i] %= 10;
    }
  }
}

bool greater_eq(int a[], int b[], int last_dg, int len) {
  if (a[last_dg + len] != 0) return true;
  for (int i = len - 1; i >= 0; --i) {
    if (a[last_dg + i] > b[i]) return true;
    if (a[last_dg + i] < b[i]) return false;
  }
  return true;
}

void div(int a[], int b[], int c[], int d[]) {
  clear(c);
  clear(d);

  int la, lb;
  for (la = LEN - 1; la > 0; --la)
    if (a[la - 1] != 0) break;
  for (lb = LEN - 1; lb > 0; --lb)
    if (b[lb - 1] != 0) break;
  if (lb == 0) {
    puts("> <");
    return;
  }

  for (int i = 0; i < la; ++i) d[i] = a[i];
  for (int i = la - lb; i >= 0; --i) {
    while (greater_eq(d, b, i, lb)) {
      for (int j = 0; j < lb; ++j) {
        d[i + j] -= b[j];
        if (d[i + j] < 0) {
          d[i + j + 1] -= 1;
          d[i + j] += 10;
        }
      }
      c[i] += 1;
    }
  }
}

int main() {
  read(a);

  char op[4];
  scanf("%s", op);

  read(b);

  switch (op[0]) {
    case '+':
      add(a, b, c);
      print(c);
      break;
    case '-':
      sub(a, b, c);
      print(c);
      break;
    case '*':
      mul(a, b, c);
      print(c);
      break;
    case '/':
      div(a, b, c, d);
      print(c);
      print(d);
      break;
    default:
      puts("> <");
  }

  return 0;
}

壓位高精度

引入

在一般的高精度加法,減法,乘法運算中,我們都是將參與運算的數拆分成一個個單獨的數碼進行運算。

例如計算 \(8192\times 42\) 時,如果按照高精度乘高精度的計算方式,我們實際上算的是 \((8000+100+90+2)\times(40+2)\)

在位數較多的時候,拆分出的數也很多,高精度運算的效率就會下降。

有沒有辦法作出一些優化呢?

注意到拆分數字的方式並不影響最終的結果,因此我們可以將若干個數碼進行合併。

過程

還是以上面這個例子為例,如果我們每兩位拆分一個數,我們可以拆分成 \((8100+92)\times 42\)

這樣的拆分不影響最終結果,但是因為拆分出的數字變少了,計算效率也就提升了。

進位制 的角度理解這一過程,我們通過在較大的進位制(上面每兩位拆分一個數,可以認為是在 \(100\) 進制下進行運算)下進行運算,從而達到減少參與運算的數字的位數,提升運算效率的目的。

這就是 壓位高精度 的思想。

下面我們給出壓位高精度的加法代碼,用於進一步闡述其實現方法:

壓位高精度加法參考實現
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
// 這裏的 a,b,c 數組均為 p 進制下的數
// 最終輸出答案時需要將數字轉為十進制
void add(int a[], int b[], int c[]) {
  clear(c);

  for (int i = 0; i < LEN - 1; ++i) {
    c[i] += a[i] + b[i];
    if (c[i] >= p) {  // 在普通高精度運算下,p=10
      c[i + 1] += 1;
      c[i] -= p;
    }
  }
}

壓位高精下的高效豎式除法

在使用壓位高精時,如果試商時仍然使用上文介紹的方法,由於試商次數會很多,計算常數會非常大。例如在萬進制下,平均每個位需要試商 5000 次,這個巨大的常數是不可接受的。因此我們需要一個更高效的試商辦法。

我們可以把 double 作為媒介。假設被除數有 4 位,是 \(a_4,a_3,a_2,a_1\),除數有 3 位,是 \(b_3,b_2,b_1\),那麼我們只要試一位的商:使用 \(base\) 進制,用式子 \(\dfrac{a_4 base + a_3}{b_3 + b_2 base^{-1} + (b_1+1)base^{-2}}\) 來估商。而對於多個位的情況,就是一位的寫法加個循環。由於除數使用 3 位的精度來參與估商,能保證估的商 q' 與實際商 q 的關係滿足 \(q-1 \le q' \le q\),這樣每個位在最壞的情況下也只需要兩次試商。但與此同時要求 \(base^3\) 在 double 的有效精度內,即 \(base^3 < 2^{53}\),所以在運用這個方法時建議不要超過 32768 進制,否則很容易因精度不足產生誤差從而導致錯誤。

另外,由於估的商總是小於等於實際商,所以還有再進一步優化的空間。絕大多數情況下每個位只估商一次,這樣在下一個位估商時,雖然得到的商有可能因為前一位的誤差造成試商結果大於等於 base,但這沒有關係,只要在最後再最後做統一進位便可。舉個例子,假設 base 是 10,求 \(395081/9876\),試商計算步驟如下:

  1. 首先試商計算得到 \(3950/988=3\),於是 \(395081-(9876 \times 3 \times 10^1) = 98801\),這一步出現了誤差,但不用管,繼續下一步計算。
  2. 對餘數 98801 繼續試商計算得到 \(9880/988=10\),於是 \(98801-(9876 \times 10 \times 10^0) = 41\),這就是最終餘數。
  3. 把試商過程的結果加起來並處理進位,即 \(3 \times 10^1 + 10 \times 10^0 = 40\) 便是準確的商。

方法雖然看着簡單,但具體實現上很容易進坑,所以以下提供一個經過多番驗證確認沒有問題的實現供大家參考,要注意的細節也寫在註釋當中。

壓位高精度高效豎式除法參考實現
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
// 完整模板和實現 https://baobaobear.github.io/post/20210228-bigint1/
// 對b乘以mul再左移offset的結果相減,為除法服務
BigIntSimple &sub_mul(const BigIntSimple &b, int mul, int offset) {
  if (mul == 0) return *this;
  int borrow = 0;
  // 與減法不同的是,borrow可能很大,不能使用減法的寫法
  for (size_t i = 0; i < b.v.size(); ++i) {
    borrow += v[i + offset] - b.v[i] * mul - BIGINT_BASE + 1;
    v[i + offset] = borrow % BIGINT_BASE + BIGINT_BASE - 1;
    borrow /= BIGINT_BASE;
  }
  // 如果還有借位就繼續處理
  for (size_t i = b.v.size(); borrow; ++i) {
    borrow += v[i + offset] - BIGINT_BASE + 1;
    v[i + offset] = borrow % BIGINT_BASE + BIGINT_BASE - 1;
    borrow /= BIGINT_BASE;
  }
  return *this;
}

BigIntSimple div_mod(const BigIntSimple &b, BigIntSimple &r) const {
  BigIntSimple d;
  r = *this;
  if (absless(b)) return d;
  d.v.resize(v.size() - b.v.size() + 1);
  // 提前算好除數的最高三位+1的倒數,若最高三位是a3,a2,a1
  // 那麼db是a3+a2/base+(a1+1)/base^2的倒數,最後用乘法估商的每一位
  // 此法在BIGINT_BASE<=32768時可在int32範圍內用
  // 但即使使用int64,那麼也只有BIGINT_BASE<=131072時可用(受double的精度限制)
  // 能保證估計結果q'與實際結果q的關係滿足q'<=q<=q'+1
  // 所以每一位的試商平均只需要一次,只要後面再統一處理進位即可
  // 如果要使用更大的base,那麼需要更換其它試商方案
  double t = (b.get((unsigned)b.v.size() - 2) +
              (b.get((unsigned)b.v.size() - 3) + 1.0) / BIGINT_BASE);
  double db = 1.0 / (b.v.back() + t / BIGINT_BASE);
  for (size_t i = v.size() - 1, j = d.v.size() - 1; j <= v.size();) {
    int rm = r.get(i + 1) * BIGINT_BASE + r.get(i);
    int m = std::max((int)(db * rm), r.get(i + 1));
    r.sub_mul(b, m, j);
    d.v[j] += m;
    if (!r.get(i + 1))  // 檢查最高位是否已為0,避免極端情況
      --i, --j;
  }
  r.trim();
  // 修正結果的個位
  int carry = 0;
  while (!r.absless(b)) {
    r.subtract(b);
    ++carry;
  }
  // 修正每一位的進位
  for (size_t i = 0; i < d.v.size(); ++i) {
    carry += d.v[i];
    d.v[i] = carry % BIGINT_BASE;
    carry /= BIGINT_BASE;
  }
  d.trim();
  d.sign = sign * b.sign;
  return d;
}

BigIntSimple operator/(const BigIntSimple &b) const {
  BigIntSimple r;
  return div_mod(b, r);
}

BigIntSimple operator%(const BigIntSimple &b) const {
  BigIntSimple r;
  div_mod(b, r);
  return r;
}

Karatsuba 乘法

記高精度數字的位數為 \(n\),那麼高精度—高精度豎式乘法需要花費 \(O(n^2)\) 的時間。本節介紹一個時間複雜度更為優秀的算法,由前蘇聯(俄羅斯)數學家 Anatoly Karatsuba 提出,是一種分治算法。

考慮兩個十進制大整數 \(x\)\(y\),均包含 \(n\) 個數碼(可以有前導零)。任取 \(0 < m < n\),記

\[ \begin{aligned} x &= x_1 \cdot 10^m + x_0, \\ y &= y_1 \cdot 10^m + y_0, \\ x \cdot y &= z_2 \cdot 10^{2m} + z_1 \cdot 10^m + z_0, \end{aligned} \]

其中 \(x_0, y_0, z_0, z_1 < 10^m\)。可得

\[ \begin{aligned} z_2 &= x_1 \cdot y_1, \\ z_1 &= x_1 \cdot y_0 + x_0 \cdot y_1, \\ z_0 &= x_0 \cdot y_0. \end{aligned} \]

觀察知

\[ z_1 = (x_1 + x_0) \cdot (y_1 + y_0) - z_2 - z_0, \]

於是要計算 \(z_1\),只需計算 \((x_1 + x_0) \cdot (y_1 + y_0)\),再與 \(z_0\)\(z_2\) 相減即可。

上式實際上是 Karatsuba 算法的核心,它將長度為 \(n\) 的乘法問題轉化為了 \(3\) 個長度更小的子問題。若令 \(m = \left\lceil \dfrac n 2 \right\rceil\),記 Karatsuba 算法計算兩個 \(n\) 位整數乘法的耗時為 \(T(n)\),則有 \(T(n) = 3 \cdot T \left(\left\lceil \dfrac n 2 \right\rceil\right) + O(n)\),由主定理可得 \(T(n) = \Theta(n^{\log_2 3}) \approx \Theta(n^{1.585})\)

整個過程可以遞歸實現。為清晰起見,下面的代碼通過 Karatsuba 算法實現了多項式乘法,最後再處理所有的進位問題。

karatsuba_mulc.cpp
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
int *karatsuba_polymul(int n, int *a, int *b) {
  if (n <= 32) {
    // 規模較小時直接計算,避免繼續遞歸帶來的效率損失
    int *r = new int[n * 2 + 1]();
    for (int i = 0; i <= n; ++i)
      for (int j = 0; j <= n; ++j) r[i + j] += a[i] * b[j];
    return r;
  }

  int m = n / 2 + 1;
  int *r = new int[m * 4 + 1]();
  int *z0, *z1, *z2;

  z0 = karatsuba_polymul(m - 1, a, b);
  z2 = karatsuba_polymul(n - m, a + m, b + m);

  // 計算 z1
  // 臨時更改,計算完畢後恢復
  for (int i = 0; i + m <= n; ++i) a[i] += a[i + m];
  for (int i = 0; i + m <= n; ++i) b[i] += b[i + m];
  z1 = karatsuba_polymul(m - 1, a, b);
  for (int i = 0; i + m <= n; ++i) a[i] -= a[i + m];
  for (int i = 0; i + m <= n; ++i) b[i] -= b[i + m];
  for (int i = 0; i <= (m - 1) * 2; ++i) z1[i] -= z0[i];
  for (int i = 0; i <= (n - m) * 2; ++i) z1[i] -= z2[i];

  // 由 z0、z1、z2 組合獲得結果
  for (int i = 0; i <= (m - 1) * 2; ++i) r[i] += z0[i];
  for (int i = 0; i <= (m - 1) * 2; ++i) r[i + m] += z1[i];
  for (int i = 0; i <= (n - m) * 2; ++i) r[i + m * 2] += z2[i];

  delete[] z0;
  delete[] z1;
  delete[] z2;
  return r;
}

void karatsuba_mul(int a[], int b[], int c[]) {
  int *r = karatsuba_polymul(LEN - 1, a, b);
  memcpy(c, r, sizeof(int) * LEN);
  for (int i = 0; i < LEN - 1; ++i)
    if (c[i] >= 10) {
      c[i + 1] += c[i] / 10;
      c[i] %= 10;
    }
  delete[] r;
}
關於 newdelete

內存池

但是這樣的實現存在一個問題:在 \(b\) 進制下,多項式的每一個係數都有可能達到 \(n \cdot b^2\) 量級,在壓位高精度實現中可能造成整數溢出;而若在多項式乘法的過程中處理進位問題,則 \(x_1 + x_0\)\(y_1 + y_0\) 的結果可能達到 \(2 \cdot b^m\),增加一個位(如果採用 \(x_1 - x_0\) 的計算方式,則不得不特殊處理負數的情況)。因此,需要依照實際的應用場景來決定採用何種實現方式。

基於多項式的高效大整數乘法

如果數據規模達到了 \(10^{10^5}\) 或更大,普通的高精度乘法可能會超時。本節將介紹用多項式優化此類乘法的方法。

對於一個 \(n\) 位的十進制整數 \(a\),可以將它看作一個每位係數均為整數且不超過 \(10\) 的多項式 \(A=a_{0} 10^0+a_{1} 10^1+\cdots+a_{n-1} 10^{n-1}\)。這樣,我們就將兩個整數乘法轉化為了兩個多項式乘法。

普通的多項式乘法時間複雜度仍是 \(O(n^2)\),但可以用多項式一節中的 快速傅里葉變換快速數論變換 等算法優化,優化後的時間複雜度是 \(O(n\log n)\)

封裝類

這裏 有一個封裝好的高精度整數類,以及 這裏 支持動態長度及四則運算的超迷你實現類。

這裏是另一個模板
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
#define MAXN 9999
// MAXN 是一位中最大的數字
#define MAXSIZE 10024
// MAXSIZE 是位數
#define DLEN 4

// DLEN 記錄壓幾位
struct Big {
  int a[MAXSIZE], len;
  bool flag;  // 標記符號'-'

  Big() {
    len = 1;
    memset(a, 0, sizeof a);
    flag = 0;
  }

  Big(const int);
  Big(const char*);
  Big(const Big&);
  Big& operator=(const Big&);
  Big operator+(const Big&) const;
  Big operator-(const Big&) const;
  Big operator*(const Big&) const;
  Big operator/(const int&) const;
  // TODO: Big / Big;
  Big operator^(const int&) const;
  // TODO: Big ^ Big;

  // TODO: Big 位運算;

  int operator%(const int&) const;
  // TODO: Big ^ Big;
  bool operator<(const Big&) const;
  bool operator<(const int& t) const;
  void print() const;
};

Big::Big(const int b) {
  int c, d = b;
  len = 0;
  // memset(a,0,sizeof a);
  CLR(a);
  while (d > MAXN) {
    c = d - (d / (MAXN + 1) * (MAXN + 1));
    d = d / (MAXN + 1);
    a[len++] = c;
  }
  a[len++] = d;
}

Big::Big(const char* s) {
  int t, k, index, l;
  CLR(a);
  l = strlen(s);
  len = l / DLEN;
  if (l % DLEN) ++len;
  index = 0;
  for (int i = l - 1; i >= 0; i -= DLEN) {
    t = 0;
    k = i - DLEN + 1;
    if (k < 0) k = 0;
    g(j, k, i) t = t * 10 + s[j] - '0';
    a[index++] = t;
  }
}

Big::Big(const Big& T) : len(T.len) {
  CLR(a);
  f(i, 0, len) a[i] = T.a[i];
  // TODO:重載此處?
}

Big& Big::operator=(const Big& T) {
  CLR(a);
  len = T.len;
  f(i, 0, len) a[i] = T.a[i];
  return *this;
}

Big Big::operator+(const Big& T) const {
  Big t(*this);
  int big = len;
  if (T.len > len) big = T.len;
  f(i, 0, big) {
    t.a[i] += T.a[i];
    if (t.a[i] > MAXN) {
      ++t.a[i + 1];
      t.a[i] -= MAXN + 1;
    }
  }
  if (t.a[big])
    t.len = big + 1;
  else
    t.len = big;
  return t;
}

Big Big::operator-(const Big& T) const {
  int big;
  bool ctf;
  Big t1, t2;
  if (*this < T) {
    t1 = T;
    t2 = *this;
    ctf = 1;
  } else {
    t1 = *this;
    t2 = T;
    ctf = 0;
  }
  big = t1.len;
  int j = 0;
  f(i, 0, big) {
    if (t1.a[i] < t2.a[i]) {
      j = i + 1;
      while (t1.a[j] == 0) ++j;
      --t1.a[j--];
      // WTF?
      while (j > i) t1.a[j--] += MAXN;
      t1.a[i] += MAXN + 1 - t2.a[i];
    } else
      t1.a[i] -= t2.a[i];
  }
  t1.len = big;
  while (t1.len > 1 && t1.a[t1.len - 1] == 0) {
    --t1.len;
    --big;
  }
  if (ctf) t1.a[big - 1] = -t1.a[big - 1];
  return t1;
}

Big Big::operator*(const Big& T) const {
  Big res;
  int up;
  int te, tee;
  f(i, 0, len) {
    up = 0;
    f(j, 0, T.len) {
      te = a[i] * T.a[j] + res.a[i + j] + up;
      if (te > MAXN) {
        tee = te - te / (MAXN + 1) * (MAXN + 1);
        up = te / (MAXN + 1);
        res.a[i + j] = tee;
      } else {
        up = 0;
        res.a[i + j] = te;
      }
    }
    if (up) res.a[i + T.len] = up;
  }
  res.len = len + T.len;
  while (res.len > 1 && res.a[res.len - 1] == 0) --res.len;
  return res;
}

Big Big::operator/(const int& b) const {
  Big res;
  int down = 0;
  gd(i, len - 1, 0) {
    res.a[i] = (a[i] + down * (MAXN + 1)) / b;
    down = a[i] + down * (MAXN + 1) - res.a[i] * b;
  }
  res.len = len;
  while (res.len > 1 && res.a[res.len - 1] == 0) --res.len;
  return res;
}

int Big::operator%(const int& b) const {
  int d = 0;
  gd(i, len - 1, 0) d = (d * (MAXN + 1) % b + a[i]) % b;
  return d;
}

Big Big::operator^(const int& n) const {
  Big t(n), res(1);
  int y = n;
  while (y) {
    if (y & 1) res = res * t;
    t = t * t;
    y >>= 1;
  }
  return res;
}

bool Big::operator<(const Big& T) const {
  int ln;
  if (len < T.len) return 233;
  if (len == T.len) {
    ln = len - 1;
    while (ln >= 0 && a[ln] == T.a[ln]) --ln;
    if (ln >= 0 && a[ln] < T.a[ln]) return 233;
    return 0;
  }
  return 0;
}

bool Big::operator<(const int& t) const {
  Big tee(t);
  return *this < tee;
}

void Big::print() const {
  printf("%d", a[len - 1]);
  gd(i, len - 2, 0) { printf("%04d", a[i]); }
}

void print(const Big& s) {
  int len = s.len;
  printf("%d", s.a[len - 1]);
  gd(i, len - 2, 0) { printf("%04d", s.a[i]); }
}

char s[100024];

習題

參考資料與鏈接

  1. Karatsuba algorithm - Wikipedia