Skip to contents

Classification tasks are a key challenge when dealing with unstructured data in surveys, customer feedback, or for image analysis. This article introduces a practical, step-by-step guide to using tidyllm for classifying text, streamlining the process.

A Common Classification Task

Imagine you’ve just collected thousands of survey responses where people describe their jobs in their own words. Some responses are detailed, others are vague, and there’s plenty of variation in between. Now, you need to categorize these into standardized occupation codes, like those from the SOC classification system of the Bureau of Labor Statistics. Manually sorting through each response could take days, if not weeks, and inconsistencies between coders are almost guaranteed. For instance, your dataset might look something like this: 7,000 rows of occupation descriptions, ranging from “Librarian” to “Making sure everything runs”

library(tidyllm)
library(tidyverse)
library(glue)
occ_data <- read_rds("occupation_data.rds")
occ_data
#> # A tibble: 7,000 × 2
#>    respondent occupation_open                  
#>         <int> <chr>                            
#>  1     100019 Ops oversight and strategy       
#>  2     100266 Coordinating operations          
#>  3     100453 Making sure everything runs      
#>  4     100532 Building and demolition          
#>  5     100736 Help lawyers with cases          
#>  6     100910 I sell mechanical parts          
#>  7     101202 Librarian                        
#>  8     101325 Operations planning and execution
#>  9     101329 Bookkeeper                       
#> 10     101367 Kitchen staff                    
#> # ℹ 6,990 more rows

Our goal is to classify these messy responses into one of these 22 2-digit occupation codes:

occ_codes <- read_rds("occ_codes_2digits.rds") |>
  print(n=Inf)
#> # A tibble: 22 × 2
#>     occ2 occ_title                                                 
#>    <dbl> <chr>                                                     
#>  1    11 Management Occupations                                    
#>  2    13 Business and Financial Operations Occupations             
#>  3    15 Computer and Mathematical Occupations                     
#>  4    17 Architecture and Engineering Occupations                  
#>  5    19 Life, Physical, and Social Science Occupations            
#>  6    21 Community and Social Service Occupations                  
#>  7    23 Legal Occupations                                         
#>  8    25 Educational Instruction and Library Occupations           
#>  9    27 Arts, Design, Entertainment, Sports, and Media Occupations
#> 10    29 Healthcare Practitioners and Technical Occupations        
#> 11    31 Healthcare Support Occupations                            
#> 12    33 Protective Service Occupations                            
#> 13    35 Food Preparation and Serving Related Occupations          
#> 14    37 Building and Grounds Cleaning and Maintenance Occupations 
#> 15    39 Personal Care and Service Occupations                     
#> 16    41 Sales and Related Occupations                             
#> 17    43 Office and Administrative Support Occupations             
#> 18    45 Farming, Fishing, and Forestry Occupations                
#> 19    47 Construction and Extraction Occupations                   
#> 20    49 Installation, Maintenance, and Repair Occupations         
#> 21    51 Production Occupations                                    
#> 22    53 Transportation and Material Moving Occupations

In this article, we take a structured approach to tackle the classification task efficiently. Here’s a step-by-step workflow for this kind of task:

  1. Classifying and Manually Correcting a Sub-Sample:
    • Pick a Sample:Start by filtering the dataset to retain only distinct occupation descriptions. Then, randomly select a sample of distinct responses to work with. For this example, let’s select 20% of the dataset for manual classification and correction.
    • Initial Classification: Use a simple prompt to categorize these responses into occupation codes.
    • Manual Correction: Review and correct the classifications to create a reliable ground truth.
  2. Optimizing the Classification Process:
    • Training/Test Split: Again split this ground truth dataset using rsample into training and test sets.
    • Experimentation: Test different prompts, models, and parameters on the training set, comparing one-shot and multi-shot approaches.
    • Model Evaluation: Use yardstick to find the best-performing combination on the training data.
    • Testing: Apply the best-performing model to the test set to evaluate how well it performs on unseen occupation descriptions.
  3. Scaling Up to the Full Dataset:
    • Full Classification: Use the validated model setup to classify the entire dataset efficiently.

