O Que Aprendi ao Tentar Empurrar a Compressão de LLM Abaixo de 3 Bits num MacBook Pro

Se tivesse parado no "win" aparente do Fisher×Wanda, tinha publicado um resultado enganoso. A ablação rigorosa transformou "descobri método X que funciona" em "mapa honesto de um espaço de design com achados positivos e negativos".

17 de abril de 2026
O Que Aprendi ao Tentar Empurrar a Compressão de LLM Abaixo de 3 Bits num MacBook Pro

O Que Aprendi ao Tentar Empurrar a Compressão de LLM Abaixo de 3 Bits num MacBook Pro

Comecei por tentar comprimir um LLM a 1 bit por peso. Explodiu

Quinze iterações depois, num MacBook Pro M3 em CPU, cheguei a algo útil — mas não foi o que eu esperava.

O objetivo era ambicioso: comprimir o Gemma-2-2B agressivamente, sem retraining, sem GPU. Research experimental, para perceber até onde se consegue ir com ferramentas simples.

O PONTO DE PARTIDA

Primeira tentativa: 1-bit ingénuo. Perplexity > 10.000. Modelo completamente destruído. Lição: modelos como Bonsai ou BitNet só funcionam em 1-bit porque são TREINADOS em 1-bit. Post-training quantization agressiva sem compensação é destruição de informação.

Segunda tentativa: Wanda pruning + INT4. Perplexity 21, parece bom. Até testar num prompt real: "Explain concentric vs eccentric muscle contraction" → o modelo começou a debitar problemas de física em loop. Perplexity mentiu.

𝗘𝘀𝘁𝗮 𝗳𝗼𝗶 𝗮 𝗽𝗿𝗶𝗺𝗲𝗶𝗿𝗮 𝗴𝗿𝗮𝗻𝗱𝗲 𝗱𝗲𝘀𝗰𝗼𝗯𝗲𝗿𝘁𝗮: em compressão agressiva, perplexity pode coexistir com generation completamente colapsada. A métrica standard do campo é insuficiente.

O "WIN" APARENTE - E A ABLAÇÃO QUE O DESMONTOU

Avancei para algo mais sofisticado: Fisher information + Wanda scores, alocação mista de bits por camada. Resultado: 17.57 perplexity no modelo base, 23.51 no instruction-tuned. Ratio 1.07×. Território GPTQ/AWQ mas com compressão superior.

Aqui podia ter celebrado e publicado. Em vez disso, fiz ablação rigorosa.

Testei 4 condições ao mesmo budget (~2.7 bits):

Uniforme simples (sem Fisher nem adaptação): ppl 25.34 Wanda adaptativo sem Fisher: ppl 27.06 Fisher sozinho: ppl 25.26 Fisher×Wanda joint (o "meu método"): ppl 24.97

𝗗𝗶𝗳𝗲𝗿𝗲𝗻𝗰̧𝗮 𝗱𝗲 𝟭.𝟱% 𝗲𝗻𝘁𝗿𝗲 𝗷𝗼𝗶𝗻𝘁 𝗲 𝘂𝗻𝗶𝗳𝗼𝗿𝗺𝗲. Dentro da margem de erro de calibração. O "método inteligente" não era detectavelmente melhor que o simples.

O FATOR REAL QUE IMPORTAVA

Se Fisher não era responsável, o que era? Os métodos tinham uma coisa em comum: todos skippavam embed_tokens e lm_head.

Testei sem esse skip. Perplexity subiu de 25.34 para 31.63. 𝗔 𝗱𝗲𝗰𝗶𝘀𝗮̃𝗼 𝗱𝗲 𝘀𝗸𝗶𝗽 𝘀𝗼𝘇𝗶𝗻𝗵𝗮 𝗰𝗼𝗻𝘁𝗿𝗶𝗯𝘂𝗶 𝟮𝟱% 𝗱𝗲 𝗿𝗲𝗱𝘂𝗰̧𝗮̃𝗼 — vs 1.5% para Fisher×Wanda joint. Ordem de grandeza a mais.

