1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
//! Policy for pure Monte Carlo tree search implementation
//!
//! # Examples
//! ```rust
//! # extern crate connect6;
//! # use connect6::{agent::Agent, policy::DefaultPolicy};
//! let mut policy = DefaultPolicy::with_num_iter(2);
//! let result = Agent::new(&mut policy).play();
//! assert!(result.is_ok());
//! ```
use game::{Game, Player};
use policy::simulate::Simulate;
use policy::Policy;
use {Board, BOARD_SIZE};

use rand;
use rand::prelude::{thread_rng, SliceRandom};
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};

#[cfg(test)]
mod tests;

/// Tree node, get child node as hash value of board array
struct Node {
    visit: i32,
    black_win: i32,
    board: Board,
    next_node: Vec<u64>,
}

impl Node {
    /// Construct a new `Node`
    fn new(board: &Board) -> Node {
        Node {
            visit: 0,
            black_win: 0,
            board: *board,
            next_node: Vec::new(),
        }
    }
}

/// generate hash value of board
///
/// # Examples
/// ```rust
/// # extern crate connect6;
/// # use connect6::{game::{Game, Player}, policy::hash, BOARD_SIZE};
/// let game = Game::new();
/// let hashed = hash(game.get_board());
/// assert_eq!(hashed, hash(&[[Player::None; BOARD_SIZE]; BOARD_SIZE]));
/// ```
pub fn hash(board: &Board) -> u64 {
    let mut hasher = DefaultHasher::new();
    board.hash(&mut hasher);
    hasher.finish()
}

/// compare the board and return the difference between position by position.
///
/// # Examples
/// ```rust
/// # extern crate connect6;
/// # use connect6::{game::Game, policy::{Simulate, diff_board}};
/// let game = Game::new();
/// let mut sim = Simulate::from_game(&game);
/// sim.simulate_in(0, 0);
///
/// let diff = diff_board(game.get_board(), &sim.board());
/// assert_eq!(diff, Some((0, 0)));
/// ```
pub fn diff_board(board1: &Board, board2: &Board) -> Option<(usize, usize)> {
    for row in 0..BOARD_SIZE {
        for col in 0..BOARD_SIZE {
            if board1[row][col] != board2[row][col] {
                return Some((row, col));
            }
        }
    }
    return None;
}
/// Policy for pure Monte Carlo tree search implementation
///
/// # Examples
/// ```rust
/// # extern crate connect6;
/// # use connect6::{agent::Agent, policy::DefaultPolicy};
/// let mut policy = DefaultPolicy::with_num_iter(2);
/// let result = Agent::new(&mut policy).play();
/// assert!(result.is_ok());
/// ```
pub struct DefaultPolicy {
    num_iter: i32,
    map: HashMap<u64, Node>,
}

impl DefaultPolicy {
    /// Construct a new `DefaultPolicy`
    pub fn new() -> DefaultPolicy {
        DefaultPolicy {
            num_iter: 50,
            map: HashMap::new(),
        }
    }

    /// Construct a `DefaultPolicy` with number of iteration in simulation task.
    pub fn with_num_iter(num_iter: i32) -> DefaultPolicy {
        DefaultPolicy {
            num_iter,
            map: HashMap::new(),
        }
    }

    /// Initialize policy
    ///
    /// For the first tree search, tree must be initialized with game status.
    /// `Init` initialize the tree with given `Simulate`
    fn init(&mut self, sim: &Simulate) {
        let board = sim.board();
        self.map.entry(hash(&board)).or_insert(Node::new(&board));
    }

