diff --git a/control/cmd_conf.go b/control/cmd_conf.go new file mode 100644 index 0000000..e3d43b0 --- /dev/null +++ b/control/cmd_conf.go @@ -0,0 +1,65 @@ +package control + +import ( + "strings" + + "github.com/cretz/bine/util" +) + +type ConfEntry struct { + Key string + Value *string +} + +func NewConfEntry(key string, value *string) *ConfEntry { + return &ConfEntry{Key: key, Value: value} +} + +func (c *Conn) SetConf(entries []*ConfEntry) error { + return c.sendSetConf("SETCONF", entries) +} + +func (c *Conn) ResetConf(entries []*ConfEntry) error { + return c.sendSetConf("RESETCONF", entries) +} + +func (c *Conn) sendSetConf(cmd string, entries []*ConfEntry) error { + for _, entry := range entries { + cmd += " " + entry.Key + if entry.Value != nil { + cmd += "=" + util.EscapeSimpleQuotedStringIfNeeded(*entry.Value) + } + } + _, err := c.SendRequest(cmd) + return err +} + +func (c *Conn) GetConf(keys ...string) ([]*ConfEntry, error) { + resp, err := c.SendRequest("GETCONF %v", strings.Join(keys, " ")) + if err != nil { + return nil, err + } + data := resp.DataWithReply() + ret := make([]*ConfEntry, 0, len(data)) + for _, data := range data { + key, val, ok := util.PartitionString(data, '=') + entry := &ConfEntry{Key: key} + if ok { + if val, err = util.UnescapeSimpleQuotedStringIfNeeded(val); err != nil { + return nil, err + } + entry.Value = &val + } + ret = append(ret, entry) + } + return ret, nil +} + +func (c *Conn) SaveConf(force bool) error { + cmd := "SAVECONF" + if force { + cmd += " FORCE" + } + _, err := c.SendRequest(cmd) + return err +} diff --git a/control/cmd_event.go b/control/cmd_event.go new file mode 100644 index 0000000..5701756 --- /dev/null +++ b/control/cmd_event.go @@ -0,0 +1,167 @@ +package control + +import ( + "strings" + "time" + + "github.com/cretz/bine/util" +) + +type EventCode string + +const ( + EventCodeAddrMap EventCode = "ADDRMAP" + EventCodeCirc EventCode = "CIRC" +) + +func (c *Conn) AddEventListener(events []EventCode, ch chan<- Event) error { + // TODO: do we want to set the local map first? Or do we want to lock on the net request too? + c.eventListenersLock.Lock() + for _, event := range events { + // Must completely replace the array, never mutate it + prevArr := c.eventListeners[event] + newArr := make([]chan<- Event, len(prevArr)+1) + copy(newArr, prevArr) + newArr[len(newArr)-1] = ch + c.eventListeners[event] = newArr + } + c.eventListenersLock.Unlock() + return c.sendSetEvents() +} + +func (c *Conn) RemoveEventListener(events []EventCode, ch chan<- Event) error { + // TODO: do we want to mutate the local map first? + c.eventListenersLock.Lock() + for _, event := range events { + arr := c.eventListeners[event] + index := -1 + for i, listener := range arr { + if listener == ch { + index = i + break + } + } + if index != -1 { + if len(arr) == 1 { + delete(c.eventListeners, event) + } else { + // Must completely replace the array, never mutate it + newArr := make([]chan<- Event, len(arr)-1) + copy(newArr, arr[:index]) + copy(newArr[index:], arr[index+1:]) + c.eventListeners[event] = newArr + } + } + } + c.eventListenersLock.Unlock() + return c.sendSetEvents() +} + +func (c *Conn) sendSetEvents() error { + c.eventListenersLock.RLock() + cmd := "SETEVENTS" + for event := range c.eventListeners { + cmd += " " + string(event) + } + c.eventListenersLock.RUnlock() + _, err := c.SendRequest(cmd) + return err +} + +// zero on fail +func parseISOTime2Frac(str string) time.Time { + // Essentially time.RFC3339Nano but without TZ info + const layout = "2006-01-02T15:04:05.999999999" + ret, err := time.Parse(layout, str) + if err != nil { + ret = time.Time{} + } + return ret +} + +type CircuitEvent struct { + CircuitID string + Status string + Path []string + BuildFlags []string + Purpose string + HSState string + RendQuery string + TimeCreated time.Time + Reason string + RemoteReason string + SocksUsername string + SocksPassword string + Raw string +} + +func ParseCircuitEvent(raw string) *CircuitEvent { + event := &CircuitEvent{Raw: raw} + event.CircuitID, raw, _ = util.PartitionString(raw, ' ') + var ok bool + event.Status, raw, ok = util.PartitionString(raw, ' ') + var attr string + first := true + for ok { + if attr, raw, ok = util.PartitionString(raw, ' '); !ok { + break + } + key, val, _ := util.PartitionString(attr, '=') + switch key { + case "BUILD_FLAGS": + event.BuildFlags = strings.Split(val, ",") + case "PURPOSE": + event.Purpose = val + case "HS_STATE": + event.HSState = val + case "REND_QUERY": + event.RendQuery = val + case "TIME_CREATED": + event.TimeCreated = parseISOTime2Frac(val) + case "REASON": + event.Reason = val + case "REMOTE_REASON": + event.RemoteReason = val + case "SOCKS_USERNAME": + event.SocksUsername = val + case "SOCKS_PASSWORD": + event.SocksPassword = val + default: + if first { + event.Path = strings.Split(val, ",") + } + } + first = false + } + return event +} + +type Event interface { + Code() EventCode +} + +func (*CircuitEvent) Code() EventCode { return EventCodeCirc } + +func (c *Conn) relayAsyncEvents(resp *Response) { + code, data, _ := util.PartitionString(resp.Reply, ' ') + // Only relay if there are chans + c.eventListenersLock.RLock() + chans := c.eventListeners[EventCode(code)] + c.eventListenersLock.RUnlock() + if len(chans) == 0 { + return + } + // Parse the event + // TODO: more events + var event Event + switch EventCode(code) { + case EventCodeCirc: + event = ParseCircuitEvent(data) + } + if event != nil { + for _, ch := range chans { + // Just send, if closed or blocking, that's not our problem + ch <- event + } + } +} diff --git a/control/cmd_misc.go b/control/cmd_misc.go new file mode 100644 index 0000000..2251a61 --- /dev/null +++ b/control/cmd_misc.go @@ -0,0 +1,36 @@ +package control + +import "github.com/cretz/bine/util" + +func (c *Conn) Signal(signal string) error { + _, err := c.SendRequest("SIGNAL %v", signal) + return err +} + +type MappedAddress struct { + Old string + New string +} + +func NewMappedAddress(old string, new string) *MappedAddress { + return &MappedAddress{Old: old, New: new} +} + +func (c *Conn) MapAddresses(addresses []*MappedAddress) ([]*MappedAddress, error) { + cmd := "MAPADDRESS" + for _, address := range addresses { + cmd += " " + address.New + "=" + address.Old + } + resp, err := c.SendRequest(cmd) + if err != nil { + return nil, err + } + data := resp.DataWithReply() + ret := make([]*MappedAddress, 0, len(data)) + for _, address := range data { + mappedAddress := &MappedAddress{} + mappedAddress.Old, mappedAddress.New, _ = util.PartitionString(address, '=') + ret = append(ret, mappedAddress) + } + return ret, nil +} diff --git a/control/cmd_protocolinfo.go b/control/cmd_protocolinfo.go index 0231548..5b94e39 100644 --- a/control/cmd_protocolinfo.go +++ b/control/cmd_protocolinfo.go @@ -56,7 +56,7 @@ func (c *Conn) sendProtocolInfo() (*ProtocolInfo, error) { if !strings.HasPrefix(cookieFile, "COOKIEFILE=") { continue } - if ret.CookieFile, err = util.ParseSimpleQuotedString(cookieFile[11:]); err != nil { + if ret.CookieFile, err = util.UnescapeSimpleQuotedString(cookieFile[11:]); err != nil { continue } } @@ -64,7 +64,7 @@ func (c *Conn) sendProtocolInfo() (*ProtocolInfo, error) { case "VERSION": torVersion, _, _ := util.PartitionString(val, ' ') if strings.HasPrefix(torVersion, "Tor=") { - ret.TorVersion, err = util.ParseSimpleQuotedString(torVersion[4:]) + ret.TorVersion, err = util.UnescapeSimpleQuotedString(torVersion[4:]) } } } diff --git a/control/conn.go b/control/conn.go index 8be3212..4be92dc 100644 --- a/control/conn.go +++ b/control/conn.go @@ -14,20 +14,24 @@ type Conn struct { conn *textproto.Conn asyncChansLock sync.RWMutex - // Never mutated outside of lock, always created anew + // Can be traversed outside of lock, entire field is replaced on change asyncChans []chan<- *Response // Set lazily protocolInfo *ProtocolInfo Authenticated bool + + eventListenersLock sync.RWMutex + // The value slices can be traversed outside of lock, they are completely replaced on change, never mutated + eventListeners map[EventCode][]chan<- Event } -func NewConn(conn *textproto.Conn) *Conn { return &Conn{conn: conn} } - -func (c *Conn) SendSignal(signal string) error { - _, err := c.SendRequest("SIGNAL %v", signal) - return err +func NewConn(conn *textproto.Conn) *Conn { + return &Conn{ + conn: conn, + eventListeners: map[EventCode][]chan<- Event{}, + } } func (c *Conn) SendRequest(format string, args ...interface{}) (*Response, error) { @@ -54,6 +58,11 @@ func (c *Conn) SendRequest(format string, args ...interface{}) (*Response, error return resp, err } +func (c *Conn) Quit() error { + _, err := c.SendRequest("QUIT") + return err +} + func (c *Conn) Close() error { // We'll close all the chans first c.asyncChansLock.Lock() @@ -63,7 +72,7 @@ func (c *Conn) Close() error { c.asyncChans = nil c.asyncChansLock.Unlock() // Ignore the response and ignore the error - c.SendRequest("QUIT") + c.Quit() return c.conn.Close() } @@ -97,6 +106,8 @@ func (c *Conn) RemoveAsyncChan(ch chan<- *Response) bool { } func (c *Conn) onAsyncResponse(resp *Response) { + // First, relay events + c.relayAsyncEvents(resp) c.asyncChansLock.RLock() chans := c.asyncChans c.asyncChansLock.RUnlock() diff --git a/control/response.go b/control/response.go index d2b4218..dc3bef3 100644 --- a/control/response.go +++ b/control/response.go @@ -26,13 +26,20 @@ type Response struct { // IsOk returns true if the response status code indicates success or an asynchronous event. func (r *Response) IsOk() bool { switch r.Err.Code { - case StatusOk, StatusOkUnneccecary, StatusAsyncEvent: + case StatusOk, StatusOkUnnecessary, StatusAsyncEvent: return true default: return false } } +func (r *Response) DataWithReply() []string { + ret := make([]string, len(r.Data)+1) + copy(ret, r.Data) + ret[len(ret)-1] = r.Reply + return ret +} + // IsAsync returns true if the response is an asyncrhonous event. func (r *Response) IsAsync() bool { return r.Err.Code == StatusAsyncEvent diff --git a/control/status.go b/control/status.go index 2e8469c..8be3464 100644 --- a/control/status.go +++ b/control/status.go @@ -9,7 +9,7 @@ import ( // The various control port StatusCode constants. const ( StatusOk = 250 - StatusOkUnneccecary = 251 + StatusOkUnnecessary = 251 StatusErrResourceExhausted = 451 StatusErrSyntaxError = 500 @@ -31,7 +31,7 @@ const ( var statusCodeStringMap = map[int]string{ StatusOk: "OK", - StatusOkUnneccecary: "Operation was unnecessary", + StatusOkUnnecessary: "Operation was unnecessary", StatusErrResourceExhausted: "Resource exhausted", StatusErrSyntaxError: "Syntax error: protocol", diff --git a/util/strings.go b/util/strings.go index 6fb9b5b..a94916c 100644 --- a/util/strings.go +++ b/util/strings.go @@ -13,14 +13,39 @@ func PartitionString(str string, ch byte) (string, string, bool) { return str[:index], str[index+1:], true } -func ParseSimpleQuotedString(str string) (string, error) { +func EscapeSimpleQuotedStringIfNeeded(str string) string { + if strings.ContainsAny(str, " \\\"\r\n") { + return EscapeSimpleQuotedString(str) + } + return str +} + +var SimpleQuotedStringEscapeReplacer = strings.NewReplacer( + "\\", "\\\\", + "\"", "\\\"", + "\r", "\\r", + "\n", "\\n", +) + +func EscapeSimpleQuotedString(str string) string { + return "\"" + SimpleQuotedStringEscapeReplacer.Replace(str) + "\"" +} + +func UnescapeSimpleQuotedStringIfNeeded(str string) (string, error) { + if len(str) >= 2 && str[0] == '"' && str[len(str)-1] == '"' { + return UnescapeSimpleQuotedString(str) + } + return str, nil +} + +func UnescapeSimpleQuotedString(str string) (string, error) { if len(str) < 2 || str[0] != '"' || str[len(str)-1] != '"' { return "", fmt.Errorf("Missing quotes") } - return UnescapeSimpleQuoted(str[1 : len(str)-1]) + return UnescapeSimpleQuotedStringContents(str[1 : len(str)-1]) } -func UnescapeSimpleQuoted(str string) (string, error) { +func UnescapeSimpleQuotedStringContents(str string) (string, error) { ret := "" escaping := false for _, c := range str {