Classifying a Sub-Sample

We start by ensuring we only classify distinct responses. This eliminates duplicates and ensures a more efficient and reliable classification process:

# Pick only distinct occupations from the dataset
distinct_occupations <- occ_data |> 
  distinct(occupation = occupation_open)

print(distinct_occupations, n = 5)
#> # A tibble: 2,209 × 1
#>   occupation                 
#>   <chr>                      
#> 1 Ops oversight and strategy 
#> 2 Coordinating operations    
#> 3 Making sure everything runs
#> 4 Building and demolition    
#> 5 Help lawyers with cases    
#> # ℹ 2,204 more rows

This will help us focus on variations across distinct occupations, avoiding repeated classification efforts on identical responses.

Next, we divide the distinct occupations into a sub-sample for manual classification and a remaining portion to be used later. We use the initial_split() function from the rsample package, splitting 10% of the data into a smaller test set for manual correction and model training:

#Set a seed for reproducability
set.seed(123)

# Split the distinct occupations into a sub-sample (10%) and the rest (90%)
library(rsample)
occ_split <- initial_split(distinct_occupations, prop = 0.8)

# Retrieve the sub-sample and the remaining data
rest_of_data <- training(occ_split)
sub_sample <- testing(occ_split)

print(sub_sample, n = 5) 
#> # A tibble: 442 × 1
#>   occupation                     
#>   <chr>                          
#> 1 Making sure everything runs    
#> 2 Bartender                      
#> 3 Post-secondary health education
#> 4 Food servin                    
#> 5 Exectutive assistant           
#> # ℹ 437 more rows

By splitting the data, we now have a smaller sub-sample of 422 observations to work with during the initial classification stage.

Writing an initial classifier function

To classify this sub-sample of occupation descriptions, we use a first function that wraps llm_message(). This function sends each occupation description to a large language model (LLM) and prompts it to assign one of the pre-defined occupation codes. For this first step, we choose a reliable commercial model to make our lives easier, as it requires fewer manual corrections due to its high-quality output, helping us build a more accurate ground truth:

classify_occupation <- function(occupation){
  # Output what the model is currently doing to the console
  glue("Classifying: {occupation}\n") |> cat("\n")
  
  # Generate the prompt
  prompt <- glue('
      Classify this occupation response from a survey: {occupation}
      
      Pick one of the following numerical codes from this list. 
      Respond only with the code!
      11 = Management Occupations                                    
      13 = Business and Financial Operations Occupations             
      15 = Computer and Mathematical Occupations                     
      17 = Architecture and Engineering Occupations                  
      19 = Life, Physical, and Social Science Occupations            
      21 = Community and Social Service Occupations                  
      23 = Legal Occupations                                         
      25 = Educational Instruction and Library Occupations           
      27 = Arts, Design, Entertainment, Sports, and Media Occupations
      29 = Healthcare Practitioners and Technical Occupations        
      31 = Healthcare Support Occupations                            
      33 = Protective Service Occupations                            
      35 = Food Preparation and Serving Related Occupations          
      37 = Building and Grounds Cleaning and Maintenance Occupations 
      39 = Personal Care and Service Occupations                     
      41 = Sales and Related Occupations                             
      43 = Office and Administrative Support Occupations             
      45 = Farming, Fishing, and Forestry Occupations                
      47 = Construction and Extraction Occupations                   
      49 = Installation, Maintenance, and Repair Occupations         
      51 = Production Occupations                                    
      53 = Transportation and Material Moving Occupations
      99 = Missing Occupation (No clear occupation)'
  )
  
  # List of valid codes as strings
  valid_codes <- as.character(occ_codes$occ2)
  
  # Initialize classification_raw
  classification <- tryCatch({
    # Get the assistant's reply
    assistant_reply <- llm_message(prompt) |>
      claude(.model = "claude-3-5-sonnet-20240620", .temperature = 0) |>
      last_reply() |>
      str_squish()
    
    # Validate the assistant's reply
    if (assistant_reply %in% valid_codes) {
      as.integer(assistant_reply)
    } else {
      # If the reply is not a valid code, set code 98
      98L
    }
  }, error = function(e){
    # If there's an error with the model, set code 97
    97L
  })
  
  # Output a tibble
  tibble(
    occupation_open = occupation,
    occ2            = classification
  )
}

