dep.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. package topo
  2. import (
  3. "fmt"
  4. "strings"
  5. )
  6. type Mapable interface {
  7. Key() string
  8. }
  9. type (
  10. AliasMap[T comparable] map[T]T
  11. NodeSet[T comparable] map[T]bool
  12. DepMap[T comparable] map[T]NodeSet[T]
  13. )
  14. type NodeInfo struct {
  15. Color string
  16. }
  17. type Graph[T comparable] struct {
  18. alias AliasMap[T]
  19. nodes NodeSet[T]
  20. // node info map
  21. nodeInfo map[T]NodeInfo
  22. // `dependencies` tracks child -> parents.
  23. dependencies DepMap[T]
  24. // `dependents` tracks parent -> children.
  25. dependents DepMap[T]
  26. // Keep track of the nodes of the graph themselves.
  27. }
  28. func New[T comparable]() *Graph[T] {
  29. return &Graph[T]{
  30. nodes: make(NodeSet[T]),
  31. dependencies: make(DepMap[T]),
  32. dependents: make(DepMap[T]),
  33. alias: make(AliasMap[T]),
  34. nodeInfo: make(map[T]NodeInfo),
  35. }
  36. }
  37. func (g *Graph[T]) Alias(node, alias T) error {
  38. if alias == node {
  39. return ErrSelfReferential
  40. }
  41. // add node
  42. g.nodes[node] = true
  43. // add alias
  44. if _, ok := g.alias[alias]; ok {
  45. return ErrConflictingAlias
  46. }
  47. g.alias[alias] = node
  48. return nil
  49. }
  50. func (g *Graph[T]) AddNode(node T) {
  51. // check aliases
  52. if aliasNode, ok := g.alias[node]; ok {
  53. node = aliasNode
  54. }
  55. g.nodes[node] = true
  56. }
  57. func (g *Graph[T]) SetNodeInfo(node T, nodeInfo *NodeInfo) {
  58. // check aliases
  59. if aliasNode, ok := g.alias[node]; ok {
  60. node = aliasNode
  61. }
  62. g.nodeInfo[node] = *nodeInfo
  63. }
  64. func (g *Graph[T]) DependOn(child, parent T) error {
  65. if child == parent {
  66. return ErrSelfReferential
  67. }
  68. if g.DependsOn(parent, child) {
  69. return ErrCircular
  70. }
  71. g.AddNode(parent)
  72. g.AddNode(child)
  73. // Add nodes.
  74. g.nodes[parent] = true
  75. g.nodes[child] = true
  76. // Add edges.
  77. g.dependents.addNodeToNodeset(parent, child)
  78. g.dependencies.addNodeToNodeset(child, parent)
  79. return nil
  80. }
  81. func (g *Graph[T]) String() string {
  82. var sb strings.Builder
  83. sb.WriteString("digraph {\n")
  84. sb.WriteString("compound=true;\n")
  85. sb.WriteString("concentrate=true;\n")
  86. sb.WriteString("node [shape = record, ordering=out];\n")
  87. for node := range g.nodes {
  88. extra := ""
  89. if info, ok := g.nodeInfo[node]; ok {
  90. if info.Color != "" {
  91. extra = fmt.Sprintf("[color = %s]", info.Color)
  92. }
  93. }
  94. sb.WriteString(fmt.Sprintf("\t\"%v\"%s;\n", node, extra))
  95. }
  96. for parent, children := range g.dependencies {
  97. for child := range children {
  98. sb.WriteString(fmt.Sprintf("\t\"%v\" -> \"%v\";\n", parent, child))
  99. }
  100. }
  101. sb.WriteString("}")
  102. return sb.String()
  103. }
  104. func (g *Graph[T]) DependsOn(child, parent T) bool {
  105. deps := g.Dependencies(child)
  106. _, ok := deps[parent]
  107. return ok
  108. }
  109. func (g *Graph[T]) HasDependent(parent, child T) bool {
  110. deps := g.Dependents(parent)
  111. _, ok := deps[child]
  112. return ok
  113. }
  114. func (g *Graph[T]) Leaves() []T {
  115. leaves := make([]T, 0)
  116. for node := range g.nodes {
  117. if _, ok := g.dependencies[node]; !ok {
  118. leaves = append(leaves, node)
  119. }
  120. }
  121. return leaves
  122. }
  123. // TopoSortedLayers returns a slice of all of the graph nodes in topological sort order. That is,
  124. // if `B` depends on `A`, then `A` is guaranteed to come before `B` in the sorted output.
  125. // The graph is guaranteed to be cycle-free because cycles are detected while building the
  126. // graph. Additionally, the output is grouped into "layers", which are guaranteed to not have
  127. // any dependencies within each layer. This is useful, e.g. when building an execution plan for
  128. // some DAG, in which case each element within each layer could be executed in parallel. If you
  129. // do not need this layered property, use `Graph.TopoSorted()`, which flattens all elements.
  130. func (g *Graph[T]) TopoSortedLayers() [][]T {
  131. layers := [][]T{}
  132. // Copy the graph
  133. shrinkingGraph := g.clone()
  134. for {
  135. leaves := shrinkingGraph.Leaves()
  136. if len(leaves) == 0 {
  137. break
  138. }
  139. layers = append(layers, leaves)
  140. for _, leafNode := range leaves {
  141. shrinkingGraph.remove(leafNode)
  142. }
  143. }
  144. return layers
  145. }
  146. func (dm DepMap[T]) removeFromDepmap(key, node T) {
  147. if nodes := dm[key]; len(nodes) == 1 {
  148. // The only element in the nodeset must be `node`, so we
  149. // can delete the entry entirely.
  150. delete(dm, key)
  151. } else {
  152. // Otherwise, remove the single node from the nodeset.
  153. delete(nodes, node)
  154. }
  155. }
  156. func (g *Graph[T]) remove(node T) {
  157. // Remove edges from things that depend on `node`.
  158. for dependent := range g.dependents[node] {
  159. g.dependencies.removeFromDepmap(dependent, node)
  160. }
  161. delete(g.dependents, node)
  162. // Remove all edges from node to the things it depends on.
  163. for dependency := range g.dependencies[node] {
  164. g.dependents.removeFromDepmap(dependency, node)
  165. }
  166. delete(g.dependencies, node)
  167. // Finally, remove the node itself.
  168. delete(g.nodes, node)
  169. }
  170. // TopoSorted returns all the nodes in the graph is topological sort order.
  171. // See also `Graph.TopoSortedLayers()`.
  172. func (g *Graph[T]) TopoSorted() []T {
  173. nodeCount := 0
  174. layers := g.TopoSortedLayers()
  175. for _, layer := range layers {
  176. nodeCount += len(layer)
  177. }
  178. allNodes := make([]T, 0, nodeCount)
  179. for _, layer := range layers {
  180. allNodes = append(allNodes, layer...)
  181. }
  182. return allNodes
  183. }
  184. func (g *Graph[T]) Dependencies(child T) NodeSet[T] {
  185. return g.buildTransitive(child, g.immediateDependencies)
  186. }
  187. func (g *Graph[T]) immediateDependencies(node T) NodeSet[T] {
  188. return g.dependencies[node]
  189. }
  190. func (g *Graph[T]) Dependents(parent T) NodeSet[T] {
  191. return g.buildTransitive(parent, g.immediateDependents)
  192. }
  193. func (g *Graph[T]) immediateDependents(node T) NodeSet[T] {
  194. return g.dependents[node]
  195. }
  196. func (g *Graph[T]) clone() *Graph[T] {
  197. return &Graph[T]{
  198. dependencies: g.dependencies.copy(),
  199. dependents: g.dependents.copy(),
  200. nodes: g.nodes.copy(),
  201. }
  202. }
  203. // buildTransitive starts at `root` and continues calling `nextFn` to keep discovering more nodes until
  204. // the graph cannot produce any more. It returns the set of all discovered nodes.
  205. func (g *Graph[T]) buildTransitive(root T, nextFn func(T) NodeSet[T]) NodeSet[T] {
  206. if _, ok := g.nodes[root]; !ok {
  207. return nil
  208. }
  209. out := make(NodeSet[T])
  210. searchNext := []T{root}
  211. for len(searchNext) > 0 {
  212. // List of new nodes from this layer of the dependency graph. This is
  213. // assigned to `searchNext` at the end of the outer "discovery" loop.
  214. discovered := []T{}
  215. for _, node := range searchNext {
  216. // For each node to discover, find the next nodes.
  217. for nextNode := range nextFn(node) {
  218. // If we have not seen the node before, add it to the output as well
  219. // as the list of nodes to traverse in the next iteration.
  220. if _, ok := out[nextNode]; !ok {
  221. out[nextNode] = true
  222. discovered = append(discovered, nextNode)
  223. }
  224. }
  225. }
  226. searchNext = discovered
  227. }
  228. return out
  229. }
  230. func (s NodeSet[T]) copy() NodeSet[T] {
  231. out := make(NodeSet[T], len(s))
  232. for k, v := range s {
  233. out[k] = v
  234. }
  235. return out
  236. }
  237. func (m DepMap[T]) copy() DepMap[T] {
  238. out := make(DepMap[T], len(m))
  239. for k, v := range m {
  240. out[k] = v.copy()
  241. }
  242. return out
  243. }
  244. func (dm DepMap[T]) addNodeToNodeset(key, node T) {
  245. nodes, ok := dm[key]
  246. if !ok {
  247. nodes = make(NodeSet[T])
  248. dm[key] = nodes
  249. }
  250. nodes[node] = true
  251. }