85 lines
2.0 KiB
Go
85 lines
2.0 KiB
Go
|
package spam
|
||
|
|
||
|
import (
|
||
|
"crypto/rand"
|
||
|
"crypto/sha256"
|
||
|
"git.mascherari.press/cwtch/protocol"
|
||
|
"github.com/golang/protobuf/proto"
|
||
|
"github.com/s-rah/go-ricochet/utils"
|
||
|
"github.com/s-rah/go-ricochet/wire/control"
|
||
|
"io"
|
||
|
)
|
||
|
|
||
|
type SpamGuard struct {
|
||
|
Difficulty int
|
||
|
nonce [24]byte
|
||
|
}
|
||
|
|
||
|
func (sg *SpamGuard) GenerateChallenge(channelID int32) []byte {
|
||
|
|
||
|
cr := &Protocol_Data_Control.ChannelResult{
|
||
|
ChannelIdentifier: proto.Int32(channelID),
|
||
|
Opened: proto.Bool(true),
|
||
|
}
|
||
|
|
||
|
var nonce [24]byte
|
||
|
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
|
||
|
utils.CheckError(err)
|
||
|
}
|
||
|
sg.nonce = nonce
|
||
|
err := proto.SetExtension(cr, protocol.E_ServerNonce, sg.nonce[:])
|
||
|
utils.CheckError(err)
|
||
|
|
||
|
pc := &Protocol_Data_Control.Packet{
|
||
|
ChannelResult: cr,
|
||
|
}
|
||
|
ret, err := proto.Marshal(pc)
|
||
|
utils.CheckError(err)
|
||
|
return ret
|
||
|
}
|
||
|
|
||
|
func (sg *SpamGuard) SolveChallenge(challenge []byte, message []byte) []byte {
|
||
|
solved := false
|
||
|
var spamguard [24]byte
|
||
|
sum := sha256.Sum256([]byte{})
|
||
|
solve := make([]byte, len(challenge)+len(message)+len(spamguard))
|
||
|
for !solved {
|
||
|
|
||
|
if _, err := io.ReadFull(rand.Reader, spamguard[:]); err != nil {
|
||
|
utils.CheckError(err)
|
||
|
}
|
||
|
|
||
|
copy(solve[0:], challenge[:])
|
||
|
copy(solve[len(challenge):], message[:])
|
||
|
copy(solve[len(challenge)+len(message):], spamguard[:])
|
||
|
|
||
|
sum = sha256.Sum256(solve)
|
||
|
|
||
|
solved = true
|
||
|
for i := 0; i < sg.Difficulty; i++ {
|
||
|
if sum[i] != 0x00 {
|
||
|
solved = false
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//log.Printf("Solved answer: %v %x %x\n", len(solve), solve, sum)
|
||
|
return spamguard[:]
|
||
|
}
|
||
|
|
||
|
func (sg *SpamGuard) ValidateChallenge(message []byte, spamguard []byte) bool {
|
||
|
//log.Printf("%v %v\n", sg.nonce[:], spamguard[:])
|
||
|
solve := make([]byte, len(sg.nonce)+len(message)+len(spamguard))
|
||
|
copy(solve[0:], sg.nonce[:])
|
||
|
copy(solve[len(sg.nonce):], message[:])
|
||
|
copy(solve[len(sg.nonce)+len(message):], spamguard[:])
|
||
|
sum := sha256.Sum256(solve)
|
||
|
//log.Printf("Got answer: %v %x %x\n", len(solve), solve, sum)
|
||
|
for i := 0; i < sg.Difficulty; i++ {
|
||
|
if sum[i] != 0x00 {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
}
|