Here’s How It Works:

  1. Prompt Generation: The function creates a detailed prompt that instructs the model to classify the occupation into one of the 22 SOC codes. This structured prompt ensures the model provides clear and accurate outputs. We use glue() to add the function input (i.e. a response in the open occupation question) into the prompt with {occupation}.

  2. Calling the LLM: The model (in this case, Claude-3.5-Sonnet) is called using the claude() function with the .temperature parameter set to 0, ensuring deterministic (non-random) output. The model’s reply is retrieved and cleaned using str_squish() to remove unnecessary spaces:

llm_message(prompt) |>
      claude(.model = "claude-3-5-sonnet-20240620", .temperature = 0) |>
      last_reply() |>
      str_squish()
  1. Error Handling: The function uses tryCatch to handle any errors during the classification process. If the API call fails, the function returns occupation code 97 to flag the issue.

  2. Validation of Model Output: The function checks whether the model’s response is a valid SOC code. If the response is invalid, it assigns code 98. This validation ensures that any unexpected outputs are handled appropriately.

  3. Returning the Result: The function outputs a tibble containing the original occupation description and the model’s classification or error code, providing a clean and structured result.

A typical function outputs look like this:

classify_occupation("Software Engineer")
#> # A tibble: 1 × 2
#>   occupation_open    occ2
#>   <chr>             <dbl>
#> 1 Software Engineer    15

Next, we apply our classifier to the entire sub-sample in one go using map_dfr() from the purrr package, efficiently generating a tibble that contains the classification results for each occupation description.

sub_sample_classified <- sub_sample$occupation |>
  map_dfr(classify_occupation)
#> Classifying: Making sure everything runs
#> Classifying: Bartender
#> Classifying: Post-secondary health education
#> Classifying: Food servin
#> Classifying: Exectutive assistant
#> Classifying: ...

Running the classifier on this sub-sample costs roughly 50 cents and takes 12 minutes. To ensure that the output is correct, we can export the results to Excel with write_xlsx from the writexl package and manually fix miss-classifications:

# Step 1: Bind the sub-sample classifications with descriptive occupation titles
# We also include special codes for errors and invalid responses
classified_with_titles <- sub_sample_classified |>
  left_join(
    occ_codes |> 
      # Add error codes and titles to the occupation code table
      bind_rows(
        tribble(
          ~occ2, ~occ_title,
          97, "API-Connection Failure",
          98, "Invalid Response",
          99, "Missing (No Clear Occupation)"
        )
      ), 
    by = "occ2"
  )

# Step 2: Save the results to an Excel file for manual review
# This allows us to inspect the classifications and any potential errors.
writexl::write_xlsx(classified_with_titles, "ground_truth_excel.xlsx")

A manual review of the classifier’s output showed highly promising results. Out of 443 classifications, only 9 needed corrections, indicating an error rate of just 2% for the claude()-based classifier with our initial prompt. Most issues arose from not correctly identifying unclear responses as missing, such as “Doin’ the numbers” or “Barster” (a potential mix of Barrister or Barkeeper). Interestingly, in 5 of the 9 error cases, the model returned an invalid response (code 98) instead of missing (code 99), which suggests it correctly avoided making an incorrect guess, rather than misclassifying the occupations outright.

While we could likely proceed with classifying the rest of the dataset using the claude()-based classifier, given its strong performance, we now have a solid ground truth to work with. This allows us to experiment with different models and prompts to determine if simpler alternatives or small local models would perform just as well.

