diff --git a/control/cmd_event.go b/control/cmd_event.go index 52fa78c..7907b8c 100644 --- a/control/cmd_event.go +++ b/control/cmd_event.go @@ -1,6 +1,7 @@ package control import ( + "context" "strconv" "strings" "time" @@ -33,7 +34,76 @@ const ( EventCodeStreamBandwidth EventCode = "STREAM_BW" ) -func (c *Conn) AddEventListener(events []EventCode, ch chan<- Event) error { +func EventCodes() []EventCode { + return []EventCode{ + EventCodeAddrMap, + EventCodeBandwidth, + EventCodeCircuit, + EventCodeClientsSeen, + EventCodeDescChanged, + EventCodeGuard, + EventCodeLogDebug, + EventCodeLogErr, + EventCodeLogInfo, + EventCodeLogNotice, + EventCodeLogWarn, + EventCodeNetworkStatus, + EventCodeNewConsensus, + EventCodeNewDesc, + EventCodeORConn, + EventCodeStatusClient, + EventCodeStatusGeneral, + EventCodeStatusServer, + EventCodeStream, + EventCodeStreamBandwidth, + } +} + +// HandleEvents loops until the context is closed dispatching async events. Can dispatch events even after context is +// done and of course during synchronous request. This will always end with an error, either from ctx.Done() or from an +// error reading/handling the event. +func (c *Conn) HandleEvents(ctx context.Context) error { + errCh := make(chan error, 1) + go func() { + for ctx.Err() == nil { + if err := c.HandleNextEvent(); err != nil { + errCh <- err + break + } + } + }() + select { + case err := <-errCh: + return err + case <-ctx.Done(): + return ctx.Err() + } +} + +// HandleNextEvent attempts to read and handle the next event. It will return on first message seen, event or not. +// Otherwise it will wait until there is a message read. +func (c *Conn) HandleNextEvent() error { + c.readLock.Lock() + defer c.readLock.Unlock() + // We'll just peek for the next 3 bytes and see if they are async + byts, err := c.conn.R.Peek(3) + if err != nil { + return err + } + statusCode, err := strconv.Atoi(string(byts)) + if err != nil || statusCode != StatusAsyncEvent { + return err + } + // Read the entire thing and handle it + resp, err := c.ReadResponse() + if err != nil { + return err + } + c.onAsyncResponse(resp) + return nil +} + +func (c *Conn) AddEventListener(ch chan<- Event, events ...EventCode) error { // TODO: do we want to set the local map first? Or do we want to lock on the net request too? c.eventListenersLock.Lock() for _, event := range events { @@ -48,7 +118,7 @@ func (c *Conn) AddEventListener(events []EventCode, ch chan<- Event) error { return c.sendSetEvents() } -func (c *Conn) RemoveEventListener(events []EventCode, ch chan<- Event) error { +func (c *Conn) RemoveEventListener(ch chan<- Event, events ...EventCode) error { // TODO: do we want to mutate the local map first? c.eventListenersLock.Lock() for _, event := range events { @@ -87,10 +157,17 @@ func (c *Conn) sendSetEvents() error { } func (c *Conn) relayAsyncEvents(resp *Response) { - code, data, _ := util.PartitionString(resp.Reply, ' ') - // If there is an element in the data array, use that instead for the data + // If there is data, use the first line as the code and + var code, data string if len(resp.Data) > 0 { - data = resp.Data[0] + firstNewline := strings.Index(resp.Data[0], "\r\n") + if firstNewline == -1 { + return + } + code, data = resp.Data[0][:firstNewline], resp.Data[0][firstNewline+2:] + } else { + // Otherwise, the reply line has the data + code, data, _ = util.PartitionString(resp.Reply, ' ') } // Only relay if there are chans c.eventListenersLock.RLock() diff --git a/control/conn.go b/control/conn.go index d5ffdfa..3d82023 100644 --- a/control/conn.go +++ b/control/conn.go @@ -25,6 +25,8 @@ type Conn struct { eventListenersLock sync.RWMutex // The value slices can be traversed outside of lock, they are completely replaced on change, never mutated eventListeners map[EventCode][]chan<- Event + + readLock sync.Mutex } func NewConn(conn *textproto.Conn) *Conn { @@ -47,6 +49,8 @@ func (c *Conn) SendRequest(format string, args ...interface{}) (*Response, error if err != nil { return nil, err } + c.readLock.Lock() + defer c.readLock.Unlock() c.conn.StartResponse(id) defer c.conn.EndResponse(id) // Get the first non-async response diff --git a/control/controltest/cmd_event_test.go b/control/controltest/cmd_event_test.go new file mode 100644 index 0000000..68b4d2d --- /dev/null +++ b/control/controltest/cmd_event_test.go @@ -0,0 +1,46 @@ +package controltest + +import ( + "context" + "testing" + "time" + + "github.com/cretz/bine/control" +) + +func TestEvents(t *testing.T) { + SkipIfNotRunningSpecifically(t) + ctx, conn := NewTestContextAuthenticated(t) + defer ctx.CloseConnected(conn) + // Turn on event handler + eventCtx, cancelFn := context.WithCancel(ctx) + defer cancelFn() + go func() { ctx.Require.Equal(context.Canceled, conn.HandleEvents(eventCtx)) }() + // Enable all events and hold on to which ones were seen + allEvents := control.EventCodes() + seenEvents := map[control.EventCode]struct{}{} + ch := make(chan control.Event, 1000) + ctx.Require.NoError(conn.AddEventListener(ch, allEvents...)) + // Turn on the network + ctx.Require.NoError(conn.SetConf(control.NewKeyVal("DisableNetwork", "0"))) +MainLoop: + for { + select { + case e := <-ch: + // Remove the event listener if it was seen + if _, ok := seenEvents[e.Code()]; !ok { + ctx.Debugf("Got event: %v", e.Code()) + seenEvents[e.Code()] = struct{}{} + ctx.Require.NoError(conn.RemoveEventListener(ch, e.Code())) + } + case <-time.After(3 * time.Second): + ctx.Debugf("3 seconds passed") + break MainLoop + } + } + // Check that each event was sent at least once + for _, event := range allEvents { + _, ok := seenEvents[event] + ctx.Debugf("Event %v seen? %v", event, ok) + } +} diff --git a/control/controltest/cmd_hiddenservice_test.go b/control/controltest/cmd_hiddenservice_test.go index d846ce7..b9ffdbe 100644 --- a/control/controltest/cmd_hiddenservice_test.go +++ b/control/controltest/cmd_hiddenservice_test.go @@ -2,16 +2,10 @@ package controltest import ( "testing" - "time" - - "github.com/cretz/bine/control" ) func TestGetHiddenServiceDescriptorAsync(t *testing.T) { ctx, conn := NewTestContextAuthenticated(t) defer ctx.CloseConnected(conn) - // Enable the network - ctx.Require.NoError(conn.SetConf(control.NewKeyVal("DisableNetwork", "0"))) - ctx.Require.NoError(conn.GetHiddenServiceDescriptorAsync("facebookcorewwwi", "")) - time.Sleep(60 * time.Second) + t.Skip("TODO") } diff --git a/control/controltest/control_test.go b/control/controltest/control_test.go new file mode 100644 index 0000000..2309240 --- /dev/null +++ b/control/controltest/control_test.go @@ -0,0 +1,12 @@ +package controltest + +import ( + "flag" + "testing" +) + +func SkipIfNotRunningSpecifically(t *testing.T) { + if f := flag.Lookup("test.run"); f == nil || f.Value == nil || f.Value.String() != t.Name() { + t.Skip("Only runs if -run specifies this test exactly") + } +} diff --git a/control/controltest/test_context.go b/control/controltest/test_context.go index 0de6b83..023dfab 100644 --- a/control/controltest/test_context.go +++ b/control/controltest/test_context.go @@ -2,7 +2,9 @@ package controltest import ( "context" + "flag" "fmt" + "io" "net/textproto" "os" "strconv" @@ -19,10 +21,18 @@ type TestContext struct { ExtraTorArgs []string Require *require.Assertions TestTor *TestTor + DebugWriter io.Writer } func NewTestContext(ctx context.Context, t *testing.T, extraTorArgs ...string) *TestContext { - return &TestContext{Context: ctx, T: t, ExtraTorArgs: extraTorArgs, Require: require.New(t)} + ret := &TestContext{Context: ctx, T: t, ExtraTorArgs: extraTorArgs, Require: require.New(t)} + testVerboseFlag := flag.Lookup("test.v") + if testVerboseFlag != nil && testVerboseFlag.Value != nil && testVerboseFlag.Value.String() == "true" { + ret.DebugWriter = os.Stdout + } else { + ret.ExtraTorArgs = append(append([]string{}, ret.ExtraTorArgs...), "--quiet") + } + return ret } func NewTestContextConnected(t *testing.T, extraTorArgs ...string) (*TestContext, *control.Conn) { @@ -76,6 +86,16 @@ func (t *TestContext) ConnectTestTor() (*control.Conn, error) { return nil, err } conn := control.NewConn(textConn) - conn.DebugWriter = os.Stdout + conn.DebugWriter = t.DebugWriter return conn, nil } + +func (t *TestContext) DebugEnabled() bool { + return t.DebugWriter != nil +} + +func (t *TestContext) Debugf(format string, args ...interface{}) { + if w := t.DebugWriter; w != nil { + fmt.Fprintf(w, format+"\n", args...) + } +} diff --git a/control/controltest/test_tor.go b/control/controltest/test_tor.go index 78c3c1f..dd6b6ba 100644 --- a/control/controltest/test_tor.go +++ b/control/controltest/test_tor.go @@ -42,7 +42,6 @@ func StartTestTor(ctx context.Context, extraArgs ...string) (*TestTor, error) { DataDir: dataDir, OrigArgs: append([]string{ "-f", torrcFile, - // "--quiet", "--DisableNetwork", "1", "--ControlPort", "auto", "--ControlPortWriteToFile", controlPortFile, diff --git a/control/response.go b/control/response.go index cdf371c..5b455fe 100644 --- a/control/response.go +++ b/control/response.go @@ -87,8 +87,8 @@ func (c *Conn) ReadResponse() (*Response, error) { return nil, err } dotBodyStr := strings.TrimRight(string(dotBody), "\n\r") - c.debugf("Read dot body:\n---\n%v\n---", dotBodyStr) - resp.Data = append(resp.Data, line[4:]+dotBodyStr) + // c.debugf("Read dot body:\n---\n%v\n---", dotBodyStr) + resp.Data = append(resp.Data, line[4:]+"\r\n"+dotBodyStr) dotLines := strings.Split(dotBodyStr, "\n") for _, dotLine := range dotLines[:len(dotLines)-1] { resp.RawLines = append(resp.RawLines, dotLine)