# relnCollFilt.py - Latent Feature-based Collaborative Filtering
# Python 3 code. Full documentation at http://artint.info/code/python/code.pdf

# Artificial Intelligence: Foundations of Computational Agents
# http://artint.info
# Copyright David L Poole and Alan K Mackworth 2016.
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# See: http://creativecommons.org/licenses/by-nc-sa/4.0/deed.en_US.

import random
import matplotlib.pyplot as plt
import urllib.request

class CF_learner(object):
    def __init__(self,
                 ratings,              # list of (u,i,r,t) ratings
                 step_size = 0.01,     # gradient descent step size
                 regularization = 0.1, # the weight for the regulization terms
                 num_features = 10,    # number of hidden features
                 feature_range = 0.2   # feaures are initialized to be between
                                       # -feature_range and feature_range
                 ):
        self.ratings = ratings
        self.step_size = step_size
        self.regularization = regularization
        self.num_features = num_features
        self.num_ratings = len(ratings)
        self.ave_rating = (sum([rating for (user,item,rating,timestamp) in ratings])
                           /self.num_ratings)

        self.set_users_that_rated = set([user for (user,item,rating,timestamp) in ratings]) 
        self.set_items_rated = set([item for (user,item,rating,timestamp) in ratings])
        self.max_item_index = max(self.set_items_rated)
        self.max_user_index = max(self.set_users_that_rated)
        self.user_offset = [0]*(self.max_user_index+1)
        self.item_offset = [0]*(self.max_item_index+1)
        self.self_offset = 0
        self.user_feature = [[random.uniform(-feature_range,feature_range)
                              for f in range(num_features)]
                             for i in range(self.max_user_index+1)]
        self.item_feature = [[random.uniform(-feature_range,feature_range)
                              for f in range(num_features)]
                             for i in range(self.max_item_index+1)]
        self.iter=0

    def learn(self,
              num_iter = 50,        # number of iterations of gradient descent
              trace = 1):
      """ do num_it iterations of gradient descent."""
      for i in range(num_iter):
        self.iter += 1
        total_error=0
        sumsq_error=0
        for (user,item,rating,timestamp) in self.ratings:
            prediction = (self.ave_rating
                          + self.user_offset[user]
                          + self.item_offset[item]
                          + (self.self_offset if timestamp else 0)
                          + sum([self.user_feature[user][f]*self.item_feature[item][f]
                                 for f in range(self.num_features)]))
            error = prediction - rating
            total_error += abs(error)
            sumsq_error += error * error
            if timestamp: self.self_offset += -self.step_size*error - self.step_size*self.regularization* self.self_offset
            self.user_offset[user] += (
                -self.step_size*error
                -self.step_size*self.regularization* self.user_offset[user])
            self.item_offset[item] += (
                -self.step_size*error
                -self.step_size*self.regularization*self.item_offset[item])
            for f in range(self.num_features):
                user_f = self.user_feature[user][f]
                item_f = self.item_feature[item][f]
                self.user_feature[user][f] += (
                    -self.step_size*error*item_f
                    -self.step_size*self.regularization * user_f)
                self.item_feature[item][f] += (
                    -self.step_size*error*user_f
                    -self.step_size*self.regularization*item_f)
        if trace>0:
           print("Iteration",i,
              "Ave Abs Error =",total_error/self.num_ratings,
              "Ave sq error =",sumsq_error/self.num_ratings)

    def plot_feature(self,
                     f,               # feature
                     plot_all=False,  # true if all points should be plotted
                     num_points=200   # number of random points plotted if not all
                     ):
        """plot some of the user-movie ratings,
        if plot_all is true
        num_points is the number of points selected at random plotted.

        the plot has the users on the x-axis sorted by their value on feature f and
        with the items on the y-axis sorted by their value on feature f and 
        the ratings plotted at the corresponding x-y position.
        """
        plt.ion()
        plt.xlabel("users")
        plt.ylabel("items")
        user_vals = [self.user_feature[user][f]
                     for user in range(self.max_user_index+1)
                     if user in self.set_users_that_rated]
        item_vals = [self.item_feature[item][f]
                     for item in range(self.max_item_index+1)
                     if item in self.set_items_rated]
        plt.axis([min(user_vals)-0.02,
                  max(user_vals)+0.05,
                  min(item_vals)-0.02,
                  max(item_vals)+0.05])
        if plot_all:
            for (user,item,rating,timestamp) in self.ratings:
                plt.text(self.user_feature[user][f],
                         self.item_feature[item][f],
                         str(rating),color='red' if timestamp else 'black')
        else:
            for i in range(num_points):
                (user,item,rating,timestamp) = random.choice(self.ratings)
                plt.text(self.user_feature[user][f],
                         self.item_feature[item][f],
                         str(rating), color='red' if timestamp else 'black')
        plt.show()

class Rating_set(object):
    def __init__(self,
                 local_file=False,
                 url="http://files.grouplens.org/datasets/movielens/ml-100k/u.data",
                 file_name="u.data",
                 trace=1):
        if trace>0:
            print("reading...")
        if local_file:
            lines = open(file_name,'r')
            self.all_ratings = [tuple(int(e) for e in line.strip().split('\t'))
                        for line in lines]
        else:
            lines = urllib.request.urlopen(url)
            self.all_ratings = [tuple(int(e) for e in line.decode('utf-8').strip().split('\t'))
                        for line in lines]
        if trace>0:
            print("...read")

    def create_top_subset(self, num_items = 30, num_users = 30):
        """Returns a subset of the ratings by picking the most rated items,
        and then the users that have most ratings on these, and then all of the
        ratings that involve these users and items.
        """
        max_item_index = max(item for (user,item,rating,timestamp) in self.all_ratings)

        item_counts = [0]*(max_item_index+1)
        for (user,item,rating,timestamp) in self.all_ratings:
            item_counts[item] += 1

        items_sorted = sorted((item_counts[item],item) for item in range(max_item_index))
        top_items = items_sorted[-num_items:]
        set_top_items = set(item for (count, item) in top_items)

        max_user_index = max(user for (user,item,rating,timestamp) in self.all_ratings)
        user_counts = [0]*(max_user_index+1)
        for (user,item,rating,timestamp) in self.all_ratings:
            if item in set_top_items:
                user_counts[user] += 1

        users_sorted = sorted((user_counts[user],user)
                               for user in range(max_user_index))
        top_users = users_sorted[-num_users:]
        set_top_users = set(user for (count, user) in top_users)
        used_ratings = [ (user,item,rating,timestamp)
                         for (user,item,rating,timestamp) in self.all_ratings
                         if user in set_top_users and item in set_top_items]
        return used_ratings

#movielens = Rating_set(local_file=True)
#learner0 = CF_learner(movielens.all_ratings)
#learner0.learn(20)   # full dataset of 100,000 ratings
#learner0.plot_feature(0)  # try other features too
#movielens_subset = movielens.create_top_subset(num_items = 20, num_users = 20)
#learner1 = CF_learner(movielens_subset, num_features=1)
#learner1.learn(100)   #smaller dataset of around 400 ratings
#learner1.plot_feature(0,plot_all=True)

#This is the analysis of Assignmet 1 grades.
#Tuples are (student, page, mark, own_page)
#from cs522 import cs522as1
#learner522 = CF_learner(cs522as1, regularization = 0, num_features=1)
# learner522.learn(100) # try this a few times
# learner522.plot_feature(0,plot_all=True)
# sorted(learner522.item_offset)
# sorted(learner522.user_offset)
# learner522.self_offset
