Python代码重构技巧与坏味道识别
代码重构是改善代码结构而不改变外部行为的过程。本文将介绍Python代码重构技巧和坏味道识别。
代码坏味道
- 长方法:方法过长,职责过多
- 重复代码:相同或相似的代码多处出现
- 过大的类:类包含过多属性和方法
- 过长参数列表:方法参数过多
重构核心实现
"""
Python代码重构技巧与坏味道识别
包含提取方法、内联、移动方法等重构技巧
"""
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
# ============ 坏味道示例 ============
class OrderProcessorBad:
"""坏味道示例:过大的类"""
def __init__(self):
self.orders = []
self.customers = []
self.products = []
self.inventory = {}
self.shipping_rates = {}
def process_order(self, order_id: int, customer_id: int,
product_ids: List[int], quantities: List[int],
shipping_address: str, payment_method: str,
discount_code: Optional[str] = None,
gift_wrap: bool = False,
special_instructions: str = ""):
"""坏味道:过长参数列表"""
# 验证客户
customer = None
for c in self.customers:
if c['id'] == customer_id:
customer = c
break
if not customer:
raise ValueError("客户不存在")
# 验证产品
products = []
for pid in product_ids:
for p in self.products:
if p['id'] == pid:
products.append(p)
break
# 计算价格
total = 0
for i, product in enumerate(products):
qty = quantities[i]
price = product['price']
# 应用折扣
if discount_code:
if discount_code == "SAVE10":
price = price * 0.9
elif discount_code == "SAVE20":
price = price * 0.8
total += price * qty
# 计算运费
if total < 50:
shipping = 5.99
elif total < 100:
shipping = 2.99
else:
shipping = 0
total += shipping
# 处理支付
if payment_method == "credit_card":
print(f"处理信用卡支付: {total}")
elif payment_method == "paypal":
print(f"处理PayPal支付: {total}")
elif payment_method == "bank_transfer":
print(f"处理银行转账: {total}")
# 创建订单
order = {
'id': order_id,
'customer': customer,
'products': products,
'total': total,
'shipping_address': shipping_address
}
self.orders.append(order)
return order
# ============ 重构后代码 ============
@dataclass
class Customer:
"""客户"""
id: int
name: str
email: str
@dataclass
class Product:
"""产品"""
id: int
name: str
price: float
@dataclass
class OrderItem:
"""订单项"""
product: Product
quantity: int
@property
def subtotal(self) -> float:
return self.product.price * self.quantity
class DiscountStrategy:
"""折扣策略"""
DISCOUNTS = {
"SAVE10": 0.9,
"SAVE20": 0.8,
"SAVE50": 0.5
}
@classmethod
def apply_discount(cls, price: float, code: Optional[str]) -> float:
if code and code in cls.DISCOUNTS:
return price * cls.DISCOUNTS[code]
return price
class ShippingCalculator:
"""运费计算器"""
@staticmethod
def calculate_shipping(order_total: float) -> float:
if order_total < 50:
return 5.99
elif order_total < 100:
return 2.99
return 0
class PaymentProcessor:
"""支付处理器"""
def process_payment(self, amount: float, method: str) -> bool:
processors = {
"credit_card": self._process_credit_card,
"paypal": self._process_paypal,
"bank_transfer": self._process_bank_transfer
}
processor = processors.get(method)
if processor:
return processor(amount)
raise ValueError(f"不支持的支付方式: {method}")
def _process_credit_card(self, amount: float) -> bool:
print(f"处理信用卡支付: {amount}")
return True
def _process_paypal(self, amount: float) -> bool:
print(f"处理PayPal支付: {amount}")
return True
def _process_bank_transfer(self, amount: float) -> bool:
print(f"处理银行转账: {amount}")
return True
class CustomerRepository:
"""客户仓储"""
def __init__(self):
self._customers: Dict[int, Customer] = {}
def add(self, customer: Customer):
self._customers[customer.id] = customer
def get_by_id(self, customer_id: int) -> Optional[Customer]:
return self._customers.get(customer_id)
class ProductRepository:
"""产品仓储"""
def __init__(self):
self._products: Dict[int, Product] = {}
def add(self, product: Product):
self._products[product.id] = product
def get_by_ids(self, product_ids: List[int]) -> List[Product]:
return [self._products[pid] for pid in product_ids if pid in self._products]
@dataclass
class Order:
"""订单"""
id: int
customer: Customer
items: List[OrderItem]
shipping_address: str
payment_method: str
discount_code: Optional[str] = None
@property
def subtotal(self) -> float:
return sum(item.subtotal for item in self.items)
@property
def shipping_cost(self) -> float:
return ShippingCalculator.calculate_shipping(self.subtotal)
@property
def total(self) -> float:
discounted = sum(
DiscountStrategy.apply_discount(item.subtotal, self.discount_code)
for item in self.items
)
return discounted + self.shipping_cost
class OrderService:
"""订单服务(重构后)"""
def __init__(self):
self.customer_repo = CustomerRepository()
self.product_repo = ProductRepository()
self.payment_processor = PaymentProcessor()
self.orders: List[Order] = []
def create_order(
self,
customer_id: int,
items_data: List[Tuple[int, int]], # (product_id, quantity)
shipping_address: str,
payment_method: str,
discount_code: Optional[str] = None
) -> Order:
"""创建订单"""
# 获取客户
customer = self.customer_repo.get_by_id(customer_id)
if not customer:
raise ValueError("客户不存在")
# 获取产品并创建订单项
product_ids = [item[0] for item in items_data]
products = self.product_repo.get_by_ids(product_ids)
if len(products) != len(product_ids):
raise ValueError("部分产品不存在")
order_items = [
OrderItem(product, qty)
for product, (_, qty) in zip(products, items_data)
]
# 创建订单
order = Order(
id=len(self.orders) + 1,
customer=customer,
items=order_items,
shipping_address=shipping_address,
payment_method=payment_method,
discount_code=discount_code
)
# 处理支付
self.payment_processor.process_payment(order.total, payment_method)
self.orders.append(order)
return order
def demonstrate_refactoring():
"""演示重构"""
print("="*60)
print("代码重构演示")
print("="*60)
# 设置数据
service = OrderService()
# 添加客户
service.customer_repo.add(Customer(1, "Alice", "alice@example.com"))
# 添加产品
service.product_repo.add(Product(1, "Laptop", 999.99))
service.product_repo.add(Product(2, "Mouse", 29.99))
# 创建订单
order = service.create_order(
customer_id=1,
items_data=[(1, 1), (2, 2)],
shipping_address="123 Main St",
payment_method="credit_card",
discount_code="SAVE10"
)
print(f"/n订单ID: {order.id}")
print(f"客户: {order.customer.name}")
print(f"商品小计: ${order.subtotal:.2f}")
print(f"运费: ${order.shipping_cost:.2f}")
print(f"总计: ${order.total:.2f}")
print("/n" + "="*60)
print("重构改进")
print("="*60)
print("1. 提取类: Customer, Product, OrderItem")
print("2. 提取方法: ShippingCalculator, PaymentProcessor")
print("3. 引入策略模式: DiscountStrategy")
print("4. 使用仓储模式: CustomerRepository, ProductRepository")
print("5. 减少参数: 使用对象封装参数")
print("="*60)
def main():
"""主函数"""
demonstrate_refactoring()
if __name__ == "__main__":
main()
重构架构图
关键要点
- 提取方法:将长方法拆分为小方法
- 移动方法:将方法放到合适的类中
- 引入参数对象:减少参数数量
- 策略模式:替换条件语句
- 单一职责:每个类只负责一件事
重构是持续改进代码质量的重要实践。
IT极限技术分享汇