1. 简介

本文基于Neo4j 3.5版本,采用嵌入式的方法开发,neo4j本身其实已经实现了最短路径算法,本文虽然基于neo4j实现,但是更多的是做算法思想的记录,同时本文讲解的最短路径为无权最短路径。

  • 无权最短路径与带权最短路径不同,带权最短路径可能权重最小的路径并不是路径最短的路径。而无权最短路径,仅按路径长短来衡量,所以求最短路径最合适的方法为广度遍历。
  • 一般网上描述的找最短路径的方法为,从起始点开始广度遍历,找到终止点时停止,这个方法并不是性能最高的方法,本文要说明的是从起始点和终止点双向开始进行广度遍历的算法(双向广搜),可以极大提升找最短路径效率。

2.算法介绍

  1. 为什么双向广搜可以提升效率?
    因为每增加一度遍历点数量都是指数型的增长的,其中还有大部分点在重复遍历。从起始点和终止点双向进行广度搜索,降低了遍历的点数量。

度:对起始点执行广度遍历的次数,也可以叫做步长。
举个例子:广度搜索的扩展就好像围绕起始点扩展出一个圆,而两个半径为5的圆,面积小于一个半径为10的圆。

  1. 如何判断是否已经找到最短路径?
    需要维护两个数组,一个数组用于存储被起始点遍历过的点,一个数组存储被终止点遍历过的点。如果终止点遍历到的点已经存在于起始点的数组中,或者相反起始点遍历到的点已经存在于终止点的数组中,那么可以确定已经找到了最短路径。

  2. 算法再优化
    在上面我们已经实现了双向广搜,但是双向广度遍历都是每次遍历一度,在最后找到最短路径碰撞之后算法停止,而碰撞并不一定是在每一度最后的两个点之间发生。在碰撞之后未遍历的点就是无用遍历。
    所以可以在此基础上,创建一个迭代器,每次只迭代遍历上一次遍历结果中的一个点,减少了无用遍历。

3. 执行环境

为了后期还能够重现现在使用的环境,简单做个环境记录:

<dependency>
    <!-- 服务器开发需要的jar包 -->
    <groupId>org.neo4j.driver</groupId>
    <artifactId>neo4j-java-driver</artifactId>
    <version>1.5.0</version>
</dependency>
<dependency>
    <!-- 嵌入式开发需要的jar包 -->
    <groupId>org.neo4j</groupId>
    <artifactId>neo4j</artifactId>
    <version>3.5.13</version>
</dependency>
<dependency>
    <!-- 算法包 -->
    <groupId>org.neo4j</groupId>
    <artifactId>neo4j-graph-algo</artifactId>
    <version>3.5.13</version>
</dependency>

嵌入式开发连接Neo4j数据库的程序实现:

public class EmbeNeo4jSource implements Source<GraphDatabaseService>{
    private GraphDatabaseService graphDb;
    private static EmbeNeo4jSource source;

    private EmbeNeo4jSource(){
	// 实际上连接数据库的代码,就这一行实现
        graphDb = new GraphDatabaseFactory().newEmbeddedDatabase(new File("neo4j/" + Config.GRAPH_NAME));
    }

    public static EmbeNeo4jSource build(){
        if (source == null){
            source = new EmbeNeo4jSource();
        }
        return source;
    }

    @Override
    public GraphDatabaseService getDataBase() {
        return graphDb;
    }

    @Override
    public void close() {
        graphDb.shutdown();
    }
}

Neo4j官方最短路径算法示例:https://neo4j.com/docs/java-reference/current/java-embedded/graph-algorithms/index.html

        Node startNode = tx.createNode();
        Node middleNode1 = tx.createNode();
        Node middleNode2 = tx.createNode();
        Node middleNode3 = tx.createNode();
        Node endNode = tx.createNode();
        createRelationshipsBetween( startNode, middleNode1, endNode );
        createRelationshipsBetween( startNode, middleNode2, middleNode3, endNode );

        // Will find the shortest path between startNode and endNode via
        // "MY_TYPE" relationships (in OUTGOING direction), like f.ex:
        //
        // (startNode)-->(middleNode1)-->(endNode)
        //
        PathFinder<Path> finder = GraphAlgoFactory.shortestPath( new BasicEvaluationContext( tx, graphDb ),
            PathExpanders.forTypeAndDirection( ExampleTypes.MY_TYPE, Direction.OUTGOING ), 15 );
        Iterable<Path> paths = finder.findAllPaths( startNode, endNode );

4. 算法实现

实现最短路径的部分内容,算法核心部分:

