矩阵乘法计算最短路径对
背景介绍
想起这个的原因是在研究半环的时候在网上看到可以通过矩阵乘法可以实现最短路径对的计算。原理和Floyd-Warshall算法差不多,每两个点i j之间寻找是否存在另一个点k使得i到k再到j的距离比i到j的距离短,只是通过矩阵乘法来实现。
矩阵乘法对于矩阵内元素所应用的操作只有两个:乘法和加法,因此元素只需要是一个半环即可。在最短路径对的计算中,乘法和加法分别对应加法以及最小值。
准备工作
首先我们需要准备我们的Scala环境。对于半环,我们没有现成定义,不过半群的话Cats有定义,且有相应测试,因此我们需要添加Cats的依赖。在写作的时候,Scala 3也准备得差不多了,因此我们也顺便用上Scala 3。
build.sbt
:
ThisBuild / scalaVersion := "3.0.0-RC2"
ThisBuild / version := "0.1.0-SNAPSHOT"
ThisBuild / organization := "online.aoxiang"
ThisBuild / organizationName := "天上的八哥"
ThisBuild / organizationHomepage := Some(url("https://aoxiang.online"))
ThisBuild / licenses += "MIT" -> url("https://mit-license.org/")
lazy val root = project
.in(file("."))
.settings(
name := "apsp",
libraryDependencies ++= Seq(
"org.typelevel" %% "cats-core" % "2.6.0",
"org.scalatest" %% "scalatest" % "3.2.7" % "test"
)
)
project/build.properties
:
sbt.version=1.5.2
定义半环
半环即一个定义在加法上的半群和一个在乘法上的半群。半群在Cats中是Semigroup
。如果直接写trait Semiring[A] extends Semigroup[A]
肯定会有问题,因为我们需要两个半群。为此,用Scala 3的新特性,即trait也可以接收参数。于是有了:
src/main/scala/Semiring.scala
:
import cats._
import cats.implicits._
trait Semiring[A](using val addSG: Semigroup[A], val multiplySG: Semigroup[A]) {
def add = addSG.combine
def multiply = multiplySG.combine
}
我们这里让我们的Semiring
依赖两个Semigroup
的实现。当然,其中加法的Semiring
应该符合交换律,但是这个无法在声明中显示出来,只能在测试中定义。
之后我们提供了一部分实现:
src/main/scala/Semiring.scala
(续):
object Semiring {
given normalInt: Semiring[Int] with {
override val addSG = new Semigroup[Int] {
override def combine(a: Int, b: Int): Int = {
a + b
}
}
override val multiplySG = new Semigroup[Int] {
override def combine(a: Int, b: Int): Int = {
a * b
}
}
}
given distanceSemiring: Semiring[Double] with {
override val addSG = new Semigroup[Double] {
override def combine(a: Double, b: Double): Double = {
Math.min(a, b)
}
}
override val multiplySG = new Semigroup[Double] {
override def combine(a: Double, b: Double): Double = {
a + b
}
}
}
}
可以看到normalInt
就是对应整数,而distanceSemiring
则是之后用矩阵乘法计算最短路径对时所用的。
定义矩阵
矩阵的定义就比较简单随意,数据本身采用一个Map
加上长度宽度。乘法用了些Scala 3的新语法使其成为中缀运算符。
src/main/scala/Matrix.scala
:
object Matrix {
extension [A](x: Matrix[A]) def * (y: Matrix[A])(using semiring: Semiring[A]): Matrix[A] = {
if (x.width != y.height) throw IllegalArgumentException(s"Matrix size mismatch: ${x.width} != ${y.height}")
else {
val matrix = for {
i <- 0 until y.width
j <- 0 until x.height
} yield {
val ls = for (k <- 0 until x.width) yield semiring.multiply(x.matrix((k, j)), y.matrix((i, k)))
(i, j) -> ls.reduce(semiring.add)
}
Matrix(matrix.toMap, y.width, x.height)
}
}
}
运用矩阵乘法来计算最短路径对
这里多了一步转换,方便输入的时候直接用List
输入。
src/main/scala/APSP.scala
:
object APSP {
import Semiring.distanceSemiring
def calculate(matrix: List[List[Double]]): Matrix[Double] = {
if (matrix.length == 0 || matrix.head.length != matrix.length)
throw IllegalArgumentException("This is not a distance matrix")
else {
val matrixAux = Matrix(
(for {
i <- 0 until matrix.length
j <- 0 until matrix.length
} yield ((i, j) -> matrix(i)(j))).toMap,
matrix.length,
matrix.length
)
(0 until matrix.length).map(_ => matrixAux).reduce(_ * _)
}
}
}
最后我们就可以在main
函数中调用,来获得结果。同样,这里也用了Scala 3对main函数定义的新语法。
src/main/scala/Main.scala
:
@main def hello: Unit = {
println("Hello world!")
println(APSP.calculate(distanceMatrix))
}
def distanceMatrix = List(
List(0.0, 1, 4, 2),
List(Double.PositiveInfinity, 0, 1, Double.PositiveInfinity),
List(Double.PositiveInfinity, Double.PositiveInfinity, 0, 5),
List(2, Double.PositiveInfinity, Double.PositiveInfinity, 0)
)
最后就可以获得结果了。
总结
这个例子一方面纯粹好玩,另一方面也体现了抽象的力量,使得我们可以用任意矩阵乘法来计算最短路径对。如果有矩阵乘法效率极高的实现,那么我们也可以对其他部分不加更改,直接利用这一实现。如果有对于稀疏矩阵特化的乘法,当然也可以利用。不过,目前的实现效率低下,由于矩阵乘法本身复杂度是O(V^3),这个算法的复杂度是O(V^4),V为节点数量。如果使用重复乘方的算法,则可以将复杂度降到O(V^3*lnV)。