Mosh compiler/VM のデバッグを1週間続けた話

Mosh は R7RS に対応しつつあるので ecraven/r7rs-benchmarks というベンチマークを走らせてみた。今回は速度は気にせずにR7RS Scheme 処理系としてコードを間違いなく正しく実行できるかに注力。結局57個中で3つのベンチマークで実行に失敗することが分かった。うち2つは比較的簡単に修正できた。最後の1つが手強かったので記録を残す。

スタート

失敗するのは conform というベンチマーク。期待される結果とは違うものを Mosh が返してしまう。そして実行時間がやけに短い。conform ベンチマークスクリプトr7rs-benchmarks/conform.scm)を見てみるとグラフ構造を作って何かやっているらしい。正直コードを理解しようとするのは3秒で諦めた。

この時点ではデバッグが難航する事は知る由もないのでデバッグ print を入れてなぜ期待と違うかを調べようとするがすぐに挫折。なぜならば

  • A) グラフ構造を扱っているので自分の中で自分を参照していてデバッグ print と相性が悪いこと。write/ss で shared structure を print できるがそれでも視認性が悪い。
  • B) データ構造が大きく print しきれない
  • C) そもそも何をやっているのかコードから読み取れない。

この状態で2-3時間無駄に使った。

心理的安全性

これは難航しそうだとようやく気付いたので少し落ち着くことにした。正しい状態がよく分からないのが問題なので Gauche で実行して比べることとした。次に処理内容が分からないのは良いとして、メインの処理を何となく名前から特定した。ここでようやくメインの処理にデバッグ print を入れて Gauche と比較できるようになり、ある関数で一瞬で間違った値が返っていることが分かった。

間違うポイントが分かったので勝利を確信。その関数への入力を維持したまま再現コードを小さくしていくことにした。ところがこれがかなり難しい。入力も出力もグラフなので文字列や数字を扱うのとは別の難しさがある。色々やっているうちにぐちゃぐちゃになってしまった。元に戻らなくなってしまい大反省。debug という git branch を作り少しずつ進むようにしたら急に捗ったし壊すことも無くなった。チェックポイントあるし壊しても大丈夫という心理的安全性大事。1日かけて小さなコードに絞ることができた。

