尧图网站建设 尧图网络
  • 首页
  • 关于我们
  • 服务项目
  • 案例展示
  • 建站流程
  • 资讯中心
  • 联系我们
首页/资讯中心/详情

深入解析:李宏毅2025春季机器学习作业ML2025_Spring_HW4在kaggle上的实操笔记

深入解析:李宏毅2025春季机器学习作业ML2025_Spring_HW4在kaggle上的实操笔记
📅 发布时间:2026/6/20 3:16:23

Training Transformer

TA’s Slide

Slide

Description

In this assignment, we are tasked with utilizing a transformer decoder-only architecture for pretraining, with a focus on next-token prediction, applied to Pokémon images.

Please feel free to mail us if you have any questions.

ntu-ml-2025-spring-ta@googlegroups.com

Utilities

Download packages

!pip install datasets==3.3.2
Collecting datasets==3.3.2Using cached datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Requirement already satisfied: filelock in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (3.17.0)
Requirement already satisfied: numpy>=1.17 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (2.0.1)
Requirement already satisfied: pyarrow>=15.0.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (21.0.0)
Collecting dill<0.3.9,>=0.3.0 (from datasets==3.3.2)Using cached dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Requirement already satisfied: pandas in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (2.3.1)
Requirement already satisfied: requests>=2.32.2 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (2.32.5)
Requirement already satisfied: tqdm>=4.66.3 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (4.67.1)
Requirement already satisfied: xxhash in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (3.6.0)
Requirement already satisfied: multiprocess<0.70.17 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (0.70.16)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets==3.3.2)Using cached fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Requirement already satisfied: aiohttp in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (3.13.0)
Requirement already satisfied: huggingface-hub>=0.24.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (0.35.3)
Requirement already satisfied: packaging in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (25.0)
Requirement already satisfied: pyyaml>=5.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (6.0.2)
Requirement already satisfied: aiohappyeyeballs>=2.5.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (2.6.1)
Requirement already satisfied: aiosignal>=1.4.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (1.4.0)
Requirement already satisfied: async-timeout<6.0,>=4.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (5.0.1)
Requirement already satisfied: attrs>=17.3.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (25.3.0)
Requirement already satisfied: frozenlist>=1.1.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (1.8.0)
Requirement already satisfied: multidict<7.0,>=4.5 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (6.7.0)
Requirement already satisfied: propcache>=0.2.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (0.4.0)
Requirement already satisfied: yarl<2.0,>=1.17.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (1.22.0)
Requirement already satisfied: typing-extensions>=4.1.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from multidict<7.0,>=4.5->aiohttp->datasets==3.3.2) (4.15.0)
Requirement already satisfied: idna>=2.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from yarl<2.0,>=1.17.0->aiohttp->datasets==3.3.2) (3.7)
Requirement already satisfied: charset_normalizer<4,>=2 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from requests>=2.32.2->datasets==3.3.2) (3.3.2)
Requirement already satisfied: urllib3<3,>=1.21.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from requests>=2.32.2->datasets==3.3.2) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from requests>=2.32.2->datasets==3.3.2) (2025.10.5)
Requirement already satisfied: colorama in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from tqdm>=4.66.3->datasets==3.3.2) (0.4.6)
Requirement already satisfied: python-dateutil>=2.8.2 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from pandas->datasets==3.3.2) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from pandas->datasets==3.3.2) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from pandas->datasets==3.3.2) (2025.2)
Requirement already satisfied: six>=1.5 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from python-dateutil>=2.8.2->pandas->datasets==3.3.2) (1.17.0)
Using cached datasets-3.3.2-py3-none-any.whl (485 kB)
Using cached dill-0.3.8-py3-none-any.whl (116 kB)
Using cached fsspec-2024.12.0-py3-none-any.whl (183 kB)
Installing collected packages: fsspec, dill, datasetsAttempting uninstall: fsspecFound existing installation: fsspec 2025.9.0Uninstalling fsspec-2025.9.0:Successfully uninstalled fsspec-2025.9.0---------------------------------------- 0/3 [fsspec]---------------------------------------- 0/3 [fsspec]---------------------------------------- 0/3 [fsspec]---------------------------------------- 0/3 [fsspec]Attempting uninstall: dill---------------------------------------- 0/3 [fsspec]Found existing installation: dill 0.4.0---------------------------------------- 0/3 [fsspec]Uninstalling dill-0.4.0:---------------------------------------- 0/3 [fsspec]Successfully uninstalled dill-0.4.0---------------------------------------- 0/3 [fsspec]------------- -------------------------- 1/3 [dill]------------- -------------------------- 1/3 [dill]------------- -------------------------- 1/3 [dill]Attempting uninstall: datasets------------- -------------------------- 1/3 [dill]Found existing installation: datasets 4.1.1------------- -------------------------- 1/3 [dill]Uninstalling datasets-4.1.1:------------- -------------------------- 1/3 [dill]Successfully uninstalled datasets-4.1.1------------- -------------------------- 1/3 [dill]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]---------------------------------------- 3/3 [datasets]
Successfully installed datasets-3.3.2 dill-0.3.8 fsspec-2024.12.0

