dep.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. package topo
  2. import (
  3. "fmt"
  4. "strings"
  5. )
  6. type Mapable interface {
  7. Key() string
  8. }
  9. type (
  10. AliasMap[T comparable] map[T]T
  11. NodeSet[T comparable] map[T]bool
  12. DepMap[T comparable] map[T]NodeSet[T]
  13. )
  14. type NodeInfo struct {
  15. Color string
  16. }
  17. type Graph[T comparable] struct {
  18. alias AliasMap[T]
  19. nodes NodeSet[T]
  20. // node info map
  21. nodeInfo map[T]NodeInfo
  22. // `dependencies` tracks child -> parents.
  23. dependencies DepMap[T]
  24. // `dependents` tracks parent -> children.
  25. dependents DepMap[T]
  26. // Keep track of the nodes of the graph themselves.
  27. }
  28. func New[T comparable]() *Graph[T] {
  29. return &Graph[T]{
  30. nodes: make(NodeSet[T]),
  31. dependencies: make(DepMap[T]),
  32. dependents: make(DepMap[T]),
  33. alias: make(AliasMap[T]),
  34. nodeInfo: make(map[T]NodeInfo),
  35. }
  36. }
  37. func (g *Graph[T]) Len() int {
  38. return len(g.nodes)
  39. }
  40. func (g *Graph[T]) Exists(node T) bool {
  41. // check aliases
  42. node = g.getAlias(node)
  43. _, ok := g.nodes[node]
  44. return ok
  45. }
  46. func (g *Graph[T]) Alias(node, alias T) error {
  47. if alias == node {
  48. return nil
  49. }
  50. // add node
  51. g.nodes[node] = true
  52. // add alias
  53. if _, ok := g.alias[alias]; ok {
  54. return ErrConflictingAlias
  55. }
  56. g.alias[alias] = node
  57. return nil
  58. }
  59. func (g *Graph[T]) AddNode(node T) {
  60. node = g.getAlias(node)
  61. g.nodes[node] = true
  62. }
  63. func (g *Graph[T]) getAlias(node T) T {
  64. if aliasNode, ok := g.alias[node]; ok {
  65. return aliasNode
  66. }
  67. return node
  68. }
  69. func (g *Graph[T]) SetNodeInfo(node T, nodeInfo *NodeInfo) {
  70. node = g.getAlias(node)
  71. g.nodeInfo[node] = *nodeInfo
  72. }
  73. func (g *Graph[T]) DependOn(child, parent T) error {
  74. child = g.getAlias(child)
  75. parent = g.getAlias(parent)
  76. if child == parent {
  77. return ErrSelfReferential
  78. }
  79. if g.DependsOn(parent, child) {
  80. return ErrCircular
  81. }
  82. g.AddNode(parent)
  83. g.AddNode(child)
  84. // Add nodes.
  85. g.nodes[parent] = true
  86. g.nodes[child] = true
  87. // Add edges.
  88. g.dependents.addNodeToNodeset(parent, child)
  89. g.dependencies.addNodeToNodeset(child, parent)
  90. return nil
  91. }
  92. func (g *Graph[T]) String() string {
  93. var sb strings.Builder
  94. sb.WriteString("digraph {\n")
  95. sb.WriteString("compound=true;\n")
  96. sb.WriteString("concentrate=true;\n")
  97. sb.WriteString("node [shape = record, ordering=out];\n")
  98. for node := range g.nodes {
  99. extra := ""
  100. if info, ok := g.nodeInfo[node]; ok {
  101. if info.Color != "" {
  102. extra = fmt.Sprintf("[color = %s]", info.Color)
  103. }
  104. }
  105. sb.WriteString(fmt.Sprintf("\t\"%v\"%s;\n", node, extra))
  106. }
  107. for parent, children := range g.dependencies {
  108. for child := range children {
  109. sb.WriteString(fmt.Sprintf("\t\"%v\" -> \"%v\";\n", parent, child))
  110. }
  111. }
  112. sb.WriteString("}")
  113. return sb.String()
  114. }
  115. func (g *Graph[T]) DependsOn(child, parent T) bool {
  116. deps := g.Dependencies(child)
  117. _, ok := deps[parent]
  118. return ok
  119. }
  120. func (g *Graph[T]) HasDependent(parent, child T) bool {
  121. deps := g.Dependents(parent)
  122. _, ok := deps[child]
  123. return ok
  124. }
  125. func (g *Graph[T]) Leaves() []T {
  126. leaves := make([]T, 0)
  127. for node := range g.nodes {
  128. if _, ok := g.dependencies[node]; !ok {
  129. leaves = append(leaves, node)
  130. }
  131. }
  132. return leaves
  133. }
  134. // TopoSortedLayers returns a slice of all of the graph nodes in topological sort order. That is,
  135. // if `B` depends on `A`, then `A` is guaranteed to come before `B` in the sorted output.
  136. // The graph is guaranteed to be cycle-free because cycles are detected while building the
  137. // graph. Additionally, the output is grouped into "layers", which are guaranteed to not have
  138. // any dependencies within each layer. This is useful, e.g. when building an execution plan for
  139. // some DAG, in which case each element within each layer could be executed in parallel. If you
  140. // do not need this layered property, use `Graph.TopoSorted()`, which flattens all elements.
  141. func (g *Graph[T]) TopoSortedLayers() [][]T {
  142. layers := [][]T{}
  143. // Copy the graph
  144. shrinkingGraph := g.clone()
  145. for {
  146. leaves := shrinkingGraph.Leaves()
  147. if len(leaves) == 0 {
  148. break
  149. }
  150. layers = append(layers, leaves)
  151. for _, leafNode := range leaves {
  152. shrinkingGraph.remove(leafNode)
  153. }
  154. }
  155. return layers
  156. }
  157. func (dm DepMap[T]) removeFromDepmap(key, node T) {
  158. if nodes := dm[key]; len(nodes) == 1 {
  159. // The only element in the nodeset must be `node`, so we
  160. // can delete the entry entirely.
  161. delete(dm, key)
  162. } else {
  163. // Otherwise, remove the single node from the nodeset.
  164. delete(nodes, node)
  165. }
  166. }
  167. func (g *Graph[T]) remove(node T) {
  168. // Remove edges from things that depend on `node`.
  169. for dependent := range g.dependents[node] {
  170. g.dependencies.removeFromDepmap(dependent, node)
  171. }
  172. delete(g.dependents, node)
  173. // Remove all edges from node to the things it depends on.
  174. for dependency := range g.dependencies[node] {
  175. g.dependents.removeFromDepmap(dependency, node)
  176. }
  177. delete(g.dependencies, node)
  178. // Finally, remove the node itself.
  179. delete(g.nodes, node)
  180. }
  181. // TopoSorted returns all the nodes in the graph is topological sort order.
  182. // See also `Graph.TopoSortedLayers()`.
  183. func (g *Graph[T]) TopoSorted() []T {
  184. nodeCount := 0
  185. layers := g.TopoSortedLayers()
  186. for _, layer := range layers {
  187. nodeCount += len(layer)
  188. }
  189. allNodes := make([]T, 0, nodeCount)
  190. for _, layer := range layers {
  191. allNodes = append(allNodes, layer...)
  192. }
  193. return allNodes
  194. }
  195. func (g *Graph[T]) Dependencies(child T) NodeSet[T] {
  196. return g.buildTransitive(child, g.immediateDependencies)
  197. }
  198. func (g *Graph[T]) immediateDependencies(node T) NodeSet[T] {
  199. return g.dependencies[node]
  200. }
  201. func (g *Graph[T]) Dependents(parent T) NodeSet[T] {
  202. return g.buildTransitive(parent, g.immediateDependents)
  203. }
  204. func (g *Graph[T]) immediateDependents(node T) NodeSet[T] {
  205. return g.dependents[node]
  206. }
  207. func (g *Graph[T]) clone() *Graph[T] {
  208. return &Graph[T]{
  209. dependencies: g.dependencies.copy(),
  210. dependents: g.dependents.copy(),
  211. nodes: g.nodes.copy(),
  212. }
  213. }
  214. // buildTransitive starts at `root` and continues calling `nextFn` to keep discovering more nodes until
  215. // the graph cannot produce any more. It returns the set of all discovered nodes.
  216. func (g *Graph[T]) buildTransitive(root T, nextFn func(T) NodeSet[T]) NodeSet[T] {
  217. if _, ok := g.nodes[root]; !ok {
  218. return nil
  219. }
  220. out := make(NodeSet[T])
  221. searchNext := []T{root}
  222. for len(searchNext) > 0 {
  223. // List of new nodes from this layer of the dependency graph. This is
  224. // assigned to `searchNext` at the end of the outer "discovery" loop.
  225. discovered := []T{}
  226. for _, node := range searchNext {
  227. // For each node to discover, find the next nodes.
  228. for nextNode := range nextFn(node) {
  229. // If we have not seen the node before, add it to the output as well
  230. // as the list of nodes to traverse in the next iteration.
  231. if _, ok := out[nextNode]; !ok {
  232. out[nextNode] = true
  233. discovered = append(discovered, nextNode)
  234. }
  235. }
  236. }
  237. searchNext = discovered
  238. }
  239. return out
  240. }
  241. func (s NodeSet[T]) copy() NodeSet[T] {
  242. out := make(NodeSet[T], len(s))
  243. for k, v := range s {
  244. out[k] = v
  245. }
  246. return out
  247. }
  248. func (m DepMap[T]) copy() DepMap[T] {
  249. out := make(DepMap[T], len(m))
  250. for k, v := range m {
  251. out[k] = v.copy()
  252. }
  253. return out
  254. }
  255. func (dm DepMap[T]) addNodeToNodeset(key, node T) {
  256. nodes, ok := dm[key]
  257. if !ok {
  258. nodes = make(NodeSet[T])
  259. dm[key] = nodes
  260. }
  261. nodes[node] = true
  262. }