from __future__ import division
import random, math
from copy import deepcopy
import vect, constants, physics, simobject, resources
from numpy import *

class Cell:
    """ A cell is the base unit of a cellular automata system.

    Because it can have multiple, non-discrete values it is pushing the common
    definition of a cellular automata, but it behaves in fundamentally the
    same way, soo it is a good definition.
    Each cell is associated with a grid square, so that the cell can pull in
    and push out information into the larger game simulation.
    """
    def __init__(self, CA, grid_ref, value = None, super_set = False):
        self.CA = CA
        self.grid_ref = grid_ref

        # we need to assign the cell some properties, most easily done by
        # looking at the equivalent square.
        # super _set squares extend one past the boundary of the map
        # useful for example, for shockwave algorithms to allow
        # energy and momentum to "leave the map"


        # we need to set the grid_square "local geometry" based on whether it is part of a super_set or not
        if super_set:
            self.grid_square = CA.world.grid.super_set_squares[self.grid_ref]
        else:
            self.grid_square = CA.world.grid.squares[self.grid_ref]

        #for sq in self.grid_square.face_squares:
        #self.face_cells = self.grid_square.face_squares
        #self.adj_cells = self.grid_square.adjacent_squares
        #self.location = self.grid_square.location
        self.face_cells = deepcopy(self.grid_square.face_squares)
        self.adj_cells = deepcopy(self.grid_square.adjacent_squares)
        self.location = deepcopy(self.grid_square.location)

        # now we have assigned geometry, we want to reset gridsquare if possible to an
        # interior square so that we can easily pickup sims, blocked status, etc.
        # stuff that can be affected by game sims
        if self.check_inside_grid():
            self.grid_square = CA.world.grid.squares[self.grid_ref]

        # we assign a default value if None was provided (one frequently will be)
        if value == None:
            value = CA.default_value

        self.value = value
        self.new_value = value
        self.turn = CA.turn

        # A temp value (from check_adj_cells) that helps determine a cells future value.
        # Necessary so that we can use lambda functions to make generic
        # functions to determine future cell values.
        self.check_value = None

    def check_inside_grid(self):
        x,y = self.grid_ref
        x_max, y_max = self.CA.world.grid.numSquaresX, self.CA.world.grid.numSquaresY
        return not (x < 0 or y < 0 or x >= x_max or y >= y_max)

    def current_sims(self):
        """ returns the current sims of the corresponding grid square
        """
        return self.grid_square.sims

    def check_blocked(self):
        """ returns the blocked status of the corresponding grid square
        """
        return self.grid_square.blocked

#===============================================================================
#===============================================================================

class CA:
    def __init__(self, world, name):
        self.name = name
        self.world = world
        self.turn = 0
        self.cells = {}
        self.new_cells = {}
        self.default_value = 0

    def run_simulation(self):
        pass

    def adjacent_indices(self, idx):
        x_max, y_max = self.world.grid.numSquaresX, self.world.grid.numSquaresY
        return [(x,y) for (x,y) in constants.ADJ_OFFSETS + array(tuple(idx))\
                      if 0 <= x and x < x_max and 0 <= y and y < y_max ]

    def face_indices(self, idx):
        x_max, y_max = self.world.grid.numSquaresX, self.world.grid.numSquaresY
        return [(x,y) for (x,y) in constants.FACE_OFFSETS + array(tuple(idx)) \
                      if 0 <= x and x < x_max and 0 <= y and y < y_max ]

