module Main where

import Numeric
import Maybe
import System
import Control.Concurrent

-- This program shows arbitrary one dimensional cellular automata.
-- It treats each cell as a character and each generation as a line.
-- Each type of line automata has a set of states, a rule determining
-- cells are in the local neighbourhood, and a rule for deciding the
-- next state depending on the current neighbours.

main :: IO ()
--main = loopGenerations (makeRule 2 3 ".#") [-1,0,1]
--  "...............................................#................................................"
main = getArgs >>= handleArgs

handleArgs :: [String] -> IO ()
handleArgs args = case args of
  [ruleNum, neighbourhood, states, init] -> loopGenerations 
    (makeRule (read ruleNum::Int) (length (read neighbourhood::[Int])) states)
 	(read neighbourhood::[Int]) init
  [ruleNum, states, init] -> loopGenerations 
    (makeRule (read ruleNum::Int) 3 states) [-1,0,1] init
  otherwise -> putStrLn "Usage: llife rulenum [neighbourhood] states init"

loopGenerations :: Rule -> [Int] -> [Cell] -> IO ()
loopGenerations rule neighbourhood cells = do
  putStrLn cells     -- cells are chars
  threadDelay 250000 -- period of cycle in microseconds
  loopGenerations rule neighbourhood $ nextGen cells rule neighbourhood

--------------------------------------------------------------------------------
-- Generations

type Cell = Char

nextGen :: [Cell] -> Rule -> [Int] -> [Cell]
nextGen cells rule neighbourhood
  = map (\index -> nextState rule (neighbourGroups !! index)) $ indices cells
  where neighbourGroups = map (getNeighbours cells neighbourhood) $ indices cells

-- apply a Rule to a set of neighbours.
nextState :: Rule -> [Cell] -> Cell
nextState rule neighbours = rule $ neighbours

-- get the neighbours of a cell given its position and relative positions of neighbours.
-- the list of relative positions shall be called the 'neighbourhood'.
-- this function must handle edge conditions. it currently does so by wrapping around.
getNeighbours :: [Cell] -> [Int] -> Int -> [Cell]
getNeighbours cells neighbourhood index
  = map (cells @@) $ map (index +) neighbourhood

-- like (!!), but treats list as a circle
(@@) :: [a] -> Int -> a
list @@ index = list !! (mod index $ length list)

indices :: [a] -> [Int]
indices list = [0 .. length list - 1]

--------------------------------------------------------------------------------
-- Rules
-- WOLFRAM CODES. THAT'S WHAT THE RULE NUMS ARE CALLED.
-- Each cell decides its next state depending on the states of its neighbours.
-- A rule is a function that takes a set of neighbours and returns a state.

type Rule = [Cell] -> Cell

-- generate a rule
makeRule :: Int -> Int -> [Cell] -> Rule
makeRule ruleNum numNeighbours states
  = convert possibleNeighbourStates transitions
  where possibleNeighbourStates = reverse $ permutations numNeighbours states 
        transitions = getTransitions ruleNum numNeighbours states 

-- generate a function that maps elements of the first list to the second
convert :: Eq a => [a] -> [b] -> (a -> b)
convert inputs outputs
  = \input -> fromJust $ lookup input $ zip inputs outputs

-- get the permutations of length n of elements of some set.
permutations :: Int -> [a] -> [[a]]
permutations 0 _ = [[]]
permutations n xs = concat $ map f xs
  where f = \y -> [y:ys | ys <- permutations (n-1) xs]

-- get the list of next-cell-states for a given rule number
getTransitions :: Int -> Int -> [Cell] -> [Cell]
getTransitions ruleNum numNeighbours states
  = fit (showIntAsCells ruleNum states) (length states ^ numNeighbours) $ head states
  -- this converts from base 10 to the base that is however many states there are,
  -- using the states (which are chars) as 'digits'.
  where showIntAsCells num states
           = showIntAtBase (length states) (convert [0..length states] states) num ""
	    
-- pad or truncate a string to fit a certain size.
fit :: String -> Int -> Char -> String
fit str size padChar
   | diff < 0 = (replicate (abs diff) padChar) ++ str
   | diff > 0 = drop diff str
   | otherwise = str
  where diff = (length str) - size

bitwiseNot numBits n = (2^numBits) - 1 - n