Import Packages

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, GPT2Config, set_seed
from datasets import load_dataset
from typing import Dict, Any, Optional

Check Devices

!nvidia-smi
Wed Oct  8 18:50:06 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.97                 Driver Version: 580.97         CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3090 Ti   WDDM  |   00000000:07:00.0  On |                  Off |
| 47%   42C    P8             25W /  450W |   12684MiB /  24564MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A            2292    C+G   C:\Windows\System32\dwm.exe           N/A      |
|    0   N/A  N/A            5552    C+G   ...8bbwe\PhoneExperienceHost.exe      N/A      |
|    0   N/A  N/A            9928    C+G   C:\Windows\explorer.exe               N/A      |
|    0   N/A  N/A           10036    C+G   ..._cw5n1h2txyewy\SearchHost.exe      N/A      |
|    0   N/A  N/A           10264    C+G   ...y\StartMenuExperienceHost.exe      N/A      |
|    0   N/A  N/A           10632    C+G   ...ogram Files\ToDesk\ToDesk.exe      N/A      |
|    0   N/A  N/A           14304    C+G   ...xyewy\ShellExperienceHost.exe      N/A      |
|    0   N/A  N/A           15600    C+G   ...5n1h2txyewy\TextInputHost.exe      N/A      |
|    0   N/A  N/A           15812    C+G   ...ouryDevice\asus_framework.exe      N/A      |
|    0   N/A  N/A           18660    C+G   ...crosoft\OneDrive\OneDrive.exe      N/A      |
|    0   N/A  N/A           18668    C+G   ...Chrome\Application\chrome.exe      N/A      |
|    0   N/A  N/A           21724    C+G   ....0.3537.57\msedgewebview2.exe      N/A      |
|    0   N/A  N/A           22748    C+G   ...s\TencentDocs\TencentDocs.exe      N/A      |
|    0   N/A  N/A           25412    C+G   ...ram Files\Tencent\QQNT\QQ.exe      N/A      |
|    0   N/A  N/A           25872    C+G   ...Chrome\Application\chrome.exe      N/A      |
|    0   N/A  N/A           26600    C+G   ...ocal\Programs\Quark\quark.exe      N/A      |
|    0   N/A  N/A           28688    C+G   ...ntrolPanel\SystemSettings.exe      N/A      |
|    0   N/A  N/A           30104    C+G   ...de\Microsoft VS Code\Code.exe      N/A      |
|    0   N/A  N/A           31500    C+G   ....0.3537.57\msedgewebview2.exe      N/A      |
|    0   N/A  N/A           39276    C+G   ...t\Edge\Application\msedge.exe      N/A      |
|    0   N/A  N/A           41696    C+G   ...PotPlayer\PotPlayerMini64.exe      N/A      |
|    0   N/A  N/A           44176    C+G   ...ffice6\promecefpluginhost.exe      N/A      |
|    0   N/A  N/A           72652      C   ...2025-Spring-Hw1\python.exe.c~      N/A      |
|    0   N/A  N/A          115660    C+G   ...ef.win7x64\steamwebhelper.exe      N/A      |
|    0   N/A  N/A          124396    C+G   ...yb3d8bbwe\WindowsTerminal.exe      N/A      |
+-----------------------------------------------------------------------------------------+

Set Random Seed

set_seed(0)

Prepare Data

Define Dataset