#================================================================================
# #================================================================================
class Heat(CA):
    def __init__(self, world, name):
        CA.__init__(self, world, name)
        self.initialise_cells()
        """temp (K), specific heat capacity (J/K/kg), mass (kg)"""
        self.default_value = [15,constants.SHC_AIR, constants.RHO_AIR, constants.HEAT_TRANSFER_AIR]

    def initialise_cells(self):
        self.cells.clear()
        self.new_cells.clear()

    def run_simulation(self):
        if len(self.cells)+len(self.new_cells)>0:
            initial_turn = self.turn
            while self.turn-initial_turn<1:
                self.find_new_values(lambda cell: cell.check_value>1.1*cell.CA.default_value[0])
            self.force_update()

    def initialise_cell(self, grid_ref, value = None):
        if not self.world.grid.squares[grid_ref].blocked:
            self.cells[grid_ref] = Cell(self, grid_ref, value)
            self.new_cells[grid_ref] = Cell(self, grid_ref, value)

    def cell_heat_properties(self, grid_ref):
        mass = 0
        HC = 0
        heat_transfer = 0
        for sim in self.world.grid.squares[grid_ref].sims:
            if sim.heatable:
                mass += sim.attrs['physics']['mass']
                HC += sim.attrs['physics']['mass']*sim.attrs['thermo']['shc']

        if self.world.grid.squares[grid_ref].blocked:
            for sim in self.world.grid.squares[grid_ref].sims:
                if sim.heatable:
                    heat_transfer = sim.attrs['thermo']['heat_transfer']

        if not self.world.grid.squares[grid_ref].blocked:
            """ cell must have air in it """
            mass += constants.RHO_AIR*2 #(2m^3 of air)
            HC += constants.RHO_AIR*2*constants.SHC_AIR
            heat_transfer = constants.HEAT_TRANSFER_AIR

        if grid_ref in self.cells.keys():
            cell = self.cells[grid_ref]
            cell.value[1] = HC/mass
            cell.value[2] = mass
            cell.value[3] = heat_transfer

        return (HC/mass, mass, heat_transfer)

    def find_new_values(self, keep_condition):

        for cell in self.cells.itervalues():
            cell.check_value = self.check_adj_cells(cell, cell.face_cells)
            if keep_condition(cell):
                self.new_cells[cell.grid_ref].value[0] = self.new_cells[cell.grid_ref].value[0]*9//10 # poorly simualates heat lost to the z axis - speeds up algorithm substantially though
            else:
                del self.new_cells[cell.grid_ref]
        self.cells = self.new_cells.copy()
        self.turn += 1

    def add_cell(self, grid_ref, value = None):
        if grid_ref in self.new_cells.keys():
            self.new_cells[grid_ref].value[0] += value
        else:
            self.new_cells[grid_ref] = Cell(self, grid_ref, deepcopy(self.default_value))
            self.new_cells[grid_ref].value[0] =  value
            self.cell_heat_properties(grid_ref)
        return self.new_cells[grid_ref]

    def damage_sims_in_cell(self, grid_ref, cell_temp):
        for sim in self.cells[grid_ref].current_sims().copy():
            if sim.heatable:
                """ damage the sim"""
                sim.damage(max(0,cell_temp - sim.attrs['thermo']['burn_min'])*sim.attrs['thermo']['burn_rate'], "burn")
                #sim.health -= damage
                dtemp = cell_temp - sim.attrs['thermo']['burn_temp']
                """ heat up the cell """
                if dtemp>0:
                    alpha = sim.attrs['thermo']['shc']*sim.attrs['physics']['mass']/sim.attrs['thermo']['burn_temp']
                    E_max = (sim.attrs['thermo']['burn_temp'] - sim.attrs['thermo']['burn_max'])**2*alpha
                    if cell_temp >= sim.attrs['thermo']['burn_max']:
                        energy_released = E_max
                    else:
                        energy_released = E_max - (cell_temp - sim.attrs['thermo']['burn_max'])**2*alpha
                    cell = self.cells[grid_ref]
                    """ we multiply by burnrate to reflect many (material specific) physics updates per game turn """
                    temp_increase = energy_released/(cell.value[1]*cell.value[2])*sim.attrs['thermo']['burn_rate']

                    cell.value[0] += temp_increase
                    """ add fire effect """
                    offset = random.uniform(-0.5, 0.5),random.uniform(-0.5, 0.5)
                    location = vect.add(1,sim.location,1,offset)
                    self.world.fire = self.world.add_sim("fire", location)
                    if random.random() < 0.05: # seems like there are 10-20 flames per cell at anytime
                        resources.play_sound(self.world.sounds["fire"], resources.attenuate_sound(self.world.fire))

    def check_adj_cells(self, cell, adj_cells):
        dt = self.world.dt
        new_cell = self.new_cells[cell.grid_ref]
        temp = cell.value[0] # temperature

        self.cell_heat_properties(cell.grid_ref)
        SHC = cell.value[1] #specific heat capacity
        mass = cell.value[2]
        heat_transfer = cell.value[3]

        update_factor = 20*dt#1#1.0/4/2
        #flow_factor = 1.0/4/2

        if temp>1.0*self.default_value[0] and dt>0:
            self.damage_sims_in_cell(cell.grid_ref, cell.value[0])

            for adj_ref in adj_cells:

                if adj_ref in self.cells.keys():
                    neighbour_cell = self.cells[adj_ref]
                    temp_n = neighbour_cell.value[0]
                    #self.cell_heat_properties(adj_ref)
                    SHC_n = neighbour_cell.value[1]
                    mass_n = neighbour_cell.value[2]
                    heat_transfer_n = neighbour_cell.value[3]
                else:
                    temp_n = self.default_value[0]
                    temp_values = self.cell_heat_properties(adj_ref)
                    SHC_n = temp_values[0]
                    mass_n = temp_values[1]
                    heat_transfer_n = temp_values[2]

                HC_cell = SHC*mass
                HC_neigh = SHC_n*mass_n
                temp_diff = temp_n-temp
                heat_transfer_tot = 1.0/(0.5/heat_transfer+0.5/heat_transfer_n)

                """fine energy transfer"""