(define (foo)
  (define (bar  n)
    (cond ((= n 1) #t)
          (else
           (let loop ((lst '(1)))
             (if (null? lst)
                 #t
                 (and 
                  (display "=> recusive1\n")
                  (bar 1)
                  (display "=> loop\n")                         
                  (loop (cdr lst))))))))
    (bar 0)
    )
(foo)

このコードの (bar 1) の呼び出しの後に (display "=> loop\n") が呼ばれないことが分かった。これは明らかにおかしいなぜならば (bar 1) は #t を返すから。

それで誰が悪いのか?

Scheme 処理系を書いたことのある人なら分かると思うが、これは色々怪しい。define の中に define があり named let があり末尾再帰があり。どこでバグっていてもおかしくない。

この時点でバグの原因候補は実行順に

というのを切り替えないといけない。その上この辺りを触るのは10年以上ぶりである!

コンパイラVM を調べる

最適化を OFF にすると再現しなくなったので

  • 最適化そのものが間違っている
  • 最適化で出てくるコードを正しく実行できないパスがある

あたりが怪しい。

pass2 の最適化の途中。1週間のデバッグを終えた今の目で見ればおかしいところは明らか。

($if ($asm NUMBER_EQUAL
       ($lref n[1;0])
       ($const 1))
  ($const #t)
  ($letrec ((loop[0;0] ($lambda[loop;2] (lst[2;0])
                         ($label #0
                           ($if ($asm NULL_P
                                  ($lref lst[2;0]))
                             ($const #t)
                             ($if ($call ($gref display)
                                    ($const "=> recusive1\n"))
                               ($if ($call[tail-rec] ($lref bar[2;0])
                                      ($const 1))
                                 ($if ($call ($gref display)
                                        ($const "=> loop\n"))
                                   ($call[jump] ($call[embed] ($lambda[loop;2] (lst[2;0])
                                                                label#0)
                                                  ($const (1)))
                                     ($asm CDR
                                       ($lref lst[2;0])))
                                   ($it))
                                 ($it))
                               ($it))))))
            )
    ($call[embed] ($lambda[loop;2] (lst[2;0])
                    label#0)
      ($const (1)))))

pass2 の最適化後の iform は以下の通り。ここで時間を使いすぎた。

($define () foo
  ($lambda[foo;0] ()
    ($call[embed 4 0] ($lambda[bar;2] (n[1;0])
                        ($label #0
                          ($if ($asm NUMBER_EQUAL
                                 ($lref n[1;0])
                                 ($const 1))
                            ($const #t)
                            ($call[embed 7 0] ($lambda[loop;2] (lst[2;0])
                                                ($label #1
                                                  ($if ($asm NULL_P
                                                         ($lref lst[2;0]))
                                                    ($const #t)
                                                    ($if ($call ($gref display)
                                                           ($const "=> recusive1\n"))
                                                      ($if ($call[jump 0 0] ($call[embed 4 0] ($lambda[bar;2] (n[1;0])
                                                                                                label#0)
                                                                              ($const 0))
                                                             ($const 1))
                                                        ($if ($call ($gref display)
                                                               ($const "=> loop\n"))
                                                          ($call[jump 0 0] ($call[embed 7 0] ($lambda[loop;2] (lst[2;0])
                                                                                               label#1)
                                                                             ($const (1)))
                                                            ($asm CDR
                                                              ($lref lst[2;0])))
                                                          ($it))
                                                        ($it))
                                                      ($it)))))
                              ($const (1))))))
      ($const 0))))

結局 pass2 の最適化で local call を埋め込む処理あたりで何かがおかしい事は分かるのだが。この iform がおかしいのか。後続の処理がおかしいのか分からないので後続も見る。 実際の VM instruction 列を表示してみると。ますます分からない。

  CLOSURE
  81
  0
  #f
  0
  11
  ((reproduce.scm 1) foo)
  LET_FRAME
  7
  CONSTANT_PUSH
  0
  ENTER  ;; Label #0
  1
  REFER_LOCAL_PUSH_CONSTANT
  0
  1
  BRANCH_NOT_NUMBER_EQUAL ;; if (= n 1)
  5
  CONSTANT
  #t
  LOCAL_JMP ;; goto label #1
  57
  LET_FRAME 5
  REFER_LOCAL_PUSH
  0
  DISPLAY
  1
  CONSTANT_PUSH
  (1)
  ENTER
  1
  REFER_LOCAL_BRANCH_NOT_NULL
  0
  5
  CONSTANT
  #t
  LOCAL_JMP
  38
  FRAME
  6
  CONSTANT_PUSH
  => recusive1
  REFER_GLOBAL_CALL
  display  ;; (display "=> recusive1\n")
  1        ;; # of args
  TEST     ;; (display ...) return value is true. So we skip the +1 next line and go to +2 line.
  29
  CONSTANT_PUSH ;; Come here after (display ...) call
  1
  SHIFTJ ;; adjust SP and FP
  1      ;; depth
  4      ;; diff
  0      ;; How many closure to go up?
  LOCAL_JMP ;; Jump to label #0
  -42
  TEST
  19
  FRAME
  6
  CONSTANT_PUSH
  => loop
  REFER_GLOBAL_CALL
  display
  1
  TEST
  10
  REFER_LOCAL
  0
  CDR_PUSH
  SHIFTJ
  1
  1
  0
  LOCAL_JMP
  -43
  LEAVE
  1
  LEAVE ;; label #2
  1     ;; adjust stack
  RETURN ;; return to the code (the code is taken from the top of stack). ** But we don't have the code in the stack?***
  0
  DEFINE_GLOBAL
  foo
  HALT NOP NOP NOP NOP NOP NOP NOP NOP NOP NOP NOP NOP)

動的に VM で実行される様子を stack と共に。

========================
FRAME
  FP|0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
========================
REFER_GLOBAL
  FP|0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
========================
CALL
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
========================
LET_FRAME                   # LET_FRAME for lambda[foo]
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
  FP|4: 4                   # Save fp
    |5: "#(#(CLOSURE ...))" # Closure
========================
CONSTANT
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
  FP|4: 4
    |5: "#(#(CLOSURE ...))"
========================
PUSH
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
  FP|4: 4
    |5: "#(#(CLOSURE ...))"
    |6: 0                  # Constant 0
========================
ENTER                      # Adjust fp
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
  FP|6: 0
========================
REFER_LOCAL               # a=(Constant 0)
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
  FP|6: 0
========================
PUSH
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
  FP|6: 0                  # Constant 0
    |7: 0                  # Constant 0
========================
CONSTANT                   # a=(Constant 1)
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
  FP|6: 0
    |7: 0
========================
BRANCH_NOT_NUMBER_EQUAL   # a != stack-bottom(Constant 0)
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
  FP|6: 0                  # Discarded stack top.
========================
LET_FRAME                  # LET_FRAME for loop.
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
  FP|6: 0
    |7: 6                   # Save fp
    |8: "#(#(CLOSURE ...))" # Push closure
========================
REFER_LOCAL                 # a=(Constant 0) REALLY???
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
  FP|6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
========================
PUSH
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
  FP|6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
    |9: 0                  # push (Constant 0)
========================
DISPLAY                    # Create a display and set it to closure.
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
  FP|6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"# Note stack is popped.
========================
CONSTANT                   # a=(1)
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
  FP|6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
========================
PUSH
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
  FP|6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
    |9: 1                  # (1)
========================
ENTER
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
    |6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
  FP|9: 1                 # New FP.
========================
REFER_LOCAL               # a=(1)
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
    |6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
  FP|9: 1
========================
BRANCH_NOT_NULL          # Go to else.
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
    |6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
  FP|9: 1
========================
FRAME                     # Push stuff.
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
    |6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
  FP|9: 1
    |10: "(#(#f ...)"      # closure
    |11: 9                 # fp
    |12: 54                # pc + n
    |13: "(#(CLOSURE ...)" # Current codes
========================
CONSTANT                   # a="=> recursive1"
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
    |6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
  FP|9: 1
    |10: "(#(#f ...)"
    |11: 9
    |12: 54
    |13: "(#(CLOSURE ...)"
========================
PUSH
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
    |6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
  FP|9: 1
    |10: "(#(#f ...)"
    |11: 9
    |12: 54
    |13: "(#(CLOSURE ...)"
    |14: "=> recusive1\n"
========================
REFER_GLOBAL               # a=<display>
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
    |6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
  FP|9: 1
    |10: "(#(#f ...)"
    |11: 9
    |12: 54
    |13: "(#(CLOSURE ...)"
    |14: "=> recusive1\n"
========================
CALL                       # call a=<display>. Note codes is now body of display.
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
    |6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
    |9: 1
    |10: "(#(#f ...)"
    |11: 9
    |12: 54
    |13: "(#(CLOSURE ...)"
  FP|14: "=> recusive1\n"
    |15: ()               # display has optional-arg '()
========================
REFER_LOCAL               # a="=> recursive1\n"
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
    |6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
    |9: 1
    |10: "(#(#f ...)"
    |11: 9
    |12: 54
    |13: "(#(CLOSURE ...)"
  FP|14: "=> recusive1\n"
    |15: ()
========================
BRANCH_NOT_NULL           # a is not null so we go to Else. But this is really weird.
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
    |6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
    |9: 1
    |10: "(#(#f ...)"
    |11: 9
    |12: 54
    |13: "(#(CLOSURE ...)"
  FP|14: "=> recusive1\n"
    |15: ()
========================
REFER_LOCAL                # a="=> recursive1"
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
    |6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
    |9: 1
    |10: "(#(#f ...)"
    |11: 9
    |12: 54
    |13: "(#(CLOSURE ...)"
  FP|14: "=> recusive1\n"
    |15: ()
========================
PUSH
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
    |6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
    |9: 1
    |10: "(#(#f ...)"
    |11: 9
    |12: 54
    |13: "(#(CLOSURE ...)"
  FP|14: "=> recusive1\n"
    |15: ()
    |16: "=> recusive1\n"
========================
REFER_FREE 0               # Point codes + 6
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
    |6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
    |9: 1
    |10: "(#(#f ...)"
    |11: 9
    |12: 54
    |13: "(#(CLOSURE ...)"
  FP|14: "=> recusive1\n"
    |15: ()
    |16: "=> recusive1\n"
========================
TAIL_CALL                  # call <display> and jump to #(RETURN 1 ...)
=> recusive1
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
    |6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
    |9: 1
    |10: "(#(#f ...)"
    |11: 9
    |12: 54
    |13: "(#(CLOSURE ...)"
  FP|14: "=> recusive1\n"
========================
RETURN
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
    |6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
  FP|9: 1                   # this is ()
========================
TEST                        # Return value of display is #t
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
    |6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
  FP|9: 1
========================
CONSTANT
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
    |6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
  FP|9: 1
========================
PUSH
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
    |6: 0
    |7: 6
    |8: "#(#(CLOSURE ...))"
  FP|9: 1                    # (1)
    |10: 1                   # 1
========================
SHIFTJ 1 4 0                 # Adjust frame for jump. Stack is now for bar call.
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
  FP|6: 1
========================
LOCAL_JMP
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
  FP|6: 1
========================
REFER_LOCAL                # a=1
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
  FP|6: 1
========================
PUSH
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
  FP|6: 1
    |7: 1
========================
CONSTANT
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
  FP|6: 1
    |7: 1
========================
BRANCH_NOT_NUMBER_EQUAL    # Now 1=1. Jump to else.
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
  FP|6: 1
========================
CONSTANT #t
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
  FP|6: 1
========================
LOCAL_JMP 66
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
    |4: 4
    |5: "#(#(CLOSURE ...))"
  FP|6: 1
========================
LEAVE
    |0: "#(#(FRAME ...))"
    |1: 0
    |2: 6
    |3: "(#(FRAME ...)"
========================
RETURN
========================
HALT

これを脳内で何回も再生し iform と行ったり来たりして実行していくと (bar 1) を実行した後に local_jump したら戻ってこれないのは自明であることに気づく。この lambda 埋め込みは間違いである。それに気づけばあとは簡単で (and の途中で埋め込まれてしまったのは (bar 1) 呼び出しが末尾再帰と誤認されたのが原因。

この1行の修正でお仕事完了。

Scheme で MNIST 続き

Scheme で MNIST からの学び - higepon blog の続き。

f64array

前回のプロトタイプをもとに Mosh のシステムライブラリとして f64array というものを導入した。double 型の2次元の行列を表現したもの。追加した手続きは最低限のもの。

  • make-f64array
  • f64array-ref
  • f64array-set!
  • f64array-shape
  • f64array-dot-product)

f64array の構築、setter/getter、dot product 。プロファイラで見てみると行列の shape を返す手続きも無視できない回数呼ばれている。おそらく行列同士の計算での broadcast で呼ばれているのだろう。

Portable なコードを書く

R7RS では処理系ごとの違いをうまく扱える仕組みが用意されているので GaucheMosh 両方で動くように整えてみよう。 行列の初期化で make-normal-generator を使うのだが Gauche には SRFI-194 がないようなので data.random ライブラリから似たような手続きを持ってきて rename している。ところで SRFI-194 の author は Shiro さんなのであった。

  (cond-expand
   [gauche
      (import (rename (data random) (reals-normal$ make-normal-generator)))]
   [(library (srfi 194))
     (import (only (srfi 194) make-normal-generator))])

また Mosh では .mosh.sld という拡張子のライブラリを優先して読み込むという機能があるので以下のように Mosh が読み込むライブラリ、Mosh 以外が読み込むライブラリを分けることができる。ここで行列の実装を分岐している。

そしてどちらの行列ライブラリにも共通なコードは (include "matrix-common.scm") のように直接ファイルを include する。

というわけでめでたく MoshGauche の両方で動く MNIST デモができました。他のR7RS 処理系でも少しの手直しで動くのではないかと期待している。PR 待ってます。コードは mosh/tests/mnist at master · higepon/mosh · GitHub

Scheme で MNIST からの学び

久しぶりに Mosh Scheme に触っている。良い機会なので ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装 を参考に Scheme で動く MNIST デモを実装しようと思い立つ。MNIST は 0 から 9 までの手書き数字の認識デモ。Neural Network フレームワークのデモやチュートリアルでよく使われるあれだ。なお Scheme には Python でいうところの numpy は存在しないので必要な行列演算はすべて Scheme で実装することとする。

データサイズ

  • train data 画像(28 x 28 x 1) と正解ラベル(0-9)
  • train data: 60000件
  • test data: 10000件

実装前の期待と予想

  • 他の言語・フレームワークでデモ可能だったのだから CPU 上で十分速く動作する。
  • 比較的小さなモデル・train data なので pure Scheme 実装で大丈夫だろう。
  • matrix-mul がボトルネックになるかもしれない。

技術的な選択

  • Matrix は VectorVector として表現する。Uniform Vector (格納する型が決まっているもの)もありだが Mosh では未実装なので見送り。
  • 行列演算の高速実装は特に目指さない。ナイーブな実装で良い。

matrix-mul の実装

特に工夫のないナイーブな実装。3重ループ。

(define (matrix-mul a b)
  (unless (= (matrix-shape a 1) (matrix-shape b 0))
    (error "matrix-mul shapes don't match" (matrix-shape a) (matrix-shape b)))
  (let* ([nrows (matrix-shape a 0)]
         [ncols (matrix-shape b 1)]
         [m     (matrix-shape a 1)]
         [mat (matrix nrows ncols)])
    (define (mul row col)
      (let loop ([k 0]
                 [ret 0])
        (if (= k m)
            ret
            (loop (+ k 1) (+ ret (* (mat-at a row k) (mat-at b k col)))))))
    (do ((i 0 (+ i 1)))
        ((= i nrows) mat)
      (do ((j 0 (+ j 1)))
          ((= j ncols))
        (mat-at mat i j (mul i j))))))

結果1

2層の NN を hidden_size = 10 batch_size=3 で動かしてみたが 1 batch 回したプロファイラの結果。matrix-mul は 32165 回呼ばれ 95%(98 秒)を占めている。

Time%        msec      calls   name                              location
  95        98500      32165   (matrix-mul a b)                  <transcoded-textual-input-port <binary-input-port tests/mnist.scm>>:175
   1         1140      64344   (matrix-map proc a)               <transcoded-textual-input-port <binary-input-port tests/mnist.scm>>:96

結果2

1 は明らかに遅いので C++ 側で実装してみる。98秒->75秒だが思ったよりも速くない。逆に言えば VM ループは思ったよりも速い。

  92        74590      32165   matrix-mul                        
   1          680      64543   (matrix-element-wise op a...)     <transcoded-textual-input-port <binary-input-port tests/mnist.scm>>:186

ここで C++ 側の実装を注意して見る必要がある。途中の計算結果の保持に Scheme Object を使っていること。行列に Fixnum もしくは Flonum が存在するので型チェックをしてくれる Arithmetic::add と mul を利用して計算していること。これだと most inner loop で heap allocation が発生してしまう。

    Object ret = Object::makeVector(numRowsA);
    for (size_t i = 0; i < numRowsA; i++) {
        ret.toVector()->set(i, Object::makeVector(numColsB));
    }

    for (size_t i = 0; i < numRowsA; i++) {
        for (size_t j = 0; j < numColsB; j++) {
            Object value = Object::makeFlonum(0.0);
            for (size_t k = 0; k < numColsA; k++) {
                Object aik = a->ref(i).toVector()->ref(k);
                Object bkj = b->ref(k).toVector()->ref(j);
                value = Arithmetic::add(value, Arithmetic::mul(aik, bkj));
            }
            ret.toVector()->ref(i).toVector()->set(j, value);
        }
    }

結果3

もしも Uniform Vector なら内部データの型チェックは不要になり、すべての計算を double のまま行うことが可能なはず。これを模した適当な実装をやってみる。不正なデータを入力されたらインタプリタが落ちてしまうがデモなのでよしとする。 98秒が60m秒になり実用的な速度に落ち着いた。heap allocation は避けたいと思わせる結果。

    1           60      32165   matrix-mul 
            double value = 0.0;
            for (size_t k = 0; k < numColsA; k++) {
                Object aa = a->ref(i).toVector()->ref(k);
                Object bb = b->ref(k).toVector()->ref(j);
                double aik = aa.isFixnum() ? aa.toFixnum() : aa.toFlonum()->value();
                double bkj = bb.isFixnum() ? bb.toFixnum() : bb.toFlonum()->value();
                value += aik * bkj;
            }
            ret.toVec

学びと改善のアイデア

  • Matrix を Vector で表現すると型チェックのオーバーヘッドがある
  • Matrix の内部表現は double/int の2次元配列とし C++ でアクセサと mul を実装すれば良い。
  • matrix-mul 以外の行列操作は Scheme 側で書いてもほぼ問題なさそう。多少遅くても呼ばれる回数が matrix-mul と比べて極端に少ない。例)add, element-wise multiply, transpose, shape。
  • matrix-mul は SIMD などでもっと速くなるだろう。

余談

  • numpy は歴史が長いだけあってよく出来ている。
  • オブジェクト指向の操作の方がコードが読みやすい場合がある a.t は a の転置。(t a) だと若干意味合いが違うし広すぎる。
  • numpy の broadcasting の仕組みは便利だし必須部品であることがよく分かった。スカラ lr と 行列A の掛け算とか。
  • numpy の subset が SRFI で提案されたら面白いと思う。ただ上記の理由により Scheme だけで実装すると遅いと思うので悩ましい。各処理系が独自に実装することを期待はできなそうな。
  • 現実的には Tensorflow フロントエンドで Scheme が使えれば良いのかもしれないが、それなら Python で十分。
  • Scheme で気軽に機械学習が楽しめるという Path は意外と遠い。Python 1強である理由も分かった。

Mosh + clang-tidy

伝説のプログラマポッドキャストJohn Carmack: Doom, Quake, VR, AGI, Programming, Video Games, and Rockets | Lex Fridman Podcast #309 - YouTubeを聴いていたらデバッガと static analyzer がいかに素晴らしいかを説いていたので clang-tidy を使ってみることにした。 Ubuntu 20.04 aarch64 上で

$ apt install bear clang-tidy
$ cd mosh.git
# clang-tidy が必要とする mosh.git/compile_commands.json を make コマンドの結果から生成
$ bear make
# とりあえずチェックする
$ run-clang-tidy -checks="-*,bugprone-macro-parentheses"
# 修正も実行
$ run-clang-tidy -checks="-*,bugprone-macro-parentheses" -fix

CS25: Transformers United を受講した

Stanford CS 25 Transformers Unitedを受講した。猫も杓子も Transformer という時代なので基礎の復習とキャッチアップを兼ねて。

良かった点

  • Transformer の基礎の復習がきちんとできた
  • self attention について「分かったような分からないような」という理解から「ぎりぎり他人に説明できるかも」くらいの理解になった。
  • LSTM は近くがよく見える。Transformer は遠くまでよく見える。という以前の学びも。そうだよねと腹落ちした。
  • 言語以外の応用エリア(Vision・音声)などが学べた
  • Transformer の登場に関わった人の話が面白かった

イマイチな点

  • 一部の講師・ゲストスピーカーは準備不足 or 説明が退屈だった。オンラインでの講義に慣れていないのもあると思う。
  • 自分が学生で単位を取ろうとしていたら大変だったかも。前提知識もかなり必要だし。話題の幅も広い。

The Rust Programming Language を読んだ - 2022夏休み

〇〇言語の後継XYZ言語がリリース。とニュースがあるたびに「少し触っとくか?まあいいか」を繰り返し。最後に「新しい」言語を学んだのは Swift という状態だったのでRust の本をオンラインで読みできる限りコードを写経した。以下感想。

  • メインで学びたかった Ownership 。コンパイラによる強制部分は「これがコンパイルエラーになるのか面白い」と楽しかった。
  • String と str は分かるけど分からない。String の内部表現が UTF-8 みたいな話は UTF-32 にすればいいの(誰にも支持されない)
  • OOP の話がほんのかなり終盤に出てくるのは言語の思想なんだろうか?
  • cargo は便利。
  • VS Code のサポートが貧弱な気がする。ちゃんと拡張を調べられてないだけかも。
  • VM Loop とかどう書くのだろうか。VM 命令の Pattern Match で十分速度出るのかな?