public ExecuteResult<Integer> javaapi4(Long startId, Long endId){
    Transaction tx = graphDb.beginTx();
    long startTime = System.nanoTime();
    if (startId <= 0 || startId > maxId || endId <= 0 || endId > maxId){
        long endTime = System.nanoTime();
        return  new ExecuteResult<>(endTime - startTime, 0);
    }
    // 记录起始点、终止点遍历的点,存储遍历度数
    byte[] startids = new byte[(int) (maxId+1)];
    byte[] endids = new byte[(int) (maxId+1)];
    // neo4j的获取点api
    Node startNode = graphDb.getNodeById(startId);
    Node endNode = graphDb.getNodeById(endId);
    // start方向下一步需要遍历的节点
    Collection<Node> sids = new ArrayList<>();
    sids.add(startNode);
    // end方向下一步需要遍历的节点
    Collection<Node> eids = new ArrayList<>();
    eids.add(endNode);
    // 迭代器下一步节点的迭代器,初始为空
    ResourceIterator<Node> sNextRelationships = Iterators.emptyResourceIterator();
    ResourceIterator<Node> eNextRelationships = Iterators.emptyResourceIterator();
    // 度数
    byte sDepth = 0;
    byte eDepth = 0;
    int depth = Integer.MAX_VALUE;
    Node node;
    int id;
    while (true){
        if(!sNextRelationships.hasNext()){
            sNextRelationships.close();
            // 使用自定义迭代MyIterator实现对点迭代器中的点获取边迭代器,从而迭代获取下一度的点
            // 实现了算法原理中的3. 算法再优化
            sNextRelationships = new MyIterator(new ArrayList<>(sids).iterator(),
                   edgeFilter, vertexFilter, javaDire, endNode);
            sids.clear();
            // 迭代取得的内容为空,表示已经不能扩展了,表示没有最短路径
            if (!sNextRelationships.hasNext()){
                break;
            }
            sDepth++;
        }
        node = sNextRelationships.next();
        id = (int) node.getId();
        // 该点还未被遍历过
        if (startids[id] == 0){
            // 将该点添加到下一度需要遍历的点,因为该点还未被遍历过
            sids.add(node);
            startids[id] = sDepth;
            // 该点已经被end方向遍历到过了
            if (endids[id]!=0){
                int d = sDepth + endids[id];
                if (d < depth){
                    depth = d;
                }
                // 因为算法再优化过,end并没有一次扩展一整个度,可能这个点是在end方向当前遍历的度(eDepth)上,也可能是在end上
                // 一次遍历的度(eDepth-1)上。
                // 所以这里还需要再判断,如果这个点小于eDepth,那么这个点一定是最短路径了,可以直接退出,如果等于eDepth还存在可
                // 能在上一次遍历的度上也存在更短的最短路径的情况,所以需要判断一下,不能直接break;
                if (endids[id] < eDepth){
                    break;
                }
            }
        }
        if (!eNextRelationships.hasNext()){
            eNextRelationships.close();
            eNextRelationships = new MyIterator(new ArrayList<>(eids).iterator(), edgeFilter, vertexFilter, java2Dire, startNode);
            eids.clear();
            if (!eNextRelationships.hasNext()){
                break;
            }
            eDepth++;
        }
        node = eNextRelationships.next();
        id = (int) node.getId();
        if (endids[id] == 0){
            eids.add(node);
            endids[id] = eDepth;
            if (startids[id] != 0){
                int d = eDepth + startids[id];
                if (d < depth){
                    depth = d;
                }
                if (startids[id] < sDepth){
                    break;
                }
            }
        }
    }
    long endTime = System.nanoTime();
    tx.success();
    tx.close();
    if (depth == Integer.MAX_VALUE){
        return new ExecuteResult<>(endTime - startTime, 0);
    }
    return new ExecuteResult<>(endTime - startTime, depth);
}

自定义迭代器:
该迭代器模仿Neo4j的NestingResourceIterator迭代器实现,至于为什么不直接用NestingResourceIterator而需要自己写的原因有很多,具体有如下几点:

  1. 为了实现根据点、边进行过滤的条件过滤器功能;
  2. neo4j中使用NestingResourceIterator通过点迭代器,再获取边迭代器,然后从边迭代器中取边返回的操作无法接触到,为了实现返回点而不是返回边,所以需要修改。
package com.cl.study.ShortestPath;

import java.util.Iterator;

import org.neo4j.graphdb.Direction;
import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.Relationship;
import org.neo4j.graphdb.ResourceIterator;
import org.neo4j.helpers.collection.Iterators;
import org.neo4j.helpers.collection.PrefetchingResourceIterator;

public class MyIterator extends PrefetchingResourceIterator<Node>
{
    private final Iterator<Node> source;
    private ResourceIterator<Relationship> currentNestedIterator;
    private Node currentSurfaceItem;
    private Filter edgeFilter;
    private Filter vertexFilter;
    private Direction direction;
    private Node node;