Investiguei a arquitetura e descobri que no Gemma-2-2B, embed_tokens e lm_head são a MESMA matriz (tied weights) — 590 milhões de parâmetros, 22% do modelo.

Quando reportava "2.73 bits/peso" estava a esconder 22% do modelo em bf16. O budget real era 5.14 bits, não 2.73. Admissão honesta necessária.

GENERALIZAÇÃO (OU A FALTA DELA)

Repliquei a ablação no Qwen2.5-3B-Instruct. Qwen também tem tied embeddings, mas representam apenas 10% do modelo (vs 22% do Gemma).

Impacto do skip:

Gemma (22%): +24.8% Qwen (10%): +7.9%

𝗖𝗼𝗿𝗿𝗲𝗹𝗮𝗰̧𝗮̃𝗼 𝗱𝗶𝗿𝗲𝘁𝗮. O "fator dominante" não é universal — é proporcional à percentagem de tied weights no modelo. Generalizar resultados de compressão de um único modelo é arriscado.

E bonus: Qwen degradou 3× mais que Gemma à mesma config. Arquiteturas similares, sensibilidade a compressão muito diferente.

𝗔 𝗺𝗶𝗻𝗵𝗮 𝘀𝗲𝗴𝘂𝗻𝗱𝗮 𝗵𝗶𝗽𝗼́𝘁𝗲𝘀𝗲 𝗳𝗮𝗹𝘀𝗶𝗳𝗶𝗰𝗮𝗱𝗮

Nalguns outputs do Qwen comprimido apareceram caracteres chineses em prompts em inglês. Hipótese: o lm_head quantizado corrompe a distribuição de saída, fazendo "sangrar" tokens multilingues.

Testei isoladamente. Zero caracteres CJK em todas as condições (8-bit, 6-bit, 4-bit, 4-bit+sparsity). Hipótese falsificada.

O code-switching é efeito emergente de compressão combinada — não vem de uma camada específica. Implicação prática: preservar apenas o lm_head não chega para prevenir o problema em modelos multilingues.

O QUE APRENDI

Seis conclusões honestas de 18 horas de research:

Perplexity é insuficiente. Validar generation qualitativa em prompts do domínio é essencial. Fisher information replica findings de papers recentes (V projections são mais sensíveis que Q/K) mas contribui marginalmente na prática. O fator dominante em quantização agressiva é a decisão sobre tied embeddings, não a sofisticação do método de alocação. Generalização entre modelos é frágil — 3× de diferença em sensibilidade à mesma config. Code-switching em modelos multilingues comprimidos é real mas emerge de interação combinada, não de uma camada. "Bits reportados" sem qualificação sobre tied embeddings é enganoso.

𝗔 𝗹𝗶𝗰̧𝗮̃𝗼 𝗺𝗲𝘁𝗼𝗱𝗼𝗹𝗼́𝗴𝗶𝗰𝗮

Research real é sobre tentar falsificar as próprias ideias. Cada pivô importante neste trabalho veio de admitir que uma hipótese inicial era menos robusta do que parecia.

Se tivesse parado no "win" aparente do Fisher×Wanda, tinha publicado um resultado enganoso. A ablação rigorosa transformou "descobri método X que funciona" em "mapa honesto de um espaço de design com achados positivos e negativos".

𝗡𝗮̃𝗼 𝘀𝗮𝗯𝗲𝗿 𝗲́ 𝘂𝗺 𝗿𝗲𝘀𝘂𝗹𝘁𝗮𝗱𝗼. Testar a própria hipótese e descobrir que contribui 1.5% em vez de 15% também é.

import os os.environ["HF_HOME"] = "/Volumes/Extreme SSD/hf_cache"

import torch from transformers import AutoModelForCausalLM, AutoTokenizer from datasets import load_dataset import numpy as np import gc import time import json import os.path import copy

