from mod import ZMod
from xgcd import gcd, xgcd

def solveaxb(a, b): 
    '''Given elements a and b of some ZMod(n), return a list
    of all solutions to ax = b in ZMod(n).
    >>> Mod10 = ZMod(10)
    >>> solveaxb(Mod10(3), Mod10(7))
    [Mod(9, 10)]
    >>> solveaxb(Mod10(6), Mod10(7))
    []
    >>> sorted(solveaxb(Mod10(6), Mod10(4)), key=int) # order 0...9 mod 10
    [Mod(4, 10), Mod(9, 10)]
    '''
    # not really worth it with the correction; alternative next:
    #return [a*0 + x for x in solveaxbn(a.value, b.value, a.modulus())]
    n  = a.modulus()
    k = gcd(n, a.value)
    if b.value % k != 0:
        return []
    ModN = ZMod(n)
    if k == n: # a = b = 0 mod n; x arbitrary
        return [ModN(i) for i in range(n)] 
    sol1 = (b.value//k)/ZMod(n//k)(a.value//k)
    return [ModN(sol1.value + i) for i in range(0, n, n//k)]

# In linear algebra terms, range(0, n, n//k) is the null space:
#     all solutions to ax = 0, mod n.

def solveaxbn(a, b, n):
    '''Return a list of all numbers x in {0, 1, ... n-1} that solve
    ax=b mod n, for integer a, b, and n, n >1.
    >>> solveaxbn(3, 7, 10)
    [9]
    >>> solveaxbn(6, 7, 10)
    []
    >>> sorted(solveaxbn(6, 4, 10))
    [4, 9]
    '''

    k = abs(gcd(n, a))  # my gcd does not take abs val
    if b % k != 0:
        return []
    if k == n: # a = b = 0 mod n; x arbitrary
        return range(n)
    sol1 = xgcd(a//k, n//k)[1] * b//k
    return [(sol1 + i) % n for i in range(0, n, n//k)]

if __name__ == '__main__': 
    import doctest
    doctest.testmod()#verbose=True) 
