tests/weghting_n_cutting_the_tree/funcs_test.go

118 lines
3.3 KiB
Go

package main
import (
"fmt"
"github.com/stretchr/testify/require"
"sort"
"testing"
"github.com/stretchr/testify/assert"
)
func TestNode_WeighingTreeAllAlgo(t *testing.T) {
for name,tt := range map[string]TestTree{
"Ordinary VFS tree": getTestTree(3, 3, nil),
"VFS in the form of a pathological tree":getTestTree(1, 7, nil),
}{
t.Run(name,func(t *testing.T) {
t.Run("WeightingTreeWithRecursion", func(t *testing.T) {
weightTreeWithRecursion := tt.Tree.WeightingTreeWithRecursion(nil)
printExpectations(tt.Weights, weightTreeWithRecursion)
assert.EqualValues(t, tt.Weights, weightTreeWithRecursion)
})
t.Run("WeightingTreeWithStack", func(t *testing.T) {
weightTreeWithStack := tt.Tree.WeightingTreeWithStack()
printExpectations(tt.Weights, weightTreeWithStack)
assert.EqualValues(t, tt.Weights, weightTreeWithStack)
})
t.Run("WeightingTreeWithDLL", func(t *testing.T) {
weightTreeWithDLL := tt.Tree.WeightingTreeWithDLL()
printExpectations(tt.Weights, weightTreeWithDLL)
assert.EqualValues(t, tt.Weights, weightTreeWithDLL)
})
})
}
}
func TestDecomposeTree(t *testing.T){
for name, tt := range map[string]struct{
inputTree map[string]*WeightedNode
inputThreshold int
outputParts []SubTree
err string
}{
"Normal":{
inputTree: smallTree.Tree,
inputThreshold: smallTree.Threshold,
outputParts: smallTree.Parts,
},
}{
t.Run(name, func(t *testing.T) {
result, err := DecomposeTree(tt.inputTree, tt.inputThreshold)
if tt.err == "" {
require.NoError(t, err)
require.ElementsMatch(t, result, tt.outputParts)
} else {
require.EqualError(t, err, tt.err)
}
})
}
}
func Benchmark(b *testing.B) {
b.StopTimer()
b.ResetTimer()
for name,tt := range map[string]struct{
branches int
depth int
}{
"Small tree (85 nodes)":{4, 4}, // 85 nodes,
//"Wide tree (265K nodes)": {515, 3}, // 265'741 nodes,
//"Deep tree (265K nodes)": {3, 12}, // 265'720 nodes,
//"Huge tree (2,44M nodes)": {5, 10}, // 2'441'406 nodes,
"Pathological tree (10K nodes)": {1, 10000}, // 10'000 nodes,
}{
b.Run(name, func(b *testing.B) {
tree := getTestTree(tt.branches, tt.depth, nil)
b.Cleanup(func() {
tree = TestTree{nil, nil}
})
b.Run("Weighing with recursion", func(b *testing.B) {
for i := 0; i < b.N; i++ {
b.StartTimer()
assert.EqualValues(b, tree.Weights, tree.Tree.WeightingTreeWithRecursion(nil))
b.StopTimer()
}
})
b.Run("Weighting with stack", func(b *testing.B) {
for i := 0; i < b.N; i++ {
b.StartTimer()
assert.EqualValues(b, tree.Weights, tree.Tree.WeightingTreeWithStack())
b.StopTimer()
}
})
b.Run("Weighting with DLL", func(b *testing.B) {
for i := 0; i < b.N; i++ {
b.StartTimer()
assert.EqualValues(b, tree.Weights, tree.Tree.WeightingTreeWithDLL())
b.StopTimer()
}
})
})
}
}
func printExpectations(expected, actual map[string]*WeightedNode) {
var expectedSlice, actualSlice []int
for _,v := range expected {
expectedSlice = append(expectedSlice, int(v.weight))
}
sort.Sort(sort.Reverse(sort.IntSlice(expectedSlice)))
for _,v := range actual {
actualSlice = append(actualSlice, int(v.weight))
}
sort.Sort(sort.Reverse(sort.IntSlice(actualSlice)))
fmt.Printf("expected: %+v\n", expectedSlice)
fmt.Printf("actual : %+v\n", actualSlice)
}