当前位置: 首页 > news >正文

深入解析:李宏毅2025春季机器学习作业ML2025_Spring_HW4在kaggle上的实操笔记

Training Transformer

TA’s Slide

Slide

Description

In this assignment, we are tasked with utilizing a transformer decoder-only architecture for pretraining, with a focus on next-token prediction, applied to Pokémon images.

Please feel free to mail us if you have any questions.

ntu-ml-2025-spring-ta@googlegroups.com

Utilities

Download packages

!pip install datasets==3.3.2
Collecting datasets==3.3.2Using cached datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Requirement already satisfied: filelock in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (3.17.0)
Requirement already satisfied: numpy>=1.17 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (2.0.1)
Requirement already satisfied: pyarrow>=15.0.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (21.0.0)
Collecting dill<0.3.9,>=0.3.0 (from datasets==3.3.2)Using cached dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Requirement already satisfied: pandas in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (2.3.1)
Requirement already satisfied: requests>=2.32.2 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (2.32.5)
Requirement already satisfied: tqdm>=4.66.3 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (4.67.1)
Requirement already satisfied: xxhash in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (3.6.0)
Requirement already satisfied: multiprocess<0.70.17 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (0.70.16)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets==3.3.2)Using cached fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Requirement already satisfied: aiohttp in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (3.13.0)
Requirement already satisfied: huggingface-hub>=0.24.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (0.35.3)
Requirement already satisfied: packaging in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (25.0)
Requirement already satisfied: pyyaml>=5.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (6.0.2)
Requirement already satisfied: aiohappyeyeballs>=2.5.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (2.6.1)
Requirement already satisfied: aiosignal>=1.4.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (1.4.0)
Requirement already satisfied: async-timeout<6.0,>=4.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (5.0.1)
Requirement already satisfied: attrs>=17.3.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (25.3.0)
Requirement already satisfied: frozenlist>=1.1.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (1.8.0)
Requirement already satisfied: multidict<7.0,>=4.5 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (6.7.0)
Requirement already satisfied: propcache>=0.2.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (0.4.0)
Requirement already satisfied: yarl<2.0,>=1.17.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (1.22.0)
Requirement already satisfied: typing-extensions>=4.1.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from multidict<7.0,>=4.5->aiohttp->datasets==3.3.2) (4.15.0)
Requirement already satisfied: idna>=2.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from yarl<2.0,>=1.17.0->aiohttp->datasets==3.3.2) (3.7)
Requirement already satisfied: charset_normalizer<4,>=2 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from requests>=2.32.2->datasets==3.3.2) (3.3.2)
Requirement already satisfied: urllib3<3,>=1.21.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from requests>=2.32.2->datasets==3.3.2) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from requests>=2.32.2->datasets==3.3.2) (2025.10.5)
Requirement already satisfied: colorama in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from tqdm>=4.66.3->datasets==3.3.2) (0.4.6)
Requirement already satisfied: python-dateutil>=2.8.2 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from pandas->datasets==3.3.2) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from pandas->datasets==3.3.2) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from pandas->datasets==3.3.2) (2025.2)
Requirement already satisfied: six>=1.5 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from python-dateutil>=2.8.2->pandas->datasets==3.3.2) (1.17.0)
Using cached datasets-3.3.2-py3-none-any.whl (485 kB)
Using cached dill-0.3.8-py3-none-any.whl (116 kB)
Using cached fsspec-2024.12.0-py3-none-any.whl (183 kB)
Installing collected packages: fsspec, dill, datasetsAttempting uninstall: fsspecFound existing installation: fsspec 2025.9.0Uninstalling fsspec-2025.9.0:Successfully uninstalled fsspec-2025.9.0---------------------------------------- 0/3 [fsspec]---------------------------------------- 0/3 [fsspec]---------------------------------------- 0/3 [fsspec]---------------------------------------- 0/3 [fsspec]Attempting uninstall: dill---------------------------------------- 0/3 [fsspec]Found existing installation: dill 0.4.0---------------------------------------- 0/3 [fsspec]Uninstalling dill-0.4.0:---------------------------------------- 0/3 [fsspec]Successfully uninstalled dill-0.4.0---------------------------------------- 0/3 [fsspec]------------- -------------------------- 1/3 [dill]------------- -------------------------- 1/3 [dill]------------- -------------------------- 1/3 [dill]Attempting uninstall: datasets------------- -------------------------- 1/3 [dill]Found existing installation: datasets 4.1.1------------- -------------------------- 1/3 [dill]Uninstalling datasets-4.1.1:------------- -------------------------- 1/3 [dill]Successfully uninstalled datasets-4.1.1------------- -------------------------- 1/3 [dill]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]---------------------------------------- 3/3 [datasets]
Successfully installed datasets-3.3.2 dill-0.3.8 fsspec-2024.12.0

