From 4156ef3cfc34544bf0adba8b8b184ba8025a8df5 Mon Sep 17 00:00:00 2001 From: Chad Retz Date: Thu, 10 May 2018 16:29:16 -0500 Subject: [PATCH] Test framework and first set of tests --- control/cmd_protocolinfo.go | 22 +++-- control/conn.go | 16 +++- control/controltest/cmd_protocolinfo_test.go | 18 ++++ control/controltest/test_context.go | 53 +++++++++++ control/controltest/test_tor.go | 98 ++++++++++++++++++++ control/response.go | 1 + process/creator.go | 21 +++++ process/process.go | 17 +--- process/util.go | 18 ++++ control/parse.go => util/strings.go | 10 +- 10 files changed, 245 insertions(+), 29 deletions(-) create mode 100644 control/controltest/cmd_protocolinfo_test.go create mode 100644 control/controltest/test_context.go create mode 100644 control/controltest/test_tor.go create mode 100644 process/creator.go create mode 100644 process/util.go rename control/parse.go => util/strings.go (72%) diff --git a/control/cmd_protocolinfo.go b/control/cmd_protocolinfo.go index f3c27b3..3710489 100644 --- a/control/cmd_protocolinfo.go +++ b/control/cmd_protocolinfo.go @@ -2,6 +2,8 @@ package control import ( "strings" + + "github.com/cretz/bine/util" ) type ProtocolInfo struct { @@ -25,20 +27,20 @@ func (c *Conn) RequestProtocolInfo() (*ProtocolInfo, error) { if err != nil { return nil, err } - // Check PIVERSION - if len(resp.Data) == 0 || resp.Data[0] != "1" { - return nil, newProtocolError("Invalid PIVERSION: %s", resp.Reply) - } - // Get other response vals + // Check data vals ret := &ProtocolInfo{RawResponse: resp} for _, piece := range resp.Data { - key, val, ok := partitionString(piece, ' ') + key, val, ok := util.PartitionString(piece, ' ') if !ok { continue } switch key { + case "PROTOCOLINFO": + if val != "1" { + return nil, newProtocolError("Invalid PIVERSION: %v", val) + } case "AUTH": - methods, cookieFile, _ := partitionString(val, ' ') + methods, cookieFile, _ := util.PartitionString(val, ' ') if !strings.HasPrefix(methods, "METHODS=") { continue } @@ -46,15 +48,15 @@ func (c *Conn) RequestProtocolInfo() (*ProtocolInfo, error) { if !strings.HasPrefix(cookieFile, "COOKIEFILE=") { continue } - if ret.CookieFile, err = parseQuotedString(cookieFile[11:]); err != nil { + if ret.CookieFile, err = util.ParseSimpleQuotedString(cookieFile[11:]); err != nil { continue } } ret.AuthMethods = strings.Split(methods[8:], ",") case "VERSION": - torVersion, _, _ := partitionString(val, ' ') + torVersion, _, _ := util.PartitionString(val, ' ') if strings.HasPrefix(torVersion, "Tor=") { - ret.TorVersion, _ = parseQuotedString(torVersion[4:]) + ret.TorVersion, err = util.ParseSimpleQuotedString(torVersion[4:]) } } } diff --git a/control/conn.go b/control/conn.go index 3ea4484..bcb8985 100644 --- a/control/conn.go +++ b/control/conn.go @@ -2,11 +2,15 @@ package control import ( "fmt" + "io" "net/textproto" "sync" ) type Conn struct { + // No debug logs if nil + DebugWriter io.Writer + conn *textproto.Conn asyncChansLock sync.RWMutex @@ -65,7 +69,7 @@ func (c *Conn) AddAsyncChan(ch chan<- *Response) { } // Does not close -func (c *Conn) RemoveChan(ch chan<- *Response) bool { +func (c *Conn) RemoveAsyncChan(ch chan<- *Response) bool { c.asyncChansLock.Lock() chans := make([]chan<- *Response, len(c.asyncChans)+1) copy(chans, c.asyncChans) @@ -94,6 +98,16 @@ func (c *Conn) onAsyncResponse(resp *Response) { } } +func (c *Conn) debugEnabled() bool { + return c.DebugWriter != nil +} + +func (c *Conn) debugf(format string, args ...interface{}) { + if w := c.DebugWriter; w != nil { + fmt.Fprintf(w, format+"\n", args...) + } +} + func newProtocolError(format string, args ...interface{}) textproto.ProtocolError { return textproto.ProtocolError(fmt.Sprintf(format, args...)) } diff --git a/control/controltest/cmd_protocolinfo_test.go b/control/controltest/cmd_protocolinfo_test.go new file mode 100644 index 0000000..5bfed52 --- /dev/null +++ b/control/controltest/cmd_protocolinfo_test.go @@ -0,0 +1,18 @@ +package controltest + +import ( + "context" + "strings" + "testing" +) + +func TestProtocolInfo(t *testing.T) { + ctx := NewTestContext(context.Background(), t) + defer ctx.Close() + conn := ctx.ConnectTestTor() + defer conn.Close() + info, err := conn.RequestProtocolInfo() + ctx.Require.NoError(err) + ctx.Require.Contains(info.AuthMethods, "NULL") + ctx.Require.True(strings.HasPrefix(info.TorVersion, "0.3")) +} diff --git a/control/controltest/test_context.go b/control/controltest/test_context.go new file mode 100644 index 0000000..f8e3aec --- /dev/null +++ b/control/controltest/test_context.go @@ -0,0 +1,53 @@ +package controltest + +import ( + "context" + "fmt" + "net/textproto" + "os" + "strconv" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/cretz/bine/control" +) + +type TestContext struct { + context.Context + *testing.T + Require *require.Assertions + TestTor *TestTor +} + +func NewTestContext(ctx context.Context, t *testing.T) *TestContext { + return &TestContext{Context: ctx, T: t, Require: require.New(t)} +} + +func (t *TestContext) EnsureTestTorStarted() { + if t.TestTor == nil { + var err error + if t.TestTor, err = StartTestTor(t); err != nil { + t.Fatal(err) + } + } +} + +func (t *TestContext) Close() { + if t.TestTor != nil { + if err := t.TestTor.Close(); err != nil { + fmt.Printf("Warning, close failed on tor inst: %v", err) + } + } +} + +func (t *TestContext) ConnectTestTor() *control.Conn { + t.EnsureTestTorStarted() + textConn, err := textproto.Dial("tcp", "127.0.0.1:"+strconv.Itoa(t.TestTor.ControlPort)) + if err != nil { + t.Fatal(err) + } + conn := control.NewConn(textConn) + conn.DebugWriter = os.Stdout + return conn +} diff --git a/control/controltest/test_tor.go b/control/controltest/test_tor.go new file mode 100644 index 0000000..7f54305 --- /dev/null +++ b/control/controltest/test_tor.go @@ -0,0 +1,98 @@ +package controltest + +import ( + "context" + "flag" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "time" + + "github.com/cretz/bine/process" +) + +var torExePath string + +func init() { + flag.StringVar(&torExePath, "tor", "tor", "The TOR exe path") + flag.Parse() +} + +type TestTor struct { + DataDir string + OrigArgs []string + ControlPort int + + processCancelFn context.CancelFunc +} + +func StartTestTor(ctx context.Context) (*TestTor, error) { + dataDir, err := ioutil.TempDir(".", "test-data-dir-") + if err != nil { + return nil, err + } + controlPortFile := filepath.Join(dataDir, "control-port") + ret := &TestTor{ + DataDir: dataDir, + OrigArgs: []string{ + // "--quiet", + "--DisableNetwork", "1", + "--ControlPort", "auto", + "--ControlPortWriteToFile", controlPortFile, + "--DataDirectory", dataDir, + }, + } + errCh := make(chan error, 1) + var processCtx context.Context + processCtx, ret.processCancelFn = context.WithCancel(ctx) + go func() { + p, err := process.NewProcessCreator(torExePath).New(processCtx, ret.OrigArgs...) + if err == nil { + err = p.Run() + } + errCh <- err + }() + err = nil + for err == nil { + select { + case err = <-errCh: + if err == nil { + err = fmt.Errorf("Process returned earlier than expected") + } + case <-processCtx.Done(): + err = ctx.Err() + default: + // Try to read the controlport file, or wait a bit + var byts []byte + if byts, err = ioutil.ReadFile(controlPortFile); err == nil { + if ret.ControlPort, err = process.ControlPortFromFileContents(string(byts)); err == nil { + return ret, nil + } + } else if os.IsNotExist(err) { + // Wait a bit + err = nil + time.Sleep(100 * time.Millisecond) + } + } + } + // Delete the data dir and stop the process since we errored + if closeErr := ret.Close(); closeErr != nil { + fmt.Printf("Warning, unable to remove data dir %v: %v", dataDir, closeErr) + } + return nil, err +} + +func (t *TestTor) Close() (err error) { + if t.processCancelFn != nil { + t.processCancelFn() + } + // Try this twice while waiting a bit between each + for i := 0; i < 2; i++ { + if err = os.RemoveAll(t.DataDir); err == nil { + break + } + time.Sleep(300 * time.Millisecond) + } + return +} diff --git a/control/response.go b/control/response.go index 09b3954..daba82e 100644 --- a/control/response.go +++ b/control/response.go @@ -52,6 +52,7 @@ func (c *Conn) ReadResponse() (*Response, error) { if err != nil { return nil, err } + c.debugf("Read line: %v", line) // Parse the line that was just read. if len(line) < 4 { diff --git a/process/creator.go b/process/creator.go new file mode 100644 index 0000000..e122fcb --- /dev/null +++ b/process/creator.go @@ -0,0 +1,21 @@ +package process + +import ( + "context" + "os/exec" +) + +type ProcessCreator interface { + New(ctx context.Context, args ...string) (Process, error) +} + +type exeProcessCreator struct { + exePath string +} + +func NewProcessCreator(exePath string) ProcessCreator { + return &exeProcessCreator{exePath} +} +func (e *exeProcessCreator) New(ctx context.Context, args ...string) (Process, error) { + return &exeProcess{Cmd: exec.CommandContext(ctx, e.exePath, args...)}, nil +} diff --git a/process/process.go b/process/process.go index 3f13d1d..2cb7593 100644 --- a/process/process.go +++ b/process/process.go @@ -1,22 +1,13 @@ package process import ( - "context" "os/exec" ) -type Tor interface { - Start(ctx context.Context, args []string) error +type Process interface { + Run() error } -type exeTor struct { - exePath string -} - -func FromExePath(exePath string) Tor { - return &exeTor{exePath} -} - -func (e *exeTor) Start(ctx context.Context, args []string) error { - return exec.CommandContext(ctx, e.exePath, args...).Start() +type exeProcess struct { + *exec.Cmd } diff --git a/process/util.go b/process/util.go new file mode 100644 index 0000000..00fab4d --- /dev/null +++ b/process/util.go @@ -0,0 +1,18 @@ +package process + +import ( + "fmt" + "strconv" + "strings" + + "github.com/cretz/bine/util" +) + +func ControlPortFromFileContents(contents string) (int, error) { + contents = strings.TrimSpace(contents) + _, port, ok := util.PartitionString(contents, ':') + if !ok || !strings.HasPrefix(contents, "PORT=") { + return 0, fmt.Errorf("Invalid port format: %v", contents) + } + return strconv.Atoi(port) +} diff --git a/control/parse.go b/util/strings.go similarity index 72% rename from control/parse.go rename to util/strings.go index 89f4649..6fb9b5b 100644 --- a/control/parse.go +++ b/util/strings.go @@ -1,11 +1,11 @@ -package control +package util import ( "fmt" "strings" ) -func partitionString(str string, ch byte) (string, string, bool) { +func PartitionString(str string, ch byte) (string, string, bool) { index := strings.IndexByte(str, ch) if index == -1 { return str, "", false @@ -13,14 +13,14 @@ func partitionString(str string, ch byte) (string, string, bool) { return str[:index], str[index+1:], true } -func parseQuotedString(str string) (string, error) { +func ParseSimpleQuotedString(str string) (string, error) { if len(str) < 2 || str[0] != '"' || str[len(str)-1] != '"' { return "", fmt.Errorf("Missing quotes") } - return unescapeQuoted(str) + return UnescapeSimpleQuoted(str[1 : len(str)-1]) } -func unescapeQuoted(str string) (string, error) { +func UnescapeSimpleQuoted(str string) (string, error) { ret := "" escaping := false for _, c := range str {