diff --git a/application/application.go b/application/application.go index f864573..08753cd 100644 --- a/application/application.go +++ b/application/application.go @@ -118,7 +118,17 @@ func (ra *RicochetApplication) Shutdown() { ra.lock.Lock() ra.ls.Close() for _, instance := range ra.instances { - instance.Connection.Conn.Close() + instance.Connection.Close() + } + ra.lock.Unlock() +} + +func (ra *RicochetApplication) Close(onion string) { + ra.lock.Lock() + for _, instance := range ra.instances { + if instance.RemoteHostname == onion { + instance.Connection.Close() + } } ra.lock.Unlock() } diff --git a/connection/connection.go b/connection/connection.go index 2506553..a4d6ff3 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -39,7 +39,7 @@ type Connection struct { // it for anything else. See those functions for an explanation. processBlockMutex sync.Mutex - Conn io.ReadWriteCloser + conn io.ReadWriteCloser IsInbound bool am AuthorizationManager RemoteHostname string @@ -66,7 +66,7 @@ func (rc *Connection) init() { // modelling an Inbound Connection func NewInboundConnection(conn io.ReadWriteCloser) *Connection { rc := new(Connection) - rc.Conn = conn + rc.conn = conn rc.IsInbound = true rc.init() rc.channelManager = NewServerChannelManager() @@ -78,7 +78,7 @@ func NewInboundConnection(conn io.ReadWriteCloser) *Connection { // modelling an Inbound Connection func NewOutboundConnection(conn io.ReadWriteCloser, remoteHostname string) *Connection { rc := new(Connection) - rc.Conn = conn + rc.conn = conn rc.IsInbound = false rc.init() rc.RemoteHostname = remoteHostname @@ -90,7 +90,7 @@ func NewOutboundConnection(conn io.ReadWriteCloser, remoteHostname string) *Conn // start func (rc *Connection) start() { for { - packet, err := rc.RecvRicochetPacket(rc.Conn) + packet, err := rc.RecvRicochetPacket(rc.conn) if err != nil { rc.errorChannel <- err return @@ -206,7 +206,7 @@ func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) func (rc *Connection) handleChannelOpening(channel *channels.Channel, response []byte, err error) (*channels.Channel, error) { if err == nil { - rc.SendRicochetPacket(rc.Conn, 0, response) + rc.SendRicochetPacket(rc.conn, 0, response) return channel, nil } log.Debugln(fmt.Sprintf("failed to request open channel: %v", err)) @@ -220,13 +220,13 @@ func (rc *Connection) buildChannel(handler channels.Handler, openChannelFunc fun channel, err := openChannelFunc(handler) if err == nil { channel.SendMessage = func(message []byte) { - rc.SendRicochetPacket(rc.Conn, channel.ID, message) + rc.SendRicochetPacket(rc.conn, channel.ID, message) } channel.DelegateAuthorization = func() { rc.am.AddAuthorization(handler.Type()) } channel.CloseChannel = func() { - rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{}) + rc.SendRicochetPacket(rc.conn, channel.ID, []byte{}) rc.channelManager.RemoveChannel(channel.ID) } return channel, nil @@ -294,7 +294,7 @@ func (rc *Connection) Process(handler Handler) error { case packet = <-rc.packetChannel: break case err := <-rc.errorChannel: - rc.Conn.Close() + rc.conn.Close() rc.closing = true // In order to safely close down concurrent calls to Do or Break, @@ -360,7 +360,7 @@ func (rc *Connection) Process(handler Handler) error { // that channel. log.Debugln(fmt.Sprintf("received packet on unknown channel %d. closing.", packet.Channel)) if len(packet.Data) != 0 { - rc.SendRicochetPacket(rc.Conn, packet.Channel, []byte{}) + rc.SendRicochetPacket(rc.conn, packet.Channel, []byte{}) } } } @@ -381,7 +381,7 @@ func (rc *Connection) controlPacket(handler Handler, res *Protocol_Data_Control. response, err := chandler.OpenInbound(channel, opm) _, err = rc.handleChannelOpening(channel, response, err) if err != nil { - rc.SendRicochetPacket(rc.Conn, 0, []byte{}) + rc.SendRicochetPacket(rc.conn, 0, []byte{}) } return } @@ -396,7 +396,7 @@ func (rc *Connection) controlPacket(handler Handler, res *Protocol_Data_Control. // Send Error Packet response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), errorText) log.Debugln(fmt.Sprintf("sending reject open channel for %v: %v", opm.GetChannelIdentifier(), errorText)) - rc.SendRicochetPacket(rc.Conn, 0, response) + rc.SendRicochetPacket(rc.conn, 0, response) } else if res.GetChannelResult() != nil { rc.ctrlChannel.ProcessChannelResult(res.GetChannelResult()) @@ -408,13 +408,13 @@ func (rc *Connection) controlPacket(handler Handler, res *Protocol_Data_Control. respond, data := rc.ctrlChannel.ProcessKeepAlive(res.GetKeepAlive()) if respond { log.Debugln("sending keep alive response") - rc.SendRicochetPacket(rc.Conn, 0, data) + rc.SendRicochetPacket(rc.conn, 0, data) } } else if res.GetEnableFeatures() != nil { log.Debugln("received enable features packet") data := rc.ctrlChannel.ProcessEnableFeatures(handler, res.GetEnableFeatures()) log.Debugln(fmt.Sprintf("sending featured enabled: %v", data)) - rc.SendRicochetPacket(rc.Conn, 0, data) + rc.SendRicochetPacket(rc.conn, 0, data) } else if res.GetFeaturesEnabled() != nil { rc.SupportChannels = res.GetFeaturesEnabled().GetFeature() log.Debugln(fmt.Sprintf("connection supports: %v", rc.SupportChannels)) @@ -426,7 +426,7 @@ func (rc *Connection) controlPacket(handler Handler, res *Protocol_Data_Control. func (rc *Connection) EnableFeatures(features []string) { messageBuilder := new(utils.MessageBuilder) raw := messageBuilder.EnableFeatures(features) - rc.SendRicochetPacket(rc.Conn, 0, raw) + rc.SendRicochetPacket(rc.conn, 0, raw) } // Break causes Process() to return, but does not close the underlying connection @@ -452,3 +452,12 @@ func (rc *Connection) Break() error { func (rc *Connection) Channel(ctype string, way channels.Direction) *channels.Channel { return rc.channelManager.Channel(ctype, way) } + +// Close tearsdown a Process() and explicitly Closes the connection. +// Note, that if Process() is holding a connection this will trigger an Error +func (rc *Connection) Close() { + // Kill the Ricochet Connection. + log.Debugf("Closing Ricochet Connection for %v", rc.RemoteHostname) + rc.conn.Close() + rc.closed = true +} diff --git a/connection/connection_test.go b/connection/connection_test.go index 7045c18..843ea05 100644 --- a/connection/connection_test.go +++ b/connection/connection_test.go @@ -82,6 +82,9 @@ func TestProcessAuthAs3DHServer(t *testing.T) { if err != nil { t.Errorf("Error while testing ProcessAuthAsServer: %v", err) } + + // Test Close + rc.Close() } func TestProcessAuthAsV3ServerFail(t *testing.T) { @@ -179,6 +182,6 @@ func TestProcessAuthTimeout(t *testing.T) { rc := NewInboundConnection(conn) err := HandleInboundConnection(rc).ProcessAuthAsServer(identity.Initialize("", privateKey), ServerAuthValid) if err != utils.ActionTimedOutError { - t.Errorf("Error while testing TestProcessAuthTimeout - Should have timed out after 15 seconds") + t.Errorf("Error while testing TestProcessAuthTimeout - Should have timed out after 15 seconds, instead ERR was %v", err) } } diff --git a/connectivity/acn.go b/connectivity/acn.go index 7be9eb4..3cbf4a6 100644 --- a/connectivity/acn.go +++ b/connectivity/acn.go @@ -26,7 +26,7 @@ type ACN interface { // WaitTillBootstrapped Blocks until underlying network is bootstrapped WaitTillBootstrapped() - // Open takes a hostname and returns a net.Conn to the derived endpoint + // Open takes a hostname and returns a net.conn to the derived endpoint // Open allows a client to resolve various hostnames to connections // The supported types are onions address are: // * ricochet:jlq67qzo6s4yp3sp diff --git a/ricochet.go b/ricochet.go index 911e6d6..c8fab46 100644 --- a/ricochet.go +++ b/ricochet.go @@ -8,7 +8,7 @@ import ( "net" ) -// Open establishes a protocol session on an established net.Conn, and returns a new +// Open establishes a protocol session on an established net.conn, and returns a new // OpenConnection instance representing this connection. On error, the connection // will be closed. This function blocks until version negotiation has completed. // The application should call Process() on the returned OpenConnection to continue