Import Packages

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, GPT2Config, set_seed
from datasets import load_dataset
from typing import Dict, Any, Optional

Check Devices

!nvidia-smi
Wed Oct  8 18:50:06 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.97                 Driver Version: 580.97         CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3090 Ti   WDDM  |   00000000:07:00.0  On |                  Off |
| 47%   42C    P8             25W /  450W |   12684MiB /  24564MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A            2292    C+G   C:\Windows\System32\dwm.exe           N/A      |
|    0   N/A  N/A            5552    C+G   ...8bbwe\PhoneExperienceHost.exe      N/A      |
|    0   N/A  N/A            9928    C+G   C:\Windows\explorer.exe               N/A      |
|    0   N/A  N/A           10036    C+G   ..._cw5n1h2txyewy\SearchHost.exe      N/A      |
|    0   N/A  N/A           10264    C+G   ...y\StartMenuExperienceHost.exe      N/A      |
|    0   N/A  N/A           10632    C+G   ...ogram Files\ToDesk\ToDesk.exe      N/A      |
|    0   N/A  N/A           14304    C+G   ...xyewy\ShellExperienceHost.exe      N/A      |
|    0   N/A  N/A           15600    C+G   ...5n1h2txyewy\TextInputHost.exe      N/A      |
|    0   N/A  N/A           15812    C+G   ...ouryDevice\asus_framework.exe      N/A      |
|    0   N/A  N/A           18660    C+G   ...crosoft\OneDrive\OneDrive.exe      N/A      |
|    0   N/A  N/A           18668    C+G   ...Chrome\Application\chrome.exe      N/A      |
|    0   N/A  N/A           21724    C+G   ....0.3537.57\msedgewebview2.exe      N/A      |
|    0   N/A  N/A           22748    C+G   ...s\TencentDocs\TencentDocs.exe      N/A      |
|    0   N/A  N/A           25412    C+G   ...ram Files\Tencent\QQNT\QQ.exe      N/A      |
|    0   N/A  N/A           25872    C+G   ...Chrome\Application\chrome.exe      N/A      |
|    0   N/A  N/A           26600    C+G   ...ocal\Programs\Quark\quark.exe      N/A      |
|    0   N/A  N/A           28688    C+G   ...ntrolPanel\SystemSettings.exe      N/A      |
|    0   N/A  N/A           30104    C+G   ...de\Microsoft VS Code\Code.exe      N/A      |
|    0   N/A  N/A           31500    C+G   ....0.3537.57\msedgewebview2.exe      N/A      |
|    0   N/A  N/A           39276    C+G   ...t\Edge\Application\msedge.exe      N/A      |
|    0   N/A  N/A           41696    C+G   ...PotPlayer\PotPlayerMini64.exe      N/A      |
|    0   N/A  N/A           44176    C+G   ...ffice6\promecefpluginhost.exe      N/A      |
|    0   N/A  N/A           72652      C   ...2025-Spring-Hw1\python.exe.c~      N/A      |
|    0   N/A  N/A          115660    C+G   ...ef.win7x64\steamwebhelper.exe      N/A      |
|    0   N/A  N/A          124396    C+G   ...yb3d8bbwe\WindowsTerminal.exe      N/A      |
+-----------------------------------------------------------------------------------------+

Set Random Seed

set_seed(0)

Prepare Data

Define Dataset

