from textSpaceVector import textVector,textVectorCollection

class catTextVectorCollection(textVectorCollection):
  '''this class represents a collection of categorized texts and extends textVectorCollection'''

  def __init__(self,file,pattern=r'\w+'):
    '''parameter $file$ is the path of a file containing a list of categorized texts to be loaded in the collection
    each line of the file consists of a text file path, a tabulation and a list of catergories (separated with commas ',')'''
    textVectorCollection.__init__(self)
    self.__pattern = pattern # regex to tokenize the texts
    self.__categories = {} # dictionary of categories (key: category label;value: a textVectorCollection, the set of documents that are assigned the category)
    self.__barycentres = {} #dictionary of barycentres (key: category; value: a barycentre (of type textVector))
    f = open(file, 'rU')
    while 1:
      line = f.readline()
      if not line: break
      line = line.rstrip()
      tab = line.split('\t')
      v = textVector(file = tab[0],pattern = pattern) # loading document tab[0]      
      textVectorCollection.addVector(self,v)      
      tab = tab[1].split(',')
      for c in tab:
        if len(c) is not 0:
          self.addTextVector2Cat(c,v) # adding textVector v to category c
    f.close()
    self.computeBarycentres()


  def addTextVector2Cat(self,cat,vector):
    '''adds a textVector $vector$ in a category $cat$'''
    tc = self.__categories.get(cat,textVectorCollection(name=cat))
    self.__categories[cat] = tc
    tc.addVector(vector)  

  def getCollection(self,cat):
    '''gets the textVectorCollection associated with category $cat$'''
    return self.__categories[cat]

  def computeBarycentres(self):
    for cat in self.__categories.keys():
      tc = self.getCollection(cat)
      if tc is not None:
        self.__barycentres[cat] = tc.barycentre(all=self)

  def getBarycentre(self,cat):
    '''gets the barycentre of category $cat$'''
    return self.__barycentres[cat]

  def getBarycentres(self):
    return self.__barycentres

  def categories(self):
    return self.getBarycentres().keys()

  def getCategoryCollection(self,cat):
    return self.__categories[cat]

  def getBestCategory(self,file = None,vector = None):
    '''finds the best category for text $file$ (if $file different from None$), according to those learnt from collection $self$
    else if vector $vector$ different from None, finds the best category for it'''
    
    cat = 'None'
    # TODO
    return cat

