Authenticate support
This commit is contained in:
parent
4156ef3cfc
commit
917a38aeb6
|
@ -0,0 +1,108 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Authenticate authenticates with the Tor instance using the "best" possible authentication method if not already
|
||||
// authenticated and sets the Authenticated field. The password argument is optional, and will only be used if the
|
||||
// "SAFECOOKIE" and "NULL" authentication methods are not available and "HASHEDPASSWORD" is.
|
||||
func (c *Conn) Authenticate(password string) error {
|
||||
if c.Authenticated {
|
||||
return nil
|
||||
}
|
||||
// Determine the supported authentication methods, and the cookie path.
|
||||
pi, err := c.ProtocolInfo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Get the bytes to pass to with authenticate
|
||||
var authBytes []byte
|
||||
if pi.HasAuthMethod("NULL") {
|
||||
// No auth bytes
|
||||
} else if pi.HasAuthMethod("SAFECOOKIE") {
|
||||
if pi.CookieFile == "" {
|
||||
return c.protoErr("Invalid (empty) COOKIEFILE")
|
||||
}
|
||||
cookie, err := ioutil.ReadFile(pi.CookieFile)
|
||||
if err != nil {
|
||||
return c.protoErr("Failed to read COOKIEFILE: %v", err)
|
||||
} else if len(cookie) != 32 {
|
||||
return c.protoErr("Invalid cookie file length: %v", len(cookie))
|
||||
}
|
||||
|
||||
// Send an AUTHCHALLENGE command, and parse the response.
|
||||
var clientNonce [32]byte
|
||||
if _, err := rand.Read(clientNonce[:]); err != nil {
|
||||
return c.protoErr("Failed to generate clientNonce: %v", err)
|
||||
}
|
||||
resp, err := c.SendRequest("AUTHCHALLENGE %s %s", "SAFECOOKIE", hex.EncodeToString(clientNonce[:]))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
splitResp := strings.Split(resp.Reply, " ")
|
||||
if len(splitResp) != 3 || !strings.HasPrefix(splitResp[1], "SERVERHASH=") ||
|
||||
!strings.HasPrefix(splitResp[2], "SERVERNONCE=") {
|
||||
return c.protoErr("Invalid AUTHCHALLENGE response")
|
||||
}
|
||||
serverHash, err := hex.DecodeString(splitResp[1][11:])
|
||||
if err != nil {
|
||||
return c.protoErr("Failed to decode ServerHash: %v", err)
|
||||
}
|
||||
if len(serverHash) != 32 {
|
||||
return c.protoErr("Invalid ServerHash length: %d", len(serverHash))
|
||||
}
|
||||
serverNonce, err := hex.DecodeString(splitResp[2][12:])
|
||||
if err != nil {
|
||||
return c.protoErr("Failed to decode ServerNonce: %v", err)
|
||||
}
|
||||
if len(serverNonce) != 32 {
|
||||
return c.protoErr("Invalid ServerNonce length: %d", len(serverNonce))
|
||||
}
|
||||
|
||||
// Validate the ServerHash.
|
||||
m := hmac.New(sha256.New, []byte("Tor safe cookie authentication server-to-controller hash"))
|
||||
m.Write(cookie)
|
||||
m.Write(clientNonce[:])
|
||||
m.Write(serverNonce)
|
||||
dervServerHash := m.Sum(nil)
|
||||
if !hmac.Equal(serverHash, dervServerHash) {
|
||||
return c.protoErr("invalid ServerHash: mismatch")
|
||||
}
|
||||
|
||||
// Calculate the ClientHash, and issue the AUTHENTICATE.
|
||||
m = hmac.New(sha256.New, []byte("Tor safe cookie authentication controller-to-server hash"))
|
||||
m.Write(cookie)
|
||||
m.Write(clientNonce[:])
|
||||
m.Write(serverNonce)
|
||||
authBytes = m.Sum(nil)
|
||||
} else if pi.HasAuthMethod("HASHEDPASSWORD") {
|
||||
// Despite the name HASHEDPASSWORD, the raw password is actually sent. According to the code, this can either be
|
||||
// a QuotedString, or base16 encoded, so go with the later since it's easier to handle.
|
||||
if password == "" {
|
||||
return c.protoErr("password auth needs a password")
|
||||
}
|
||||
authBytes = []byte(password)
|
||||
} else {
|
||||
return c.protoErr("No supported authentication methods")
|
||||
}
|
||||
// Send it
|
||||
if err = c.sendAuthenticate(authBytes); err == nil {
|
||||
c.Authenticated = true
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) sendAuthenticate(byts []byte) (err error) {
|
||||
if len(byts) == 0 {
|
||||
_, err = c.SendRequest("AUTHENTICATE")
|
||||
} else {
|
||||
_, err = c.SendRequest("AUTHENTICATE %v", hex.EncodeToString(byts))
|
||||
}
|
||||
return
|
||||
}
|
|
@ -22,7 +22,15 @@ func (p *ProtocolInfo) HasAuthMethod(authMethod string) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (c *Conn) RequestProtocolInfo() (*ProtocolInfo, error) {
|
||||
func (c *Conn) ProtocolInfo() (*ProtocolInfo, error) {
|
||||
var err error
|
||||
if c.protocolInfo == nil {
|
||||
c.protocolInfo, err = c.sendProtocolInfo()
|
||||
}
|
||||
return c.protocolInfo, err
|
||||
}
|
||||
|
||||
func (c *Conn) sendProtocolInfo() (*ProtocolInfo, error) {
|
||||
resp, err := c.SendRequest("PROTOCOLINFO")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -37,7 +45,7 @@ func (c *Conn) RequestProtocolInfo() (*ProtocolInfo, error) {
|
|||
switch key {
|
||||
case "PROTOCOLINFO":
|
||||
if val != "1" {
|
||||
return nil, newProtocolError("Invalid PIVERSION: %v", val)
|
||||
return nil, c.protoErr("Invalid PIVERSION: %v", val)
|
||||
}
|
||||
case "AUTH":
|
||||
methods, cookieFile, _ := util.PartitionString(val, ' ')
|
||||
|
|
|
@ -16,6 +16,11 @@ type Conn struct {
|
|||
asyncChansLock sync.RWMutex
|
||||
// Never mutated outside of lock, always created anew
|
||||
asyncChans []chan<- *Response
|
||||
|
||||
// Set lazily
|
||||
protocolInfo *ProtocolInfo
|
||||
|
||||
Authenticated bool
|
||||
}
|
||||
|
||||
func NewConn(conn *textproto.Conn) *Conn { return &Conn{conn: conn} }
|
||||
|
@ -26,6 +31,9 @@ func (c *Conn) SendSignal(signal string) error {
|
|||
}
|
||||
|
||||
func (c *Conn) SendRequest(format string, args ...interface{}) (*Response, error) {
|
||||
if c.debugEnabled() {
|
||||
c.debugf("Write line: %v", fmt.Sprintf(format, args...))
|
||||
}
|
||||
id, err := c.conn.Cmd(format, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -108,6 +116,6 @@ func (c *Conn) debugf(format string, args ...interface{}) {
|
|||
}
|
||||
}
|
||||
|
||||
func newProtocolError(format string, args ...interface{}) textproto.ProtocolError {
|
||||
func (*Conn) protoErr(format string, args ...interface{}) textproto.ProtocolError {
|
||||
return textproto.ProtocolError(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
package controltest
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestAuthenticateNull(t *testing.T) {
|
||||
ctx, conn := NewTestContextConnected(t)
|
||||
defer ctx.CloseConnected(conn)
|
||||
ctx.Require.NoError(conn.Authenticate(""))
|
||||
}
|
||||
|
||||
func TestAuthenticateSafeCookie(t *testing.T) {
|
||||
ctx, conn := NewTestContextConnected(t, "--CookieAuthentication", "1")
|
||||
defer ctx.CloseConnected(conn)
|
||||
ctx.Require.NoError(conn.Authenticate(""))
|
||||
}
|
||||
|
||||
func TestAuthenticateHashedPassword(t *testing.T) {
|
||||
// "testpass" - 16:5417AE717521511A609921392778FFA8518EC089BF2162A199241AEB4A
|
||||
ctx, conn := NewTestContextConnected(t, "--HashedControlPassword",
|
||||
"16:5417AE717521511A609921392778FFA8518EC089BF2162A199241AEB4A")
|
||||
defer ctx.CloseConnected(conn)
|
||||
ctx.Require.NoError(conn.Authenticate("testpass"))
|
||||
}
|
|
@ -1,17 +1,14 @@
|
|||
package controltest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProtocolInfo(t *testing.T) {
|
||||
ctx := NewTestContext(context.Background(), t)
|
||||
defer ctx.Close()
|
||||
conn := ctx.ConnectTestTor()
|
||||
defer conn.Close()
|
||||
info, err := conn.RequestProtocolInfo()
|
||||
ctx, conn := NewTestContextConnected(t)
|
||||
defer ctx.CloseConnected(conn)
|
||||
info, err := conn.ProtocolInfo()
|
||||
ctx.Require.NoError(err)
|
||||
ctx.Require.Contains(info.AuthMethods, "NULL")
|
||||
ctx.Require.True(strings.HasPrefix(info.TorVersion, "0.3"))
|
||||
|
|
|
@ -16,18 +16,29 @@ import (
|
|||
type TestContext struct {
|
||||
context.Context
|
||||
*testing.T
|
||||
Require *require.Assertions
|
||||
TestTor *TestTor
|
||||
ExtraTorArgs []string
|
||||
Require *require.Assertions
|
||||
TestTor *TestTor
|
||||
}
|
||||
|
||||
func NewTestContext(ctx context.Context, t *testing.T) *TestContext {
|
||||
return &TestContext{Context: ctx, T: t, Require: require.New(t)}
|
||||
func NewTestContext(ctx context.Context, t *testing.T, extraTorArgs ...string) *TestContext {
|
||||
return &TestContext{Context: ctx, T: t, ExtraTorArgs: extraTorArgs, Require: require.New(t)}
|
||||
}
|
||||
|
||||
func NewTestContextConnected(t *testing.T, extraTorArgs ...string) (*TestContext, *control.Conn) {
|
||||
ctx := NewTestContext(context.Background(), t, extraTorArgs...)
|
||||
conn, err := ctx.ConnectTestTor()
|
||||
if err != nil {
|
||||
ctx.Close()
|
||||
ctx.Fatal(err)
|
||||
}
|
||||
return ctx, conn
|
||||
}
|
||||
|
||||
func (t *TestContext) EnsureTestTorStarted() {
|
||||
if t.TestTor == nil {
|
||||
var err error
|
||||
if t.TestTor, err = StartTestTor(t); err != nil {
|
||||
if t.TestTor, err = StartTestTor(t, t.ExtraTorArgs...); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
@ -41,13 +52,20 @@ func (t *TestContext) Close() {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *TestContext) ConnectTestTor() *control.Conn {
|
||||
func (t *TestContext) CloseConnected(conn *control.Conn) {
|
||||
if err := conn.Close(); err != nil {
|
||||
fmt.Printf("Warning, close failed on tor conn: %v", err)
|
||||
}
|
||||
t.Close()
|
||||
}
|
||||
|
||||
func (t *TestContext) ConnectTestTor() (*control.Conn, error) {
|
||||
t.EnsureTestTorStarted()
|
||||
textConn, err := textproto.Dial("tcp", "127.0.0.1:"+strconv.Itoa(t.TestTor.ControlPort))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return nil, err
|
||||
}
|
||||
conn := control.NewConn(textConn)
|
||||
conn.DebugWriter = os.Stdout
|
||||
return conn
|
||||
return conn, nil
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ type TestTor struct {
|
|||
processCancelFn context.CancelFunc
|
||||
}
|
||||
|
||||
func StartTestTor(ctx context.Context) (*TestTor, error) {
|
||||
func StartTestTor(ctx context.Context, extraArgs ...string) (*TestTor, error) {
|
||||
dataDir, err := ioutil.TempDir(".", "test-data-dir-")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -35,13 +35,13 @@ func StartTestTor(ctx context.Context) (*TestTor, error) {
|
|||
controlPortFile := filepath.Join(dataDir, "control-port")
|
||||
ret := &TestTor{
|
||||
DataDir: dataDir,
|
||||
OrigArgs: []string{
|
||||
OrigArgs: append([]string{
|
||||
// "--quiet",
|
||||
"--DisableNetwork", "1",
|
||||
"--ControlPort", "auto",
|
||||
"--ControlPortWriteToFile", controlPortFile,
|
||||
"--DataDirectory", dataDir,
|
||||
},
|
||||
}, extraArgs...),
|
||||
}
|
||||
errCh := make(chan error, 1)
|
||||
var processCtx context.Context
|
||||
|
|
|
@ -6,27 +6,24 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
// Response is a response to a control port command, or an asyncrhonous event.
|
||||
// Response is a response to a control port command, or an asynchronous event.
|
||||
type Response struct {
|
||||
// Err is the status code and string representation associated with a
|
||||
// response. Responses that have completed successfully will also have
|
||||
// Err set to indicate such.
|
||||
// Err is the status code and string representation associated with a response. Responses that have completed
|
||||
// successfully will also have Err set to indicate such.
|
||||
Err *textproto.Error
|
||||
|
||||
// Reply is the text on the EndReplyLine of the response.
|
||||
Reply string
|
||||
|
||||
// Data is the MidReplyLines/DataReplyLines of the response. Dot encoded
|
||||
// data is "decoded" and presented as a single string (terminal ".CRLF"
|
||||
// removed, all intervening CRs stripped).
|
||||
// Data is the MidReplyLines/DataReplyLines of the response. Dot encoded data is "decoded" and presented as a single
|
||||
// string (terminal ".CRLF" removed, all intervening CRs stripped).
|
||||
Data []string
|
||||
|
||||
// RawLines is all of the lines of a response, without CRLFs.
|
||||
RawLines []string
|
||||
}
|
||||
|
||||
// IsOk returns true if the response status code indicates success or
|
||||
// an asynchronous event.
|
||||
// IsOk returns true if the response status code indicates success or an asynchronous event.
|
||||
func (r *Response) IsOk() bool {
|
||||
switch r.Err.Code {
|
||||
case StatusOk, StatusOkUnneccecary, StatusAsyncEvent:
|
||||
|
@ -41,9 +38,7 @@ func (r *Response) IsAsync() bool {
|
|||
return r.Err.Code == StatusAsyncEvent
|
||||
}
|
||||
|
||||
// ReadResponse returns the next response object. Calling this
|
||||
// simultaniously with Read, Request, or StartAsyncReader will lead to
|
||||
// undefined behavior
|
||||
// ReadResponse returns the next response object.
|
||||
func (c *Conn) ReadResponse() (*Response, error) {
|
||||
var resp *Response
|
||||
var statusCode int
|
||||
|
@ -56,45 +51,31 @@ func (c *Conn) ReadResponse() (*Response, error) {
|
|||
|
||||
// Parse the line that was just read.
|
||||
if len(line) < 4 {
|
||||
return nil, newProtocolError("truncated response: '%s'", line)
|
||||
return nil, c.protoErr("Truncated response: %v", line)
|
||||
}
|
||||
if code, err := strconv.Atoi(line[0:3]); err != nil {
|
||||
return nil, newProtocolError("invalid status code: '%s'", line[0:3])
|
||||
} else if code < 100 {
|
||||
return nil, newProtocolError("invalid status code: '%s'", line[0:3])
|
||||
if code, err := strconv.Atoi(line[0:3]); err != nil || code < 100 {
|
||||
return nil, c.protoErr("Invalid status code: %v", line[0:3])
|
||||
} else if resp == nil {
|
||||
resp = new(Response)
|
||||
resp = &Response{}
|
||||
statusCode = code
|
||||
} else if code != statusCode {
|
||||
// The status code should stay fixed for all lines of the
|
||||
// response, since events can't be interleaved with response
|
||||
// lines.
|
||||
return nil, newProtocolError("status code changed: %03d != %03d", code, statusCode)
|
||||
// The status code should stay fixed for all lines of the response, since events can't be interleaved with
|
||||
// response lines.
|
||||
return nil, c.protoErr("Status code changed: %v != %v", code, statusCode)
|
||||
}
|
||||
if resp.RawLines == nil {
|
||||
resp.RawLines = make([]string, 0, 1)
|
||||
}
|
||||
|
||||
if line[3] == ' ' {
|
||||
resp.RawLines = append(resp.RawLines, line)
|
||||
switch line[3] {
|
||||
case ' ':
|
||||
// Final line in the response.
|
||||
resp.Reply = line[4:]
|
||||
resp.Err = statusCodeToError(statusCode, resp.Reply)
|
||||
resp.RawLines = append(resp.RawLines, line)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if resp.Data == nil {
|
||||
resp.Data = make([]string, 0, 1)
|
||||
}
|
||||
switch line[3] {
|
||||
case '-':
|
||||
// Continuation, keep reading.
|
||||
resp.Data = append(resp.Data, line[4:])
|
||||
resp.RawLines = append(resp.RawLines, line)
|
||||
case '+':
|
||||
// A "dot-encoded" payload follows.
|
||||
resp.Data = append(resp.Data, line[4:])
|
||||
resp.RawLines = append(resp.RawLines, line)
|
||||
dotBody, err := c.conn.ReadDotBytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -106,7 +87,7 @@ func (c *Conn) ReadResponse() (*Response, error) {
|
|||
}
|
||||
resp.RawLines = append(resp.RawLines, ".")
|
||||
default:
|
||||
return nil, newProtocolError("invalid separator: '%c'", line[3])
|
||||
return nil, c.protoErr("Invalid separator: '%v'", line[3])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue