ACMTrainingSetBuilder.java

package edu.odu.cs.cs350;

import weka.core.*;
import weka.core.stemmers.IteratedLovinsStemmer;
import weka.core.tokenizers.AlphabeticTokenizer;
import weka.core.stopwords.WordsFromFile;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.StringToWordVector;
// import weka.classifiers.Evaluation;
import weka.classifiers.functions.SMO;
import weka.classifiers.meta.FilteredClassifier;

import java.util.*;

import org.apache.pdfbox.Loader;
import org.apache.pdfbox.pdmodel.PDDocument;
import org.apache.pdfbox.text.PDFTextStripper;

import java.io.*;
import java.nio.charset.StandardCharsets;

import edu.odu.cs.cs350.categorization.trainingData.TrainingData;

/**
 * Build a training set based on TrainingData PDFs and train a FilteredClassifier.
 */
public class ACMTrainingSetBuilder {
    /**
     * Main method to build training set and train classifier.
     * @param inputArguments program arguments (unused)
     * @throws Exception if files cannot be read or written
     */
    public static void main(String[] inputArguments) throws Exception {
        // Read listing and categories
        List<String> docPaths = getDocumentListing();
        List<String> possibleCategories = getACMClasses(docPaths);

        // Build raw Instances with a string `text` attribute and nominal `class`
        ArrayList<Attribute> inputAttributes = new ArrayList<>();
        Attribute textAttribute = new Attribute("text", (List<String>) null);
        Attribute categoryAttribute = new Attribute("class", possibleCategories);
        inputAttributes.add(textAttribute);
        inputAttributes.add(categoryAttribute);

        Instances raw = new Instances("ACMDocsRaw", inputAttributes, docPaths.size());
        raw.setClassIndex(1);

        addStringInstancesFromRepository(raw, possibleCategories, docPaths);

        StringToWordVector filteredTrainingData = buildTrainingFilter(raw);
        Instances trainingSet = Filter.useFilter(raw, filteredTrainingData);

        // Train an SMO classifier wrapped in a FilteredClassifier (so the same
        // StringToWordVector filter is applied at classification time).
        SMO learningMachine = new SMO();
        FilteredClassifier filteredClassifier = new FilteredClassifier();
        filteredClassifier.setFilter(filteredTrainingData);
        filteredClassifier.setClassifier(learningMachine);
        filteredClassifier.buildClassifier(raw);

        writeModels(raw, trainingSet, filteredClassifier);

        // evaluate cross validation for testing only

        // Evaluation filteredClassifierEvaluator = new Evaluation(trainingSet);
        // filteredClassifierEvaluator.crossValidateModel(filteredClassifier, trainingSet, 5, new Random(1));

        // double evaluationAccuracy = filteredClassifierEvaluator.pctCorrect();
        // System.out.printf("Cross Validation Accuracy: %.2f%%%n", evaluationAccuracy);
    }

    /** 
     * Write trained models to disk.
     * @param raw raw Instances header
     * @param trainingSet vectorized training Instances
     * @param trainedClassifier trained FilteredClassifier
     * @throws Exception if files cannot be written
     */
    public static void writeModels(Instances raw, Instances trainingSet, FilteredClassifier trainedClassifier) throws Exception {
        // Save training ARFF (not used by the current classifier since it uses the .model files, will need to ask the professor if this is OK)
        try (BufferedWriter writer = new BufferedWriter(new FileWriter("acm_training.arff"))) {
            writer.write(trainingSet.toString());
        }

        SerializationHelper.write("acm_filtered.model", trainedClassifier);
        SerializationHelper.write("acm_raw_header.model", raw);
    }

    
    /**
     * Gets PDFs from Maven repo, extracts text, and adds Instances.
     * @param raw empty Instances object to fill
     * @param possibleCategories list of valid ACM categories
     * @param docPaths list of document paths within the Maven repo
     * @throws IOException if a file can't be read
     */
    public static void addStringInstancesFromRepository(Instances raw,
                                                         List<String> possibleCategories,
                                                         List<String> docPaths) throws IOException {
        if (docPaths == null) {
            System.err.println("Document path list is null");
            return;
        }

        Attribute textAttribute = raw.attribute("text");
        Attribute categoryAttribute = raw.classAttribute();

        for (String resourceRelative : docPaths) {
            if (resourceRelative == null || resourceRelative.trim().isEmpty()) continue;
            String documentText = getTextFromDocument(resourceRelative);

            String category = determineCategoryFromPath(resourceRelative, possibleCategories);

            DenseInstance instance = new DenseInstance(2);
            instance.setDataset(raw);
            // register string value and set it
            instance.setValue(textAttribute, textAttribute.addStringValue(documentText));
            instance.setValue(categoryAttribute, category);
            raw.add(instance);
        }
    }

