internal/dag: remove unused code

Since the DAG package was lifted from Terraform, its contents are more
than what we need for now, so this commit cleans-up the package to keep
only the currently needed parts of code.
If we need to support more in the future, we can revert this commit, or
pickup the changes again from Terraform.
This commit is contained in:
Lucas Bajolet 2024-10-29 15:44:53 -04:00 committed by Lucas Bajolet
parent 418ebca7ef
commit 9076c7b24a
10 changed files with 7 additions and 2424 deletions

View file

@ -6,7 +6,6 @@ package dag
import (
"errors"
"fmt"
"sort"
"strings"
"github.com/hashicorp/hcl/v2"
@ -28,105 +27,6 @@ func (g *AcyclicGraph) DirectedGraph() Grapher {
return g
}
// Returns a Set that includes every Vertex yielded by walking down from the
// provided starting Vertex v.
func (g *AcyclicGraph) Ancestors(vs ...Vertex) (Set, error) {
s := make(Set)
memoFunc := func(v Vertex, d int) error {
s.Add(v)
return nil
}
start := make(Set)
for _, v := range vs {
for _, dep := range g.downEdgesNoCopy(v) {
start.Add(dep)
}
}
if err := g.DepthFirstWalk(start, memoFunc); err != nil {
return nil, err
}
return s, nil
}
// Returns a Set that includes every Vertex yielded by walking up from the
// provided starting Vertex v.
func (g *AcyclicGraph) Descendents(vs ...Vertex) (Set, error) {
s := make(Set)
memoFunc := func(v Vertex, d int) error {
s.Add(v)
return nil
}
start := make(Set)
for _, v := range vs {
for _, dep := range g.upEdgesNoCopy(v) {
start.Add(dep)
}
}
if err := g.ReverseDepthFirstWalk(start, memoFunc); err != nil {
return nil, err
}
return s, nil
}
// Root returns the root of the DAG, or an error.
//
// Complexity: O(V)
func (g *AcyclicGraph) Root() (Vertex, error) {
roots := make([]Vertex, 0, 1)
for _, v := range g.Vertices() {
if g.upEdgesNoCopy(v).Len() == 0 {
roots = append(roots, v)
}
}
if len(roots) > 1 {
// TODO(mitchellh): make this error message a lot better
return nil, fmt.Errorf("multiple roots: %#v", roots)
}
if len(roots) == 0 {
return nil, fmt.Errorf("no roots found")
}
return roots[0], nil
}
// TransitiveReduction performs the transitive reduction of graph g in place.
// The transitive reduction of a graph is a graph with as few edges as
// possible with the same reachability as the original graph. This means
// that if there are three nodes A => B => C, and A connects to both
// B and C, and B connects to C, then the transitive reduction is the
// same graph with only a single edge between A and B, and a single edge
// between B and C.
//
// The graph must be free of cycles for this operation to behave properly.
//
// Complexity: O(V(V+E)), or asymptotically O(VE)
func (g *AcyclicGraph) TransitiveReduction() {
// For each vertex u in graph g, do a DFS starting from each vertex
// v such that the edge (u,v) exists (v is a direct descendant of u).
//
// For each v-prime reachable from v, remove the edge (u, v-prime).
for _, u := range g.Vertices() {
uTargets := g.downEdgesNoCopy(u)
_ = g.DepthFirstWalk(g.downEdgesNoCopy(u), func(v Vertex, d int) error {
shared := uTargets.Intersection(g.downEdgesNoCopy(v))
for _, vPrime := range shared {
g.RemoveEdge(BasicEdge(u, vPrime))
}
return nil
})
}
}
// Validate validates the DAG. A DAG is valid if it has no cycles or self-referencing vertex.
func (g *AcyclicGraph) Validate() error {
// Look for cycles of more than 1 component
@ -167,36 +67,14 @@ func (g *AcyclicGraph) Cycles() [][]Vertex {
return cycles
}
// Walk walks the graph, calling your callback as each node is visited.
// This will walk nodes in parallel if it can. The resulting diagnostics
// contains problems from all graphs visited, in no particular order.
func (g *AcyclicGraph) Walk(cb WalkFunc) hcl.Diagnostics {
w := &Walker{Callback: cb, Reverse: true}
w.Update(g)
return w.Wait()
}
type walkType uint64
// simple convenience helper for converting a dag.Set to a []Vertex
func AsVertexList(s Set) []Vertex {
vertexList := make([]Vertex, 0, len(s))
for _, raw := range s {
vertexList = append(vertexList, raw.(Vertex))
}
return vertexList
}
type vertexAtDepth struct {
Vertex Vertex
Depth int
}
// TopologicalOrder returns a topological sort of the given graph, with source
// vertices ordered before the targets of their edges. The nodes are not sorted,
// and any valid order may be returned. This function will panic if it
// encounters a cycle.
func (g *AcyclicGraph) TopologicalOrder() []Vertex {
return g.topoOrder(upOrder)
}
const (
depthFirst walkType = 1 << iota
breadthFirst
downOrder
upOrder
)
// ReverseTopologicalOrder returns a topological sort of the given graph, with
// target vertices ordered before the sources of their edges. The nodes are not
@ -254,127 +132,3 @@ func (g *AcyclicGraph) topoOrder(order walkType) []Vertex {
return sorted
}
type walkType uint64
const (
depthFirst walkType = 1 << iota
breadthFirst
downOrder
upOrder
)
// DepthFirstWalk does a depth-first walk of the graph starting from
// the vertices in start.
func (g *AcyclicGraph) DepthFirstWalk(start Set, f DepthWalkFunc) error {
return g.walk(depthFirst|downOrder, false, start, f)
}
// ReverseDepthFirstWalk does a depth-first walk _up_ the graph starting from
// the vertices in start.
func (g *AcyclicGraph) ReverseDepthFirstWalk(start Set, f DepthWalkFunc) error {
return g.walk(depthFirst|upOrder, false, start, f)
}
// BreadthFirstWalk does a breadth-first walk of the graph starting from
// the vertices in start.
func (g *AcyclicGraph) BreadthFirstWalk(start Set, f DepthWalkFunc) error {
return g.walk(breadthFirst|downOrder, false, start, f)
}
// ReverseBreadthFirstWalk does a breadth-first walk _up_ the graph starting from
// the vertices in start.
func (g *AcyclicGraph) ReverseBreadthFirstWalk(start Set, f DepthWalkFunc) error {
return g.walk(breadthFirst|upOrder, false, start, f)
}
// Setting test to true will walk sets of vertices in sorted order for
// deterministic testing.
func (g *AcyclicGraph) walk(order walkType, test bool, start Set, f DepthWalkFunc) error {
seen := make(map[Vertex]struct{})
frontier := make([]vertexAtDepth, 0, len(start))
for _, v := range start {
frontier = append(frontier, vertexAtDepth{
Vertex: v,
Depth: 0,
})
}
if test {
testSortFrontier(frontier)
}
for len(frontier) > 0 {
// Pop the current vertex
var current vertexAtDepth
switch {
case order&depthFirst != 0:
// depth first, the frontier is used like a stack
n := len(frontier)
current = frontier[n-1]
frontier = frontier[:n-1]
case order&breadthFirst != 0:
// breadth first, the frontier is used like a queue
current = frontier[0]
frontier = frontier[1:]
default:
panic(fmt.Sprint("invalid visit order", order))
}
// Check if we've seen this already and return...
if _, ok := seen[current.Vertex]; ok {
continue
}
seen[current.Vertex] = struct{}{}
// Visit the current node
if err := f(current.Vertex, current.Depth); err != nil {
return err
}
var edges Set
switch {
case order&downOrder != 0:
edges = g.downEdgesNoCopy(current.Vertex)
case order&upOrder != 0:
edges = g.upEdgesNoCopy(current.Vertex)
default:
panic(fmt.Sprint("invalid walk order", order))
}
if test {
frontier = testAppendNextSorted(frontier, edges, current.Depth+1)
} else {
frontier = appendNext(frontier, edges, current.Depth+1)
}
}
return nil
}
func appendNext(frontier []vertexAtDepth, next Set, depth int) []vertexAtDepth {
for _, v := range next {
frontier = append(frontier, vertexAtDepth{
Vertex: v,
Depth: depth,
})
}
return frontier
}
func testAppendNextSorted(frontier []vertexAtDepth, edges Set, depth int) []vertexAtDepth {
var newEdges []vertexAtDepth
for _, v := range edges {
newEdges = append(newEdges, vertexAtDepth{
Vertex: v,
Depth: depth,
})
}
testSortFrontier(newEdges)
return append(frontier, newEdges...)
}
func testSortFrontier(f []vertexAtDepth) {
sort.Slice(f, func(i, j int) bool {
return VertexName(f[i].Vertex) < VertexName(f[j].Vertex)
})
}

View file

@ -5,15 +5,8 @@ package dag
import (
"flag"
"fmt"
"os"
"reflect"
"strconv"
"strings"
"sync"
"testing"
"github.com/hashicorp/hcl/v2"
)
func TestMain(m *testing.M) {
@ -21,172 +14,6 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}
func TestAcyclicGraphRoot(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Connect(BasicEdge(3, 2))
g.Connect(BasicEdge(3, 1))
if root, err := g.Root(); err != nil {
t.Fatalf("err: %s", err)
} else if root != 3 {
t.Fatalf("bad: %#v", root)
}
}
func TestAcyclicGraphRoot_cycle(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Connect(BasicEdge(1, 2))
g.Connect(BasicEdge(2, 3))
g.Connect(BasicEdge(3, 1))
if _, err := g.Root(); err == nil {
t.Fatal("should error")
}
}
func TestAcyclicGraphRoot_multiple(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Connect(BasicEdge(3, 2))
if _, err := g.Root(); err == nil {
t.Fatal("should error")
}
}
func TestAyclicGraphTransReduction(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Connect(BasicEdge(1, 2))
g.Connect(BasicEdge(1, 3))
g.Connect(BasicEdge(2, 3))
g.TransitiveReduction()
actual := strings.TrimSpace(g.String())
expected := strings.TrimSpace(testGraphTransReductionStr)
if actual != expected {
t.Fatalf("bad: %s", actual)
}
}
func TestAyclicGraphTransReduction_more(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Add(4)
g.Connect(BasicEdge(1, 2))
g.Connect(BasicEdge(1, 3))
g.Connect(BasicEdge(1, 4))
g.Connect(BasicEdge(2, 3))
g.Connect(BasicEdge(2, 4))
g.Connect(BasicEdge(3, 4))
g.TransitiveReduction()
actual := strings.TrimSpace(g.String())
expected := strings.TrimSpace(testGraphTransReductionMoreStr)
if actual != expected {
t.Fatalf("bad: %s", actual)
}
}
func TestAyclicGraphTransReduction_multipleRoots(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Add(4)
g.Connect(BasicEdge(1, 2))
g.Connect(BasicEdge(1, 3))
g.Connect(BasicEdge(1, 4))
g.Connect(BasicEdge(2, 3))
g.Connect(BasicEdge(2, 4))
g.Connect(BasicEdge(3, 4))
g.Add(5)
g.Add(6)
g.Add(7)
g.Add(8)
g.Connect(BasicEdge(5, 6))
g.Connect(BasicEdge(5, 7))
g.Connect(BasicEdge(5, 8))
g.Connect(BasicEdge(6, 7))
g.Connect(BasicEdge(6, 8))
g.Connect(BasicEdge(7, 8))
g.TransitiveReduction()
actual := strings.TrimSpace(g.String())
expected := strings.TrimSpace(testGraphTransReductionMultipleRootsStr)
if actual != expected {
t.Fatalf("bad: %s", actual)
}
}
// use this to simulate slow sort operations
type counter struct {
Name string
Calls int64
}
func (s *counter) String() string {
s.Calls++
return s.Name
}
// Make sure we can reduce a sizable, fully-connected graph.
func TestAyclicGraphTransReduction_fullyConnected(t *testing.T) {
var g AcyclicGraph
const nodeCount = 200
nodes := make([]*counter, nodeCount)
for i := 0; i < nodeCount; i++ {
nodes[i] = &counter{Name: strconv.Itoa(i)}
}
// Add them all to the graph
for _, n := range nodes {
g.Add(n)
}
// connect them all
for i := range nodes {
for j := range nodes {
if i == j {
continue
}
g.Connect(BasicEdge(nodes[i], nodes[j]))
}
}
g.TransitiveReduction()
vertexNameCalls := int64(0)
for _, n := range nodes {
vertexNameCalls += n.Calls
}
switch {
case vertexNameCalls > 2*nodeCount:
// Make calling it more the 2x per node fatal.
// If we were sorting this would give us roughly ln(n)(n^3) calls, or
// >59000000 calls for 200 vertices.
t.Fatalf("VertexName called %d times", vertexNameCalls)
case vertexNameCalls > 0:
// we don't expect any calls, but a change here isn't necessarily fatal
t.Logf("WARNING: VertexName called %d times", vertexNameCalls)
}
}
func TestAcyclicGraphValidate(t *testing.T) {
var g AcyclicGraph
g.Add(1)
@ -225,363 +52,3 @@ func TestAcyclicGraphValidate_cycleSelf(t *testing.T) {
t.Fatal("should error")
}
}
func TestAcyclicGraphAncestors(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Add(4)
g.Add(5)
g.Connect(BasicEdge(0, 1))
g.Connect(BasicEdge(1, 2))
g.Connect(BasicEdge(2, 3))
g.Connect(BasicEdge(3, 4))
g.Connect(BasicEdge(4, 5))
actual, err := g.Ancestors(2)
if err != nil {
t.Fatalf("err: %#v", err)
}
expected := []Vertex{3, 4, 5}
if actual.Len() != len(expected) {
t.Fatalf("bad length! expected %#v to have len %d", actual, len(expected))
}
for _, e := range expected {
if !actual.Include(e) {
t.Fatalf("expected: %#v to include: %#v", expected, actual)
}
}
}
func TestAcyclicGraphDescendents(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Add(4)
g.Add(5)
g.Connect(BasicEdge(0, 1))
g.Connect(BasicEdge(1, 2))
g.Connect(BasicEdge(2, 3))
g.Connect(BasicEdge(3, 4))
g.Connect(BasicEdge(4, 5))
actual, err := g.Descendents(2)
if err != nil {
t.Fatalf("err: %#v", err)
}
expected := []Vertex{0, 1}
if actual.Len() != len(expected) {
t.Fatalf("bad length! expected %#v to have len %d", actual, len(expected))
}
for _, e := range expected {
if !actual.Include(e) {
t.Fatalf("expected: %#v to include: %#v", expected, actual)
}
}
}
func TestAcyclicGraphWalk(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Connect(BasicEdge(3, 2))
g.Connect(BasicEdge(3, 1))
var visits []Vertex
var lock sync.Mutex
err := g.Walk(func(v Vertex) hcl.Diagnostics {
lock.Lock()
defer lock.Unlock()
visits = append(visits, v)
return nil
})
if err != nil {
t.Fatalf("err: %s", err)
}
expected := [][]Vertex{
{1, 2, 3},
{2, 1, 3},
}
for _, e := range expected {
if reflect.DeepEqual(visits, e) {
return
}
}
t.Fatalf("bad: %#v", visits)
}
func TestAcyclicGraphWalk_error(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Add(4)
g.Connect(BasicEdge(4, 3))
g.Connect(BasicEdge(3, 2))
g.Connect(BasicEdge(2, 1))
var visits []Vertex
var lock sync.Mutex
err := g.Walk(func(v Vertex) hcl.Diagnostics {
lock.Lock()
defer lock.Unlock()
var diags hcl.Diagnostics
if v == 2 {
diags = diags.Append(&hcl.Diagnostic{
Severity: hcl.DiagError,
Summary: "walk error",
Detail: "simulated error on vertex 2",
})
return diags
}
visits = append(visits, v)
return diags
})
if err == nil {
t.Fatal("should error")
}
expected := []Vertex{1}
if !reflect.DeepEqual(visits, expected) {
t.Errorf("wrong visits\ngot: %#v\nwant: %#v", visits, expected)
}
}
func BenchmarkDAG(b *testing.B) {
for i := 0; i < b.N; i++ {
count := 150
b.StopTimer()
g := &AcyclicGraph{}
// create 4 layers of fully connected nodes
// layer A
for i := 0; i < count; i++ {
g.Add(fmt.Sprintf("A%d", i))
}
// layer B
for i := 0; i < count; i++ {
B := fmt.Sprintf("B%d", i)
g.Add(B)
for j := 0; j < count; j++ {
g.Connect(BasicEdge(B, fmt.Sprintf("A%d", j)))
}
}
// layer C
for i := 0; i < count; i++ {
c := fmt.Sprintf("C%d", i)
g.Add(c)
for j := 0; j < count; j++ {
// connect them to previous layers so we have something that requires reduction
g.Connect(BasicEdge(c, fmt.Sprintf("A%d", j)))
g.Connect(BasicEdge(c, fmt.Sprintf("B%d", j)))
}
}
// layer D
for i := 0; i < count; i++ {
d := fmt.Sprintf("D%d", i)
g.Add(d)
for j := 0; j < count; j++ {
g.Connect(BasicEdge(d, fmt.Sprintf("A%d", j)))
g.Connect(BasicEdge(d, fmt.Sprintf("B%d", j)))
g.Connect(BasicEdge(d, fmt.Sprintf("C%d", j)))
}
}
b.StartTimer()
// Find dependencies for every node
for _, v := range g.Vertices() {
_, err := g.Ancestors(v)
if err != nil {
b.Fatal(err)
}
}
// reduce the final graph
g.TransitiveReduction()
}
}
func TestAcyclicGraphWalkOrder(t *testing.T) {
/* Sample dependency graph,
all edges pointing downwards.
1 2
/ \ / \
3 4 5
/ \ /
6 7
/ | \
8 9 10
\ | /
11
*/
var g AcyclicGraph
for i := 1; i <= 11; i++ {
g.Add(i)
}
g.Connect(BasicEdge(1, 3))
g.Connect(BasicEdge(1, 4))
g.Connect(BasicEdge(2, 4))
g.Connect(BasicEdge(2, 5))
g.Connect(BasicEdge(3, 6))
g.Connect(BasicEdge(4, 7))
g.Connect(BasicEdge(5, 7))
g.Connect(BasicEdge(7, 8))
g.Connect(BasicEdge(7, 9))
g.Connect(BasicEdge(7, 10))
g.Connect(BasicEdge(8, 11))
g.Connect(BasicEdge(9, 11))
g.Connect(BasicEdge(10, 11))
start := make(Set)
start.Add(2)
start.Add(1)
reverse := make(Set)
reverse.Add(11)
reverse.Add(6)
t.Run("DepthFirst", func(t *testing.T) {
var visits []vertexAtDepth
g.walk(depthFirst|downOrder, true, start, func(v Vertex, d int) error {
visits = append(visits, vertexAtDepth{v, d})
return nil
})
expect := []vertexAtDepth{
{2, 0}, {5, 1}, {7, 2}, {9, 3}, {11, 4}, {8, 3}, {10, 3}, {4, 1}, {1, 0}, {3, 1}, {6, 2},
}
if !reflect.DeepEqual(visits, expect) {
t.Errorf("expected visits:\n%v\ngot:\n%v\n", expect, visits)
}
})
t.Run("ReverseDepthFirst", func(t *testing.T) {
var visits []vertexAtDepth
g.walk(depthFirst|upOrder, true, reverse, func(v Vertex, d int) error {
visits = append(visits, vertexAtDepth{v, d})
return nil
})
expect := []vertexAtDepth{
{6, 0}, {3, 1}, {1, 2}, {11, 0}, {9, 1}, {7, 2}, {5, 3}, {2, 4}, {4, 3}, {8, 1}, {10, 1},
}
if !reflect.DeepEqual(visits, expect) {
t.Errorf("expected visits:\n%v\ngot:\n%v\n", expect, visits)
}
})
t.Run("BreadthFirst", func(t *testing.T) {
var visits []vertexAtDepth
g.walk(breadthFirst|downOrder, true, start, func(v Vertex, d int) error {
visits = append(visits, vertexAtDepth{v, d})
return nil
})
expect := []vertexAtDepth{
{1, 0}, {2, 0}, {3, 1}, {4, 1}, {5, 1}, {6, 2}, {7, 2}, {10, 3}, {8, 3}, {9, 3}, {11, 4},
}
if !reflect.DeepEqual(visits, expect) {
t.Errorf("expected visits:\n%v\ngot:\n%v\n", expect, visits)
}
})
t.Run("ReverseBreadthFirst", func(t *testing.T) {
var visits []vertexAtDepth
g.walk(breadthFirst|upOrder, true, reverse, func(v Vertex, d int) error {
visits = append(visits, vertexAtDepth{v, d})
return nil
})
expect := []vertexAtDepth{
{11, 0}, {6, 0}, {10, 1}, {8, 1}, {9, 1}, {3, 1}, {7, 2}, {1, 2}, {4, 3}, {5, 3}, {2, 4},
}
if !reflect.DeepEqual(visits, expect) {
t.Errorf("expected visits:\n%v\ngot:\n%v\n", expect, visits)
}
})
t.Run("TopologicalOrder", func(t *testing.T) {
order := g.topoOrder(downOrder)
// Validate the order by checking it against the initial graph. We only
// need to verify that each node has it's direct dependencies
// satisfied.
completed := map[Vertex]bool{}
for _, v := range order {
deps := g.DownEdges(v)
for _, dep := range deps {
if !completed[dep] {
t.Fatalf("walking node %v, but dependency %v was not yet seen", v, dep)
}
}
completed[v] = true
}
})
t.Run("ReverseTopologicalOrder", func(t *testing.T) {
order := g.topoOrder(upOrder)
// Validate the order by checking it against the initial graph. We only
// need to verify that each node has it's direct dependencies
// satisfied.
completed := map[Vertex]bool{}
for _, v := range order {
deps := g.UpEdges(v)
for _, dep := range deps {
if !completed[dep] {
t.Fatalf("walking node %v, but dependency %v was not yet seen", v, dep)
}
}
completed[v] = true
}
})
}
const testGraphTransReductionStr = `
1
2
2
3
3
`
const testGraphTransReductionMoreStr = `
1
2
2
3
3
4
4
`
const testGraphTransReductionMultipleRootsStr = `
1
2
2
3
3
4
4
5
6
6
7
7
8
8
`

View file

@ -1,285 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dag
import (
"bytes"
"fmt"
"sort"
"strings"
)
// DotOpts are the options for generating a dot formatted Graph.
type DotOpts struct {
// Allows some nodes to decide to only show themselves when the user has
// requested the "verbose" graph.
Verbose bool
// Highlight Cycles
DrawCycles bool
// How many levels to expand modules as we draw
MaxDepth int
// use this to keep the cluster_ naming convention from the previous dot writer
cluster bool
}
// GraphNodeDotter can be implemented by a node to cause it to be included
// in the dot graph. The Dot method will be called which is expected to
// return a representation of this node.
type GraphNodeDotter interface {
// Dot is called to return the dot formatting for the node.
// The first parameter is the title of the node.
// The second parameter includes user-specified options that affect the dot
// graph. See GraphDotOpts below for details.
DotNode(string, *DotOpts) *DotNode
}
// DotNode provides a structure for Vertices to return in order to specify their
// dot format.
type DotNode struct {
Name string
Attrs map[string]string
}
// Returns the DOT representation of this Graph.
func (g *marshalGraph) Dot(opts *DotOpts) []byte {
if opts == nil {
opts = &DotOpts{
DrawCycles: true,
MaxDepth: -1,
Verbose: true,
}
}
var w indentWriter
_, _ = w.WriteString("digraph {\n")
w.Indent()
// some dot defaults
_, _ = w.WriteString(`compound = "true"` + "\n")
_, _ = w.WriteString(`newrank = "true"` + "\n")
// the top level graph is written as the first subgraph
_, _ = w.WriteString(`subgraph "root" {` + "\n")
g.writeBody(opts, &w)
// cluster isn't really used other than for naming purposes in some graphs
opts.cluster = opts.MaxDepth != 0
maxDepth := opts.MaxDepth
if maxDepth == 0 {
maxDepth = -1
}
for _, s := range g.Subgraphs {
g.writeSubgraph(s, opts, maxDepth, &w)
}
w.Unindent()
_, _ = w.WriteString("}\n")
return w.Bytes()
}
func (v *marshalVertex) dot(g *marshalGraph, opts *DotOpts) []byte {
var buf bytes.Buffer
graphName := g.Name
if graphName == "" {
graphName = "root"
}
name := v.Name
attrs := v.Attrs
if v.graphNodeDotter != nil {
node := v.graphNodeDotter.DotNode(name, opts)
if node == nil {
return []byte{}
}
newAttrs := make(map[string]string)
for k, v := range attrs {
newAttrs[k] = v
}
for k, v := range node.Attrs {
newAttrs[k] = v
}
name = node.Name
attrs = newAttrs
}
buf.WriteString(fmt.Sprintf(`"[%s] %s"`, graphName, name))
writeAttrs(&buf, attrs)
buf.WriteByte('\n')
return buf.Bytes()
}
func (e *marshalEdge) dot(g *marshalGraph) string {
var buf bytes.Buffer
graphName := g.Name
if graphName == "" {
graphName = "root"
}
sourceName := g.vertexByID(e.Source).Name
targetName := g.vertexByID(e.Target).Name
s := fmt.Sprintf(`"[%s] %s" -> "[%s] %s"`, graphName, sourceName, graphName, targetName)
buf.WriteString(s)
writeAttrs(&buf, e.Attrs)
return buf.String()
}
func cycleDot(e *marshalEdge, g *marshalGraph) string {
return e.dot(g) + ` [color = "red", penwidth = "2.0"]`
}
// Write the subgraph body. The is recursive, and the depth argument is used to
// record the current depth of iteration.
func (g *marshalGraph) writeSubgraph(sg *marshalGraph, opts *DotOpts, depth int, w *indentWriter) {
if depth == 0 {
return
}
depth--
name := sg.Name
if opts.cluster {
// we prefix with cluster_ to match the old dot output
name = "cluster_" + name
sg.Attrs["label"] = sg.Name
}
_, _ = w.WriteString(fmt.Sprintf("subgraph %q {\n", name))
sg.writeBody(opts, w)
for _, sg := range sg.Subgraphs {
g.writeSubgraph(sg, opts, depth, w)
}
}
func (g *marshalGraph) writeBody(opts *DotOpts, w *indentWriter) {
w.Indent()
for _, as := range attrStrings(g.Attrs) {
_, _ = w.WriteString(as + "\n")
}
// list of Vertices that aren't to be included in the dot output
skip := map[string]bool{}
for _, v := range g.Vertices {
if v.graphNodeDotter == nil {
skip[v.ID] = true
continue
}
_, _ = w.Write(v.dot(g, opts))
}
var dotEdges []string
if opts.DrawCycles {
for _, c := range g.Cycles {
if len(c) < 2 {
continue
}
for i, j := 0, 1; i < len(c); i, j = i+1, j+1 {
if j >= len(c) {
j = 0
}
src := c[i]
tgt := c[j]
if skip[src.ID] || skip[tgt.ID] {
continue
}
e := &marshalEdge{
Name: fmt.Sprintf("%s|%s", src.Name, tgt.Name),
Source: src.ID,
Target: tgt.ID,
Attrs: make(map[string]string),
}
dotEdges = append(dotEdges, cycleDot(e, g))
src = tgt
}
}
}
for _, e := range g.Edges {
dotEdges = append(dotEdges, e.dot(g))
}
// srot these again to match the old output
sort.Strings(dotEdges)
for _, e := range dotEdges {
_, _ = w.WriteString(e + "\n")
}
w.Unindent()
_, _ = w.WriteString("}\n")
}
func writeAttrs(buf *bytes.Buffer, attrs map[string]string) {
if len(attrs) > 0 {
buf.WriteString(" [")
buf.WriteString(strings.Join(attrStrings(attrs), ", "))
buf.WriteString("]")
}
}
func attrStrings(attrs map[string]string) []string {
strings := make([]string, 0, len(attrs))
for k, v := range attrs {
strings = append(strings, fmt.Sprintf("%s = %q", k, v))
}
sort.Strings(strings)
return strings
}
// Provide a bytes.Buffer like structure, which will indent when starting a
// newline.
type indentWriter struct {
bytes.Buffer
level int
}
func (w *indentWriter) indent() {
newline := []byte("\n")
if !bytes.HasSuffix(w.Bytes(), newline) {
return
}
for i := 0; i < w.level; i++ {
w.Buffer.WriteString("\t")
}
}
// Indent increases indentation by 1
func (w *indentWriter) Indent() { w.level++ }
// Unindent decreases indentation by 1
func (w *indentWriter) Unindent() { w.level-- }
// the following methods intercecpt the byte.Buffer writes and insert the
// indentation when starting a new line.
func (w *indentWriter) Write(b []byte) (int, error) {
w.indent()
return w.Buffer.Write(b)
}
func (w *indentWriter) WriteString(s string) (int, error) {
w.indent()
return w.Buffer.WriteString(s)
}
func (w *indentWriter) WriteByte(b byte) error {
w.indent()
return w.Buffer.WriteByte(b)
}
func (w *indentWriter) WriteRune(r rune) (int, error) {
w.indent()
return w.Buffer.WriteRune(r)
}

View file

@ -1,42 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dag
import (
"reflect"
"testing"
)
func TestGraphDot_opts(t *testing.T) {
var v testDotVertex
var g Graph
g.Add(&v)
opts := &DotOpts{MaxDepth: 42}
actual := g.Dot(opts)
if len(actual) == 0 {
t.Fatal("should not be empty")
}
if !v.DotNodeCalled {
t.Fatal("should call DotNode")
}
if !reflect.DeepEqual(v.DotNodeOpts, opts) {
t.Fatalf("bad; %#v", v.DotNodeOpts)
}
}
type testDotVertex struct {
DotNodeCalled bool
DotNodeTitle string
DotNodeOpts *DotOpts
DotNodeReturn *DotNode
}
func (v *testDotVertex) DotNode(title string, opts *DotOpts) *DotNode {
v.DotNodeCalled = true
v.DotNodeTitle = title
v.DotNodeOpts = opts
return v.DotNodeReturn
}

View file

@ -107,80 +107,6 @@ func (g *Graph) Add(v Vertex) Vertex {
return v
}
// Remove removes a vertex from the graph. This will also remove any
// edges with this vertex as a source or target.
func (g *Graph) Remove(v Vertex) Vertex {
// Delete the vertex itself
g.vertices.Delete(v)
// Delete the edges to non-existent things
for _, target := range g.downEdgesNoCopy(v) {
g.RemoveEdge(BasicEdge(v, target))
}
for _, source := range g.upEdgesNoCopy(v) {
g.RemoveEdge(BasicEdge(source, v))
}
return nil
}
// Replace replaces the original Vertex with replacement. If the original
// does not exist within the graph, then false is returned. Otherwise, true
// is returned.
func (g *Graph) Replace(original, replacement Vertex) bool {
// If we don't have the original, we can't do anything
if !g.vertices.Include(original) {
return false
}
// If they're the same, then don't do anything
if original == replacement {
return true
}
// Add our new vertex, then copy all the edges
g.Add(replacement)
for _, target := range g.downEdgesNoCopy(original) {
g.Connect(BasicEdge(replacement, target))
}
for _, source := range g.upEdgesNoCopy(original) {
g.Connect(BasicEdge(source, replacement))
}
// Remove our old vertex, which will also remove all the edges
g.Remove(original)
return true
}
// RemoveEdge removes an edge from the graph.
func (g *Graph) RemoveEdge(edge Edge) {
g.init()
// Delete the edge from the set
g.edges.Delete(edge)
// Delete the up/down edges
if s, ok := g.downEdges[hashcode(edge.Source())]; ok {
s.Delete(edge.Target())
}
if s, ok := g.upEdges[hashcode(edge.Target())]; ok {
s.Delete(edge.Source())
}
}
// UpEdges returns the vertices that are *sources* of edges that target the
// destination Vertex v.
func (g *Graph) UpEdges(v Vertex) Set {
return g.upEdgesNoCopy(v).Copy()
}
// DownEdges returns the vertices that are *targets* of edges that originate
// from the source Vertex v.
func (g *Graph) DownEdges(v Vertex) Set {
return g.downEdgesNoCopy(v).Copy()
}
// downEdgesNoCopy returns the vertices targeted by edges from the source Vertex
// v as a Set. This Set is the same as used internally by the Graph to prevent a
// copy, and must not be modified by the caller.
@ -234,70 +160,6 @@ func (g *Graph) Connect(edge Edge) {
s.Add(source)
}
// Subsume imports all of the nodes and edges from the given graph into the
// reciever, leaving the given graph unchanged.
//
// If any of the nodes in the given graph are already present in the reciever
// then the existing node will be retained and any new edges from the given
// graph will be connected with it.
//
// If the given graph has edges in common with the reciever then they will be
// ignored, because each pair of nodes can only be connected once.
func (g *Graph) Subsume(other *Graph) {
// We're using Set.Filter just as a "visit each element" here, so we're
// not doing anything with the result (which will always be empty).
other.vertices.Filter(func(i interface{}) bool {
g.Add(i)
return false
})
other.edges.Filter(func(i interface{}) bool {
g.Connect(i.(Edge))
return false
})
}
// String outputs some human-friendly output for the graph structure.
func (g *Graph) StringWithNodeTypes() string {
var buf bytes.Buffer
// Build the list of node names and a mapping so that we can more
// easily alphabetize the output to remain deterministic.
vertices := g.Vertices()
names := make([]string, 0, len(vertices))
mapping := make(map[string]Vertex, len(vertices))
for _, v := range vertices {
name := VertexName(v)
names = append(names, name)
mapping[name] = v
}
sort.Strings(names)
// Write each node in order...
for _, name := range names {
v := mapping[name]
targets := g.downEdges[hashcode(v)]
buf.WriteString(fmt.Sprintf("%s - %T\n", name, v))
// Alphabetize dependencies
deps := make([]string, 0, targets.Len())
targetNodes := make(map[string]Vertex)
for _, target := range targets {
dep := VertexName(target)
deps = append(deps, dep)
targetNodes[dep] = target
}
sort.Strings(deps)
// Write dependencies
for _, d := range deps {
buf.WriteString(fmt.Sprintf(" %s - %T\n", d, targetNodes[d]))
}
}
return buf.String()
}
// String outputs some human-friendly output for the graph structure.
func (g *Graph) String() string {
var buf bytes.Buffer
@ -352,11 +214,6 @@ func (g *Graph) init() {
}
}
// Dot returns a dot-formatted representation of the Graph.
func (g *Graph) Dot(opts *DotOpts) []byte {
return newMarshalGraph("", g).Dot(opts)
}
// VertexName returns the name of a vertex.
func VertexName(raw Vertex) string {
switch v := raw.(type) {

View file

@ -36,53 +36,6 @@ func TestGraph_basic(t *testing.T) {
}
}
func TestGraph_remove(t *testing.T) {
var g Graph
g.Add(1)
g.Add(2)
g.Add(3)
g.Connect(BasicEdge(1, 3))
g.Remove(3)
actual := strings.TrimSpace(g.String())
expected := strings.TrimSpace(testGraphRemoveStr)
if actual != expected {
t.Fatalf("bad: %s", actual)
}
}
func TestGraph_replace(t *testing.T) {
var g Graph
g.Add(1)
g.Add(2)
g.Add(3)
g.Connect(BasicEdge(1, 2))
g.Connect(BasicEdge(2, 3))
g.Replace(2, 42)
actual := strings.TrimSpace(g.String())
expected := strings.TrimSpace(testGraphReplaceStr)
if actual != expected {
t.Fatalf("bad: %s", actual)
}
}
func TestGraph_replaceSelf(t *testing.T) {
var g Graph
g.Add(1)
g.Add(2)
g.Add(3)
g.Connect(BasicEdge(1, 2))
g.Connect(BasicEdge(2, 3))
g.Replace(2, 2)
actual := strings.TrimSpace(g.String())
expected := strings.TrimSpace(testGraphReplaceSelfStr)
if actual != expected {
t.Fatalf("bad: %s", actual)
}
}
// This tests that connecting edges works based on custom Hashcode
// implementations for uniqueness.
func TestGraph_hashcode(t *testing.T) {
@ -173,42 +126,6 @@ func TestGraphEdgesTo(t *testing.T) {
}
}
func TestGraphUpdownEdges(t *testing.T) {
// Verify that we can't inadvertently modify the internal graph sets
var g Graph
g.Add(1)
g.Add(2)
g.Add(3)
g.Connect(BasicEdge(1, 2))
g.Connect(BasicEdge(2, 3))
up := g.UpEdges(2)
if up.Len() != 1 || !up.Include(1) {
t.Fatalf("expected only an up edge of '1', got %#v", up)
}
// modify the up set
up.Add(9)
orig := g.UpEdges(2)
diff := up.Difference(orig)
if diff.Len() != 1 || !diff.Include(9) {
t.Fatalf("expected a diff of only '9', got %#v", diff)
}
down := g.DownEdges(2)
if down.Len() != 1 || !down.Include(3) {
t.Fatalf("expected only a down edge of '3', got %#v", down)
}
// modify the down set
down.Add(8)
orig = g.DownEdges(2)
diff = down.Difference(orig)
if diff.Len() != 1 || !diff.Include(8) {
t.Fatalf("expected a diff of only '8', got %#v", diff)
}
}
type hashVertex struct {
code interface{}
}
@ -233,24 +150,3 @@ const testGraphEmptyStr = `
2
3
`
const testGraphRemoveStr = `
1
2
`
const testGraphReplaceStr = `
1
42
3
42
3
`
const testGraphReplaceSelfStr = `
1
2
2
3
3
`

View file

@ -1,200 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dag
import (
"fmt"
"reflect"
"sort"
"strconv"
)
// the marshal* structs are for serialization of the graph data.
type marshalGraph struct {
// Type is always "Graph", for identification as a top level object in the
// JSON stream.
Type string
// Each marshal structure requires a unique ID so that it can be referenced
// by other structures.
ID string `json:",omitempty"`
// Human readable name for this graph.
Name string `json:",omitempty"`
// Arbitrary attributes that can be added to the output.
Attrs map[string]string `json:",omitempty"`
// List of graph vertices, sorted by ID.
Vertices []*marshalVertex `json:",omitempty"`
// List of edges, sorted by Source ID.
Edges []*marshalEdge `json:",omitempty"`
// Any number of subgraphs. A subgraph itself is considered a vertex, and
// may be referenced by either end of an edge.
Subgraphs []*marshalGraph `json:",omitempty"`
// Any lists of vertices that are included in cycles.
Cycles [][]*marshalVertex `json:",omitempty"`
}
func (g *marshalGraph) vertexByID(id string) *marshalVertex {
for _, v := range g.Vertices {
if id == v.ID {
return v
}
}
return nil
}
type marshalVertex struct {
// Unique ID, used to reference this vertex from other structures.
ID string
// Human readable name
Name string `json:",omitempty"`
Attrs map[string]string `json:",omitempty"`
// This is to help transition from the old Dot interfaces. We record if the
// node was a GraphNodeDotter here, so we can call it to get attributes.
graphNodeDotter GraphNodeDotter
}
func newMarshalVertex(v Vertex) *marshalVertex {
dn, ok := v.(GraphNodeDotter)
if !ok {
dn = nil
}
// the name will be quoted again later, so we need to ensure it's properly
// escaped without quotes.
name := strconv.Quote(VertexName(v))
name = name[1 : len(name)-1]
return &marshalVertex{
ID: marshalVertexID(v),
Name: name,
Attrs: make(map[string]string),
graphNodeDotter: dn,
}
}
// vertices is a sort.Interface implementation for sorting vertices by ID
type vertices []*marshalVertex
func (v vertices) Less(i, j int) bool { return v[i].Name < v[j].Name }
func (v vertices) Len() int { return len(v) }
func (v vertices) Swap(i, j int) { v[i], v[j] = v[j], v[i] }
type marshalEdge struct {
// Human readable name
Name string
// Source and Target Vertices by ID
Source string
Target string
Attrs map[string]string `json:",omitempty"`
}
func newMarshalEdge(e Edge) *marshalEdge {
return &marshalEdge{
Name: fmt.Sprintf("%s|%s", VertexName(e.Source()), VertexName(e.Target())),
Source: marshalVertexID(e.Source()),
Target: marshalVertexID(e.Target()),
Attrs: make(map[string]string),
}
}
// edges is a sort.Interface implementation for sorting edges by Source ID
type edges []*marshalEdge
func (e edges) Less(i, j int) bool { return e[i].Name < e[j].Name }
func (e edges) Len() int { return len(e) }
func (e edges) Swap(i, j int) { e[i], e[j] = e[j], e[i] }
// build a marshalGraph structure from a *Graph
func newMarshalGraph(name string, g *Graph) *marshalGraph {
mg := &marshalGraph{
Type: "Graph",
Name: name,
Attrs: make(map[string]string),
}
for _, v := range g.Vertices() {
id := marshalVertexID(v)
if sg, ok := marshalSubgrapher(v); ok {
smg := newMarshalGraph(VertexName(v), sg)
smg.ID = id
mg.Subgraphs = append(mg.Subgraphs, smg)
}
mv := newMarshalVertex(v)
mg.Vertices = append(mg.Vertices, mv)
}
sort.Sort(vertices(mg.Vertices))
for _, e := range g.Edges() {
mg.Edges = append(mg.Edges, newMarshalEdge(e))
}
sort.Sort(edges(mg.Edges))
for _, c := range (&AcyclicGraph{*g}).Cycles() {
var cycle []*marshalVertex
for _, v := range c {
mv := newMarshalVertex(v)
cycle = append(cycle, mv)
}
mg.Cycles = append(mg.Cycles, cycle)
}
return mg
}
// Attempt to return a unique ID for any vertex.
func marshalVertexID(v Vertex) string {
val := reflect.ValueOf(v)
switch val.Kind() {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer:
return strconv.Itoa(int(val.Pointer()))
case reflect.Interface:
// A vertex shouldn't contain another layer of interface, but handle
// this just in case.
return fmt.Sprintf("%#v", val.Interface())
}
if v, ok := v.(Hashable); ok {
h := v.Hashcode()
if h, ok := h.(string); ok {
return h
}
}
// fallback to a name, which we hope is unique.
return VertexName(v)
// we could try harder by attempting to read the arbitrary value from the
// interface, but we shouldn't get here from terraform right now.
}
// check for a Subgrapher, and return the underlying *Graph.
func marshalSubgrapher(v Vertex) (*Graph, bool) {
sg, ok := v.(Subgrapher)
if !ok {
return nil, false
}
switch g := sg.Subgraph().DirectedGraph().(type) {
case *Graph:
return g, true
case *AcyclicGraph:
return &g.Graph, true
}
return nil, false
}

View file

@ -1,104 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dag
import (
"strings"
"testing"
)
func TestGraphDot_empty(t *testing.T) {
var g Graph
g.Add(1)
g.Add(2)
g.Add(3)
actual := strings.TrimSpace(string(g.Dot(nil)))
expected := strings.TrimSpace(testGraphDotEmptyStr)
if actual != expected {
t.Fatalf("bad: %s", actual)
}
}
func TestGraphDot_basic(t *testing.T) {
var g Graph
g.Add(1)
g.Add(2)
g.Add(3)
g.Connect(BasicEdge(1, 3))
actual := strings.TrimSpace(string(g.Dot(nil)))
expected := strings.TrimSpace(testGraphDotBasicStr)
if actual != expected {
t.Fatalf("bad: %s", actual)
}
}
func TestGraphDot_quoted(t *testing.T) {
var g Graph
quoted := `name["with-quotes"]`
other := `other`
g.Add(quoted)
g.Add(other)
g.Connect(BasicEdge(quoted, other))
actual := strings.TrimSpace(string(g.Dot(nil)))
expected := strings.TrimSpace(testGraphDotQuotedStr)
if actual != expected {
t.Fatalf("\ngot: %q\nwanted %q\n", actual, expected)
}
}
func TestGraphDot_attrs(t *testing.T) {
var g Graph
g.Add(&testGraphNodeDotter{
Result: &DotNode{
Name: "foo",
Attrs: map[string]string{"foo": "bar"},
},
})
actual := strings.TrimSpace(string(g.Dot(nil)))
expected := strings.TrimSpace(testGraphDotAttrsStr)
if actual != expected {
t.Fatalf("bad: %s", actual)
}
}
type testGraphNodeDotter struct{ Result *DotNode }
func (n *testGraphNodeDotter) Name() string { return n.Result.Name }
func (n *testGraphNodeDotter) DotNode(string, *DotOpts) *DotNode { return n.Result }
const testGraphDotQuotedStr = `digraph {
compound = "true"
newrank = "true"
subgraph "root" {
"[root] name[\"with-quotes\"]" -> "[root] other"
}
}`
const testGraphDotBasicStr = `digraph {
compound = "true"
newrank = "true"
subgraph "root" {
"[root] 1" -> "[root] 3"
}
}
`
const testGraphDotEmptyStr = `digraph {
compound = "true"
newrank = "true"
subgraph "root" {
}
}`
const testGraphDotAttrsStr = `digraph {
compound = "true"
newrank = "true"
subgraph "root" {
"[root] foo" [foo = "bar"]
}
}`

View file

@ -1,453 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dag
import (
"log"
"sync"
"time"
"github.com/hashicorp/hcl/v2"
)
// Walker is used to walk every vertex of a graph in parallel.
//
// A vertex will only be walked when the dependencies of that vertex have
// been walked. If two vertices can be walked at the same time, they will be.
//
// Update can be called to update the graph. This can be called even during
// a walk, changing vertices/edges mid-walk. This should be done carefully.
// If a vertex is removed but has already been executed, the result of that
// execution (any error) is still returned by Wait. Changing or re-adding
// a vertex that has already executed has no effect. Changing edges of
// a vertex that has already executed has no effect.
//
// Non-parallelism can be enforced by introducing a lock in your callback
// function. However, the goroutine overhead of a walk will remain.
// Walker will create V*2 goroutines (one for each vertex, and dependency
// waiter for each vertex). In general this should be of no concern unless
// there are a huge number of vertices.
//
// The walk is depth first by default. This can be changed with the Reverse
// option.
//
// A single walker is only valid for one graph walk. After the walk is complete
// you must construct a new walker to walk again. State for the walk is never
// deleted in case vertices or edges are changed.
type Walker struct {
// Callback is what is called for each vertex
Callback WalkFunc
// Reverse, if true, causes the source of an edge to depend on a target.
// When false (default), the target depends on the source.
Reverse bool
// changeLock must be held to modify any of the fields below. Only Update
// should modify these fields. Modifying them outside of Update can cause
// serious problems.
changeLock sync.Mutex
vertices Set
edges Set
vertexMap map[Vertex]*walkerVertex
// wait is done when all vertices have executed. It may become "undone"
// if new vertices are added.
wait sync.WaitGroup
// diagsMap contains the diagnostics recorded so far for execution,
// and upstreamFailed contains all the vertices whose problems were
// caused by upstream failures, and thus whose diagnostics should be
// excluded from the final set.
//
// Readers and writers of either map must hold diagsLock.
diagsMap map[Vertex]hcl.Diagnostics
upstreamFailed map[Vertex]struct{}
diagsLock sync.Mutex
}
func (w *Walker) init() {
if w.vertices == nil {
w.vertices = make(Set)
}
if w.edges == nil {
w.edges = make(Set)
}
}
type walkerVertex struct {
// These should only be set once on initialization and never written again.
// They are not protected by a lock since they don't need to be since
// they are write-once.
// DoneCh is closed when this vertex has completed execution, regardless
// of success.
//
// CancelCh is closed when the vertex should cancel execution. If execution
// is already complete (DoneCh is closed), this has no effect. Otherwise,
// execution is cancelled as quickly as possible.
DoneCh chan struct{}
CancelCh chan struct{}
// Dependency information. Any changes to any of these fields requires
// holding DepsLock.
//
// DepsCh is sent a single value that denotes whether the upstream deps
// were successful (no errors). Any value sent means that the upstream
// dependencies are complete. No other values will ever be sent again.
//
// DepsUpdateCh is closed when there is a new DepsCh set.
DepsCh chan bool
DepsUpdateCh chan struct{}
DepsLock sync.Mutex
// Below is not safe to read/write in parallel. This behavior is
// enforced by changes only happening in Update. Nothing else should
// ever modify these.
deps map[Vertex]chan struct{}
depsCancelCh chan struct{}
}
// Wait waits for the completion of the walk and returns diagnostics describing
// any problems that arose. Update should be called to populate the walk with
// vertices and edges prior to calling this.
//
// Wait will return as soon as all currently known vertices are complete.
// If you plan on calling Update with more vertices in the future, you
// should not call Wait until after this is done.
func (w *Walker) Wait() hcl.Diagnostics {
// Wait for completion
w.wait.Wait()
var diags hcl.Diagnostics
w.diagsLock.Lock()
for v, vDiags := range w.diagsMap {
if _, upstream := w.upstreamFailed[v]; upstream {
// Ignore diagnostics for nodes that had failed upstreams, since
// the downstream diagnostics are likely to be redundant.
continue
}
diags = diags.Extend(vDiags)
}
w.diagsLock.Unlock()
return diags
}
// Update updates the currently executing walk with the given graph.
// This will perform a diff of the vertices and edges and update the walker.
// Already completed vertices remain completed (including any errors during
// their execution).
//
// This returns immediately once the walker is updated; it does not wait
// for completion of the walk.
//
// Multiple Updates can be called in parallel. Update can be called at any
// time during a walk.
func (w *Walker) Update(g *AcyclicGraph) {
w.init()
v := make(Set)
e := make(Set)
if g != nil {
v, e = g.vertices, g.edges
}
// Grab the change lock so no more updates happen but also so that
// no new vertices are executed during this time since we may be
// removing them.
w.changeLock.Lock()
defer w.changeLock.Unlock()
// Initialize fields
if w.vertexMap == nil {
w.vertexMap = make(map[Vertex]*walkerVertex)
}
// Calculate all our sets
newEdges := e.Difference(w.edges)
oldEdges := w.edges.Difference(e)
newVerts := v.Difference(w.vertices)
oldVerts := w.vertices.Difference(v)
// Add the new vertices
for _, raw := range newVerts {
v := raw.(Vertex)
// Add to the waitgroup so our walk is not done until everything finishes
w.wait.Add(1)
// Add to our own set so we know about it already
w.vertices.Add(raw)
// Initialize the vertex info
info := &walkerVertex{
DoneCh: make(chan struct{}),
CancelCh: make(chan struct{}),
deps: make(map[Vertex]chan struct{}),
}
// Add it to the map and kick off the walk
w.vertexMap[v] = info
}
// Remove the old vertices
for _, raw := range oldVerts {
v := raw.(Vertex)
// Get the vertex info so we can cancel it
info, ok := w.vertexMap[v]
if !ok {
// This vertex for some reason was never in our map. This
// shouldn't be possible.
continue
}
// Cancel the vertex
close(info.CancelCh)
// Delete it out of the map
delete(w.vertexMap, v)
w.vertices.Delete(raw)
}
// Add the new edges
changedDeps := make(Set)
for _, raw := range newEdges {
edge := raw.(Edge)
waiter, dep := w.edgeParts(edge)
// Get the info for the waiter
waiterInfo, ok := w.vertexMap[waiter]
if !ok {
// Vertex doesn't exist... shouldn't be possible but ignore.
continue
}
// Get the info for the dep
depInfo, ok := w.vertexMap[dep]
if !ok {
// Vertex doesn't exist... shouldn't be possible but ignore.
continue
}
// Add the dependency to our waiter
waiterInfo.deps[dep] = depInfo.DoneCh
// Record that the deps changed for this waiter
changedDeps.Add(waiter)
w.edges.Add(raw)
}
// Process removed edges
for _, raw := range oldEdges {
edge := raw.(Edge)
waiter, dep := w.edgeParts(edge)
// Get the info for the waiter
waiterInfo, ok := w.vertexMap[waiter]
if !ok {
// Vertex doesn't exist... shouldn't be possible but ignore.
continue
}
// Delete the dependency from the waiter
delete(waiterInfo.deps, dep)
// Record that the deps changed for this waiter
changedDeps.Add(waiter)
w.edges.Delete(raw)
}
// For each vertex with changed dependencies, we need to kick off
// a new waiter and notify the vertex of the changes.
for _, raw := range changedDeps {
v := raw.(Vertex)
info, ok := w.vertexMap[v]
if !ok {
// Vertex doesn't exist... shouldn't be possible but ignore.
continue
}
// Create a new done channel
doneCh := make(chan bool, 1)
// Create the channel we close for cancellation
cancelCh := make(chan struct{})
// Build a new deps copy
deps := make(map[Vertex]<-chan struct{})
for k, v := range info.deps {
deps[k] = v
}
// Update the update channel
info.DepsLock.Lock()
if info.DepsUpdateCh != nil {
close(info.DepsUpdateCh)
}
info.DepsCh = doneCh
info.DepsUpdateCh = make(chan struct{})
info.DepsLock.Unlock()
// Cancel the older waiter
if info.depsCancelCh != nil {
close(info.depsCancelCh)
}
info.depsCancelCh = cancelCh
// Start the waiter
go w.waitDeps(v, deps, doneCh, cancelCh)
}
// Start all the new vertices. We do this at the end so that all
// the edge waiters and changes are set up above.
for _, raw := range newVerts {
v := raw.(Vertex)
go w.walkVertex(v, w.vertexMap[v])
}
}
// edgeParts returns the waiter and the dependency, in that order.
// The waiter is waiting on the dependency.
func (w *Walker) edgeParts(e Edge) (Vertex, Vertex) {
if w.Reverse {
return e.Source(), e.Target()
}
return e.Target(), e.Source()
}
// walkVertex walks a single vertex, waiting for any dependencies before
// executing the callback.
func (w *Walker) walkVertex(v Vertex, info *walkerVertex) {
// When we're done executing, lower the waitgroup count
defer w.wait.Done()
// When we're done, always close our done channel
defer close(info.DoneCh)
// Wait for our dependencies. We create a [closed] deps channel so
// that we can immediately fall through to load our actual DepsCh.
var depsSuccess bool
var depsUpdateCh chan struct{}
depsCh := make(chan bool, 1)
depsCh <- true
close(depsCh)
for {
select {
case <-info.CancelCh:
// Cancel
return
case depsSuccess = <-depsCh:
// Deps complete! Mark as nil to trigger completion handling.
depsCh = nil
case <-depsUpdateCh:
// New deps, reloop
}
// Check if we have updated dependencies. This can happen if the
// dependencies were satisfied exactly prior to an Update occurring.
// In that case, we'd like to take into account new dependencies
// if possible.
info.DepsLock.Lock()
if info.DepsCh != nil {
depsCh = info.DepsCh
info.DepsCh = nil
}
if info.DepsUpdateCh != nil {
depsUpdateCh = info.DepsUpdateCh
}
info.DepsLock.Unlock()
// If we still have no deps channel set, then we're done!
if depsCh == nil {
break
}
}
// If we passed dependencies, we just want to check once more that
// we're not cancelled, since this can happen just as dependencies pass.
select {
case <-info.CancelCh:
// Cancelled during an update while dependencies completed.
return
default:
}
// Run our callback or note that our upstream failed
var diags hcl.Diagnostics
var upstreamFailed bool
if depsSuccess {
diags = w.Callback(v)
} else {
log.Printf("[TRACE] dag/walk: upstream of %q errored, so skipping", VertexName(v))
// This won't be displayed to the user because we'll set upstreamFailed,
// but we need to ensure there's at least one error in here so that
// the failures will cascade downstream.
diags = diags.Append(&hcl.Diagnostic{
Severity: hcl.DiagError,
Summary: "Upstream dependencies failed",
})
upstreamFailed = true
}
// Record the result (we must do this after execution because we mustn't
// hold diagsLock while visiting a vertex.)
w.diagsLock.Lock()
if w.diagsMap == nil {
w.diagsMap = make(map[Vertex]hcl.Diagnostics)
}
w.diagsMap[v] = diags
if w.upstreamFailed == nil {
w.upstreamFailed = make(map[Vertex]struct{})
}
if upstreamFailed {
w.upstreamFailed[v] = struct{}{}
}
w.diagsLock.Unlock()
}
func (w *Walker) waitDeps(
v Vertex,
deps map[Vertex]<-chan struct{},
doneCh chan<- bool,
cancelCh <-chan struct{}) {
// For each dependency given to us, wait for it to complete
for dep, depCh := range deps {
DepSatisfied:
for {
select {
case <-depCh:
// Dependency satisfied!
break DepSatisfied
case <-cancelCh:
// Wait cancelled. Note that we didn't satisfy dependencies
// so that anything waiting on us also doesn't run.
doneCh <- false
return
case <-time.After(time.Second * 5):
log.Printf("[TRACE] dag/walk: vertex %q is waiting for %q",
VertexName(v), VertexName(dep))
}
}
}
// Dependencies satisfied! We need to check if any errored
w.diagsLock.Lock()
defer w.diagsLock.Unlock()
for dep := range deps {
if w.diagsMap[dep].HasErrors() {
// One of our dependencies failed, so return false
doneCh <- false
return
}
}
// All dependencies satisfied and successful
doneCh <- true
}

View file

@ -1,307 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dag
import (
"reflect"
"sync"
"testing"
"time"
"github.com/hashicorp/hcl/v2"
)
func TestWalker_basic(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Connect(BasicEdge(1, 2))
// Run it a bunch of times since it is timing dependent
for i := 0; i < 50; i++ {
var order []interface{}
w := &Walker{Callback: walkCbRecord(&order)}
w.Update(&g)
// Wait
if err := w.Wait(); err != nil {
t.Fatalf("err: %s", err)
}
// Check
expected := []interface{}{1, 2}
if !reflect.DeepEqual(order, expected) {
t.Errorf("wrong order\ngot: %#v\nwant: %#v", order, expected)
}
}
}
func TestWalker_updateNilGraph(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Connect(BasicEdge(1, 2))
// Run it a bunch of times since it is timing dependent
for i := 0; i < 50; i++ {
var order []interface{}
w := &Walker{Callback: walkCbRecord(&order)}
w.Update(&g)
w.Update(nil)
// Wait
if err := w.Wait(); err != nil {
t.Fatalf("err: %s", err)
}
}
}
func TestWalker_error(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Add(4)
g.Connect(BasicEdge(1, 2))
g.Connect(BasicEdge(2, 3))
g.Connect(BasicEdge(3, 4))
// Record function
var order []interface{}
recordF := walkCbRecord(&order)
// Build a callback that delays until we close a channel
cb := func(v Vertex) hcl.Diagnostics {
if v == 2 {
var diags hcl.Diagnostics
diags = diags.Append(&hcl.Diagnostic{
Severity: hcl.DiagError,
Summary: "simulated error",
})
return diags
}
return recordF(v)
}
w := &Walker{Callback: cb}
w.Update(&g)
// Wait
if err := w.Wait(); err == nil {
t.Fatal("expect error")
}
// Check
expected := []interface{}{1}
if !reflect.DeepEqual(order, expected) {
t.Errorf("wrong order\ngot: %#v\nwant: %#v", order, expected)
}
}
func TestWalker_newVertex(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Connect(BasicEdge(1, 2))
// Record function
var order []interface{}
recordF := walkCbRecord(&order)
done2 := make(chan int)
// Build a callback that notifies us when 2 has been walked
var w *Walker
cb := func(v Vertex) hcl.Diagnostics {
if v == 2 {
defer close(done2)
}
return recordF(v)
}
// Add the initial vertices
w = &Walker{Callback: cb}
w.Update(&g)
// if 2 has been visited, the walk is complete so far
<-done2
// Update the graph
g.Add(3)
w.Update(&g)
// Update the graph again but with the same vertex
g.Add(3)
w.Update(&g)
// Wait
if err := w.Wait(); err != nil {
t.Fatalf("err: %s", err)
}
// Check
expected := []interface{}{1, 2, 3}
if !reflect.DeepEqual(order, expected) {
t.Errorf("wrong order\ngot: %#v\nwant: %#v", order, expected)
}
}
func TestWalker_removeVertex(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Connect(BasicEdge(1, 2))
// Record function
var order []interface{}
recordF := walkCbRecord(&order)
var w *Walker
cb := func(v Vertex) hcl.Diagnostics {
if v == 1 {
g.Remove(2)
w.Update(&g)
}
return recordF(v)
}
// Add the initial vertices
w = &Walker{Callback: cb}
w.Update(&g)
// Wait
if err := w.Wait(); err != nil {
t.Fatalf("err: %s", err)
}
// Check
expected := []interface{}{1}
if !reflect.DeepEqual(order, expected) {
t.Errorf("wrong order\ngot: %#v\nwant: %#v", order, expected)
}
}
func TestWalker_newEdge(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Connect(BasicEdge(1, 2))
// Record function
var order []interface{}
recordF := walkCbRecord(&order)
var w *Walker
cb := func(v Vertex) hcl.Diagnostics {
// record where we are first, otherwise the Updated vertex may get
// walked before the first visit.
diags := recordF(v)
if v == 1 {
g.Add(3)
g.Connect(BasicEdge(3, 2))
w.Update(&g)
}
return diags
}
// Add the initial vertices
w = &Walker{Callback: cb}
w.Update(&g)
// Wait
if err := w.Wait(); err != nil {
t.Fatalf("err: %s", err)
}
// Check
expected := []interface{}{1, 3, 2}
if !reflect.DeepEqual(order, expected) {
t.Errorf("wrong order\ngot: %#v\nwant: %#v", order, expected)
}
}
func TestWalker_removeEdge(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Connect(BasicEdge(1, 2))
g.Connect(BasicEdge(1, 3))
g.Connect(BasicEdge(3, 2))
// Record function
var order []interface{}
recordF := walkCbRecord(&order)
// The way this works is that our original graph forces
// the order of 1 => 3 => 2. During the execution of 1, we
// remove the edge forcing 3 before 2. Then, during the execution
// of 3, we wait on a channel that is only closed by 2, implicitly
// forcing 2 before 3 via the callback (and not the graph). If
// 2 cannot execute before 3 (edge removal is non-functional), then
// this test will timeout.
var w *Walker
gateCh := make(chan struct{})
cb := func(v Vertex) hcl.Diagnostics {
t.Logf("visit vertex %#v", v)
switch v {
case 1:
g.RemoveEdge(BasicEdge(3, 2))
w.Update(&g)
t.Logf("removed edge from 3 to 2")
case 2:
// this visit isn't completed until we've recorded it
// Once the visit is official, we can then close the gate to
// let 3 continue.
defer close(gateCh)
defer t.Logf("2 unblocked 3")
case 3:
select {
case <-gateCh:
t.Logf("vertex 3 gate channel is now closed")
case <-time.After(500 * time.Millisecond):
t.Logf("vertex 3 timed out waiting for the gate channel to close")
var diags hcl.Diagnostics
diags = diags.Append(&hcl.Diagnostic{
Severity: hcl.DiagError,
Summary: "timeout",
Detail: "timeout 3 waiting for 2",
})
return diags
}
}
return recordF(v)
}
// Add the initial vertices
w = &Walker{Callback: cb}
w.Update(&g)
// Wait
if diags := w.Wait(); diags.HasErrors() {
t.Fatalf("unexpected errors: %s", diags.Error())
}
// Check
expected := []interface{}{1, 2, 3}
if !reflect.DeepEqual(order, expected) {
t.Errorf("wrong order\ngot: %#v\nwant: %#v", order, expected)
}
}
// walkCbRecord is a test helper callback that just records the order called.
func walkCbRecord(order *[]interface{}) WalkFunc {
var l sync.Mutex
return func(v Vertex) hcl.Diagnostics {
l.Lock()
defer l.Unlock()
*order = append(*order, v)
return nil
}
}