#                if temp_diff>0:
#                    energy_flow = temp_diff * HC_cell
#                else:
#                    energy_flow = temp_diff * HC_neigh
                energy_flow = temp_diff * heat_transfer_tot
                energy_flow *= update_factor # to take into account that many "physics" updates happen each game tick

                "find transfer temperatures"""
                dtemp = energy_flow / HC_cell
                dtemp_n = - energy_flow / HC_neigh

                """ kill oscillations in temperature between cells"""
                new_temp = temp + dtemp
                new_temp_n = temp_n + dtemp_n
                if ((energy_flow>0 and new_temp_n < new_temp) or
                        (energy_flow<0 and new_temp_n > new_temp)):
                    avg_temp = (temp*HC_cell + temp_n*HC_neigh) /(HC_cell + HC_neigh)
                    print "OSCILLATION", avg_temp
                    dtemp_n = avg_temp -  temp_n
                    dtemp = avg_temp - temp

                """ update neighbour cell, and create one if necessary """
                if adj_ref in self.new_cells.keys():
                    new_neighbour_cell = self.new_cells[adj_ref]
                    new_neighbour_cell.value[0] += dtemp_n
                elif math.fabs(dtemp_n) > 0.2*self.default_value[0]:
                    new_neighbour_cell = self.add_cell(adj_ref, self.default_value[0])
                    new_neighbour_cell.value[0] += dtemp_n

                """ update temp of cell"""
                new_cell.value[0] += dtemp

        return new_cell.value[0]

    def show_active_cells(self):
        for ref, cell in self.cells.items():
            draw = False
            if self.world.player:
                if ref in self.world.player.visible_tiles:
                    draw = True
            else:
                draw = True
            if draw:
                if cell.value[0]>0:
                    frac = 1-self.default_value[0]/max(cell.value[0],self.default_value[0])
                    frac2 = 1-500/max(cell.value[0],500)
                    color = (255,int(frac*255),int(frac2*255))
                    self.world.grid.draw_rect(self.world, ref, color)

    def force_update(self):
        #if self.world.draw_sprite_paths:
            self.show_active_cells()


