DocumentClassifier.java

package edu.odu.cs.cs350;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.Map;
import java.util.List;
import java.util.ArrayList;

import weka.classifiers.meta.FilteredClassifier;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SerializationHelper;

/**
 * Main class for the ACM Classifier application.
 */
public class DocumentClassifier {

    /**
     * FilteredClassifier to be validated.
     */
    private static FilteredClassifier classificationModel;
    /**
     * Raw Instances header used for classification.
     */
    private static Instances rawHeader;
    /**
     * Text attribute in the Instances header.
     */
    private static Attribute textAttribute;
    /**
     * Index of the category attribute.
     */
    private static int categoryIndex;

    /**
     * List of document processing times.
     */
    private static List<Double> processingTimes = new ArrayList<>();
    
    /**
     * List of document classification times.
     */
    private static List<Double> classificationTimes = new ArrayList<>();

    /**
     * Main method that calls the classifier with input arguments.
     *
     * @param inputArguments program arguments
     */
    public static void main(String[] inputArguments) {
        if (inputArguments.length == 0) {
            System.out.println(
                    "Please provide some paths to PDF files for classification.");
            return;
        }

        // Total runtime timer (processing + model load + classification)
        Stopwatch totalTimer = new Stopwatch();
        totalTimer.start();

        Corpus documentsToClassify;
        try {
            documentsToClassify = createExampleCorpus(
                    inputArguments);
        } catch (IOException e) {
            logError(e);
            return;
        }

        classifyAndPrint(documentsToClassify);

        double totalTime = totalTimer.elapsedSeconds();

        printFullPerformanceReport(totalTime);
    }

    // basic logging implementation, can update later
    private static void logError(Exception e) {
        try {
            String timestamp = LocalDateTime.now()
                    .format(DateTimeFormatter.ofPattern("yyyy_MM_dd_HH-mm-ss"));

            String fileName = "error_log_" + timestamp + ".log";

            try (PrintWriter pw = new PrintWriter(new FileWriter(fileName, true))) {
                pw.println("----- ERROR -----");
                e.printStackTrace(pw);
                pw.println();
            }

        } catch (IOException io) {
            io.printStackTrace();
        }
    }

    /**
     * Print the names and word counts of documents in a corpus.
     *
     * @param corpus the documents to be printed
     */
    public static void printFileNames(Corpus corpus) {
        for (Document doc : corpus.getDocuments()) {
            System.out.println("Document Name: " + doc.getName());
            System.out.println("Total Words: " + doc.getTotalWordCount());
            System.out.println("Unique Words: " + doc.getWords().size());
        }
    }

    /**
     * Create a Corpus of Documents from input files and record processing times.
     *
     * @param filePaths        paths of files to include in the corpus
     * @return a Corpus of Documents
     * @throws IOException if files cannot be read
     */
    public static Corpus createExampleCorpus(String[] filePaths)
                                             throws IOException {
        Corpus documentsToClassify = new Corpus();

        for (String path : filePaths) {

            Stopwatch docTimer = new Stopwatch();
            docTimer.start();
            
            Document docToAdd = null;
            File localDoc = new File(path);
            if (DocumentIdentifier.identify(localDoc) == DocumentIdentifier.DocumentType.PDF) {
                docToAdd = pdfFileProcessor.processFile(localDoc);
            } else if (DocumentIdentifier.identify(localDoc) == DocumentIdentifier.DocumentType.ASCII_TEXT) {
                docToAdd = new txtFileProcessor(path).processFile();
            } else {
                logError(new IOException("Unsupported file type for file: " + path));
                
            }

            // Record processing time for this document
            processingTimes.add(docTimer.elapsedSeconds());

            documentsToClassify.addDocument(docToAdd);
        }
        return documentsToClassify;
    }

    /**
     * Load the trained model and header from files generated by trainer.
     *
     * @throws Exception if model or header cannot be loaded
     */
    private static void loadModelAndHeader() throws Exception {
        // get models from classpath if possible
        InputStream modelStream = DocumentClassifier.class.getClassLoader()
                .getResourceAsStream("models/acm_filtered.model");
        InputStream headerStream = DocumentClassifier.class.getClassLoader()
                .getResourceAsStream("models/acm_raw_header.model");

        if (modelStream != null && headerStream != null) {
            try (InputStream m = modelStream; InputStream h = headerStream) {
                classificationModel = (FilteredClassifier) SerializationHelper.read(m);
                rawHeader = (Instances) SerializationHelper.read(h);
            }
        } else {
            // if not (like during demo task), get directly from training folder.
            File modelFile = new File("../training/acm_filtered.model");
            File headerFile = new File("../training/acm_raw_header.model");
            if (!modelFile.exists() || !headerFile.exists()) {
                System.err.println("Model or header not found. Please run ACMTrainingSetBuilder to generate them.");
                logError(new FileNotFoundException("Model or header not found"));
                throw new FileNotFoundException();
            }

            classificationModel = (FilteredClassifier) SerializationHelper.read(modelFile.getPath());
            rawHeader = (Instances) SerializationHelper.read(headerFile.getPath());
        }
        if (rawHeader.classIndex() == -1) {
            Attribute categoryAttribute = rawHeader.attribute("class");
            if (categoryAttribute != null) {
                rawHeader.setClassIndex(categoryAttribute.index());
            } else {
                rawHeader.setClassIndex(rawHeader.numAttributes() - 1);
            }
        }

        textAttribute = rawHeader.attribute("text");
        categoryIndex = rawHeader.classIndex();
    }