torch.set_num_threads(8) torch.set_num_interop_threads(2) torch.set_grad_enabled(False)

model_id = "google/gemma-2-2b-it" device = "cpu"

N_CALIB_WANDA = 32 MAX_LEN_WANDA = 256

BASELINE_CACHE = "baseline_it.json" FISHER_CACHE = "fisher_it.json" X3_RESULTS_CACHE = "x3_results.json"

test_prompts = [ "Explain the difference between concentric and eccentric muscle contraction:", "A rugby player needs to improve their 40m sprint time. Suggest a training plan:", "What is velocity-based training and why does it matter for strength coaches?", "The athlete completed 5 reps at 80% 1RM with mean velocity of 0.62 m/s. What does this tell us?", "Write a 3-sentence coaching cue for correcting knee valgus in a squat:", ]

# FUNÇÕES AUXILIARES def calculate_perplexity(model, tokenizer, texts): total_loss, total_tokens = 0.0, 0 model.eval() with torch.no_grad(): for text in texts: inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device) loss = model(**inputs, labels=inputs.input_ids).loss.item() total_loss += loss * inputs.input_ids.size(1) total_tokens += inputs.input_ids.size(1) return np.exp(total_loss / total_tokens)

def generate_with_chat_template(model, tokenizer, prompt, max_new_tokens=150): messages = [{"role": "user", "content": prompt}] input_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(input_text, return_tensors="pt").to(device) with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, repetition_penalty=1.3, pad_token_id=tokenizer.eos_token_id, ) return tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)

def wanda_prune(W, act_norm, sparsity): W = W.float().clone() score = W.abs() * act_norm.unsqueeze(0) k = int(W.shape[1] * sparsity) _, prune_idx = torch.topk(score, k, dim=1, largest=False) mask = torch.ones_like(W, dtype=torch.bool) mask.scatter_(1, prune_idx, False) W[~mask] = 0.0 return W, mask

def quantize_vectorized(W, mask, bits, group_size=128): W = W.float().clone() out_f, in_f = W.shape pad = (group_size - in_f % group_size) % group_size if pad > 0: W = torch.nn.functional.pad(W, (0, pad)) mask = torch.nn.functional.pad(mask, (0, pad), value=False) in_f_padded = W.shape[1] n_groups = in_f_padded // group_size W_g = W.view(out_f, n_groups, group_size) m_g = mask.view(out_f, n_groups, group_size) W_masked = W_g * m_g.float() max_abs = W_masked.abs().amax(dim=2, keepdim=True) qmax = max(1, 2**(bits-1) - 1) if bits > 1 else 1 scale = (max_abs / qmax).clamp(min=1e-8) if bits == 1: q = torch.sign(W_g) else: q = torch.round(W_g / scale).clamp(-qmax, qmax) dequant = q * scale dequant = dequant * m_g.float() return dequant.view(out_f, in_f_padded)[:, :in_f]

def compute_budget(linear_modules, allocation): total_bits = 0 total_weights = 0 for name, module in linear_modules.items(): if name not in allocation: continue w = module.weight.numel() alloc = allocation[name] effective_bits = alloc["bits"] * (1 - alloc["sparsity"]) total_bits += effective_bits * w total_weights += w return total_bits / total_weights

def apply_allocation(model, allocation, act_norms): """Aplica uma alocação (bits+sparsity por camada) ao modelo.""" linear_modules = {} for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear) and "embed_tokens" not in name and "lm_head" not in name: linear_modules[name] = module for name, module in linear_modules.items(): if name not in allocation or name not in act_norms: continue alloc = allocation[name] W_pruned, mask = wanda_prune(module.weight.data, act_norms[name], alloc["sparsity"]) W_final = quantize_vectorized(W_pruned, mask, alloc["bits"]) module.weight.data = W_final.to(torch.bfloat16)