    /**
     * Determine ACM category from document path within repo.
     * @param resourceRelative path to the document relative to the TrainingData resource path
     * @param possibleCategories list of valid ACM categories
     * @return determined category or null if not recognized
     */
    public static String determineCategoryFromPath(String resourceRelative, List<String> possibleCategories) {
            String category = null;
            int slash = resourceRelative.indexOf('/');
            if (slash > 0) {
                category = resourceRelative.substring(0, slash);
            }
            if (category == null || !possibleCategories.contains(category)) {
                System.err.println("Warning: category '" + category + "' not recognized for resource " + resourceRelative + "; skipping.");
                return null;
            }
            return category;
    }

    /**
     * Extract text from a PDF document in the TrainingData repository.
     * @param resourceRelative path to the document relative to the TrainingData resource path
     * @return extracted text, or null if extraction failed
     * @throws IOException if the document cannot be read
     */
    public static String getTextFromDocument(String resourceRelative) throws IOException{
            String documentText = null;
            try (InputStream inputDocument = TrainingData.class.getResourceAsStream(TrainingData.resourcePath + resourceRelative)) {
                if (inputDocument == null) {
                    System.err.println("Warning: resource not found: " + TrainingData.resourcePath + resourceRelative);
                    return null;
                }
                byte[] inputBuffer = inputDocument.readAllBytes();
                try (PDDocument doc = Loader.loadPDF(inputBuffer)) {
                    PDFTextStripper stripper = new PDFTextStripper();
                    documentText = stripper.getText(doc);
                } catch (IOException e) {
                    System.err.println("Warning: failed to parse PDF " + resourceRelative + " : " + e.getMessage());
                    return null;
                }
            }
            return documentText;
    }

    /**
     * Build a StringToWordVector filter for training and classification.
     * @param raw raw Instances to set input format
     * @return configured StringToWordVector filter
     * @throws Exception if filter cannot be initialized
     */
    public static StringToWordVector buildTrainingFilter(Instances raw) throws Exception{
        StringToWordVector trainingFilter = new StringToWordVector();
        trainingFilter.setWordsToKeep(5000); // below 5000 seems to drop accuracy but above makes no difference
        trainingFilter.setMinTermFreq(1); // i tried messing around with this but it made 0 difference
        trainingFilter.setLowerCaseTokens(true);
        trainingFilter.setTFTransform(false); // setting this to true decreases the validation accuracy slightly (by like 1%)
        trainingFilter.setIDFTransform(false); // same as above
        trainingFilter.setOutputWordCounts(true); // have not tested this
        trainingFilter.setAttributeNamePrefix("count_"); // don't change

        // Use AlphabeticTokenizer to avoid punctuation and other invalid characters from PDFs
        AlphabeticTokenizer alphabeticTokenizer = new AlphabeticTokenizer();
        trainingFilter.setTokenizer(alphabeticTokenizer);

        // this stemmer is kind of aggressive but getting porter stemmer working was an experience on its own
        IteratedLovinsStemmer aggressiveStemmer = new IteratedLovinsStemmer();
        trainingFilter.setStemmer(aggressiveStemmer);

        // Load custom stopwords from training/stopwords.txt
        WordsFromFile ignoredWords = new WordsFromFile();
        ignoredWords.setStopwords(new File("stopwords.txt"));
        trainingFilter.setStopwordsHandler(ignoredWords);

        trainingFilter.setInputFormat(raw);

        return trainingFilter;
    }

    /**
     * Get list of document paths from TrainingData repository.
     * @return list of document paths including category
     * @throws IOException if the listing cannot be read
     */
    public static List<String> getDocumentListing() throws IOException{
                // get list of PDFs from TrainingData repo
        InputStream directoryIn = TrainingData.class
        .getResourceAsStream(TrainingData.resourcePath 
                             + TrainingData.directory);

        List<String> docPaths = new ArrayList<>();
        try (BufferedReader directoryReader = new BufferedReader(new InputStreamReader(directoryIn, StandardCharsets.UTF_8))) {
            String line;
            while ((line = directoryReader.readLine()) != null) {
                line = line.trim();
                if (!line.isEmpty()) {
                    docPaths.add(line);
                }
            }
        }
        return docPaths;
    }


    /**
     * Get list of ACM categories from document paths (using the subfolder names).
     * @param docPaths list of document paths from repository
     * @return list of unique ACM categories
     */
    public static List<String> getACMClasses(List<String> docPaths) {
        List<String> categories = new ArrayList<>();
        for (String docPath : docPaths) {
            int directorySeparatorIndex = docPath.indexOf('/');
            if ((directorySeparatorIndex > 0)) {
                String category = docPath.substring(0, directorySeparatorIndex);
                if (!categories.contains(category)) {
                    categories.add(category);
                }
            }
        }
        return categories;
    }
}