^{2}= x

^{3}+ ax + b over Q, where a and b are integers, is trivial?

Context:

As a personal project I wrote some code to analyze an elliptic curve over the rationals to extract and classify its torsion subgroup. It includes some lines to collect statistics about how often the 15 possibilities enumerated in Mazur's theorem are encountered.

It loops over all curves of the form y

^{2}= x

^{3}+ ax + b (where a and b are integers) with |a| + |b| = 0, then all curves of that form with |a| + |b| = 1, then all curves of that form with |a| + |b| = 2, etc.

So far, it's examined 680,030,280 curves, and of those, it's found 1 torsion subgroup of order 5, 1 of order 9, 2 of order 7, 2 of order 8 (both 2x4, neither cyclic), 30 of order 6, 994 of order 3, 1636 of order 4 (287 cyclic, 1349 2x2), and 330794 of order 2. The remaining 679696820 torsion subgroups (99.95%) were trivial.

Currently, my code computes all points in the torsion subgroup via the Nagell-Lutz theorem. While this is pretty quick for individual curves, it's a bit slow when applied to such large numbers of curves. To that end, I am looking for a faster way to tell whether a given curve's torsion subgroup is trivial.

Code: Select all

`#! /usr/bin/env python3`

from fractions import Fraction as Frac

from labmath import divisors, factorint, isqrt

from itertools import count, chain

class ECPoint(object):

def __init__(self, a, b, x, y):

self.a, self.b, self.x, self.y, self.isinf = a, b, Frac(x), Frac(y), False

if y**2 != x**3 + a*x + b: raise Exception("The point (%s, %s) is not on the given curve (%d, %d)." % (x, y, a, b))

def __eq__(self, Q):

if self.a != Q.a or self.b != Q.b: return False

if Q.isinf or self.isinf: return Q.isinf and self.isinf

return self.x == Q.x and self.y == Q.y

def __str__(self): return "(%s, %s) on (%d, %d)" % (self.x, self.y, self.a, self.b)

def __neg__(self): return ECPoint(self.a, self.b, self.x, -self.y)

def __add__(self, Q):

if (self.a, self.b) != (Q.a, Q.b): raise Exception("Cannot add points on different curves.")

if isinstance(Q, ECPinf): return self

x1, y1, x2, y2 = self.x, self.y, Q.x, Q.y

if (x1, y1) == (x2, y2):

if y1 == 0: return ECPinf(self.a, self.b)

m = (3*x1*x1 + self.a) / (2*y1)

else:

if x1 == x2: return ECPinf(self.a, self.b) # vertical line

m = (y2 - y1) / (x2 - x1)

x3 = m*m - x2 - x1

y3 = m*(x3 - x1) + y1

return ECPoint(self.a, self.b, x3, -y3)

def __sub__(self, Q): return self + -Q

def __mul__(self, n):

if not isinstance(n, int): raise Exception("Points can only be multiplied by integers.")

if n < 0: return -self * -n

if n == 0: return ECPinf(self.a, self.b)

Q = self

R = self if n & 1 == 1 else ECPinf(self.a, self.b)

i = 2

while i <= n:

Q = Q + Q

if n & i == i: R = Q + R

i = i << 1

return R

def __rmul__(self, n): return self * n

class ECPinf(ECPoint):

def __init__(self, a, b): self.a, self.b, self.isinf = a, b, True

#def __eq__(self, Q): return (self.a, self.b) == (Q.a, Q.b) and Q.isinf # TODO exception if different curves?

def __str__(self): return "(\u221e,\u221e) on (%d, %d)" % (self.a, self.b) # That's unicode for infinity.

def __neg__(self): return self

def __add__(self, Q):

if (self.a, self.b) != (Q.a, Q.b): raise Exception("Cannot add points on different curves.")

return Q

def __mul__(self, n):

if not isinstance(n, int): raise Exception("Points can only be multiplied by integers.")

return self

# See http://math.nyu.edu/degree/undergrad/ug_research/presentation2.pdf.

ttypes = ('1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '12', '2x2', '2x4', '2x6', '2x8')

tstats = {x:0 for x in ttypes}

facs = {}

def spiral():

global facs, tstats

yield (0,0)

for s in count(1):

#print('\b'*42, s, end='', flush=True)

if s % 10 == 0: facs = {}

for x in range(s-1): yield (s-x-1,x)

for x in range(s-1): yield (-x,s-x-1)

for x in range(s-1): yield (1+x-s,-x)

for x in range(s-1): yield (x,1+x-s)

print('\b'*180 + str(s) + ' ' + ' '.join("%s: %d" % (x,tstats[x]) for x in ttypes), end='', flush=True)

def divs(n, facs):

if n not in facs: facs[n] = factorint(n)

yield from divisors(facs[n])

def torstruct(a, b, torsion):

if len(torsion)+1 in (1, 2, 3, 5, 6, 7, 9, 10): return str(len(torsion)+1)

if len(torsion)+1 == 16: return "2x8"

for (x,y) in torsion:

P = ECPoint(a, b, x, y)

Q, n = P, 1

while Q != ECPinf(a,b):

Q += P

n += 1

#assert Q == n*P

if n == len(torsion)+1: return str(len(torsion)+1)

return "2x%d" % ((len(torsion)+1)//2)

for (a,b) in spiral():

#print('\b'*42, abs(a)+abs(b), a, b, end=' ', flush=True)

D = abs(4*a*a*a + 27*b*b)

if D == 0: continue

# y^2 == x^3 + a*x + b is an elliptic curve over Q. Now we find integer points on it.

Z = set() # This will be the set of all points on the curve with finite order.

for y in chain([0], divs(D, facs)):

if y != 0 and D % (y*y) != 0: continue

# We must now find the integer roots of x^3 + a*x + b - y^2.

# By the rational roots theorem, we just have to test 0 and divisors of b-y^2.

for x in divs(abs(b-y*y), facs):

if b + a*x + x*x*x == y*y: Z |= {( x,y), ( x,-y)}

if b - a*x - x*x*x == y*y: Z |= {(-x,y), (-x,-y)}

if b == y*y: # The rational roots theorem doesn't help us when the constant term is zero.

# We're looking for roots of x^3 + ax. These are 0, sqrt(-a), and -sqrt(-a).

Z |= {(0,y), (0,-y)}

ar = isqrt(-a)

if ar*ar == -a: Z |= {(ar,y), (-ar,y), (ar,-y), (-ar,-y)}

# Z should now contain all points on the curve with finite order, and possibly some other points.

torsion = set()

for (x,y) in Z:

P = ECPoint(a, b, x, y)

# The denominators on these points get very big very fast. It's much more efficient to work up to 13P in 1P increments

# and bail early than to jump straight to 13P.

Q = inf = ECPinf(a, b)

for n in range(1, 13):

Q += P

if Q == inf: torsion.add((x,y)); break # P is a torsion point.

if Q.x.denominator != 1 != Q.y.denominator: break # P is not a torsion point.

else: assert False # My understanding of torsion points needs work.

ts = torstruct(a, b, torsion)

tstats[ts] += 1

if ts not in ('1,2,3,4,2x2'): print('\b'*180, a, b, ' '*8, len(Z), ' ', ts, ' '*8, *sorted(torsion), ' '*64)

print()