    /**
     * Build a pseudo text representation of a Document from its bag-of-words.
     *
     * @param doc the Document to convert
     * @return a space-separated string representing the Document's text
     */
    private static String buildPseudoText(Document doc) {
        StringBuilder sb = new StringBuilder();
        for (Map.Entry<String, Word> entry : doc.getWords().entrySet()) {
            String token = entry.getKey();
            int count = entry.getValue().getCount();
            for (int i = 0; i < count; i++) {
                sb.append(token).append(' ');
            }
        }
        return sb.toString().trim();
    }

    /**
     * Build a Weka Instance from a Document for classification.
     *
     * @param doc the Document to convert
     * @return a Weka Instance representing the Document
     */
    private static Instance buildInstanceFromDocument(Document doc) {
        String text = buildPseudoText(doc);

        // Make a one-instance dataset with the same structure as rawHeader
        Instances one = new Instances(rawHeader, 0);
        DenseInstance rawInstance = new DenseInstance(rawHeader.numAttributes());
        rawInstance.setDataset(one);

        // String attribute: register value via addStringValue
        rawInstance.setValue(textAttribute, textAttribute.addStringValue(text));

        // Unknown class (we want the classifier to predict it)
        rawInstance.setMissing(categoryIndex);

        one.add(rawInstance);
        return rawInstance;
    }

    /**
     * Classify a Document and return the predicted category.
     *
     * @param doc the Document to classify
     * @return the predicted category label
     * @throws Exception if classification fails
     */
    private static String classifyDocument(Document doc) throws Exception {
        Instance inst = buildInstanceFromDocument(doc);
        double classificationResult = classificationModel.classifyInstance(inst);
        return rawHeader.classAttribute().value((int) classificationResult);
    }

    /**
     * Classify documents in a Corpus, print results, and record classification times.
     *
     * @param corpus              the Corpus of Documents to classify
     */
    public static void classifyAndPrint(Corpus corpus) {
        try {
            loadModelAndHeader();

            if (corpus.getDocuments().isEmpty()) {
                System.out.println("No valid documents to classify.");
                return;
            }

            for (Document doc : corpus.getDocuments()) {
                if (doc == null) {
                    continue;
                }

                Stopwatch timer = new Stopwatch();
                timer.start();

                String predictedCategory = classifyDocument(doc);

                double seconds = timer.elapsedSeconds();
                classificationTimes.add(seconds);

                System.out.println("Document Name: " + doc.getName());
                System.out.println("Predicted ACM Category: " + predictedCategory);
                System.out.println();
            }
        } catch (Exception e) {
            logError(e);
            System.err.println("Error during classification: " + e.getMessage());
        }
    }

    /**
     * Print a detailed performance report including per-document processing and
     * classification times and total runtime.
     *
     * @param totalTime            total runtime for the program (seconds)
     */
    private static void printFullPerformanceReport(
            double totalTime) {

        System.out.println("\n========== PERFORMANCE REPORT ==========");

        System.out.println("Processing time:");
        for (int i = 0; i < processingTimes.size(); i++) {
            if (i == 0) {
                System.out.printf(" - First document: %.3f seconds%n", processingTimes.get(i));
            } else {
                System.out.printf(" - Additional document %d: %.3f seconds%n",
                        i + 1, processingTimes.get(i));
            }
        }

        System.out.println("\nClassification time:");
        for (int i = 0; i < classificationTimes.size(); i++) {
            if (i == 0) {
                System.out.printf(" - First document: %.3f seconds%n", classificationTimes.get(i));
            } else {
                System.out.printf(" - Additional document %d: %.3f seconds%n",
                        i + 1, classificationTimes.get(i));
            }
        }

        System.out.printf("%nTotal classification run time: %.3f seconds%n", totalTime);
        System.out.println("=========================================");
    }
}