# #================================================================================
class Shockwave(CA):
    def __init__(self, world, name):
        CA.__init__(self, world, name)
        self.initialise_cells()

        self.default_value = 1
        self.flow_threshold = 1
        self.pressure_threshold = 1

        self.reference_squares = self.world.grid.super_set_squares

    def check_inside_grid(self, grid_ref):
        x,y = grid_ref
        x_max, y_max = self.world.grid.numSquaresX, self.world.grid.numSquaresY
        return not (x < 0 or y < 0 or x >= x_max or y >= y_max)

    def grid_ref_blocked(self, grid_ref):
        """ Returns true unless a cell is inside the grid and blocked.
        """
        # we can put the blocked check after the check inside grid check. This
        # is because the method will exit as soon as it determines that a cell
        # is outside the grid.
        return self.check_inside_grid(grid_ref) and self.world.grid.squares[grid_ref].blocked

    def initialise_cells(self):
        pass

    def initialise_cell(self, grid_ref, value = None, super_set = True):
        """ Creates a new cell unless the cell is blocked.
        """
        if not grid_ref_blocked(grid_ref):
            self.cells[grid_ref] = Cell(self, grid_ref, value, super_set)
            self.new_cells[grid_ref] = Cell(self, grid_ref, value, super_set)

    def run_simulation(self):
        if len(self.cells)+len(self.new_cells)>0:
            initial_turn = self.turn
            while self.turn-initial_turn<1:
                self.find_new_values(lambda cell: cell.check_value>1.5*cell.CA.default_value)
            self.force_update()

    def find_new_values(self, keep_condition):

        for cell in self.cells.itervalues():
            cell.check_value = self.check_adj_cells(cell, cell.face_cells)
            if keep_condition(cell):
                self.new_cells[cell.grid_ref].value[0] *= 90//10/10
            else:
                del self.new_cells[cell.grid_ref]

        self.cells = self.new_cells.copy()
        self.turn += 1

    def add_cell(self, grid_ref, value = None):
        self.new_cells[grid_ref] = Cell(self, grid_ref, value, super_set = True)
        return self.new_cells[grid_ref]

    def damage_sims_in_cell(self, grid_ref, value):
        damage_factor = 2*30000
        flow_factor = 3*5000
        for sim in self.world.grid.squares[grid_ref].sims:
            if sim.collidable or sim.opaque:
                damage = max(0,value[0]-self.default_value)/4*damage_factor
                sim.damage(damage, "kinetic")
                #sim.health-= max(0,value[0]-self.default_value)/4*damage_factor
                flow = value[1]*flow_factor, value[2]*flow_factor
                physics.force_cart(sim, flow)

    def check_adj_cells(self, cell, adj_cells):
        #delta_rho = -dt*( dp*(dir[0]*vx+dir[1]*vy) + constants.RHO_AIR*(dvx*dir[0]+dvy*dir[1]) )
        count = 0.0
        tot_flow = 0
        dt = self.world.dt
        new_cell = self.new_cells[cell.grid_ref]
        cell_flow = (cell.value[1],cell.value[2])
        f = dt #flow weighting (dt = default)
        g = 1
        a = 1 #peak flow limiter
        s = 0.9 #pressure distribution factor (1=default)
        p = cell.value[0]
        vx = cell.value[1]
        vy = cell.value[2]

        #damage_factor = 5000
        if len(adj_cells)==4:
            if p>1.5*self.default_value and dt>0:
                self.damage_sims_in_cell(cell.grid_ref, cell.value)

                for adj_ref in adj_cells:
                    dir = None

                    if adj_ref in self.cells.keys():
                        neighbour_cell = self.cells[adj_ref]

                        dir = vect.add(1, adj_ref, -1, cell.grid_ref)
                        p_n = neighbour_cell.value[0]
                        vx_n = neighbour_cell.value[1]
                        vy_n = neighbour_cell.value[2]

                    elif not self.grid_ref_blocked(adj_ref):

                        dir = vect.add(1, adj_ref, -1, cell.grid_ref)
                        p_n = self.default_value
                        vx_n = 0
                        vy_n = 0

                    if dir is not None:

                        dp = (p-p_n)
                        dvx = (vx-vx_n)#*dir[0]
                        dvy = (vy-vy_n)#*dir[1]

                        transfer_flow = vect.add(dp/dt, dir, dp/f, (vx,vy))
                        scalar_flow = vect.dot(transfer_flow, dir)

                        if scalar_flow>0:
                            tot_flow+=scalar_flow

                for adj_ref in adj_cells:
                    dir = None
                    if adj_ref in self.cells.keys():
                        neighbour_cell = self.cells[adj_ref]

                        dir = vect.add(1, adj_ref, -1, cell.grid_ref)
                        p_n = neighbour_cell.value[0]
                        vx_n = neighbour_cell.value[1]
                        vy_n = neighbour_cell.value[2]

                    elif not self.grid_ref_blocked(adj_ref):

                        dir = vect.add(1, adj_ref, -1, cell.grid_ref)
                        p_n = self.default_value
                        vx_n = 0
                        vy_n = 0

                    elif self.grid_ref_blocked(adj_ref):
                        self.damage_sims_in_cell(adj_ref, cell.value)

                    if dir is not None:

                        dp = (p-p_n)
                        dvx = (vx-vx_n)#*dir[0]
                        dvy = (vy-vy_n)#*dir[1]

                        transfer_flow = vect.add(dp/dt, dir, dp/f, (vx,vy))
                        scalar_flow = vect.dot(transfer_flow, dir)

                        if scalar_flow>0:

                            if adj_ref in self.new_cells.keys():
                                new_neighbour_cell = self.new_cells[adj_ref]
                            else:
                                new_neighbour_cell = self.add_cell(adj_ref, [self.default_value, 0,0])

                            new_neighbour_cell.value[0] += (p-self.default_value)*scalar_flow/tot_flow*s
                            new_neighbour_cell.value[1] += transfer_flow[0]*scalar_flow/tot_flow/g
                            new_neighbour_cell.value[2] += transfer_flow[1]*scalar_flow/tot_flow/g

                            new_neighbour_cell.value[1] = vect.sign(new_neighbour_cell.value[1])*min(a*(new_neighbour_cell.value[0]-self.default_value), math.fabs(new_neighbour_cell.value[1]))
                            new_neighbour_cell.value[2] = vect.sign(new_neighbour_cell.value[2])*min(a*(new_neighbour_cell.value[0]-self.default_value), math.fabs(new_neighbour_cell.value[2]))

                            new_cell.value[0] -= (p-self.default_value)*scalar_flow/tot_flow*s
                            new_cell.value[1] -= 0#transfer_flow[0]*scalar_flow/tot_flow
                            new_cell.value[2] -= 0#transfer_flow[1]*scalar_flow/tot_flow

                    else:
                        dir = vect.add(1, adj_ref, -1, cell.grid_ref)
                        if vect.dot((0,1),dir)==0:
                            new_cell.value[1] = 0
                        else:
                            new_cell.value[2] = 0

        else:
            new_cell.value[0] = 0
        return new_cell.value[0]

    def show_active_cells(self):
        for ref, cell in self.cells.items():
            if cell.value[0]>1:
                frac = 1-1.0/max(cell.value[0],1)
                color = (255,int(frac*255),0)
                self.world.grid.draw_rect(self.world, ref, color)
            elif cell.value[0]==1:
                color = (0,255,0)
                self.world.grid.draw_rect(self.world, ref, color)
            else:
                frac = 1+1.0/min(cell.value[0],-1)
                color = (0,0,int(frac*255))
                self.world.grid.draw_rect(self.world, ref, color)

    def force_update(self):
        self.show_active_cells()