def evaluate_condition(model, tokenizer, texts, name): """Calcula ppl + generation e devolve métricas.""" print(f"\n [{name}] Calculando PPL...") t0 = time.time() ppl = calculate_perplexity(model, tokenizer, texts) print(f" PPL={ppl:.2f} ({time.time()-t0:.0f}s)") print(f" [{name}] Gerando outputs...") torch.manual_seed(42) outputs = [] for prompt in test_prompts: gen = generate_with_chat_template(model, tokenizer, prompt) outputs.append((prompt, gen)) return ppl, outputs

def repetition_rate(text, n=4): tokens = text.split() if len(tokens) < n: return 0.0 ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)] return 1.0 - (len(set(ngrams)) / max(len(ngrams), 1))

def distinct_n(text, n=2): tokens = text.split() if len(tokens) < n: return 0.0 ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)] return len(set(ngrams)) / max(len(ngrams), 1)

# CARREGAR TUDO O QUE TEMOS EM CACHE print("Carregando caches...")

with open(BASELINE_CACHE, "r") as f: baseline_data = json.load(f) ppl_baseline = baseline_data["ppl_baseline"] baseline_outputs = [(o["prompt"], o["response"]) for o in baseline_data["outputs"]] print(f" Baseline IT: ppl={ppl_baseline:.2f}")

with open(FISHER_CACHE, "r") as f: fisher_scores = json.load(f) print(f" Fisher: {len(fisher_scores)} camadas")

with open(X3_RESULTS_CACHE, "r") as f: x3_data = json.load(f) ppl_x3 = x3_data["ppl_compressed"] x3_outputs = [(o["prompt"], o["response"]) for o in x3_data["outputs"]] budget_x3 = x3_data["budget"] print(f" X3: ppl={ppl_x3:.2f}, budget={budget_x3:.2f}")

# CARREGAR MODELO E CALCULAR WANDA UMA VEZ print(f"\nCarregando {model_id} em bf16...") model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16).to(device) tokenizer = AutoTokenizer.from_pretrained(model_id)

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") texts = [t for t in dataset["text"][:300] if t.strip()]

print(f"\nCalculando Wanda activation norms ({N_CALIB_WANDA} samples)...")

linear_modules = {} for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear) and "embed_tokens" not in name and "lm_head" not in name: linear_modules[name] = module

act_norms = {} def make_hook(name): def hook(module, input, output): x = input[0].detach() x_flat = x.reshape(-1, x.shape[-1]).float() norm_sq = x_flat.pow(2).sum(dim=0) if name in act_norms: act_norms[name] = act_norms[name] + norm_sq else: act_norms[name] = norm_sq return hook

hooks = [] for name, module in linear_modules.items(): hooks.append(module.register_forward_hook(make_hook(name)))

t0 = time.time() with torch.no_grad(): for text in texts[:N_CALIB_WANDA]: inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MAX_LEN_WANDA).to(device) if inputs.input_ids.size(1) < 8: continue model(**inputs)

for h in hooks: h.remove() for name in act_norms: act_norms[name] = act_norms[name].sqrt()

print(f" Wanda em {time.time()-t0:.0f}s")

# Guarda uma cópia original dos pesos antes de qualquer modificação print("\nGuardando snapshot dos pesos originais...") original_weights = {} for name, module in linear_modules.items(): original_weights[name] = module.weight.data.clone()

def restore_weights(): for name, module in linear_modules.items(): module.weight.data = original_weights[name].clone()

# CONDIÇÃO 1: UNIFORM BASELINE # Sparsity 32% uniforme, 4-bit uniforme em todas as camadas # Budget: 0.68 × 4 = 2.72 bits (~igual a X3) print("\n" + "="*70) print("CONDIÇÃO 1: UNIFORM (sparsity 32% + 4-bit, tudo igual)") print("="*70)

alloc_1 = {name: {"bits": 4, "sparsity": 0.32} for name in linear_modules} budget_1 = compute_budget(linear_modules, alloc_1) print(f"Budget efetivo: {budget_1:.2f} bits/peso")

