diff --git a/control/cmd_protocolinfo.go b/control/cmd_protocolinfo.go new file mode 100644 index 0000000..f3c27b3 --- /dev/null +++ b/control/cmd_protocolinfo.go @@ -0,0 +1,62 @@ +package control + +import ( + "strings" +) + +type ProtocolInfo struct { + AuthMethods []string + CookieFile string + TorVersion string + RawResponse *Response +} + +func (p *ProtocolInfo) HasAuthMethod(authMethod string) bool { + for _, m := range p.AuthMethods { + if m == authMethod { + return true + } + } + return false +} + +func (c *Conn) RequestProtocolInfo() (*ProtocolInfo, error) { + resp, err := c.SendRequest("PROTOCOLINFO") + if err != nil { + return nil, err + } + // Check PIVERSION + if len(resp.Data) == 0 || resp.Data[0] != "1" { + return nil, newProtocolError("Invalid PIVERSION: %s", resp.Reply) + } + // Get other response vals + ret := &ProtocolInfo{RawResponse: resp} + for _, piece := range resp.Data { + key, val, ok := partitionString(piece, ' ') + if !ok { + continue + } + switch key { + case "AUTH": + methods, cookieFile, _ := partitionString(val, ' ') + if !strings.HasPrefix(methods, "METHODS=") { + continue + } + if cookieFile != "" { + if !strings.HasPrefix(cookieFile, "COOKIEFILE=") { + continue + } + if ret.CookieFile, err = parseQuotedString(cookieFile[11:]); err != nil { + continue + } + } + ret.AuthMethods = strings.Split(methods[8:], ",") + case "VERSION": + torVersion, _, _ := partitionString(val, ' ') + if strings.HasPrefix(torVersion, "Tor=") { + ret.TorVersion, _ = parseQuotedString(torVersion[4:]) + } + } + } + return ret, nil +} diff --git a/control/conn.go b/control/conn.go new file mode 100644 index 0000000..3ea4484 --- /dev/null +++ b/control/conn.go @@ -0,0 +1,99 @@ +package control + +import ( + "fmt" + "net/textproto" + "sync" +) + +type Conn struct { + conn *textproto.Conn + + asyncChansLock sync.RWMutex + // Never mutated outside of lock, always created anew + asyncChans []chan<- *Response +} + +func NewConn(conn *textproto.Conn) *Conn { return &Conn{conn: conn} } + +func (c *Conn) SendSignal(signal string) error { + _, err := c.SendRequest("SIGNAL %v", signal) + return err +} + +func (c *Conn) SendRequest(format string, args ...interface{}) (*Response, error) { + id, err := c.conn.Cmd(format, args...) + if err != nil { + return nil, err + } + c.conn.StartResponse(id) + defer c.conn.EndResponse(id) + // Get the first non-async response + var resp *Response + for { + if resp, err = c.ReadResponse(); err != nil || !resp.IsAsync() { + break + } + c.onAsyncResponse(resp) + } + if err == nil && !resp.IsOk() { + err = resp.Err + } + return resp, err +} + +func (c *Conn) Close() error { + // We'll close all the chans first + c.asyncChansLock.Lock() + for _, ch := range c.asyncChans { + close(ch) + } + c.asyncChans = nil + c.asyncChansLock.Unlock() + // Ignore the response and ignore the error + c.SendRequest("QUIT") + return c.conn.Close() +} + +func (c *Conn) AddAsyncChan(ch chan<- *Response) { + c.asyncChansLock.Lock() + chans := make([]chan<- *Response, len(c.asyncChans)+1) + copy(chans, c.asyncChans) + chans[len(chans)-1] = ch + c.asyncChans = chans + c.asyncChansLock.Unlock() +} + +// Does not close +func (c *Conn) RemoveChan(ch chan<- *Response) bool { + c.asyncChansLock.Lock() + chans := make([]chan<- *Response, len(c.asyncChans)+1) + copy(chans, c.asyncChans) + index := -1 + for i, existing := range chans { + if existing == ch { + index = i + break + } + } + if index != -1 { + chans = append(chans[:index], chans[index+1:]...) + } + c.asyncChans = chans + c.asyncChansLock.Unlock() + return index != -1 +} + +func (c *Conn) onAsyncResponse(resp *Response) { + c.asyncChansLock.RLock() + chans := c.asyncChans + c.asyncChansLock.RUnlock() + // We will allow channels to block + for _, ch := range chans { + ch <- resp + } +} + +func newProtocolError(format string, args ...interface{}) textproto.ProtocolError { + return textproto.ProtocolError(fmt.Sprintf(format, args...)) +} diff --git a/control/parse.go b/control/parse.go new file mode 100644 index 0000000..89f4649 --- /dev/null +++ b/control/parse.go @@ -0,0 +1,47 @@ +package control + +import ( + "fmt" + "strings" +) + +func partitionString(str string, ch byte) (string, string, bool) { + index := strings.IndexByte(str, ch) + if index == -1 { + return str, "", false + } + return str[:index], str[index+1:], true +} + +func parseQuotedString(str string) (string, error) { + if len(str) < 2 || str[0] != '"' || str[len(str)-1] != '"' { + return "", fmt.Errorf("Missing quotes") + } + return unescapeQuoted(str) +} + +func unescapeQuoted(str string) (string, error) { + ret := "" + escaping := false + for _, c := range str { + switch c { + case '\\': + if escaping { + ret += "\\" + } + escaping = !escaping + case '"': + if !escaping { + return "", fmt.Errorf("Unescaped quote") + } + ret += "\"" + escaping = false + default: + if escaping { + return "", fmt.Errorf("Unexpected escape") + } + ret += string(c) + } + } + return ret, nil +} diff --git a/control/response.go b/control/response.go new file mode 100644 index 0000000..09b3954 --- /dev/null +++ b/control/response.go @@ -0,0 +1,111 @@ +package control + +import ( + "net/textproto" + "strconv" + "strings" +) + +// Response is a response to a control port command, or an asyncrhonous event. +type Response struct { + // Err is the status code and string representation associated with a + // response. Responses that have completed successfully will also have + // Err set to indicate such. + Err *textproto.Error + + // Reply is the text on the EndReplyLine of the response. + Reply string + + // Data is the MidReplyLines/DataReplyLines of the response. Dot encoded + // data is "decoded" and presented as a single string (terminal ".CRLF" + // removed, all intervening CRs stripped). + Data []string + + // RawLines is all of the lines of a response, without CRLFs. + RawLines []string +} + +// IsOk returns true if the response status code indicates success or +// an asynchronous event. +func (r *Response) IsOk() bool { + switch r.Err.Code { + case StatusOk, StatusOkUnneccecary, StatusAsyncEvent: + return true + default: + return false + } +} + +// IsAsync returns true if the response is an asyncrhonous event. +func (r *Response) IsAsync() bool { + return r.Err.Code == StatusAsyncEvent +} + +// ReadResponse returns the next response object. Calling this +// simultaniously with Read, Request, or StartAsyncReader will lead to +// undefined behavior +func (c *Conn) ReadResponse() (*Response, error) { + var resp *Response + var statusCode int + for { + line, err := c.conn.ReadLine() + if err != nil { + return nil, err + } + + // Parse the line that was just read. + if len(line) < 4 { + return nil, newProtocolError("truncated response: '%s'", line) + } + if code, err := strconv.Atoi(line[0:3]); err != nil { + return nil, newProtocolError("invalid status code: '%s'", line[0:3]) + } else if code < 100 { + return nil, newProtocolError("invalid status code: '%s'", line[0:3]) + } else if resp == nil { + resp = new(Response) + statusCode = code + } else if code != statusCode { + // The status code should stay fixed for all lines of the + // response, since events can't be interleaved with response + // lines. + return nil, newProtocolError("status code changed: %03d != %03d", code, statusCode) + } + if resp.RawLines == nil { + resp.RawLines = make([]string, 0, 1) + } + + if line[3] == ' ' { + // Final line in the response. + resp.Reply = line[4:] + resp.Err = statusCodeToError(statusCode, resp.Reply) + resp.RawLines = append(resp.RawLines, line) + return resp, nil + } + + if resp.Data == nil { + resp.Data = make([]string, 0, 1) + } + switch line[3] { + case '-': + // Continuation, keep reading. + resp.Data = append(resp.Data, line[4:]) + resp.RawLines = append(resp.RawLines, line) + case '+': + // A "dot-encoded" payload follows. + resp.Data = append(resp.Data, line[4:]) + resp.RawLines = append(resp.RawLines, line) + dotBody, err := c.conn.ReadDotBytes() + if err != nil { + return nil, err + } + resp.Data = append(resp.Data, strings.TrimRight(string(dotBody), "\n\r")) + dotLines := strings.Split(string(dotBody), "\n") + for _, dotLine := range dotLines[:len(dotLines)-1] { + resp.RawLines = append(resp.RawLines, dotLine) + } + resp.RawLines = append(resp.RawLines, ".") + default: + return nil, newProtocolError("invalid separator: '%c'", line[3]) + } + } +} diff --git a/control/status.go b/control/status.go new file mode 100644 index 0000000..2e8469c --- /dev/null +++ b/control/status.go @@ -0,0 +1,64 @@ +package control + +import ( + "fmt" + "net/textproto" + "strings" +) + +// The various control port StatusCode constants. +const ( + StatusOk = 250 + StatusOkUnneccecary = 251 + + StatusErrResourceExhausted = 451 + StatusErrSyntaxError = 500 + StatusErrUnrecognizedCmd = 510 + StatusErrUnimplementedCmd = 511 + StatusErrSyntaxErrorArg = 512 + StatusErrUnrecognizedCmdArg = 513 + StatusErrAuthenticationRequired = 514 + StatusErrBadAuthentication = 515 + StatusErrUnspecifiedTorError = 550 + StatusErrInternalError = 551 + StatusErrUnrecognizedEntity = 552 + StatusErrInvalidConfigValue = 553 + StatusErrInvalidDescriptor = 554 + StatusErrUnmanagedEntity = 555 + + StatusAsyncEvent = 650 +) + +var statusCodeStringMap = map[int]string{ + StatusOk: "OK", + StatusOkUnneccecary: "Operation was unnecessary", + + StatusErrResourceExhausted: "Resource exhausted", + StatusErrSyntaxError: "Syntax error: protocol", + StatusErrUnrecognizedCmd: "Unrecognized command", + StatusErrUnimplementedCmd: "Unimplemented command", + StatusErrSyntaxErrorArg: "Syntax error in command argument", + StatusErrUnrecognizedCmdArg: "Unrecognized command argument", + StatusErrAuthenticationRequired: "Authentication required", + StatusErrBadAuthentication: "Bad authentication", + StatusErrUnspecifiedTorError: "Unspecified Tor error", + StatusErrInternalError: "Internal error", + StatusErrUnrecognizedEntity: "Unrecognized entity", + StatusErrInvalidConfigValue: "Invalid configuration value", + StatusErrInvalidDescriptor: "Invalid descriptor", + StatusErrUnmanagedEntity: "Unmanaged entity", + + StatusAsyncEvent: "Asynchronous event notification", +} + +func statusCodeToError(code int, reply string) *textproto.Error { + err := new(textproto.Error) + err.Code = code + if msg, ok := statusCodeStringMap[code]; ok { + trimmedReply := strings.TrimSpace(strings.TrimPrefix(reply, msg)) + err.Msg = fmt.Sprintf("%s: %s", msg, trimmedReply) + } else { + err.Msg = fmt.Sprintf("Unknown status code (%03d): %s", code, reply) + } + return err +} diff --git a/process/process.go b/process/process.go new file mode 100644 index 0000000..3f13d1d --- /dev/null +++ b/process/process.go @@ -0,0 +1,22 @@ +package process + +import ( + "context" + "os/exec" +) + +type Tor interface { + Start(ctx context.Context, args []string) error +} + +type exeTor struct { + exePath string +} + +func FromExePath(exePath string) Tor { + return &exeTor{exePath} +} + +func (e *exeTor) Start(ctx context.Context, args []string) error { + return exec.CommandContext(ctx, e.exePath, args...).Start() +}