drone/store/database/dbtx/runner_test.go

349 lines
8.4 KiB
Go

// Copyright 2023 Harness, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dbtx
import (
"context"
"database/sql"
"errors"
"testing"
"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert"
)
//nolint:gocognit
func TestWithTx(t *testing.T) {
errTest := errors.New("dummy error")
tests := []struct {
name string
fn func(tx Transaction) error
errCommit error
cancelCtx bool
expectErr error
expectCommitted bool
expectRollback bool
}{
{
name: "successful",
fn: func(Transaction) error { return nil },
expectCommitted: true,
},
{
name: "err-in-transaction",
fn: func(Transaction) error { return errTest },
expectErr: errTest,
expectRollback: true,
},
{
name: "commit-failed",
fn: func(Transaction) error { return nil },
errCommit: errTest,
expectErr: errTest,
expectRollback: true,
},
{
name: "commit-failed-tx-done",
fn: func(Transaction) error { return nil },
errCommit: sql.ErrTxDone,
expectErr: sql.ErrTxDone,
expectRollback: true,
},
{
name: "commit-failed-ctx-cancelled",
fn: func(Transaction) error { return nil },
errCommit: sql.ErrTxDone,
cancelCtx: true,
expectErr: context.Canceled,
expectRollback: true,
},
{
name: "panic-in-transaction",
fn: func(Transaction) error { panic("dummy panic") },
expectRollback: true,
},
{
name: "commit-in-transaction",
fn: func(tx Transaction) error {
_ = tx.Commit()
return nil
},
expectCommitted: true,
},
{
name: "commit-in-transaction-fn-returns-err",
fn: func(tx Transaction) error {
_ = tx.Commit()
return errTest
},
expectErr: errTest,
expectCommitted: true,
},
{
name: "rollback-in-transaction",
fn: func(tx Transaction) error {
_ = tx.Rollback()
return nil
},
expectRollback: true,
},
{
name: "rollback-in-transaction-fn-returns-err",
fn: func(tx Transaction) error {
_ = tx.Rollback()
return errTest
},
expectErr: errTest,
expectRollback: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
mock := &dbMock{
t: t,
errCommit: test.errCommit,
}
run := &runnerDB{
db: mock,
mx: lockerNop{},
}
ctx, cancelFn := context.WithCancel(context.Background())
defer cancelFn()
var err error
func() {
defer func() {
_ = recover()
}()
err = run.WithTx(ctx, func(ctx context.Context) error {
if test.cancelCtx {
cancelFn()
}
return test.fn(GetTransaction(ctx))
})
}()
tx := mock.createdTx
if tx == nil {
t.Error("did not start a transaction")
return
}
if !tx.finished {
t.Error("transaction not finished")
}
if want, got := test.expectErr, err; !errors.Is(got, want) {
t.Errorf("expected error %v, but got %v", want, got)
}
if want, got := test.expectCommitted, tx.committed; want != got {
t.Errorf("expected committed %t, but got %t", want, got)
}
if want, got := test.expectRollback, tx.rollback; want != got {
t.Errorf("expected rollback %t, but got %t", want, got)
}
})
}
}
type dbMock struct {
*sqlx.DB // only to fulfill the Accessor interface, will be nil
t *testing.T
errCommit error
createdTx *txMock
}
var _ transactor = (*dbMock)(nil)
func (d *dbMock) startTx(context.Context, *sql.TxOptions) (TransactionAccessor, error) {
d.createdTx = &txMock{
t: d.t,
errCommit: d.errCommit,
finished: false,
committed: false,
rollback: false,
}
return d.createdTx, nil
}
type txMock struct {
*sqlx.Tx // only to fulfill the Accessor interface, will be nil
t *testing.T
errCommit error
finished bool
committed bool
rollback bool
}
var _ TransactionAccessor = (*txMock)(nil)
func (tx *txMock) Commit() error {
if tx.finished {
tx.t.Error("Committing an already finished transaction")
return nil
}
if tx.errCommit == nil {
tx.finished = true
tx.committed = true
}
return tx.errCommit
}
func (tx *txMock) Rollback() error {
if tx.finished {
tx.t.Error("Rolling back an already finished transaction")
return nil
}
tx.finished = true
tx.rollback = true
return nil
}
// nolint:rowserrcheck,sqlclosecheck // it's a unit test, works with mocked DB
func TestLocking(t *testing.T) {
const dummyQuery = ""
tests := []struct {
name string
fn func(db AccessorTx, l *lockerCounter)
}{
{
name: "exec-lock",
fn: func(db AccessorTx, l *lockerCounter) {
ctx := context.Background()
_, _ = db.ExecContext(ctx, dummyQuery)
_, _ = db.ExecContext(ctx, dummyQuery)
_, _ = db.ExecContext(ctx, dummyQuery)
assert.Zero(t, l.RLocks)
assert.Zero(t, l.RUnlocks)
assert.Equal(t, 3, l.Locks)
assert.Equal(t, 3, l.Unlocks)
},
},
{
name: "tx-lock",
fn: func(db AccessorTx, l *lockerCounter) {
ctx := context.Background()
_ = db.WithTx(ctx, func(ctx context.Context) error {
_, _ = GetAccessor(ctx, nil).ExecContext(ctx, dummyQuery)
_, _ = GetAccessor(ctx, nil).ExecContext(ctx, dummyQuery)
return nil
})
assert.Zero(t, l.RLocks)
assert.Zero(t, l.RUnlocks)
assert.Equal(t, 1, l.Locks)
assert.Equal(t, 1, l.Unlocks)
},
},
{
name: "tx-read-lock",
fn: func(db AccessorTx, l *lockerCounter) {
ctx := context.Background()
_ = db.WithTx(ctx, func(ctx context.Context) error {
_, _ = GetAccessor(ctx, nil).QueryContext(ctx, dummyQuery)
_, _ = GetAccessor(ctx, nil).QueryContext(ctx, dummyQuery)
return nil
}, TxDefaultReadOnly)
assert.Equal(t, 1, l.RLocks)
assert.Equal(t, 1, l.RUnlocks)
assert.Zero(t, l.Locks)
assert.Zero(t, l.Unlocks)
},
},
}
for _, test := range tests {
l := &lockerCounter{}
t.Run(test.name, func(_ *testing.T) {
test.fn(runnerDB{
db: dbMockNop{},
mx: l,
}, l)
})
}
}
type lockerCounter struct {
Locks int
Unlocks int
RLocks int
RUnlocks int
}
func (l *lockerCounter) Lock() { l.Locks++ }
func (l *lockerCounter) Unlock() { l.Unlocks++ }
func (l *lockerCounter) RLock() { l.RLocks++ }
func (l *lockerCounter) RUnlock() { l.RUnlocks++ }
type dbMockNop struct{}
func (dbMockNop) DriverName() string { return "" }
func (dbMockNop) Rebind(string) string { return "" }
func (dbMockNop) BindNamed(string, interface{}) (string, []interface{}, error) { return "", nil, nil }
//nolint:nilnil // it's a mock
func (dbMockNop) QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) {
return nil, nil
}
//nolint:nilnil // it's a mock
func (dbMockNop) QueryxContext(context.Context, string, ...interface{}) (*sqlx.Rows, error) {
return nil, nil
}
func (dbMockNop) QueryRowxContext(context.Context, string, ...interface{}) *sqlx.Row { return nil }
func (dbMockNop) ExecContext(context.Context, string, ...interface{}) (sql.Result, error) {
return nil, nil
}
func (dbMockNop) QueryRowContext(context.Context, string, ...any) *sql.Row {
return nil
}
//nolint:nilnil // it's a mock
func (dbMockNop) PrepareContext(context.Context, string) (*sql.Stmt, error) {
return nil, nil
}
//nolint:nilnil // it's a mock
func (dbMockNop) PreparexContext(context.Context, string) (*sqlx.Stmt, error) {
return nil, nil
}
//nolint:nilnil // it's a mock
func (dbMockNop) PrepareNamedContext(context.Context, string) (*sqlx.NamedStmt, error) {
return nil, nil
}
func (dbMockNop) GetContext(context.Context, interface{}, string, ...interface{}) error {
return nil
}
func (dbMockNop) SelectContext(context.Context, interface{}, string, ...interface{}) error {
return nil
}
func (dbMockNop) Commit() error { return nil }
func (dbMockNop) Rollback() error { return nil }
func (d dbMockNop) startTx(context.Context, *sql.TxOptions) (TransactionAccessor, error) {
return d, nil
}