apply_allocation(model, alloc_1, act_norms) ppl_1, out_1 = evaluate_condition(model, tokenizer, texts, "UNIFORM")

restore_weights() gc.collect()

# CONDIÇÃO 2: WANDA ADAPTATIVO SEM FISHER # Sparsity varia por camada baseada em activation magnitude (média do act_norm) # Bits: 4-bit uniforme # Budget alvo: ~2.72 bits print("\n" + "="*70) print("CONDIÇÃO 2: ACT-ADAPTIVE (sparsity varia por act_magnitude, sem Fisher)") print("="*70)

# Calcular "importância" por camada via média de act_norm act_importance = {name: float(act_norms[name].mean()) for name in linear_modules} sorted_by_act = sorted(act_importance.items(), key=lambda x: x[1], reverse=True)

# Alocar sparsity variável: top camadas ficam com menos sparsity, fundo com mais # Estrutura similar a X3 mas sem Fisher: 15% HIGH, 60% MID, 25% LOW n_total = len(sorted_by_act) n_top = int(n_total * 0.15) n_bottom = int(n_total * 0.25)

alloc_2 = {} for i, (name, _) in enumerate(sorted_by_act): if i < n_top: alloc_2[name] = {"bits": 4, "sparsity": 0.15} # HIGH act elif i >= n_total - n_bottom: alloc_2[name] = {"bits": 4, "sparsity": 0.50} # LOW act else: alloc_2[name] = {"bits": 4, "sparsity": 0.32} # MID

budget_2 = compute_budget(linear_modules, alloc_2) print(f"Budget efetivo: {budget_2:.2f} bits/peso")

apply_allocation(model, alloc_2, act_norms) ppl_2, out_2 = evaluate_condition(model, tokenizer, texts, "ACT-ADAPTIVE")

restore_weights() gc.collect()

# CONDIÇÃO 3: FISHER SOZINHO (bits variam, sparsity uniforme) # Bits: HIGH=6, MID=4, LOW=3 por Fisher # Sparsity: 32% uniforme em todas # Budget alvo: ~2.72 bits print("\n" + "="*70) print("CONDIÇÃO 3: FISHER-ONLY (bits por Fisher, sparsity uniforme)") print("="*70)

sorted_fisher = sorted(fisher_scores.items(), key=lambda x: x[1], reverse=True) alloc_3 = {} for i, (name, _) in enumerate(sorted_fisher): if i < n_top: alloc_3[name] = {"bits": 6, "sparsity": 0.32} elif i >= n_total - n_bottom: alloc_3[name] = {"bits": 3, "sparsity": 0.32} else: alloc_3[name] = {"bits": 4, "sparsity": 0.32}

budget_3 = compute_budget(linear_modules, alloc_3) print(f"Budget efetivo: {budget_3:.2f} bits/peso")

apply_allocation(model, alloc_3, act_norms) ppl_3, out_3 = evaluate_condition(model, tokenizer, texts, "FISHER-ONLY")

restore_weights() gc.collect()

# CONDIÇÃO 4: X3 (Fisher×Wanda Joint) - já temos do cache print("\n" + "="*70) print("CONDIÇÃO 4: FISHER×WANDA JOINT (X3, do cache)") print("="*70) print(f"Budget: {budget_x3:.2f} bits") print(f"PPL: {ppl_x3:.2f}")

out_4 = x3_outputs

# RESULTADOS print("\n" + "="*70) print("✅ ABLAÇÃO Y2: CONTRIBUIÇÃO ISOLADA DO FISHER") print("="*70) print(f"Baseline IT (bf16): ppl = {ppl_baseline:.2f}") print() print(f"{'Condição':<40} {'Bits':>6} {'PPL':>8} {'Ratio':>7}") print("-" * 70) print(f"{'1. UNIFORM (nem Fisher nem adapt.)':<40} {budget_1:>6.2f} {ppl_1:>8.2f} {ppl_1/ppl_baseline:>6.3f}x") print(f"{'2. ACT-ADAPTIVE (Wanda adaptativo)':<40} {budget_2:>6.2f} {ppl_2:>8.2f} {ppl_2/ppl_baseline:>6.3f}x") print(f"{'3. FISHER-ONLY (bits variam)':<40} {budget_3:>6.2f} {ppl_3:>8.2f} {ppl_3/ppl_baseline:>6.3f}x") print(f"{'4. FISHER×WANDA JOINT (X3)':<40} {budget_x3:>6.2f} {ppl_x3:>8.2f} {ppl_x3/ppl_baseline:>6.3f}x")