#================================================================================
class Partitions(CA):
    def __init__(self, world, name):
        CA.__init__(self, world, name)
        self.initialise_cells()

    def initialise_cells(self):
        self.stable_cells = {}
        self.partitions = {}
        self.candidate_cells = {}
        """ self.value is always the highest point added. we make it a permanent
        variable so we can add more cells later, and they will be consumed
        by other lower values larger sets."""
        value = 1
        self.value = 1
        for key in ndindex(self.world.grid.squares.shape):
            self.initialise_cell(key)
            value = value+1

    def check_run(self):
        return len(self.cells) + len(self.candidate_cells) > 0

    def initialise_cell(self, grid_ref, value = None):
        if not value:
            value = self.value + 1
        if not self.world.grid.squares[grid_ref].blocked:
            self.candidate_cells[grid_ref] = Cell(self, grid_ref, value)
            self.stable_cells[grid_ref] = Cell(self, grid_ref, value)
            self.value = value + 1

    def find_new_values(self, keep_condition):
        self.find_lowest_cells()

        for cell in self.cells.values():
            cell.check_value = self.check_adj_cells(cell, cell.face_cells)
            if keep_condition(cell):
                self.new_cells[cell.grid_ref].value = cell.check_value
            else:
                self.stable_cells[cell.grid_ref] = cell
                del self.new_cells[cell.grid_ref]
                del self.candidate_cells[cell.grid_ref]
        self.cells = self.new_cells.copy()
        self.create_partitions()
        self.turn += 1

    def find_lowest_cells(self):
        """ Finds the lowest set of partitioned cells to run a simulation on. Optimisation
        based on only processing cells value once, instead of potentially multiple times
        due to multiple waves of "lower" cells"""
        if not len(self.cells):
            if len(self.candidate_cells):
                cur = max(self.value, 2**30)
                mins = []
                for cell in self.candidate_cells.itervalues():
                    tmp = min(cell.value, cur)
                    if tmp < cur:
                        for c in mins: self.stable_cells[c.grid_ref] = c
                        mins = [cell]
                        cur = tmp
                    else:
                        mins.append(cell)
                for cell in mins:
                    self.cells[cell.grid_ref] = cell
                    self.new_cells[cell.grid_ref] = cell

    def create_partitions(self):
        #if not len(self.cells):
            self.partitions = {}
            max_value = 0
            for key, cell in self.stable_cells.items():
                self.add_partitioned_cell(cell)
                max_value = max(max_value, cell.value)

    def add_partitioned_cell(self, cell):
        self.partitions.setdefault(cell.value, []).append(cell.grid_ref)

    def add_cell(self, grid_ref, value = None):
        self.cells[grid_ref] = Cell(self, grid_ref, value)
        self.new_cells[grid_ref] = Cell(self, grid_ref, value)
        self.candidate_cells[grid_ref] = Cell(self, grid_ref, value)
        return self.new_cells[grid_ref]

    def check_adj_cells(self, cell, adj_cells):
        value = cell.value
        for adj_cell in adj_cells:
            if adj_cell in self.cells:
                neighbour_cell = self.cells[adj_cell]
                value = min(value, neighbour_cell.value)
                #neighbour_cell.value = value
            elif not self.world.grid.squares[adj_cell].blocked:
                if self.stable_cells[adj_cell].value>value:
                    self.add_cell(adj_cell, value)
                else:
                    value = self.stable_cells[adj_cell].value
        return value

    def show_partitions(self):
        length = max(len(self.partitions),1)
        count = 1.0
        for key, value in self.partitions.items():
            frac = count/length
            color = (int(frac*255),int(frac*255),255)
            for tuple in value:
                self.world.grid.draw_rect(self.world, tuple, color)
            count +=1

    def show_active_cells(self):
        for key, value in self.cells.items():
            frac = 1-1.0/(value.value+1)
            color = (int(frac*255),0,0)
            self.world.grid.draw_rect(self.world, key, color)

    def find_max_partition(self):
        max_size = 0
        max_key = 0
        for key, value in self.partitions.items():
            if len(value) > max_size:
                max_size = len(value)
                max_key = key
        self.max_partition = self.partitions[max_key]

