Ruby でダイクストラ法を使って最短経路木 (Shortest Path Tree) を作る

何か呆然としたので,最短経路木の生成プログラムとか実装してみた.以下は本体.長いけど,上のほうの大部分はちょっと便利のために class を定義している.

#! /usr/bin/ruby
# -*- coding: utf-8 -*-

# root ノードの番号
SINK = 0
# グラフを作るとき,この距離以下にあるノード間に辺を作る
MAXDIST = 1

class Node
  @x
  @y
  @id
  @parent
  @neighbors
  @pdist

  def initialize( id, x, y )
    @id = id
    @x = x
    @y = y
    @neighbors = []
    @pdist = 10000000000 # 本来は Float.MAX などの最大値
    @parent = nil
  end
  
  def x
    @x
  end

  def y
    @y
  end

  def distToSink
    @pdist
  end

  def id
    @id
  end

  def parent
    @parent
  end
  
  def setDist( d )
    @pdist = d
  end

  def setParent( node )
    @parent = node
  end

  def to_s
    @id.to_s + ',' + @x.to_s + ',' + @y.to_s
  end

  # グラフの生成用
  def makeNeighbors( nodes )
    @neighbors = []
    nodes.each do |n|
      next if n == self
      @neighbors.push n if dist( n ) < MAXDIST
    end
  end

  # ユークリッド距離
  def dist( node )
    Math.sqrt( dist2( node ) )
  end

  # 距離の二乗
  def dist2( node )
    ( @x - node.x ) * ( @x - node.x ) + ( @y - node.y ) * ( @y - node.y )
  end

  def neighbors
    @neighbors
  end

  # node 経由で root へ向かうほうが距離が短かければ,そのように経路を変更する
  def checkDistance( node )
    d = node.distToSink + dist( node )
    return false if @pdist < d
    @pdist = d
    @parent = node
    return true
  end

  def dumpEdge
    [@id,@parent.id,@pdist,@x,@y,@parent.x,@parent.y].join(',')
  end
end


# ダイクストラ法で最短経路木を作る
def makeSPT( nodes, top )
  results = []
  # まず点列からグラフを生成する.距離 MAXDIST 以下の距離にある点間に辺を作っている.
  nodes.each do |n|
    n.makeNeighbors( nodes )
  end

  # ルートノードを設定する.ここでは top を設定
  fronts = [ top ]
  top.setDist 0.0
  top.setParent top

  # fronts にある点ひとつとりだして p とする.
  while p = fronts.shift
    # p から辺で結ばれている点列を np として取りだし,以下を実行する.
    p.neighbors.each do |np|
      next if np == p
      results.push p
      # p 経由で np へ行くパスが最短経路なら,パスを書きかえる.
      if np.checkDistance( p )
        next if fronts.include?( np )
        # np が fronts に含まれていなければ追加する.
        fronts.push( np )
      end
    end
  end
  results
end

def main
  nodes = []
  top = nil
  while buf = gets
    # ファイルは csv で,下記の順でデータが並んでいるものとする.
    # id,x,y
    #
    if buf =~ /(\d+),([\d\.]+),([\d\.]+)/
      id = $1.to_i
      x = $2.to_f
      y = $3.to_f
      n = Node.new( id, x, y )
      nodes.push( n )
      top = n if id == SINK
    end
  end
  makeSPT( nodes, top )
  # root に近い順にソート.しなくてもいい
  nodes.sort! do |a,b|
    a.distToSink <=> b.distToSink
  end
  # 結果の出力
  nodes.each do |n|
    next if n.parent == nil
    puts n.dumpEdge
  end
end

main

サンプルのノードデータは以下のプログラムで作ることができる.

#!/usr/bin/ruby

i = 0
MAX = 10
while i < 1000
  x = rand(0) * MAX
  y = rand(0) * MAX
  puts [i,x,y].join(',')
  i += 1
end

このプログラムは以下のような csv な文字列をはく.

0,3.93444139497376,0.620915326566045
1,5.22221115084765,1.60085754393937
2,7.48946555733086,8.86162639586321
3,8.24431346816341,8.25508129049564
4,0.0801466615790514,5.54539220065131

sample.dat とかに保存して,最初のプログラムに食わせてやると,以下のような出力がでる.これが最短経路木になってる.sample_spt.dat とかにでも保存しておく.

0,0,0.0,3.93444139497376,0.620915326566045,3.93444139497376,0.620915326566045
1,517,0.674088846825506,5.22221115084765,1.60085754393937,4.89500921652028,1.7090319060696
2,435,3.42205571287596,7.48946555733086,8.86162639586321,7.7646720207587,8.70853320725571
3,70,3.22829696976259,8.24431346816341,8.25508129049564,7.95272371892281,7.94010352034856
4,905,2.30014526834389,0.0801466615790514,5.54539220065131,0.137751012699807,5.57371478843895

このフォーマットは以下のとおり.

自分のノードid,親ノードのid,rootまでの距離(積算),x,y,親のx,親のy

適当な描画プログラムに x,y から 親のx, 親のy まで線を引いてやれば,最短経路木を可視化できる.とりあえず手元に R (R: The R Project for Statistical Computing)があったので R で表示させる方法と,表示させた結果でも置いておく.上のファイルを ~/sample_spt.dat という名前で保存したとして,R で以下のコマンドをうつ.

ss <- read.csv('~/sample_spt.dat',header=FALSE)
plot(ss$V4,ss$V5)
segments(ss$V4,ss$V5,ss$V6,ss$V7)
points(ss$V4[1],ss$V5[1],col='red')

これで下みたいな図がでる.はず.

赤いとこが root ノード.いかにもそれっぽい感じ.