dep.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. package topo
  2. import (
  3. "fmt"
  4. "strings"
  5. )
  6. type Mapable interface {
  7. Key() string
  8. }
  9. type (
  10. AliasMap[T comparable] map[T]T
  11. NodeSet[T comparable] map[T]bool
  12. DepMap[T comparable] map[T]NodeSet[T]
  13. )
  14. type Graph[T comparable] struct {
  15. alias AliasMap[T]
  16. nodes NodeSet[T]
  17. // `dependencies` tracks child -> parents.
  18. dependencies DepMap[T]
  19. // `dependents` tracks parent -> children.
  20. dependents DepMap[T]
  21. // Keep track of the nodes of the graph themselves.
  22. }
  23. func New[T comparable]() *Graph[T] {
  24. return &Graph[T]{
  25. nodes: make(NodeSet[T]),
  26. dependencies: make(DepMap[T]),
  27. dependents: make(DepMap[T]),
  28. alias: make(AliasMap[T]),
  29. }
  30. }
  31. func (g *Graph[T]) Alias(node, alias T) error {
  32. if alias == node {
  33. return ErrSelfReferential
  34. }
  35. // add node
  36. g.nodes[node] = true
  37. // add alias
  38. if _, ok := g.alias[alias]; ok {
  39. return ErrConflictingAlias
  40. }
  41. g.alias[alias] = node
  42. return nil
  43. }
  44. func (g *Graph[T]) AddNode(node T) {
  45. // check aliases
  46. if aliasNode, ok := g.alias[node]; ok {
  47. node = aliasNode
  48. }
  49. g.nodes[node] = true
  50. }
  51. func (g *Graph[T]) DependOn(child, parent T) error {
  52. if child == parent {
  53. return ErrSelfReferential
  54. }
  55. if g.DependsOn(parent, child) {
  56. return ErrCircular
  57. }
  58. g.AddNode(parent)
  59. g.AddNode(child)
  60. // Add nodes.
  61. g.nodes[parent] = true
  62. g.nodes[child] = true
  63. // Add edges.
  64. g.dependents.addNodeToNodeset(parent, child)
  65. g.dependencies.addNodeToNodeset(child, parent)
  66. return nil
  67. }
  68. func (g *Graph[T]) String() string {
  69. var sb strings.Builder
  70. sb.WriteString("digraph {\n")
  71. // sb.WriteString("rankdir=LR;\n")
  72. sb.WriteString("node [shape = record, ordering=out];\n")
  73. for node := range g.nodes {
  74. sb.WriteString(fmt.Sprintf("\t\"%v\";\n", node))
  75. }
  76. for parent, children := range g.dependencies {
  77. for child := range children {
  78. sb.WriteString(fmt.Sprintf("\t\"%v\" -> \"%v\";\n", parent, child))
  79. }
  80. }
  81. sb.WriteString("}")
  82. return sb.String()
  83. }
  84. func (g *Graph[T]) DependsOn(child, parent T) bool {
  85. deps := g.Dependencies(child)
  86. _, ok := deps[parent]
  87. return ok
  88. }
  89. func (g *Graph[T]) HasDependent(parent, child T) bool {
  90. deps := g.Dependents(parent)
  91. _, ok := deps[child]
  92. return ok
  93. }
  94. func (g *Graph[T]) Leaves() []T {
  95. leaves := make([]T, 0)
  96. for node := range g.nodes {
  97. if _, ok := g.dependencies[node]; !ok {
  98. leaves = append(leaves, node)
  99. }
  100. }
  101. return leaves
  102. }
  103. // TopoSortedLayers returns a slice of all of the graph nodes in topological sort order. That is,
  104. // if `B` depends on `A`, then `A` is guaranteed to come before `B` in the sorted output.
  105. // The graph is guaranteed to be cycle-free because cycles are detected while building the
  106. // graph. Additionally, the output is grouped into "layers", which are guaranteed to not have
  107. // any dependencies within each layer. This is useful, e.g. when building an execution plan for
  108. // some DAG, in which case each element within each layer could be executed in parallel. If you
  109. // do not need this layered property, use `Graph.TopoSorted()`, which flattens all elements.
  110. func (g *Graph[T]) TopoSortedLayers() [][]T {
  111. layers := [][]T{}
  112. // Copy the graph
  113. shrinkingGraph := g.clone()
  114. for {
  115. leaves := shrinkingGraph.Leaves()
  116. if len(leaves) == 0 {
  117. break
  118. }
  119. layers = append(layers, leaves)
  120. for _, leafNode := range leaves {
  121. shrinkingGraph.remove(leafNode)
  122. }
  123. }
  124. return layers
  125. }
  126. func (dm DepMap[T]) removeFromDepmap(key, node T) {
  127. if nodes := dm[key]; len(nodes) == 1 {
  128. // The only element in the nodeset must be `node`, so we
  129. // can delete the entry entirely.
  130. delete(dm, key)
  131. } else {
  132. // Otherwise, remove the single node from the nodeset.
  133. delete(nodes, node)
  134. }
  135. }
  136. func (g *Graph[T]) remove(node T) {
  137. // Remove edges from things that depend on `node`.
  138. for dependent := range g.dependents[node] {
  139. g.dependencies.removeFromDepmap(dependent, node)
  140. }
  141. delete(g.dependents, node)
  142. // Remove all edges from node to the things it depends on.
  143. for dependency := range g.dependencies[node] {
  144. g.dependents.removeFromDepmap(dependency, node)
  145. }
  146. delete(g.dependencies, node)
  147. // Finally, remove the node itself.
  148. delete(g.nodes, node)
  149. }
  150. // TopoSorted returns all the nodes in the graph is topological sort order.
  151. // See also `Graph.TopoSortedLayers()`.
  152. func (g *Graph[T]) TopoSorted() []T {
  153. nodeCount := 0
  154. layers := g.TopoSortedLayers()
  155. for _, layer := range layers {
  156. nodeCount += len(layer)
  157. }
  158. allNodes := make([]T, 0, nodeCount)
  159. for _, layer := range layers {
  160. allNodes = append(allNodes, layer...)
  161. }
  162. return allNodes
  163. }
  164. func (g *Graph[T]) Dependencies(child T) NodeSet[T] {
  165. return g.buildTransitive(child, g.immediateDependencies)
  166. }
  167. func (g *Graph[T]) immediateDependencies(node T) NodeSet[T] {
  168. return g.dependencies[node]
  169. }
  170. func (g *Graph[T]) Dependents(parent T) NodeSet[T] {
  171. return g.buildTransitive(parent, g.immediateDependents)
  172. }
  173. func (g *Graph[T]) immediateDependents(node T) NodeSet[T] {
  174. return g.dependents[node]
  175. }
  176. func (g *Graph[T]) clone() *Graph[T] {
  177. return &Graph[T]{
  178. dependencies: g.dependencies.copy(),
  179. dependents: g.dependents.copy(),
  180. nodes: g.nodes.copy(),
  181. }
  182. }
  183. // buildTransitive starts at `root` and continues calling `nextFn` to keep discovering more nodes until
  184. // the graph cannot produce any more. It returns the set of all discovered nodes.
  185. func (g *Graph[T]) buildTransitive(root T, nextFn func(T) NodeSet[T]) NodeSet[T] {
  186. if _, ok := g.nodes[root]; !ok {
  187. return nil
  188. }
  189. out := make(NodeSet[T])
  190. searchNext := []T{root}
  191. for len(searchNext) > 0 {
  192. // List of new nodes from this layer of the dependency graph. This is
  193. // assigned to `searchNext` at the end of the outer "discovery" loop.
  194. discovered := []T{}
  195. for _, node := range searchNext {
  196. // For each node to discover, find the next nodes.
  197. for nextNode := range nextFn(node) {
  198. // If we have not seen the node before, add it to the output as well
  199. // as the list of nodes to traverse in the next iteration.
  200. if _, ok := out[nextNode]; !ok {
  201. out[nextNode] = true
  202. discovered = append(discovered, nextNode)
  203. }
  204. }
  205. }
  206. searchNext = discovered
  207. }
  208. return out
  209. }
  210. func (s NodeSet[T]) copy() NodeSet[T] {
  211. out := make(NodeSet[T], len(s))
  212. for k, v := range s {
  213. out[k] = v
  214. }
  215. return out
  216. }
  217. func (m DepMap[T]) copy() DepMap[T] {
  218. out := make(DepMap[T], len(m))
  219. for k, v := range m {
  220. out[k] = v.copy()
  221. }
  222. return out
  223. }
  224. func (dm DepMap[T]) addNodeToNodeset(key, node T) {
  225. nodes, ok := dm[key]
  226. if !ok {
  227. nodes = make(NodeSet[T])
  228. dm[key] = nodes
  229. }
  230. nodes[node] = true
  231. }