Test framework and first set of tests

This commit is contained in:
Chad Retz 2018-05-10 16:29:16 -05:00
parent 34ea0edde9
commit 4156ef3cfc
10 changed files with 245 additions and 29 deletions

View File

@ -2,6 +2,8 @@ package control
import (
"strings"
"github.com/cretz/bine/util"
)
type ProtocolInfo struct {
@ -25,20 +27,20 @@ func (c *Conn) RequestProtocolInfo() (*ProtocolInfo, error) {
if err != nil {
return nil, err
}
// Check PIVERSION
if len(resp.Data) == 0 || resp.Data[0] != "1" {
return nil, newProtocolError("Invalid PIVERSION: %s", resp.Reply)
}
// Get other response vals
// Check data vals
ret := &ProtocolInfo{RawResponse: resp}
for _, piece := range resp.Data {
key, val, ok := partitionString(piece, ' ')
key, val, ok := util.PartitionString(piece, ' ')
if !ok {
continue
}
switch key {
case "PROTOCOLINFO":
if val != "1" {
return nil, newProtocolError("Invalid PIVERSION: %v", val)
}
case "AUTH":
methods, cookieFile, _ := partitionString(val, ' ')
methods, cookieFile, _ := util.PartitionString(val, ' ')
if !strings.HasPrefix(methods, "METHODS=") {
continue
}
@ -46,15 +48,15 @@ func (c *Conn) RequestProtocolInfo() (*ProtocolInfo, error) {
if !strings.HasPrefix(cookieFile, "COOKIEFILE=") {
continue
}
if ret.CookieFile, err = parseQuotedString(cookieFile[11:]); err != nil {
if ret.CookieFile, err = util.ParseSimpleQuotedString(cookieFile[11:]); err != nil {
continue
}
}
ret.AuthMethods = strings.Split(methods[8:], ",")
case "VERSION":
torVersion, _, _ := partitionString(val, ' ')
torVersion, _, _ := util.PartitionString(val, ' ')
if strings.HasPrefix(torVersion, "Tor=") {
ret.TorVersion, _ = parseQuotedString(torVersion[4:])
ret.TorVersion, err = util.ParseSimpleQuotedString(torVersion[4:])
}
}
}

View File

@ -2,11 +2,15 @@ package control
import (
"fmt"
"io"
"net/textproto"
"sync"
)
type Conn struct {
// No debug logs if nil
DebugWriter io.Writer
conn *textproto.Conn
asyncChansLock sync.RWMutex
@ -65,7 +69,7 @@ func (c *Conn) AddAsyncChan(ch chan<- *Response) {
}
// Does not close
func (c *Conn) RemoveChan(ch chan<- *Response) bool {
func (c *Conn) RemoveAsyncChan(ch chan<- *Response) bool {
c.asyncChansLock.Lock()
chans := make([]chan<- *Response, len(c.asyncChans)+1)
copy(chans, c.asyncChans)
@ -94,6 +98,16 @@ func (c *Conn) onAsyncResponse(resp *Response) {
}
}
func (c *Conn) debugEnabled() bool {
return c.DebugWriter != nil
}
func (c *Conn) debugf(format string, args ...interface{}) {
if w := c.DebugWriter; w != nil {
fmt.Fprintf(w, format+"\n", args...)
}
}
func newProtocolError(format string, args ...interface{}) textproto.ProtocolError {
return textproto.ProtocolError(fmt.Sprintf(format, args...))
}

View File

@ -0,0 +1,18 @@
package controltest
import (
"context"
"strings"
"testing"
)
func TestProtocolInfo(t *testing.T) {
ctx := NewTestContext(context.Background(), t)
defer ctx.Close()
conn := ctx.ConnectTestTor()
defer conn.Close()
info, err := conn.RequestProtocolInfo()
ctx.Require.NoError(err)
ctx.Require.Contains(info.AuthMethods, "NULL")
ctx.Require.True(strings.HasPrefix(info.TorVersion, "0.3"))
}

View File

@ -0,0 +1,53 @@
package controltest
import (
"context"
"fmt"
"net/textproto"
"os"
"strconv"
"testing"
"github.com/stretchr/testify/require"
"github.com/cretz/bine/control"
)
type TestContext struct {
context.Context
*testing.T
Require *require.Assertions
TestTor *TestTor
}
func NewTestContext(ctx context.Context, t *testing.T) *TestContext {
return &TestContext{Context: ctx, T: t, Require: require.New(t)}
}
func (t *TestContext) EnsureTestTorStarted() {
if t.TestTor == nil {
var err error
if t.TestTor, err = StartTestTor(t); err != nil {
t.Fatal(err)
}
}
}
func (t *TestContext) Close() {
if t.TestTor != nil {
if err := t.TestTor.Close(); err != nil {
fmt.Printf("Warning, close failed on tor inst: %v", err)
}
}
}
func (t *TestContext) ConnectTestTor() *control.Conn {
t.EnsureTestTorStarted()
textConn, err := textproto.Dial("tcp", "127.0.0.1:"+strconv.Itoa(t.TestTor.ControlPort))
if err != nil {
t.Fatal(err)
}
conn := control.NewConn(textConn)
conn.DebugWriter = os.Stdout
return conn
}

View File

@ -0,0 +1,98 @@
package controltest
import (
"context"
"flag"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"time"
"github.com/cretz/bine/process"
)
var torExePath string
func init() {
flag.StringVar(&torExePath, "tor", "tor", "The TOR exe path")
flag.Parse()
}
type TestTor struct {
DataDir string
OrigArgs []string
ControlPort int
processCancelFn context.CancelFunc
}
func StartTestTor(ctx context.Context) (*TestTor, error) {
dataDir, err := ioutil.TempDir(".", "test-data-dir-")
if err != nil {
return nil, err
}
controlPortFile := filepath.Join(dataDir, "control-port")
ret := &TestTor{
DataDir: dataDir,
OrigArgs: []string{
// "--quiet",
"--DisableNetwork", "1",
"--ControlPort", "auto",
"--ControlPortWriteToFile", controlPortFile,
"--DataDirectory", dataDir,
},
}
errCh := make(chan error, 1)
var processCtx context.Context
processCtx, ret.processCancelFn = context.WithCancel(ctx)
go func() {
p, err := process.NewProcessCreator(torExePath).New(processCtx, ret.OrigArgs...)
if err == nil {
err = p.Run()
}
errCh <- err
}()
err = nil
for err == nil {
select {
case err = <-errCh:
if err == nil {
err = fmt.Errorf("Process returned earlier than expected")
}
case <-processCtx.Done():
err = ctx.Err()
default:
// Try to read the controlport file, or wait a bit
var byts []byte
if byts, err = ioutil.ReadFile(controlPortFile); err == nil {
if ret.ControlPort, err = process.ControlPortFromFileContents(string(byts)); err == nil {
return ret, nil
}
} else if os.IsNotExist(err) {
// Wait a bit
err = nil
time.Sleep(100 * time.Millisecond)
}
}
}
// Delete the data dir and stop the process since we errored
if closeErr := ret.Close(); closeErr != nil {
fmt.Printf("Warning, unable to remove data dir %v: %v", dataDir, closeErr)
}
return nil, err
}
func (t *TestTor) Close() (err error) {
if t.processCancelFn != nil {
t.processCancelFn()
}
// Try this twice while waiting a bit between each
for i := 0; i < 2; i++ {
if err = os.RemoveAll(t.DataDir); err == nil {
break
}
time.Sleep(300 * time.Millisecond)
}
return
}

View File

@ -52,6 +52,7 @@ func (c *Conn) ReadResponse() (*Response, error) {
if err != nil {
return nil, err
}
c.debugf("Read line: %v", line)
// Parse the line that was just read.
if len(line) < 4 {

21
process/creator.go Normal file
View File

@ -0,0 +1,21 @@
package process
import (
"context"
"os/exec"
)
type ProcessCreator interface {
New(ctx context.Context, args ...string) (Process, error)
}
type exeProcessCreator struct {
exePath string
}
func NewProcessCreator(exePath string) ProcessCreator {
return &exeProcessCreator{exePath}
}
func (e *exeProcessCreator) New(ctx context.Context, args ...string) (Process, error) {
return &exeProcess{Cmd: exec.CommandContext(ctx, e.exePath, args...)}, nil
}

View File

@ -1,22 +1,13 @@
package process
import (
"context"
"os/exec"
)
type Tor interface {
Start(ctx context.Context, args []string) error
type Process interface {
Run() error
}
type exeTor struct {
exePath string
}
func FromExePath(exePath string) Tor {
return &exeTor{exePath}
}
func (e *exeTor) Start(ctx context.Context, args []string) error {
return exec.CommandContext(ctx, e.exePath, args...).Start()
type exeProcess struct {
*exec.Cmd
}

18
process/util.go Normal file
View File

@ -0,0 +1,18 @@
package process
import (
"fmt"
"strconv"
"strings"
"github.com/cretz/bine/util"
)
func ControlPortFromFileContents(contents string) (int, error) {
contents = strings.TrimSpace(contents)
_, port, ok := util.PartitionString(contents, ':')
if !ok || !strings.HasPrefix(contents, "PORT=") {
return 0, fmt.Errorf("Invalid port format: %v", contents)
}
return strconv.Atoi(port)
}

View File

@ -1,11 +1,11 @@
package control
package util
import (
"fmt"
"strings"
)
func partitionString(str string, ch byte) (string, string, bool) {
func PartitionString(str string, ch byte) (string, string, bool) {
index := strings.IndexByte(str, ch)
if index == -1 {
return str, "", false
@ -13,14 +13,14 @@ func partitionString(str string, ch byte) (string, string, bool) {
return str[:index], str[index+1:], true
}
func parseQuotedString(str string) (string, error) {
func ParseSimpleQuotedString(str string) (string, error) {
if len(str) < 2 || str[0] != '"' || str[len(str)-1] != '"' {
return "", fmt.Errorf("Missing quotes")
}
return unescapeQuoted(str)
return UnescapeSimpleQuoted(str[1 : len(str)-1])
}
func unescapeQuoted(str string) (string, error) {
func UnescapeSimpleQuoted(str string) (string, error) {
ret := ""
escaping := false
for _, c := range str {