mirror of https://github.com/harness/drone.git
349 lines
8.4 KiB
Go
349 lines
8.4 KiB
Go
// Copyright 2023 Harness, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package dbtx
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"testing"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
//nolint:gocognit
|
|
func TestWithTx(t *testing.T) {
|
|
errTest := errors.New("dummy error")
|
|
|
|
tests := []struct {
|
|
name string
|
|
fn func(tx Transaction) error
|
|
errCommit error
|
|
cancelCtx bool
|
|
expectErr error
|
|
expectCommitted bool
|
|
expectRollback bool
|
|
}{
|
|
{
|
|
name: "successful",
|
|
fn: func(Transaction) error { return nil },
|
|
expectCommitted: true,
|
|
},
|
|
{
|
|
name: "err-in-transaction",
|
|
fn: func(Transaction) error { return errTest },
|
|
expectErr: errTest,
|
|
expectRollback: true,
|
|
},
|
|
{
|
|
name: "commit-failed",
|
|
fn: func(Transaction) error { return nil },
|
|
errCommit: errTest,
|
|
expectErr: errTest,
|
|
expectRollback: true,
|
|
},
|
|
{
|
|
name: "commit-failed-tx-done",
|
|
fn: func(Transaction) error { return nil },
|
|
errCommit: sql.ErrTxDone,
|
|
expectErr: sql.ErrTxDone,
|
|
expectRollback: true,
|
|
},
|
|
{
|
|
name: "commit-failed-ctx-cancelled",
|
|
fn: func(Transaction) error { return nil },
|
|
errCommit: sql.ErrTxDone,
|
|
cancelCtx: true,
|
|
expectErr: context.Canceled,
|
|
expectRollback: true,
|
|
},
|
|
{
|
|
name: "panic-in-transaction",
|
|
fn: func(Transaction) error { panic("dummy panic") },
|
|
expectRollback: true,
|
|
},
|
|
{
|
|
name: "commit-in-transaction",
|
|
fn: func(tx Transaction) error {
|
|
_ = tx.Commit()
|
|
return nil
|
|
},
|
|
expectCommitted: true,
|
|
},
|
|
{
|
|
name: "commit-in-transaction-fn-returns-err",
|
|
fn: func(tx Transaction) error {
|
|
_ = tx.Commit()
|
|
return errTest
|
|
},
|
|
expectErr: errTest,
|
|
expectCommitted: true,
|
|
},
|
|
{
|
|
name: "rollback-in-transaction",
|
|
fn: func(tx Transaction) error {
|
|
_ = tx.Rollback()
|
|
return nil
|
|
},
|
|
expectRollback: true,
|
|
},
|
|
{
|
|
name: "rollback-in-transaction-fn-returns-err",
|
|
fn: func(tx Transaction) error {
|
|
_ = tx.Rollback()
|
|
return errTest
|
|
},
|
|
expectErr: errTest,
|
|
expectRollback: true,
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
mock := &dbMock{
|
|
t: t,
|
|
errCommit: test.errCommit,
|
|
}
|
|
run := &runnerDB{
|
|
db: mock,
|
|
mx: lockerNop{},
|
|
}
|
|
|
|
ctx, cancelFn := context.WithCancel(context.Background())
|
|
defer cancelFn()
|
|
|
|
var err error
|
|
|
|
func() {
|
|
defer func() {
|
|
_ = recover()
|
|
}()
|
|
|
|
err = run.WithTx(ctx, func(ctx context.Context) error {
|
|
if test.cancelCtx {
|
|
cancelFn()
|
|
}
|
|
return test.fn(GetTransaction(ctx))
|
|
})
|
|
}()
|
|
|
|
tx := mock.createdTx
|
|
if tx == nil {
|
|
t.Error("did not start a transaction")
|
|
return
|
|
}
|
|
|
|
if !tx.finished {
|
|
t.Error("transaction not finished")
|
|
}
|
|
|
|
if want, got := test.expectErr, err; !errors.Is(got, want) {
|
|
t.Errorf("expected error %v, but got %v", want, got)
|
|
}
|
|
|
|
if want, got := test.expectCommitted, tx.committed; want != got {
|
|
t.Errorf("expected committed %t, but got %t", want, got)
|
|
}
|
|
|
|
if want, got := test.expectRollback, tx.rollback; want != got {
|
|
t.Errorf("expected rollback %t, but got %t", want, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type dbMock struct {
|
|
*sqlx.DB // only to fulfill the Accessor interface, will be nil
|
|
t *testing.T
|
|
errCommit error
|
|
createdTx *txMock
|
|
}
|
|
|
|
var _ transactor = (*dbMock)(nil)
|
|
|
|
func (d *dbMock) startTx(context.Context, *sql.TxOptions) (TransactionAccessor, error) {
|
|
d.createdTx = &txMock{
|
|
t: d.t,
|
|
errCommit: d.errCommit,
|
|
finished: false,
|
|
committed: false,
|
|
rollback: false,
|
|
}
|
|
return d.createdTx, nil
|
|
}
|
|
|
|
type txMock struct {
|
|
*sqlx.Tx // only to fulfill the Accessor interface, will be nil
|
|
t *testing.T
|
|
errCommit error
|
|
finished bool
|
|
committed bool
|
|
rollback bool
|
|
}
|
|
|
|
var _ TransactionAccessor = (*txMock)(nil)
|
|
|
|
func (tx *txMock) Commit() error {
|
|
if tx.finished {
|
|
tx.t.Error("Committing an already finished transaction")
|
|
return nil
|
|
}
|
|
if tx.errCommit == nil {
|
|
tx.finished = true
|
|
tx.committed = true
|
|
}
|
|
return tx.errCommit
|
|
}
|
|
|
|
func (tx *txMock) Rollback() error {
|
|
if tx.finished {
|
|
tx.t.Error("Rolling back an already finished transaction")
|
|
return nil
|
|
}
|
|
tx.finished = true
|
|
tx.rollback = true
|
|
return nil
|
|
}
|
|
|
|
// nolint:rowserrcheck,sqlclosecheck // it's a unit test, works with mocked DB
|
|
func TestLocking(t *testing.T) {
|
|
const dummyQuery = ""
|
|
tests := []struct {
|
|
name string
|
|
fn func(db AccessorTx, l *lockerCounter)
|
|
}{
|
|
{
|
|
name: "exec-lock",
|
|
fn: func(db AccessorTx, l *lockerCounter) {
|
|
ctx := context.Background()
|
|
_, _ = db.ExecContext(ctx, dummyQuery)
|
|
_, _ = db.ExecContext(ctx, dummyQuery)
|
|
_, _ = db.ExecContext(ctx, dummyQuery)
|
|
|
|
assert.Zero(t, l.RLocks)
|
|
assert.Zero(t, l.RUnlocks)
|
|
assert.Equal(t, 3, l.Locks)
|
|
assert.Equal(t, 3, l.Unlocks)
|
|
},
|
|
},
|
|
{
|
|
name: "tx-lock",
|
|
fn: func(db AccessorTx, l *lockerCounter) {
|
|
ctx := context.Background()
|
|
_ = db.WithTx(ctx, func(ctx context.Context) error {
|
|
_, _ = GetAccessor(ctx, nil).ExecContext(ctx, dummyQuery)
|
|
_, _ = GetAccessor(ctx, nil).ExecContext(ctx, dummyQuery)
|
|
return nil
|
|
})
|
|
|
|
assert.Zero(t, l.RLocks)
|
|
assert.Zero(t, l.RUnlocks)
|
|
assert.Equal(t, 1, l.Locks)
|
|
assert.Equal(t, 1, l.Unlocks)
|
|
},
|
|
},
|
|
{
|
|
name: "tx-read-lock",
|
|
fn: func(db AccessorTx, l *lockerCounter) {
|
|
ctx := context.Background()
|
|
_ = db.WithTx(ctx, func(ctx context.Context) error {
|
|
_, _ = GetAccessor(ctx, nil).QueryContext(ctx, dummyQuery)
|
|
_, _ = GetAccessor(ctx, nil).QueryContext(ctx, dummyQuery)
|
|
return nil
|
|
}, TxDefaultReadOnly)
|
|
|
|
assert.Equal(t, 1, l.RLocks)
|
|
assert.Equal(t, 1, l.RUnlocks)
|
|
assert.Zero(t, l.Locks)
|
|
assert.Zero(t, l.Unlocks)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
l := &lockerCounter{}
|
|
t.Run(test.name, func(_ *testing.T) {
|
|
test.fn(runnerDB{
|
|
db: dbMockNop{},
|
|
mx: l,
|
|
}, l)
|
|
})
|
|
}
|
|
}
|
|
|
|
type lockerCounter struct {
|
|
Locks int
|
|
Unlocks int
|
|
RLocks int
|
|
RUnlocks int
|
|
}
|
|
|
|
func (l *lockerCounter) Lock() { l.Locks++ }
|
|
func (l *lockerCounter) Unlock() { l.Unlocks++ }
|
|
func (l *lockerCounter) RLock() { l.RLocks++ }
|
|
func (l *lockerCounter) RUnlock() { l.RUnlocks++ }
|
|
|
|
type dbMockNop struct{}
|
|
|
|
func (dbMockNop) DriverName() string { return "" }
|
|
func (dbMockNop) Rebind(string) string { return "" }
|
|
func (dbMockNop) BindNamed(string, interface{}) (string, []interface{}, error) { return "", nil, nil }
|
|
|
|
//nolint:nilnil // it's a mock
|
|
func (dbMockNop) QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
//nolint:nilnil // it's a mock
|
|
func (dbMockNop) QueryxContext(context.Context, string, ...interface{}) (*sqlx.Rows, error) {
|
|
return nil, nil
|
|
}
|
|
func (dbMockNop) QueryRowxContext(context.Context, string, ...interface{}) *sqlx.Row { return nil }
|
|
func (dbMockNop) ExecContext(context.Context, string, ...interface{}) (sql.Result, error) {
|
|
return nil, nil
|
|
}
|
|
func (dbMockNop) QueryRowContext(context.Context, string, ...any) *sql.Row {
|
|
return nil
|
|
}
|
|
|
|
//nolint:nilnil // it's a mock
|
|
func (dbMockNop) PrepareContext(context.Context, string) (*sql.Stmt, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
//nolint:nilnil // it's a mock
|
|
func (dbMockNop) PreparexContext(context.Context, string) (*sqlx.Stmt, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
//nolint:nilnil // it's a mock
|
|
func (dbMockNop) PrepareNamedContext(context.Context, string) (*sqlx.NamedStmt, error) {
|
|
return nil, nil
|
|
}
|
|
func (dbMockNop) GetContext(context.Context, interface{}, string, ...interface{}) error {
|
|
return nil
|
|
}
|
|
func (dbMockNop) SelectContext(context.Context, interface{}, string, ...interface{}) error {
|
|
return nil
|
|
}
|
|
|
|
func (dbMockNop) Commit() error { return nil }
|
|
func (dbMockNop) Rollback() error { return nil }
|
|
|
|
func (d dbMockNop) startTx(context.Context, *sql.TxOptions) (TransactionAccessor, error) {
|
|
return d, nil
|
|
}
|