print("\n" + "="*70) print("ANÁLISE DE CONTRIBUIÇÃO") print("="*70) delta_adapt = (ppl_1 - ppl_2) / ppl_1 * 100 delta_fisher = (ppl_1 - ppl_3) / ppl_1 * 100 delta_joint = (ppl_1 - ppl_x3) / ppl_1 * 100 synergy = delta_joint - (delta_adapt + delta_fisher)

print(f"Contribuição Wanda adaptativo vs uniforme : {delta_adapt:+.1f}% redução de ppl") print(f"Contribuição Fisher-only vs uniforme : {delta_fisher:+.1f}% redução de ppl") print(f"Contribuição Joint (X3) vs uniforme : {delta_joint:+.1f}% redução de ppl") print(f"Sinergia (Joint - soma das partes) : {synergy:+.1f}%")

print("\n" + "="*70) print("MÉTRICAS DE GENERATION POR CONDIÇÃO") print("="*70) all_texts = { "Baseline": " ".join([o[1] for o in baseline_outputs]), "1-UNIFORM": " ".join([o[1] for o in out_1]), "2-ACT-ADAPT": " ".join([o[1] for o in out_2]), "3-FISHER-ONLY": " ".join([o[1] for o in out_3]), "4-JOINT (X3)": " ".join([o[1] for o in out_4]), } print(f"{'Condição':<20} {'rep4':>8} {'dist2':>8}") print("-" * 40) for name, text in all_texts.items(): print(f"{name:<20} {repetition_rate(text):>8.3f} {distinct_n(text):>8.3f}")

# OUTPUTS COMPARATIVOS — PROMPT 3 (VBT) como exemplar print("\n" + "="*70) print("EXEMPLO COMPARATIVO: Prompt 3 (VBT) em todas as condições") print("="*70) prompt = test_prompts[2] print(f"PROMPT: {prompt}\n") print("[BASELINE]") print(baseline_outputs[2][1][:350]) print("\n[1-UNIFORM]") print(out_1[2][1][:350]) print("\n[2-ACT-ADAPTIVE]") print(out_2[2][1][:350]) print("\n[3-FISHER-ONLY]") print(out_3[2][1][:350]) print("\n[4-JOINT X3]") print(out_4[2][1][:350])

# SALVAR TUDO with open("y2_ablation_results.json", "w") as f: json.dump({ "ppl_baseline": ppl_baseline, "conditions": { "1_uniform": {"budget": budget_1, "ppl": ppl_1, "outputs": [{"prompt":p, "response":r} for p,r in out_1]}, "2_act_adaptive": {"budget": budget_2, "ppl": ppl_2, "outputs": [{"prompt":p, "response":r} for p,r in out_2]}, "3_fisher_only": {"budget": budget_3, "ppl": ppl_3, "outputs": [{"prompt":p, "response":r} for p,r in out_3]}, "4_joint_x3": {"budget": budget_x3, "ppl": ppl_x3, "outputs": [{"prompt":p, "response":r} for p,r in out_4]}, }, "contributions": { "wanda_adaptive_pct": delta_adapt, "fisher_only_pct": delta_fisher, "joint_pct": delta_joint, "synergy_pct": synergy, } }, f, indent=2, ensure_ascii=False) print("\nResultados salvos em y2_ablation_results.json")

#MachineLearning #LLM #Quantization #ResearchInProgress #OpenScience