#================================================================================

class Atmosphere(CA):
    def __init__(self, world, name):
        CA.__init__(self, world, name)
        self.default_value = 1
        self.initialise_cells()

    def initialise_cells(self):
        self.cells = {}
        self.new_cells = {}

    def clear_cells(self):
        self.cells.clear()
        self.new_cells.clear()

    def create_explosion(self, grid_ref, value):
        self.add_cell(grid_ref, value)

    def run_simulation(self):
        if len(self.cells)>0:
            initial_turn = self.turn
            while self.turn-initial_turn<10:
                self.find_new_values(lambda cell: cell.check_value>1.1, lambda cell: cell.value < 2)
            self.force_update()

    def add_cell(self, grid_ref, value = None):
        self.cells[grid_ref] = Cell(self, grid_ref, value)
        self.new_cells[grid_ref] = Cell(self, grid_ref, value)
        return self.new_cells[grid_ref]

    def rem_cell(self, cell):
        del self.new_cells[cell.grid_ref]

    def find_new_values(self, keep_condition, rem_condition):
        for cell in self.cells.itervalues():
            cell.check_value = self.check_adj_cells(cell, cell.face_cells, rem_condition)
            if keep_condition(cell):
                self.new_cells[cell.grid_ref].value = value*18//20
                """ we reduce the value slightly to account for losses to other forms of energy """
                """ and it speeds it up MASSIVELY """
            else:
                del self.new_cells[cell.grid_ref]
        self.cells = self.new_cells.copy()
        self.turn += 1

    def check_adj_cells(self, cell, adj_cells, rem_condition):
        count = 1.0
        value = cell.value

        if cell.grid_square.blocked:
            for sim in cell.current_sims():
                sim.health-= cell.value*5

        for adj_cell in adj_cells:

            kill = False
            if adj_cell in self.cells.keys():
                neighbour_cell = self.cells[adj_cell]
            else:
                neighbour_cell = self.add_cell(adj_cell)
                if rem_condition(cell):
                    kill = True

            if not neighbour_cell.grid_square.blocked:
                count += 1
                value += neighbour_cell.value
            else:
                for sim in neighbour_cell.current_sims():
                    sim.health-= cell.value/4*5

            if kill:
                self.rem_cell(neighbour_cell)

        value = value//count
        """ we do floor division for speed"""
        return value

    def force_update(self):
        for cell in self.cells.itervalues():
            self.apply_changes(cell)

    def apply_changes(self, cell):
        if self.world.player:
            if cell.grid_ref in self.world.player.visible_tiles:
                frac = (1-1.0/max(cell.value,1))**10
                self.world.grid.draw_rect(self.world, cell.grid_ref, (int(255*frac),0,0))
            else:
                frac = (1-1.0/max(cell.value,1))**10
                self.world.grid.draw_rect(self.world, cell.grid_ref, (int(255*frac),0,0))
        for sim in cell.current_sims():
            sim.health-=cell.value*50

