You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

365 lines
12 KiB

  1. package ristretto255
  2. import (
  3. "bytes"
  4. "crypto/sha512"
  5. "encoding/hex"
  6. "encoding/json"
  7. "testing"
  8. "github.com/gtank/ristretto255/internal/radix51"
  9. )
  10. func assertFeEqual(value, expect *radix51.FieldElement) {
  11. if value.Equal(expect) == 1 {
  12. return
  13. } else {
  14. panic("failed equality assertion")
  15. }
  16. }
  17. type sqrtRatioTest struct {
  18. u, v *radix51.FieldElement
  19. sqrt *radix51.FieldElement
  20. choice int
  21. negative int
  22. }
  23. func TestSqrtRatioM1(t *testing.T) {
  24. // These tests can be found in curve25519-dalek's 'field.rs'
  25. var (
  26. zero, one = radix51.Zero, radix51.One
  27. // Two is nonsquare in our field, 4 is square
  28. two = new(radix51.FieldElement).Add(one, one)
  29. four = new(radix51.FieldElement).Add(two, two)
  30. // 2*i
  31. twoTimesSqrtM1 = new(radix51.FieldElement).Mul(two, sqrtM1)
  32. sqrt2i = fieldElementFromDecimal(
  33. "38214883241950591754978413199355411911188925816896391856984770930832735035196")
  34. invSqrt4 = fieldElementFromDecimal(
  35. "28948022309329048855892746252171976963317496166410141009864396001978282409974")
  36. )
  37. // Check the construction of those magic numbers.
  38. assertFeEqual(new(radix51.FieldElement).Mul(sqrt2i, sqrt2i), twoTimesSqrtM1)
  39. assertFeEqual(new(radix51.FieldElement).Mul(new(radix51.FieldElement).Square(invSqrt4), four), one)
  40. var tests = []sqrtRatioTest{
  41. {u: zero, v: zero, sqrt: zero, choice: 1, negative: 0},
  42. {u: zero, v: one, sqrt: zero, choice: 1, negative: 0},
  43. {u: one, v: zero, sqrt: zero, choice: 0, negative: 0},
  44. {u: two, v: one, sqrt: sqrt2i, choice: 0, negative: 0},
  45. {u: four, v: one, sqrt: two, choice: 1, negative: 0},
  46. {u: one, v: four, sqrt: invSqrt4, choice: 1, negative: 0},
  47. }
  48. for idx, tt := range tests {
  49. sqrt := new(radix51.FieldElement)
  50. choice := feSqrtRatio(sqrt, tt.u, tt.v)
  51. if choice != tt.choice || sqrt.Equal(tt.sqrt) != 1 || sqrt.IsNegative() != tt.negative {
  52. t.Errorf("Failed test %d", idx)
  53. t.Logf("Got {u: %v, v: %v, sqrt: %v, choice: %d, neg: %d}", tt.u, tt.v, sqrt, choice, sqrt.IsNegative())
  54. }
  55. }
  56. }
  57. // The encoding of the canonical generator.
  58. var compressedRistrettoBasepoint, _ = hex.DecodeString("e2f2ae0a6abc4e71a884a961c500515f58e30b6aa582dd8db6a65945e08d2d76")
  59. func TestRistrettoBasepointRoundTrip(t *testing.T) {
  60. decodedBasepoint := &Element{}
  61. err := decodedBasepoint.Decode(compressedRistrettoBasepoint)
  62. if err != nil {
  63. t.Fatal(err)
  64. }
  65. ristrettoBasepoint := (&Element{}).Base()
  66. if decodedBasepoint.Equal(ristrettoBasepoint) != 1 {
  67. t.Error("decode succeeded, but got wrong point")
  68. }
  69. roundtripBasepoint := decodedBasepoint.Encode(nil)
  70. if !bytes.Equal(compressedRistrettoBasepoint, roundtripBasepoint) {
  71. t.Error("decode<>encode roundtrip produced different results")
  72. }
  73. encodedBasepoint := ristrettoBasepoint.Encode(nil)
  74. if !bytes.Equal(compressedRistrettoBasepoint, encodedBasepoint) {
  75. t.Error("point encode produced different results")
  76. }
  77. }
  78. func TestRistrettoRandomRoundtrip(t *testing.T) {
  79. // TODO quickcheck
  80. }
  81. func TestRistrettoSmallMultiplesTestVectors(t *testing.T) {
  82. var testVectors = [16]string{
  83. // This is the identity point
  84. "0000000000000000000000000000000000000000000000000000000000000000",
  85. // This is the basepoint
  86. "e2f2ae0a6abc4e71a884a961c500515f58e30b6aa582dd8db6a65945e08d2d76",
  87. // These are small multiples of the basepoint
  88. "6a493210f7499cd17fecb510ae0cea23a110e8d5b901f8acadd3095c73a3b919",
  89. "94741f5d5d52755ece4f23f044ee27d5d1ea1e2bd196b462166b16152a9d0259",
  90. "da80862773358b466ffadfe0b3293ab3d9fd53c5ea6c955358f568322daf6a57",
  91. "e882b131016b52c1d3337080187cf768423efccbb517bb495ab812c4160ff44e",
  92. "f64746d3c92b13050ed8d80236a7f0007c3b3f962f5ba793d19a601ebb1df403",
  93. "44f53520926ec81fbd5a387845beb7df85a96a24ece18738bdcfa6a7822a176d",
  94. "903293d8f2287ebe10e2374dc1a53e0bc887e592699f02d077d5263cdd55601c",
  95. "02622ace8f7303a31cafc63f8fc48fdc16e1c8c8d234b2f0d6685282a9076031",
  96. "20706fd788b2720a1ed2a5dad4952b01f413bcf0e7564de8cdc816689e2db95f",
  97. "bce83f8ba5dd2fa572864c24ba1810f9522bc6004afe95877ac73241cafdab42",
  98. "e4549ee16b9aa03099ca208c67adafcafa4c3f3e4e5303de6026e3ca8ff84460",
  99. "aa52e000df2e16f55fb1032fc33bc42742dad6bd5a8fc0be0167436c5948501f",
  100. "46376b80f409b29dc2b5f6f0c52591990896e5716f41477cd30085ab7f10301e",
  101. "e0c418f7c8d9c4cdd7395b93ea124f3ad99021bb681dfc3302a9d99a2e53e64e",
  102. }
  103. basepointMultiple := (&Element{}).Zero()
  104. ristrettoBasepoint := (&Element{}).Base()
  105. for i := range testVectors {
  106. // Grab the bytes of the encoding
  107. encoding, err := hex.DecodeString(testVectors[i])
  108. if err != nil {
  109. t.Fatalf("#%d: bad hex encoding in test vector: %v", i, err)
  110. }
  111. // Decode the test vector to a ristretto255 element
  112. decodedPoint := Element{}
  113. err = decodedPoint.Decode(encoding)
  114. if err != nil {
  115. t.Fatalf("#%d: could not decode test vector: %v", i, err)
  116. }
  117. // Re-encode and check round trips
  118. roundtripEncoding := decodedPoint.Encode(nil)
  119. if !bytes.Equal(encoding, roundtripEncoding) {
  120. t.Errorf("#%d: decode<>encode roundtrip failed", i)
  121. }
  122. // Check that the test vector encodes i * B
  123. if basepointMultiple.Equal(&decodedPoint) != 1 {
  124. t.Errorf("decoded small multiple %d * B is not %d * B", i, i)
  125. }
  126. computedEncoding := basepointMultiple.Encode(nil)
  127. if !bytes.Equal(encoding, computedEncoding) {
  128. t.Errorf("#%d: encoding computed value did not match", i)
  129. }
  130. // Ensure basepointMultiple = i * B in the next iteration
  131. basepointMultiple.Add(basepointMultiple, ristrettoBasepoint)
  132. }
  133. }
  134. func TestRistrettoBadEncodingsTestVectors(t *testing.T) {
  135. var testVectors = []string{
  136. // These are all bad because they're non-canonical field encodings.
  137. "00ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
  138. "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff7f",
  139. "f3ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff7f",
  140. "edffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff7f",
  141. // These are all bad because they're negative field elements.
  142. "0100000000000000000000000000000000000000000000000000000000000000",
  143. "01ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff7f",
  144. "ed57ffd8c914fb201471d1c3d245ce3c746fcbe63a3679d51b6a516ebebe0e20",
  145. "c34c4e1826e5d403b78e246e88aa051c36ccf0aafebffe137d148a2bf9104562",
  146. "c940e5a4404157cfb1628b108db051a8d439e1a421394ec4ebccb9ec92a8ac78",
  147. "47cfc5497c53dc8e61c91d17fd626ffb1c49e2bca94eed052281b510b1117a24",
  148. "f1c6165d33367351b0da8f6e4511010c68174a03b6581212c71c0e1d026c3c72",
  149. "87260f7a2f12495118360f02c26a470f450dadf34a413d21042b43b9d93e1309",
  150. // These are all bad because they give a nonsquare x^2.
  151. "26948d35ca62e643e26a83177332e6b6afeb9d08e4268b650f1f5bbd8d81d371",
  152. "4eac077a713c57b4f4397629a4145982c661f48044dd3f96427d40b147d9742f",
  153. "de6a7b00deadc788eb6b6c8d20c0ae96c2f2019078fa604fee5b87d6e989ad7b",
  154. "bcab477be20861e01e4a0e295284146a510150d9817763caf1a6f4b422d67042",
  155. "2a292df7e32cababbd9de088d1d1abec9fc0440f637ed2fba145094dc14bea08",
  156. "f4a9e534fc0d216c44b218fa0c42d99635a0127ee2e53c712f70609649fdff22",
  157. "8268436f8c4126196cf64b3c7ddbda90746a378625f9813dd9b8457077256731",
  158. "2810e5cbc2cc4d4eece54f61c6f69758e289aa7ab440b3cbeaa21995c2f4232b",
  159. // These are all bad because they give a negative xy value.
  160. "3eb858e78f5a7254d8c9731174a94f76755fd3941c0ac93735c07ba14579630e",
  161. "a45fdc55c76448c049a1ab33f17023edfb2be3581e9c7aade8a6125215e04220",
  162. "d483fe813c6ba647ebbfd3ec41adca1c6130c2beeee9d9bf065c8d151c5f396e",
  163. "8a2e1d30050198c65a54483123960ccc38aef6848e1ec8f5f780e8523769ba32",
  164. "32888462f8b486c68ad7dd9610be5192bbeaf3b443951ac1a8118419d9fa097b",
  165. "227142501b9d4355ccba290404bde41575b037693cef1f438c47f8fbf35d1165",
  166. "5c37cc491da847cfeb9281d407efc41e15144c876e0170b499a96a22ed31e01e",
  167. "445425117cb8c90edcbc7c1cc0e74f747f2c1efa5630a967c64f287792a48a4b",
  168. // This is s = -1, which causes y = 0.
  169. "ecffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff7f",
  170. }
  171. basepointMultiple := Element{}
  172. basepointMultiple.Zero()
  173. for i := range testVectors {
  174. // Grab the bytes of the encoding
  175. encoding, err := hex.DecodeString(testVectors[i])
  176. if err != nil {
  177. t.Fatalf("#%d: bad hex encoding in test vector: %v", i, err)
  178. }
  179. // Attempt decoding
  180. decodedPoint := Element{}
  181. err = decodedPoint.Decode(encoding)
  182. if err == nil {
  183. t.Fatalf("#%d: did not fail on bad encoding", i)
  184. }
  185. }
  186. }
  187. func TestRistrettoFromUniformBytesTestVectors(t *testing.T) {
  188. inputs := []string{
  189. "Ristretto is traditionally a short shot of espresso coffee",
  190. "made with the normal amount of ground coffee but extracted with",
  191. "about half the amount of water in the same amount of time",
  192. "by using a finer grind.",
  193. "This produces a concentrated shot of coffee per volume.",
  194. "Just pulling a normal shot short will produce a weaker shot",
  195. "and is not a Ristretto as some believe.",
  196. }
  197. elements := []string{
  198. "3066f82a1a747d45120d1740f14358531a8f04bbffe6a819f86dfe50f44a0a46",
  199. "f26e5b6f7d362d2d2a94c5d0e7602cb4773c95a2e5c31a64f133189fa76ed61b",
  200. "006ccd2a9e6867e6a2c5cea83d3302cc9de128dd2a9a57dd8ee7b9d7ffe02826",
  201. "f8f0c87cf237953c5890aec3998169005dae3eca1fbb04548c635953c817f92a",
  202. "ae81e7dedf20a497e10c304a765c1767a42d6e06029758d2d7e8ef7cc4c41179",
  203. "e2705652ff9f5e44d3e841bf1c251cf7dddb77d140870d1ab2ed64f1a9ce8628",
  204. "80bd07262511cdde4863f8a7434cef696750681cb9510eea557088f76d9e5065",
  205. }
  206. var element Element
  207. for i, input := range inputs {
  208. hash := sha512.Sum512([]byte(input))
  209. element.FromUniformBytes(hash[:])
  210. if encoding := hex.EncodeToString(element.Encode(nil)); encoding != elements[i] {
  211. t.Errorf("#%d: expected %q, got %q", i, elements[i], encoding)
  212. }
  213. }
  214. }
  215. func TestEquivalentFromUniformBytes(t *testing.T) {
  216. inputs := []string{
  217. "edffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" +
  218. "1200000000000000000000000000000000000000000000000000000000000000",
  219. "edffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff7f" +
  220. "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
  221. "0000000000000000000000000000000000000000000000000000000000000080" +
  222. "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff7f",
  223. "0000000000000000000000000000000000000000000000000000000000000000" +
  224. "1200000000000000000000000000000000000000000000000000000000000080",
  225. }
  226. expected := "304282791023b73128d277bdcb5c7746ef2eac08dde9f2983379cb8e5ef0517f"
  227. var element Element
  228. for i, input := range inputs {
  229. h, err := hex.DecodeString(input)
  230. if err != nil {
  231. t.Fatalf("#%d: bad hex encoding in test vector: %v", i, err)
  232. }
  233. element.FromUniformBytes(h[:])
  234. if encoding := hex.EncodeToString(element.Encode(nil)); encoding != expected {
  235. t.Errorf("#%d: expected %q, got %q", i, expected, encoding)
  236. }
  237. }
  238. }
  239. func TestMarshalScalar(t *testing.T) {
  240. x := new(Scalar)
  241. // generate an arbitrary scalar
  242. xbytes := sha512.Sum512([]byte("Hello World"))
  243. x.FromUniformBytes(xbytes[:])
  244. text, err := json.Marshal(x)
  245. if err != nil {
  246. t.Fatalf("Could not marshal json: %v", err)
  247. }
  248. t.Logf("json: %s", text)
  249. y := new(Scalar)
  250. err = json.Unmarshal(text, y)
  251. if err != nil || y.Equal(x) == 0 {
  252. t.Fatalf("Error unmarshaling scalar from json: %s %v", text, err)
  253. }
  254. }
  255. func TestMarshalElement(t *testing.T) {
  256. x := new(Element)
  257. // generate an arbitrary element
  258. xbytes := sha512.Sum512([]byte("Hello World"))
  259. x.FromUniformBytes(xbytes[:])
  260. text, err := json.Marshal(x)
  261. if err != nil {
  262. t.Fatalf("Could not marshal json: %v", err)
  263. }
  264. t.Logf("json: %s", text)
  265. y := new(Element)
  266. err = json.Unmarshal(text, y)
  267. if err != nil || y.Equal(x) == 0 {
  268. t.Fatalf("Error unmarshaling element from json: %s %v", text, err)
  269. }
  270. }
  271. func TestElementSet(t *testing.T) {
  272. // Test this, because the internal point type being hard-copyable isn't part of the spec.
  273. el1 := NewElement().Zero()
  274. el2 := NewElement().Base()
  275. if el1.Equal(el2) == 1 {
  276. t.Error("shouldn't be the same")
  277. }
  278. // Check new value
  279. el1.Set(el2)
  280. if el1.Equal(el2) == 0 {
  281. t.Error("failed to set the value")
  282. }
  283. // Mutate source var
  284. el2.Add(el2, el2)
  285. if el1.Equal(el2) == 1 {
  286. t.Error("shouldn't have changed")
  287. }
  288. }
  289. func TestScalarSet(t *testing.T) {
  290. // Test this, because the internal scalar representation being hard-copyable isn't part of the spec.
  291. // 32-byte little endian value of "1"
  292. scOne := make([]byte, 32)
  293. scOne[0] = 0x01
  294. sc1, sc2 := NewScalar(), NewScalar().Zero()
  295. // sc1 <- 1
  296. sc1.Decode(scOne)
  297. // 1 != 0
  298. if sc1.Equal(sc2) == 1 {
  299. t.Error("shouldn't be the same")
  300. }
  301. // sc2 <- sc1
  302. sc2.Set(sc1)
  303. // 1 == 1
  304. if sc1.Equal(sc2) == 0 {
  305. t.Error("failed to set the value")
  306. }
  307. // sc1 <- 1 + 1
  308. sc1.Add(sc1, sc1)
  309. // 2 != 1
  310. if sc1.Equal(sc2) == 1 {
  311. t.Error("shouldn't have changed")
  312. }
  313. }