    public MyIterator( Iterator<Node> source, Filter edgeFilter, Filter vertexFilter, Direction direction, Node node)
    {
        this.source = source;
        this.edgeFilter = edgeFilter;
        this.vertexFilter = vertexFilter;
        this.direction = direction;
        this.node = node;
    }

    @Override
    protected Node fetchNextOrNull(){
        while (true){
            Relationship relationship = fetchRelNextOrNull();
            if (relationship == null){
                return null;
            }
            if (edgeFilter== null || edgeFilter.gl(relationship.getAllProperties())){
                Node node = relationship.getOtherNode(currentSurfaceItem);
                if (vertexFilter == null || vertexFilter.gl(node.getAllProperties()) || node == this.node){
                    return node;
                }
            }
        }
    }

    protected Relationship fetchRelNextOrNull()
    {
        if ( currentNestedIterator == null ||
            !currentNestedIterator.hasNext() )
        {
            while ( source.hasNext() )
            {
                currentSurfaceItem = source.next();
                close();
                currentNestedIterator = createNestedIterator( currentSurfaceItem );
                if ( currentNestedIterator.hasNext() )
                {
                    break;
                }
            }
        }
        return currentNestedIterator != null && currentNestedIterator.hasNext() ? currentNestedIterator.next() : null;
    }

    protected ResourceIterator<Relationship> createNestedIterator( Node node )
    {
        //返回一度邻居,asResourceIterator实现了缓存功能
        return Iterators.asResourceIterator(node.getRelationships(direction).iterator());
    }

    @Override
    public void close()
    {
        if ( currentNestedIterator != null )
        {
            currentNestedIterator.close();
        }
    }
}

关于过滤器,我这里只是一个简单的判断实现,不是本文重点,一起附上:

package com.cl.study.ShortestPath;

import java.util.List;
import java.util.Map;

public abstract class Filter {
    private Filter nextFilter;

    public boolean gl(Map<String, Object> map){
        Filter f = this;
        while(f !=null){
            if (!f.doGl(map)){
                return false;
            }
            f = f.nextFilter;
        }
        return true;
    }

    protected abstract boolean doGl(Map<String, Object> map);

    public static Filter buildFilter(List<Map<String, String>> prop){
        if (prop == null){
            return null;
        }
        Filter filter = null;
        for (Map<String, String> map : prop){
            Filter f = null;
            switch (map.get("type")){
                case "gte":{
                    f = buildFilter1(map.get("name"), map.get("value"));
                    break;
                }
                case "gt":{
                    f = buildFilter2(map.get("name"), map.get("value"));
                    break;
                }
                case "eq":{
                    f = buildFilter3(map.get("name"), map.get("value"));
                    break;
                }
                case "lt":{
                    f = buildFilter4(map.get("name"), map.get("value"));
                    break;
                }
                case "lte":{
                    f = buildFilter5(map.get("name"), map.get("value"));
                    break;
                }
                case "ne":{
                    f = buildFilter6(map.get("name"), map.get("value"));
                    break;
                }
            }
            if (f !=null){
                f.nextFilter = filter;
                filter = f;
            }
        }
        return filter;
    }

    public static Filter buildFilter1(String name, Object value){
        return new Filter() {
            @Override
            protected boolean doGl(Map<String, Object> map) {
                return map.get(name).toString().compareTo(value.toString()) >=0;
            }
        };
    }

    public static Filter buildFilter2(String name, Object value){
        return new Filter() {
            @Override
            protected boolean doGl(Map<String, Object> map) {
                return map.get(name).toString().compareTo(value.toString()) >0;
            }
        };
    }

    public static Filter buildFilter3(String name, Object value){
        return new Filter() {
            @Override
            protected boolean doGl(Map<String, Object> map) {
                return map.get(name).toString().compareTo(value.toString()) == 0;
            }
        };
    }

    public static Filter buildFilter4(String name, Object value){
        return new Filter() {
            @Override
            protected boolean doGl(Map<String, Object> map) {
                return map.get(name).toString().compareTo(value.toString()) < 0;
            }
        };
    }

    public static Filter buildFilter5(String name, Object value){
        return new Filter() {
            @Override
            protected boolean doGl(Map<String, Object> map) {
                return map.get(name).toString().compareTo(value.toString()) <= 0;
            }
        };
    }

    public static Filter buildFilter6(String name, Object value){
        return new Filter() {
            @Override
            protected boolean doGl(Map<String, Object> map) {
                return map.get(name).toString().compareTo(value.toString()) != 0;
            }
        };
    }
}