diff --git a/control/cmd_authenticate.go b/control/cmd_authenticate.go new file mode 100644 index 0000000..3d01aa0 --- /dev/null +++ b/control/cmd_authenticate.go @@ -0,0 +1,108 @@ +package control + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "io/ioutil" + "strings" +) + +// Authenticate authenticates with the Tor instance using the "best" possible authentication method if not already +// authenticated and sets the Authenticated field. The password argument is optional, and will only be used if the +// "SAFECOOKIE" and "NULL" authentication methods are not available and "HASHEDPASSWORD" is. +func (c *Conn) Authenticate(password string) error { + if c.Authenticated { + return nil + } + // Determine the supported authentication methods, and the cookie path. + pi, err := c.ProtocolInfo() + if err != nil { + return err + } + // Get the bytes to pass to with authenticate + var authBytes []byte + if pi.HasAuthMethod("NULL") { + // No auth bytes + } else if pi.HasAuthMethod("SAFECOOKIE") { + if pi.CookieFile == "" { + return c.protoErr("Invalid (empty) COOKIEFILE") + } + cookie, err := ioutil.ReadFile(pi.CookieFile) + if err != nil { + return c.protoErr("Failed to read COOKIEFILE: %v", err) + } else if len(cookie) != 32 { + return c.protoErr("Invalid cookie file length: %v", len(cookie)) + } + + // Send an AUTHCHALLENGE command, and parse the response. + var clientNonce [32]byte + if _, err := rand.Read(clientNonce[:]); err != nil { + return c.protoErr("Failed to generate clientNonce: %v", err) + } + resp, err := c.SendRequest("AUTHCHALLENGE %s %s", "SAFECOOKIE", hex.EncodeToString(clientNonce[:])) + if err != nil { + return err + } + splitResp := strings.Split(resp.Reply, " ") + if len(splitResp) != 3 || !strings.HasPrefix(splitResp[1], "SERVERHASH=") || + !strings.HasPrefix(splitResp[2], "SERVERNONCE=") { + return c.protoErr("Invalid AUTHCHALLENGE response") + } + serverHash, err := hex.DecodeString(splitResp[1][11:]) + if err != nil { + return c.protoErr("Failed to decode ServerHash: %v", err) + } + if len(serverHash) != 32 { + return c.protoErr("Invalid ServerHash length: %d", len(serverHash)) + } + serverNonce, err := hex.DecodeString(splitResp[2][12:]) + if err != nil { + return c.protoErr("Failed to decode ServerNonce: %v", err) + } + if len(serverNonce) != 32 { + return c.protoErr("Invalid ServerNonce length: %d", len(serverNonce)) + } + + // Validate the ServerHash. + m := hmac.New(sha256.New, []byte("Tor safe cookie authentication server-to-controller hash")) + m.Write(cookie) + m.Write(clientNonce[:]) + m.Write(serverNonce) + dervServerHash := m.Sum(nil) + if !hmac.Equal(serverHash, dervServerHash) { + return c.protoErr("invalid ServerHash: mismatch") + } + + // Calculate the ClientHash, and issue the AUTHENTICATE. + m = hmac.New(sha256.New, []byte("Tor safe cookie authentication controller-to-server hash")) + m.Write(cookie) + m.Write(clientNonce[:]) + m.Write(serverNonce) + authBytes = m.Sum(nil) + } else if pi.HasAuthMethod("HASHEDPASSWORD") { + // Despite the name HASHEDPASSWORD, the raw password is actually sent. According to the code, this can either be + // a QuotedString, or base16 encoded, so go with the later since it's easier to handle. + if password == "" { + return c.protoErr("password auth needs a password") + } + authBytes = []byte(password) + } else { + return c.protoErr("No supported authentication methods") + } + // Send it + if err = c.sendAuthenticate(authBytes); err == nil { + c.Authenticated = true + } + return err +} + +func (c *Conn) sendAuthenticate(byts []byte) (err error) { + if len(byts) == 0 { + _, err = c.SendRequest("AUTHENTICATE") + } else { + _, err = c.SendRequest("AUTHENTICATE %v", hex.EncodeToString(byts)) + } + return +} diff --git a/control/cmd_protocolinfo.go b/control/cmd_protocolinfo.go index 3710489..0231548 100644 --- a/control/cmd_protocolinfo.go +++ b/control/cmd_protocolinfo.go @@ -22,7 +22,15 @@ func (p *ProtocolInfo) HasAuthMethod(authMethod string) bool { return false } -func (c *Conn) RequestProtocolInfo() (*ProtocolInfo, error) { +func (c *Conn) ProtocolInfo() (*ProtocolInfo, error) { + var err error + if c.protocolInfo == nil { + c.protocolInfo, err = c.sendProtocolInfo() + } + return c.protocolInfo, err +} + +func (c *Conn) sendProtocolInfo() (*ProtocolInfo, error) { resp, err := c.SendRequest("PROTOCOLINFO") if err != nil { return nil, err @@ -37,7 +45,7 @@ func (c *Conn) RequestProtocolInfo() (*ProtocolInfo, error) { switch key { case "PROTOCOLINFO": if val != "1" { - return nil, newProtocolError("Invalid PIVERSION: %v", val) + return nil, c.protoErr("Invalid PIVERSION: %v", val) } case "AUTH": methods, cookieFile, _ := util.PartitionString(val, ' ') diff --git a/control/conn.go b/control/conn.go index bcb8985..8be3212 100644 --- a/control/conn.go +++ b/control/conn.go @@ -16,6 +16,11 @@ type Conn struct { asyncChansLock sync.RWMutex // Never mutated outside of lock, always created anew asyncChans []chan<- *Response + + // Set lazily + protocolInfo *ProtocolInfo + + Authenticated bool } func NewConn(conn *textproto.Conn) *Conn { return &Conn{conn: conn} } @@ -26,6 +31,9 @@ func (c *Conn) SendSignal(signal string) error { } func (c *Conn) SendRequest(format string, args ...interface{}) (*Response, error) { + if c.debugEnabled() { + c.debugf("Write line: %v", fmt.Sprintf(format, args...)) + } id, err := c.conn.Cmd(format, args...) if err != nil { return nil, err @@ -108,6 +116,6 @@ func (c *Conn) debugf(format string, args ...interface{}) { } } -func newProtocolError(format string, args ...interface{}) textproto.ProtocolError { +func (*Conn) protoErr(format string, args ...interface{}) textproto.ProtocolError { return textproto.ProtocolError(fmt.Sprintf(format, args...)) } diff --git a/control/controltest/cmd_authenticate_test.go b/control/controltest/cmd_authenticate_test.go new file mode 100644 index 0000000..b8233a8 --- /dev/null +++ b/control/controltest/cmd_authenticate_test.go @@ -0,0 +1,23 @@ +package controltest + +import "testing" + +func TestAuthenticateNull(t *testing.T) { + ctx, conn := NewTestContextConnected(t) + defer ctx.CloseConnected(conn) + ctx.Require.NoError(conn.Authenticate("")) +} + +func TestAuthenticateSafeCookie(t *testing.T) { + ctx, conn := NewTestContextConnected(t, "--CookieAuthentication", "1") + defer ctx.CloseConnected(conn) + ctx.Require.NoError(conn.Authenticate("")) +} + +func TestAuthenticateHashedPassword(t *testing.T) { + // "testpass" - 16:5417AE717521511A609921392778FFA8518EC089BF2162A199241AEB4A + ctx, conn := NewTestContextConnected(t, "--HashedControlPassword", + "16:5417AE717521511A609921392778FFA8518EC089BF2162A199241AEB4A") + defer ctx.CloseConnected(conn) + ctx.Require.NoError(conn.Authenticate("testpass")) +} diff --git a/control/controltest/cmd_protocolinfo_test.go b/control/controltest/cmd_protocolinfo_test.go index 5bfed52..9dae54f 100644 --- a/control/controltest/cmd_protocolinfo_test.go +++ b/control/controltest/cmd_protocolinfo_test.go @@ -1,17 +1,14 @@ package controltest import ( - "context" "strings" "testing" ) func TestProtocolInfo(t *testing.T) { - ctx := NewTestContext(context.Background(), t) - defer ctx.Close() - conn := ctx.ConnectTestTor() - defer conn.Close() - info, err := conn.RequestProtocolInfo() + ctx, conn := NewTestContextConnected(t) + defer ctx.CloseConnected(conn) + info, err := conn.ProtocolInfo() ctx.Require.NoError(err) ctx.Require.Contains(info.AuthMethods, "NULL") ctx.Require.True(strings.HasPrefix(info.TorVersion, "0.3")) diff --git a/control/controltest/test_context.go b/control/controltest/test_context.go index f8e3aec..c124502 100644 --- a/control/controltest/test_context.go +++ b/control/controltest/test_context.go @@ -16,18 +16,29 @@ import ( type TestContext struct { context.Context *testing.T - Require *require.Assertions - TestTor *TestTor + ExtraTorArgs []string + Require *require.Assertions + TestTor *TestTor } -func NewTestContext(ctx context.Context, t *testing.T) *TestContext { - return &TestContext{Context: ctx, T: t, Require: require.New(t)} +func NewTestContext(ctx context.Context, t *testing.T, extraTorArgs ...string) *TestContext { + return &TestContext{Context: ctx, T: t, ExtraTorArgs: extraTorArgs, Require: require.New(t)} +} + +func NewTestContextConnected(t *testing.T, extraTorArgs ...string) (*TestContext, *control.Conn) { + ctx := NewTestContext(context.Background(), t, extraTorArgs...) + conn, err := ctx.ConnectTestTor() + if err != nil { + ctx.Close() + ctx.Fatal(err) + } + return ctx, conn } func (t *TestContext) EnsureTestTorStarted() { if t.TestTor == nil { var err error - if t.TestTor, err = StartTestTor(t); err != nil { + if t.TestTor, err = StartTestTor(t, t.ExtraTorArgs...); err != nil { t.Fatal(err) } } @@ -41,13 +52,20 @@ func (t *TestContext) Close() { } } -func (t *TestContext) ConnectTestTor() *control.Conn { +func (t *TestContext) CloseConnected(conn *control.Conn) { + if err := conn.Close(); err != nil { + fmt.Printf("Warning, close failed on tor conn: %v", err) + } + t.Close() +} + +func (t *TestContext) ConnectTestTor() (*control.Conn, error) { t.EnsureTestTorStarted() textConn, err := textproto.Dial("tcp", "127.0.0.1:"+strconv.Itoa(t.TestTor.ControlPort)) if err != nil { - t.Fatal(err) + return nil, err } conn := control.NewConn(textConn) conn.DebugWriter = os.Stdout - return conn + return conn, nil } diff --git a/control/controltest/test_tor.go b/control/controltest/test_tor.go index 7f54305..7694d69 100644 --- a/control/controltest/test_tor.go +++ b/control/controltest/test_tor.go @@ -27,7 +27,7 @@ type TestTor struct { processCancelFn context.CancelFunc } -func StartTestTor(ctx context.Context) (*TestTor, error) { +func StartTestTor(ctx context.Context, extraArgs ...string) (*TestTor, error) { dataDir, err := ioutil.TempDir(".", "test-data-dir-") if err != nil { return nil, err @@ -35,13 +35,13 @@ func StartTestTor(ctx context.Context) (*TestTor, error) { controlPortFile := filepath.Join(dataDir, "control-port") ret := &TestTor{ DataDir: dataDir, - OrigArgs: []string{ + OrigArgs: append([]string{ // "--quiet", "--DisableNetwork", "1", "--ControlPort", "auto", "--ControlPortWriteToFile", controlPortFile, "--DataDirectory", dataDir, - }, + }, extraArgs...), } errCh := make(chan error, 1) var processCtx context.Context diff --git a/control/response.go b/control/response.go index daba82e..d2b4218 100644 --- a/control/response.go +++ b/control/response.go @@ -6,27 +6,24 @@ import ( "strings" ) -// Response is a response to a control port command, or an asyncrhonous event. +// Response is a response to a control port command, or an asynchronous event. type Response struct { - // Err is the status code and string representation associated with a - // response. Responses that have completed successfully will also have - // Err set to indicate such. + // Err is the status code and string representation associated with a response. Responses that have completed + // successfully will also have Err set to indicate such. Err *textproto.Error // Reply is the text on the EndReplyLine of the response. Reply string - // Data is the MidReplyLines/DataReplyLines of the response. Dot encoded - // data is "decoded" and presented as a single string (terminal ".CRLF" - // removed, all intervening CRs stripped). + // Data is the MidReplyLines/DataReplyLines of the response. Dot encoded data is "decoded" and presented as a single + // string (terminal ".CRLF" removed, all intervening CRs stripped). Data []string // RawLines is all of the lines of a response, without CRLFs. RawLines []string } -// IsOk returns true if the response status code indicates success or -// an asynchronous event. +// IsOk returns true if the response status code indicates success or an asynchronous event. func (r *Response) IsOk() bool { switch r.Err.Code { case StatusOk, StatusOkUnneccecary, StatusAsyncEvent: @@ -41,9 +38,7 @@ func (r *Response) IsAsync() bool { return r.Err.Code == StatusAsyncEvent } -// ReadResponse returns the next response object. Calling this -// simultaniously with Read, Request, or StartAsyncReader will lead to -// undefined behavior +// ReadResponse returns the next response object. func (c *Conn) ReadResponse() (*Response, error) { var resp *Response var statusCode int @@ -56,45 +51,31 @@ func (c *Conn) ReadResponse() (*Response, error) { // Parse the line that was just read. if len(line) < 4 { - return nil, newProtocolError("truncated response: '%s'", line) + return nil, c.protoErr("Truncated response: %v", line) } - if code, err := strconv.Atoi(line[0:3]); err != nil { - return nil, newProtocolError("invalid status code: '%s'", line[0:3]) - } else if code < 100 { - return nil, newProtocolError("invalid status code: '%s'", line[0:3]) + if code, err := strconv.Atoi(line[0:3]); err != nil || code < 100 { + return nil, c.protoErr("Invalid status code: %v", line[0:3]) } else if resp == nil { - resp = new(Response) + resp = &Response{} statusCode = code } else if code != statusCode { - // The status code should stay fixed for all lines of the - // response, since events can't be interleaved with response - // lines. - return nil, newProtocolError("status code changed: %03d != %03d", code, statusCode) + // The status code should stay fixed for all lines of the response, since events can't be interleaved with + // response lines. + return nil, c.protoErr("Status code changed: %v != %v", code, statusCode) } - if resp.RawLines == nil { - resp.RawLines = make([]string, 0, 1) - } - - if line[3] == ' ' { + resp.RawLines = append(resp.RawLines, line) + switch line[3] { + case ' ': // Final line in the response. resp.Reply = line[4:] resp.Err = statusCodeToError(statusCode, resp.Reply) - resp.RawLines = append(resp.RawLines, line) return resp, nil - } - - if resp.Data == nil { - resp.Data = make([]string, 0, 1) - } - switch line[3] { case '-': // Continuation, keep reading. resp.Data = append(resp.Data, line[4:]) - resp.RawLines = append(resp.RawLines, line) case '+': // A "dot-encoded" payload follows. resp.Data = append(resp.Data, line[4:]) - resp.RawLines = append(resp.RawLines, line) dotBody, err := c.conn.ReadDotBytes() if err != nil { return nil, err @@ -106,7 +87,7 @@ func (c *Conn) ReadResponse() (*Response, error) { } resp.RawLines = append(resp.RawLines, ".") default: - return nil, newProtocolError("invalid separator: '%c'", line[3]) + return nil, c.protoErr("Invalid separator: '%v'", line[3]) } } }