1

グラフの連結要素をカウントするスクリプトを作成しようとしましたが、正しい解が得られません。6 つのノード (頂点) を持つ単純なグラフがあり、ノード 1 と 2 が接続され、ノード 3 と 4 が接続されています (6 つの頂点; 1-2,3-4,5,6)。したがって、グラフには 4 つの連結要素が含まれます。次のスクリプトを使用して連結成分を数えますが、間違った結果が得られます (2)。

nodes = [[1, [2], False], [2, [1], False], [3, [4], False], [4, [3], False], [5, [], False], [6, [], False]]
# 6 nodes, every node has an id, list of connected nodes and boolean whether the node has already been visited    

componentsCount = 0

def mark_nodes( list_of_nodes):
    global componentsCount
    componentsCount = 0
    for node in list_of_nodes:
      node[2] = False
      mark_node_auxiliary( node)

def mark_node_auxiliary( node): 
    global componentsCount
    if not node[2] == True: 
      node[2] = True
      for neighbor in node[1]:
        nodes[neighbor - 1][2] = True
        mark_node_auxiliary( nodes[neighbor - 1])
    else:
      unmarkedNodes = []
      for neighbor in node[1]:
        if not nodes[neighbor - 1][2] == True:  # This condition is never met. WHY???
          unmarkedNodes.append( neighbor)
          componentsCount += 1   
      for unmarkedNode in unmarkedNodes:
        mark_node_auxiliary( nodes[unmarkedNode - 1])

def get_connected_components_number( graph):
    result = componentsCount
    mark_nodes( graph)
    for node in nodes:
      if len( node[1]) == 0:      # For every vertex without neighbor...  
        result += 1               # ... increment number of connected components by 1.
    return result

print get_connected_components_number( nodes)

誰でも間違いを見つけるのを手伝ってもらえますか?

4

3 に答える 3

6

ばらばらなデータ構造は、ここで明確なコードを書くのに本当に役立ちます。ウィキペディアを参照してください。

基本的な考え方は、セットをグラフの各ノードに関連付け、エッジごとに 2 つのエンドポイントのセットをマージすることです。2 組xy同じ場合x.find() == y.find()

これは最も単純な実装 (最悪の場合の複雑さが悪い) ですが、ウィキペディアのページに DisjointSet クラスのいくつかの最適化があり、その上に数行のコードを追加することでこれを効率的にしています。明確にするためにそれらを省略しました。

nodes = [[1, [2]], [2, [1]], [3, [4]], [4, [3]], [5, []], [6, []]]

def count_components(nodes):
    sets = {}
    for node in nodes:
      sets[node[0]] = DisjointSet()
    for node in nodes:
        for vtx in node[1]:
            sets[node[0]].union(sets[vtx])
    return len(set(x.find() for x in sets.itervalues()))

class DisjointSet(object):
    def __init__(self):
        self.parent = None

    def find(self):
        if self.parent is None: return self
        return self.parent.find()

    def union(self, other):
        them = other.find()
        us = self.find()
        if them != us:
            us.parent = them

print count_components(nodes)
于 2010-10-24T12:20:30.883 に答える
4

コードを読むよりも書く方が簡単な場合があります。

これをいくつかのテストにかけます。すべての接続が双方向である限り(あなたの例のように)、常に機能すると確信しています。

def recursivelyMark(nodeID, nodes):
    (connections, visited) = nodes[nodeID]
    if visited:
        return
    nodes[nodeID][1] = True
    for connectedNodeID in connections:
        recursivelyMark(connectedNodeID, nodes)

def main():
    nodes = [[[1], False], [[0], False], [[3], False], [[2], False], [[], False], [[], False]]
    componentsCount = 0
    for (nodeID, (connections, visited)) in enumerate(nodes):
        if visited == False:
            componentsCount += 1
            recursivelyMark(nodeID, nodes)
    print(componentsCount)

if __name__ == '__main__':
    main()

配列内の位置が ID であるため、ノード情報から ID を削除したことに注意してください。このプログラムが必要なことを行わない場合はお知らせください。

于 2010-10-23T17:57:35.147 に答える