#================================================================================
#
#================================================================================

class Deposition(CA):
    def __init__(self, world, name):
        CA.__init__(self, world, name)
        self.cells = ndarray(self.world.grid.squares.shape, dtype=bool)
        self.cells.fill(self.default_value)
        self.new_cells = ndarray(self.world.grid.squares.shape, dtype=bool)
        self.new_cells.fill(self.default_value)

    def initialise_cells(self, set = None):
        assert not set
        self.turn = 0

    def seed_map(self, prob):
        self.cells = random.binomial(1, prob, self.world.grid.squares.shape).astype(bool)
        self.turn = 1

    def find_new_values(self, condition):
        for idx, val in ndenumerate(self.cells):
            check_value = self.check_adj_cells(idx, val)
            self.new_cells[idx] = condition(idx, val, check_value)
        self.cells, self.new_cells = self.new_cells, self.cells
        self.new_cells.fill(self.default_value)
        self.turn += 1

    def check_adj_cells(self, idx, val):
        return sum(self.cells[i] for i in self.adjacent_indices(idx))

    def force_update(self):
        for idx, val in ndenumerate(self.cells):
            loc = lambda x,y: ((x+0.5)*self.world.grid.square_width, (y+0.5)*self.world.grid.square_height)
            if val > 0: self.world.add_sim(self.name, loc(*idx))
            else: self.world.empty_square(idx, self.name)