from typing import List, Tuple, Union
import torch
from torch.utils.data import Dataset
class PixelSequenceDataset(Dataset):
def __init__(self, data: List[List[int]], mode: str = "train"):
"""
A dataset class for handling pixel sequences.
Args:
data (List[List[int]]): A list of sequences, where each sequence is a list of integers.
mode (str): The mode of operation, either "train", "dev", or "test".
- "train": Returns (input_ids, labels) where input_ids are sequence[:-1] and labels are sequence[1:].
- "dev": Returns (input_ids, labels) where input_ids are sequence[:-160] and labels are sequence[-160:].
- "test": Returns only input_ids, as labels are not available.
"""
self.data = data
self.mode = mode
def __len__(self) -> int:
"""Returns the total number of sequences in the dataset."""
return len(self.data)
def __getitem__(self, idx: int) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Fetches a sequence from the dataset and processes it based on the mode.
Args:
idx (int): The index of the sequence.
Returns:
- If mode == "train": Tuple[torch.Tensor, torch.Tensor] -> (input_ids, labels)
- If mode == "dev": Tuple[torch.Tensor, torch.Tensor] -> (input_ids, labels)
- If mode == "test": torch.Tensor -> input_ids
"""
sequence = self.data[idx]
if self.mode == "train":
input_ids = torch.tensor(sequence[:-1], dtype=torch.long)
labels = torch.tensor(sequence[1:], dtype=torch.long)
return input_ids, labels
elif self.mode == "dev":
input_ids = torch.tensor(sequence[:-160], dtype=torch.long)
labels = torch.tensor(sequence[-160:], dtype=torch.long)
return input_ids, labels
elif self.mode == "test":
input_ids = torch.tensor(sequence, dtype=torch.long)
return input_ids
raise ValueError(f"Invalid mode: {
self.mode}. Choose from 'train', 'dev', or 'test'.")

Download Dataset & Prepare Dataloader

# Load the pokemon dataset from Hugging Face Hub
pokemon_dataset = load_dataset("lca0503/ml2025-hw4-pokemon")
# Load the colormap from Hugging Face Hub
colormap = list(load_dataset("lca0503/ml2025-hw4-colormap")["train"]["color"])
# Define number of classes
num_classes = len(colormap)
# Define batch size
batch_size = 16
# === Prepare Dataset and DataLoader for Training ===
train_dataset: PixelSequenceDataset = PixelSequenceDataset(
pokemon_dataset["train"]["pixel_color"], mode="train"
)
train_dataloader: DataLoader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
)
# === Prepare Dataset and DataLoader for Validation ===
dev_dataset: PixelSequenceDataset = PixelSequenceDataset(
pokemon_dataset["dev"]["pixel_color"], mode="dev"
)
dev_dataloader: DataLoader = DataLoader(
dev_dataset, batch_size=batch_size, shuffle=False
)
# === Prepare Dataset and DataLoader for Testing ===
test_dataset: PixelSequenceDataset = PixelSequenceDataset(
pokemon_dataset["test"][

相关新闻

  • 完整教程:PostgreSQL + Redis + Elasticsearch 实时同步方案实践:从触发器到高性能搜索
  • 基于最小二乘法的五颗可见卫星伪距定位
  • new day

最新新闻

  • 10分钟搞定黑苹果配置:OpCore-Simplify让复杂OpenCore EFI创建变得简单
  • 免费解锁WeMod专业版:终极指南与完整解决方案
  • LPC540xx系列微控制器外设深度解析:GPIO、通信接口与低功耗设计实践
  • MC68HC908QF4时钟系统深度解析:从内部RC到外部晶振的实战配置与避坑指南
  • 终极指南:如何用AlienFX Tools完全掌控你的Alienware设备灯光和风扇
  • MC68HC908GR8/GR4 Flash与中断系统深度解析与避坑指南

日新闻

  • 信任的进化:技术实现详解——如何用JavaScript构建博弈论模拟器
  • Terrakube自定义工作流:如何集成OPA、Infracost等工具扩展IaC能力
  • grunt-concurrent快速入门:5分钟学会并行运行Grunt任务

周新闻

  • 3步解锁iOS设备:applera1n激活锁绕过完全指南
  • 39 2026 人工智能证书终极盘点,普通人选 AI 证书可以从这些方向入手
  • Redis 暴露公网有多危险?从端口检查到补救步骤

月新闻

  • 【总结】入门篇:50句话让你记住架构核心概念
  • WeChatMsg技术方案解析:实现Mac微信数据自主管理的完整解决方案
  • WeChatMsg:革新性微信数据备份方案,打造你的专属数字记忆库

关于尧图

  • 公司简介
  • 团队介绍
  • 企业文化
  • 荣誉资质

服务项目

  • 定制开发
  • 电商建站
  • UI 设计
  • 运维服务

快速链接

  • 案例展示
  • 建站流程
  • 常见问题
  • 资讯中心

联系方式

  • 📍北京市朝阳区互联网产业园 A 座 10 层
  • 📞400-888-8888
  • ✉️contact@rkmt.cn
  • 🕐周一至周日 9:00-21:00

© 2024 北京尧图网络科技有限公司 版权所有 | 京 ICP 备 XXXXXXXX 号