    /// Select the position of the highest winning probability.
    ///
    /// *Note* Given simulation must be initialized by `init` or `expand`.
    fn select(&self, sim: &Simulate) -> Option<(usize, usize)> {
        let node = sim.node.borrow();
        let tree_node = self.map.get(&hash(&node.board)).unwrap();

        // `Node` structure is based on player Black.
        // To calculate probability of given player, it should condition on given player and apply unary function.
        let unary: fn(f32) -> f32 = match sim.turn {
            Player::None => panic!("couldn't calculate none user's prob"),
            Player::Black => |x| x,
            Player::White => |x| -x,
        };
        let prob = |node: &Node| unary(node.black_win as f32 / (1. + node.visit as f32));
        // get the maximum probability node
        let max = tree_node.next_node.iter().max_by(|n1, n2| {
            let node1 = self.map.get(*n1).unwrap();
            let node2 = self.map.get(*n2).unwrap();
            prob(node1).partial_cmp(&prob(node2)).unwrap()
        });

        // if tree_node.next_node is not empty
        if let Some(hashed) = max {
            let max_node = self.map.get(hashed).unwrap();
            // if child_node has meaningful probability
            if prob(max_node) != 0. {
                let pos = diff_board(&max_node.board, &node.board);
                return pos;
            }
        }
        None
    }

    /// Expand the tree in given simulation
    fn expand(&mut self, sim: &Simulate) -> (usize, usize) {
        let mut rng = rand::thread_rng();
        let (row, col) = {
            let node = sim.node.borrow();
            *node.possible.choose(&mut rng).unwrap()
        };
        // simulate random selected position
        let board = sim.simulate(row, col).board();
        let hashed_board = hash(&board);
        // generate node
        self.map.insert(hashed_board, Node::new(&board));

        let parent_node = {
            let node = sim.node.borrow();
            self.map.get_mut(&hash(&node.board)).unwrap()
        };
        // make connection between parent and child
        parent_node.next_node.push(hashed_board);

        (row, col)
    }

    /// Update the tree, random simulation on child node and update visit count of parents'.
    ///
    /// Make random simulation of child node and trace to update visit count, black_win of parent nodes.
    /// If random simulation of child node is end with no one win, method will be returned without update.
    fn update(&mut self, sim: &Simulate, path: &Vec<(usize, usize)>) {
        let mut simulate = sim.deep_clone();
        let mut rng = rand::thread_rng();
        // random simulation
        while simulate.search_winner() == Player::None {
            let (row, col) = {
                let node = simulate.node.borrow();
                match node.possible.choose(&mut rng) {
                    Some(pos) => *pos,
                    None => break,
                }
            };
            simulate.simulate_in(row, col);
        }
        let win = simulate.search_winner();
        if win == Player::None {
            return;
        }
        let black_win = (win == Player::Black) as i32;

        // update parent node
        let mut sim = sim.deep_clone();
        let mut update = |sim: &Simulate| {
            let node = self.map.get_mut(&hash(&sim.board())).unwrap();
            node.visit += 1;
            node.black_win += black_win;
        };

        update(&sim);
        // trace the parent nodes
        for (row, col) in path.iter().rev() {
            sim.rollback_in(*row, *col);
            update(&sim);
        }
    }

    /// Search the tree. Pack of select, expand, update.
    fn search(&mut self, game: &Game) {
        // 1. initialize
        let mut simulate = Simulate::from_game(game);
        self.init(&simulate);

        // 2. searching the tree with selection policy
        let mut path = Vec::new();
        while let Some((row, col)) = self.select(&simulate) {
            // store the history for method `update` to trace parents
            path.push((row, col));
            simulate.simulate_in(row, col);
        }

        if simulate.search_winner() != Player::None {
            return;
        }
        // 3. expand
        let (row, col) = self.expand(&simulate);

        path.push((row, col));
        simulate.simulate_in(row, col);
        // 4. update
        self.update(&simulate, &path);
    }

    /// Generate the policy, prob based selection or else random selection.
    fn policy(&self, sim: &Simulate) -> Option<(usize, usize)> {
        let res = if let Some(pos) = self.select(sim) {
            pos
        } else {
            let node = sim.node.borrow();
            *node.possible.choose(&mut thread_rng()).unwrap()
        };
        Some(res)
    }
}

impl Policy for DefaultPolicy {
    /// Select position based on pure MCTS.
    fn next(&mut self, game: &Game) -> Option<(usize, usize)> {
        // Simulation
        for _ in 0..self.num_iter {
            self.search(game);
        }
        let simulate = Simulate::from_game(game);
        // generate
        self.policy(&simulate)
    }
}