from typing import List, Tuple, Union
import torch
from torch.utils.data import Dataset
class PixelSequenceDataset(Dataset):
def __init__(self, data: List[List[int]], mode: str = "train"):
"""
A dataset class for handling pixel sequences.
Args:
data (List[List[int]]): A list of sequences, where each sequence is a list of integers.
mode (str): The mode of operation, either "train", "dev", or "test".
- "train": Returns (input_ids, labels) where input_ids are sequence[:-1] and labels are sequence[1:].
- "dev": Returns (input_ids, labels) where input_ids are sequence[:-160] and labels are sequence[-160:].
- "test": Returns only input_ids, as labels are not available.
"""
self.data = data
self.mode = mode
def __len__(self) -> int:
"""Returns the total number of sequences in the dataset."""
return len(self.data)
def __getitem__(self, idx: int) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Fetches a sequence from the dataset and processes it based on the mode.
Args:
idx (int): The index of the sequence.
Returns:
- If mode == "train": Tuple[torch.Tensor, torch.Tensor] -> (input_ids, labels)
- If mode == "dev": Tuple[torch.Tensor, torch.Tensor] -> (input_ids, labels)
- If mode == "test": torch.Tensor -> input_ids
"""
sequence = self.data[idx]
if self.mode == "train":
input_ids = torch.tensor(sequence[:-1], dtype=torch.long)
labels = torch.tensor(sequence[1:], dtype=torch.long)
return input_ids, labels
elif self.mode == "dev":
input_ids = torch.tensor(sequence[:-160], dtype=torch.long)
labels = torch.tensor(sequence[-160:], dtype=torch.long)
return input_ids, labels
elif self.mode == "test":
input_ids = torch.tensor(sequence, dtype=torch.long)
return input_ids
raise ValueError(f"Invalid mode: {
self.mode}. Choose from 'train', 'dev', or 'test'.")

Download Dataset & Prepare Dataloader

# Load the pokemon dataset from Hugging Face Hub
pokemon_dataset = load_dataset("lca0503/ml2025-hw4-pokemon")
# Load the colormap from Hugging Face Hub
colormap = list(load_dataset("lca0503/ml2025-hw4-colormap")["train"]["color"])
# Define number of classes
num_classes = len(colormap)
# Define batch size
batch_size = 16
# === Prepare Dataset and DataLoader for Training ===
train_dataset: PixelSequenceDataset = PixelSequenceDataset(
pokemon_dataset["train"]["pixel_color"], mode="train"
)
train_dataloader: DataLoader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
)
# === Prepare Dataset and DataLoader for Validation ===
dev_dataset: PixelSequenceDataset = PixelSequenceDataset(
pokemon_dataset["dev"]["pixel_color"], mode="dev"
)
dev_dataloader: DataLoader = DataLoader(
dev_dataset, batch_size=batch_size, shuffle=False
)
# === Prepare Dataset and DataLoader for Testing ===
test_dataset: PixelSequenceDataset = PixelSequenceDataset(
pokemon_dataset["test"][
http://www.rkmt.cn/news/44604.html

相关文章:

  • 完整教程:PostgreSQL + Redis + Elasticsearch 实时同步方案实践:从触发器到高性能搜索
  • 基于最小二乘法的五颗可见卫星伪距定位
  • new day
  • 2025 年 11 月冰水机厂家推荐排行榜,工业冰水机,冷却冰水机,制冷冰水机,低温冰水机公司精选
  • 完整教程:用 Java 指挥 3500 只机器人跳舞——Ocado 高密度仓储集群的架构实践
  • new day
  • How to do PhD work
  • 关于计算机语言的学习
  • VisionPro学习笔记-CogColorExtractorTool和CogColorSegmenterTool
  • CSP挂分记
  • 深入解析:51单片机基础-动态数码管显示
  • Ubuntu 22.04 的镜像源列表
  • 关于梅特勒-托利多 称重传感器检查
  • 局域网---传输文件资料信息
  • 从C++到wasm,并在JavaScript中调用
  • 详细介绍:计算某字符出现次数
  • 2026 NOI 做题记录(九)
  • 实用指南:C++STL---静态数组array
  • MCP神器!一键部署连接任何MCP服务器
  • [ docker del imags containers ]
  • Flask的核心知识点如下
  • 学习流程
  • 2025年评价高的MC减速机厂家最新推荐排行榜
  • 2025年口碑好的压榨机设备行业内知名厂家排行榜
  • 2025年质量好的圆管犁厂家最新权威推荐排行榜
  • 2025年口碑好的消防转子泵实力厂家TOP推荐榜
  • 2025年评价高的圆盘耙TOP品牌厂家排行榜
  • python--手势识别 - 详解
  • 2025年比较好的木浆竹浆挤浆机厂家推荐及采购参考
  • 2025年评价高的动画制作2025优质品牌榜