Optimizing and Testing the Classifiers

Now that we have a reliable ground-truth dataset, it’s time to optimize and experiment with different model configurations. We’ll split the dataset into training and test sets using rsample to ensure that we can experiment with different prompts and setups on training data but only evaluate our final model performance on unseen data.

ground_truth <- readxl:::read_xlsx("ground_truth_corrected.xlsx")

# Split the ground-truth into training and testing sets
set.seed(123)
gt_split <- initial_split(ground_truth, prop = 0.7)

# Retrieve training and testing data
train_data <- training(gt_split)
test_data  <- testing(gt_split)

print(train_data,n=5)
#> # A tibble: 309 × 3
#>   occupation_open              occ2 occ_title                                   
#>   <chr>                       <dbl> <chr>                                       
#> 1 Computer network technician    15 Computer and Mathematical Occupations       
#> 2 Educational support            25 Educational Instruction and Library Occupat…
#> 3 Fine Carpentry                 47 Construction and Extraction Occupations     
#> 4 Keep things organized          43 Office and Administrative Support Occupatio…
#> 5 Group fitness instructor       39 Personal Care and Service Occupations       
#> # ℹ 304 more rows

Modifying the Classifier Function

To test different prompts and models systematically we need to allow for a more flexible classifier function that can handle different prompts or models. For this we take the prompt-building logic out of the function and allow for different api-functions and models as function arguments:

# External numerical code list for reusability
numerical_code_list <- c('
      11 = Management Occupations                                    
      13 = Business and Financial Operations Occupations             
      15 = Computer and Mathematical Occupations                     
      ... abreviated ...
      51 = Production Occupations                                    
      53 = Transportation and Material Moving Occupations
      99 = Missing Occupation')

# Classification function that accepts prompt, api_function, and model 
# as well as the ground truth to pass through as arguments
classify_occupation_grid <- function(occupation,
                                occ2,
                                prompt,
                                prompt_id,
                                api_function,
                                model){
  # Output what the model is currently doing to the console
  glue("Classifying: {model} - {prompt_id} - {occupation}\n") |> cat("\n")
  
  # List of valid codes as strings
  valid_codes <- as.character(occ_codes$occ2)
  
  # Initialize classification_raw
  classification <- tryCatch({
    # Get the assistant's reply using the dynamically provided API function and model
    assistant_reply <- llm_message(prompt) |>
      api_function(.model = model, .temperature = 0) |>
      last_reply() |>
      str_squish()
    
    # Validate the assistant's reply
    if (assistant_reply %in% valid_codes) {
      as.integer(assistant_reply)
    } else {
      98L  # Return 98 for invalid responses
    }
  }, error = function(e){
    97L  # Return 97 in case of an error (e.g., API failure)
  })
  
  # Return a tibble containing the original occupation description and classification result
  tibble(
    occupation_open = occupation,
    occ2_predict    = classification,
    occ2_truth      = occ2,
    model           = model,
    prompt_id       = prompt_id
  )
}

Defining the Prompt and Model Grid

We’ll define a set of prompts and models that we want to test. This will allow us to apply the classifier across different configurations and compare results. Here’s how the prompts and models are set up:

  1. Prompts:
  • Prompt 1: A detailed prompt, asking the model to classify occupations and warning it not to make guesses.
  • Prompt 2: Explicitly ask to handle invalid responses by returning a special code (99) when the input does not resemble a valid occupation.
  • Prompt 3: A shorter, more concise version to test whether the model performs similarly with less detailed instructions.
  1. Models:
  • Llama3.2:3B: A opensource large language model with just 3 billion parameters that is very fast, when you run it locally via ollama()
  • Gemma2:9B: Another candidate model, which performs well for classification tasks, but is more than double the size of Lama3.2 and is therefore somewhat slower.

We set up a grid combining all the prompts and models. The expand_grid() function from the tidyverse is a useful tool here to create every possible combination of prompts and models, which we will use to evaluate the classifier:

prompts <- tibble(prompt = 
         c( #Original prompt
            'Classify this occupation response from a survey: {occupation}
            Pick one of the following numerical codes from this list. 
            Respond only with the code!
            {numerical_code_list}',
            #Explicit instruction to avoid classifying something wrong
           'Classify this occupation response from a survey: {occupation}
            Pick one of the following numerical codes from this list. 
            Respond only with the code!
            {numerical_code_list}
           
            If this does not look like a valid occupation response reply with 
            just 99
           ',
           #Shorter prompt
           'Classify this occupation: {occupation}. 
           Respond only with one of the following codes: 
           {numerical_code_list}'
         ),
         prompt_id = 1:3)


grid <- expand_grid(train_data, 
                    prompts, 
                    model = c("llama3.2", "gemma2")) |>
  arrange(model) %>% # Arrange by model so ollama does not reload them often
  rename(occupation = occupation_open) |>
  rowwise() %>%  # Glue together prompts and occupation row-by-row
  mutate(prompt = glue(prompt)) |>
  ungroup() |> # Ungroup after the rowwise operation
  select(model,occupation,occ2,prompt_id,prompt)

nrow(grid)
#> [1] 1854

To run the classification across the entire grid, we use pmap_dfr() from the purrr package, which allows us to iterate over multiple arguments simultaneously. Each combination of occupation response, prompt, and model is passed into the classify_occupation() function, and the results are concatenated into a single tibble:

grid_results <- grid %>%
  pmap_dfr(classify_occupation_grid,
           api_function = ollama)
#> Classifying: gemma2 - 1 - Computer network technician
#> Classifying: gemma2 - 2 - Computer network technician
#> Classifying: gemma2 - 3 - Computer network technician
#> Classifying: ...

⚠️ Note: Running an extensive classification grid like this, especially with large datasets or slow models, can take a significant amount of time. Therefore, it’s often reasonable to save intermediate results periodically, so that you don’t lose progress if something goes wrong (e.g., a crash or a network issue). By combining pwalk() with save_rds(), you can run each combination of the grid independently and store results incrementally.

After we run our grid we get some insights into how well models and prompts work on the train data. Here we can experiment with different models and parameters as much as we want and see what works here.

Accuracy estimates

We will create an overview of prediction accuracy overview using the yardstick package. For many functions we use, we also need to encode the ground truth and the model predictions as factors:

library(yardstick)
#> 
#> Attaching package: 'yardstick'
#> The following object is masked from 'package:readr':
#> 
#>     spec

gr_factors <- grid_results |>
  mutate(across(starts_with("occ2_"),
                ~factor(.x,
                        levels=c(occ_codes$occ2,97, 98,99),
                        labels=c(occ_codes$occ_title,"APIFAIL","INVALID","MISSING")
                ))) 

First we just just calculate and plot the model accuracy:

accuracy <- gr_factors  |>
  group_by(prompt_id, model) |>
  accuracy(truth = occ2_truth, estimate = occ2_predict)

accuracy %>%
  ggplot(aes(x = as.factor(prompt_id), y = .estimate, fill = model)) +
  geom_bar(stat = "identity", position = "dodge") +
  labs(title = "Accuracy by Prompt and Model", x = "Prompt ID", y = "Accuracy", fill="") +
  theme_bw(22) +
  scale_y_continuous(labels = scales::label_percent(),limits = c(0,1)) +
  scale_fill_brewer(palette = "Set1")

Here are the main insights of our classification experiment: - Gemma2 consistently outperforms llama3.2, but even gemma2’s top performance on Prompt 1 (79.1) falls far short of Claude Sonnet’s 98% accuracy in the initial run. - The simplified prompt clearly introduces challenges, especially for the smaller models like llama3.2, which shows particularly poor performance (as low as 20.2% on Prompt 3). This suggests that these models might not generalize well to slightly varied or less explicit prompts, whereas Claude is able to handle the variations with far greater ease.

Let’s look at the confusion matrix for the top gemma2 performance to see whether any specific occupation category causes problems here:

conf_mat <- gr_factors |>
  filter(prompt_id == 1, model == "gemma2") |>
  conf_mat(truth = occ2_truth, estimate = occ2_predict)

# Autoplot the confusion matrix
autoplot(conf_mat, type = "heatmap") +
  scale_fill_gradient(low = "white", high = "#E41A1C") +
  ggtitle("Confusion Matrix for Model gemma2, Prompt 1") +
  theme_bw(22) +
  theme(axis.text.x = element_text(angle = 45, hjust = 1))
#> Scale for fill is already present.
#> Adding another scale for fill, which will replace the existing scale.

A lot of the miss-classifications are in management occupations and office services, missing occupations are again not classified well, but classified into other occupations. So nothing that seems like an immediate easy fix.

Another thing worth trying are cheap, simple and fast chatgpt-4o mini models. Let’s try it on the initial prompt

chatgpt_grid <- grid |>
  filter(model=="gemma2", prompt_id ==1) |>
  mutate(model="gpt-4o-mini") |>
  pmap_dfr(classify_occupation_grid,
           api_function = chatgpt) 

chatgpt_grid %>%
  mutate(across(starts_with("occ2_"),
                ~factor(.x,
                        levels=c(occ_codes$occ2,97, 98,99),
                        labels=c(occ_codes$occ_title,"APIFAIL","INVALID","MISSING")
                )))  |>
  accuracy(truth = occ2_truth, estimate = occ2_predict)
#> # A tibble: 1 × 3
#>   .metric  .estimator .estimate
#>   <chr>    <chr>          <dbl>
#> 1 accuracy multiclass     0.767

Still worse accuracy than Claude Sonnet and not much better than gemma2. But there is another trick that goes beyond simple one-shot classification. A common idea is to use multi-shot classifications (and give the model classification examples in the prompt, or even guide its output by letting it complete conversations where it answers the right way). The function df_llm_message() let’s you easily built such message histories where you put some words into the mouth of the model. But there are other ways that can help you increase accuracy. A popular one is to first let the model reason in detail what the best candidates for classification would be to build a chain of thought before it gives the final answer.

Multistep Chain-of-Thought Prompting

For this we need a major change in the classification function, because we send two messages to the model. A first one elicits a reasoning step, the second one asks for the final code, based on the answer to the first message. This is easily doable with the message-chaining abilities of tidyllm

Here is a modified chain of thought prompting function with a reasoning and an output step:

classify_occupation_cot <- function(occupation,
                                    occ2,
                                    api_function,
                                    model,
                                    stream = FALSE){
  # Output what the model is currently doing to the console
  glue("Classifying with CoT: {model} - {occupation}\n") |> cat("\n")
  
  # Step 1: Ask the model to think through the problem
  prompt_reasoning <- glue('
    Think about which of the following occupation codes would best describe this occupation description from a survey respondent: "{occupation}"
    
    {numerical_code_list}
    
    Explain your reasoning for the 3 top candidate codes step by step. Then evaluate which seems best.
  ')
  
  reasoning_response <- tryCatch({
    conversation <<- llm_message(prompt_reasoning) |>
      api_function(.model = model, .temperature = 0, .stream=stream) 
    
    conversation |>
      last_reply()
  }, error = function(e){
    conversation <<- llm_message("Please classify this occupation response: {occupation}")
    "Error in reasoning step."
  })
  
  # Step 2: Ask the model to provide the final answer
  prompt_final <- glue('
    Based on your reasoning, which code do you pick? Answer only with a numerical code!

  ')
  
  final_response <- tryCatch({
    conversation |>
    llm_message(prompt_final) |>
      api_function(.model = model, .temperature = 0, .stream=stream) |>
      last_reply() |>
      str_squish()
  }, error = function(e){
    "97"
  })
  
  # Validate the model's final response
  valid_codes <- as.character(occ_codes$occ2)
  
  classification <- if (final_response %in% valid_codes) {
    as.integer(final_response)
  } else {
    98L  # Return 98 for invalid responses
  }
  
  # Return a tibble containing the original occupation description and classification result
  tibble(
    occupation_open = occupation,
    occ2_predict    = classification,
    occ2_truth      = occ2,
    model           = glue("{model}_cot"),
    reasoning       = reasoning_response,
    final_response  = final_response
  )
}

Note that this increases the compute time for gemma2 massively because it now produces a multiple of the output from before. Each run produces roughly a page of reasoning, before the second question generates a number, while our original prompting strategy just gave us one number. So run-times increase strongly.

Note: To test more complicated prompting strategies like this in realtime, another feature of tidyllm comes in handy. You can stream back responses just like in online chatbot interfaces during a first test run with the .stream-argument in most api functions to immediately see whether you made strange errors in your prompting. Therefore, here we pass the stream arguement trough classify_occupation_cot() to first test the function and the prompting on single inputs before we scale it up to the whole data.

Let’s run this function with gemma2:

results_cot <- grid |>
  filter(model=="gemma2", prompt_id ==1) |>
  select(-prompt,-prompt_id) |>
  pmap_dfr(classify_occupation_cot,api_function = ollama, stream=FALSE)
#> Classifying with CoT: gemma2 - 1 - Computer network technician
#> Classifying with CoT: gemma2 - 2 - Educational support
#> Classifying with CoT: gemma2 - 3 - Fine Carpentry
#> Classifying with CoT: ...

Did this work. It turns out in this case it did not!

results_cot %>%
  mutate(across(starts_with("occ2_"),
                ~factor(.x,
                        levels=c(occ_codes$occ2,97, 98,99),
                        labels=c(occ_codes$occ_title,"APIFAIL","INVALID","MISSING")
                )))  %>%
  accuracy(truth = occ2_truth, estimate = occ2_predict)
#> # A tibble: 1 × 3
#>   .metric  .estimator .estimate
#>   <chr>    <chr>          <dbl>
#> 1 accuracy multiclass     0.754

The accuracy is even worse than without the reasoning step! One interesting thing though is that the model has the right classification in one of its candidates in the reasoning step for 94% of the cases. So probably, there is still some way to get a better final result out of gemma2:

results_cot %>%
  rowwise() %>%
  mutate(reasoning_classifications = str_extract_all(reasoning,"\\d{2}"),
         reasoning_classifications = list(map_int(reasoning_classifications, as.integer)),
         right_in_reasoning = occ2_truth %in% reasoning_classifications
         ) %>%
  ungroup() %>%
  count(right_in_reasoning) %>%
  mutate(pct=n/sum(n))
#> # A tibble: 2 × 3
#>   right_in_reasoning     n    pct
#>   <lgl>              <int>  <dbl>
#> 1 FALSE                 19 0.0615
#> 2 TRUE                 290 0.939

Yet it likely is not worth to invest more time into trying to get good results out of a local model for this use-case. Thus, it looks like we still have a clear favorite final choice without even needing to do further testing. We already know test performance for Claude Sonnet 3.5 on the original split, since we manually checked it. Paying the roughly 2,50$ for the Anthropic API and use Claude Sonnet to classify the whole sample at 98% accuracy is probably worth it in this case. So the final step is to use the initial classify_occupation() based on claude() on all data.

Scaling Up to the Full Dataset

So we choose Claude 3.5 sonnet, our original prompt and apply it to our data-set:

full_occupations_classified <- distinct_occupations$occupation |>
  map_dfr(classify_occupation)

occ_data |> left_join(full_occupations_classified, by="occupation_open")

Conclusion

In this article, we’ve demonstrated how tidyllm can be effectively used to tackle complex text classification tasks. By walking through the process step-by-step, from sample selection and prompt design to model testing and optimization, we’ve highlighted the flexibility of the package in dealing with real-world data.